mlx_graphs.data.data.HeteroGraphData#

class mlx_graphs.data.data.HeteroGraphData(edge_index_dict: dict[tuple[str, str, str], mlx.core.array], node_features_dict: dict[str, mlx.core.array] | None = None, edge_features_dict: dict[str, mlx.core.array] | None = None, graph_features: mlx.core.array | None = None, node_labels_dict: dict[str, mlx.core.array] | None = None, edge_labels_dict: dict[str, mlx.core.array] | None = None, edge_labels_index_dict: dict[str, mlx.core.array] | None = None, graph_labels: mlx.core.array | None = None, **kwargs) None[source]#

Represents a graph structure with multiple node and edge types

Parameters:
  • edge_index_dict (dict[tuple[str, str, str], array]) – A dictionary mapping edge types to their corresponding edge indices. The edge indices are represented as a 2D array of shape [2, num_edges], where the first row contains the source node indices and the second row contains the destination node indices.

  • node_features_dict (Optional[dict[str, array]]) – A dictionary mapping node types to their corresponding node feature. Each node feature has shape [num_nodes, num_features].

  • edge_features_dict (Optional[dict[str, array]]) – A dictionary mapping edge types to their corresponding edge feature matrices. Each edge feature matrix has shape [num_edges, num_features].

  • graph_features (Optional[array]) – A 1D array containing graph-level features.

  • node_labels_dict (Optional[dict[str, array]]) – A dictionary mapping node types to their corresponding node label arrays. Each node label array has shape [num_nodes].

  • edge_labels_dict (Optional[dict[str, array]]) – A dictionary mapping edge types to their corresponding edge label arrays. Each edge label array has shape [num_edges].

  • edge_labels_index_dict (Optional[dict[str, array]]) – A dictionary mapping edge types to their corresponding edge label index arrays. The edge label indices indicate the edges for which labels are available.

  • graph_labels (Optional[array]) – A 1D array containing graph-level labels.

  • **kwargs – Additional keyword arguments to store custom attributes.

Example:

edge_index_dict = {
    ("user", "rates", "movie"): mx.array([[0, 1], [0, 1]]),
    ("movie", "rev_rates", "user"): mx.array([[0, 1], [0, 1]]),
}
node_features_dict = {
    "user": mx.array([[0.2], [0.8]]),
    "movie": mx.array([[0.5], [0.3]]),
}
data = HeteroGraphData(edge_index_dict, node_features_dict)

Methods

has_isolated_nodes(node_type)

Returns a boolean of whether the graph has isolated nodes of the given type (i.e., nodes that don't have a link to any other nodes)

has_self_loops()

Returns a boolean of whether the graph contains self loops for the given edge type.

is_directed()

Returns a boolean of whether all edges are directed or not.

is_undirected()

Returns a boolean indicating whether all edge types in the graph are undirected.

to_dict()

Converts the Data object to a dictionary.

Attributes

num_edge_classes

Returns a dictionary of the number of edge classes for each edge type in the current graph.

num_edge_features

Returns a dictionary of the number of edge features for each edge type.

num_edges

dictionary of number of edges for each edge type in the graph.

num_graph_features

Returns the number of graph features.

num_node_classes

Returns a dictionary of the number of node classes for each node type in the current graph.

num_node_features

Returns a dictionary of the number of node features for each node type.

num_nodes

has_isolated_nodes(node_type: str) bool[source]#

Returns a boolean of whether the graph has isolated nodes of the given type (i.e., nodes that don’t have a link to any other nodes)

Return type:

bool

has_self_loops() bool[source]#

Returns a boolean of whether the graph contains self loops for the given edge type.

Return type:

bool

is_directed() bool[source]#

Returns a boolean of whether all edges are directed or not.

Return type:

bool

is_undirected() bool[source]#

Returns a boolean indicating whether all edge types in the graph are undirected.

Return type:

bool

property num_edge_classes: dict[str, int]#

Returns a dictionary of the number of edge classes for each edge type in the current graph.

property num_edge_features: dict[str, int]#

Returns a dictionary of the number of edge features for each edge type.

property num_edges: dict[str, int]#

dictionary of number of edges for each edge type in the graph.

property num_graph_features: int#

Returns the number of graph features.

property num_node_classes: dict[str, int] | None#

Returns a dictionary of the number of node classes for each node type in the current graph.

property num_node_features: dict[str, int]#

Returns a dictionary of the number of node features for each node type.

to_dict() dict[source]#

Converts the Data object to a dictionary.

Return type:

dict

Returns:

A dictionary with all attributes of the HeteroGraphData object.