mlx_graphs.utils.transformations.get_isolated_nodes_mask

mlx_graphs.utils.transformations.get_isolated_nodes_mask#

mlx_graphs.utils.transformations.get_isolated_nodes_mask(edge_index: mlx.core.array, num_nodes: int, complement: bool = True) mlx.core.array[source]#

Returns a mask with isolated nodes set to False to filter them out if needed.

Parameters:
  • edge_index (array) – Edge index from which to remove isolated nodes.

  • num_nodes (int) – Number of nodes of the graph.

  • complement (bool) – Wether to filter isolated or non isolated nodes.

Return type:

array

Returns:

A boolean mask of size num_nodes where True means the node isn’t isolated and False means it is.

Example:

edge_index = mx.array([[0, 2, 0], [2, 0, 0]])
mask = get_isolated_nodes_mask(edge_index, 3)
>>> mx.array([0,2])
mask = get_isolated_nodes_mask(edge_index, 3, complement=False)
>>> mx.array([1])