mlx_graphs.data.data.HeteroGraphData#
- class mlx_graphs.data.data.HeteroGraphData(edge_index_dict: dict[tuple[str, str, str], mlx.core.array], node_features_dict: dict[str, mlx.core.array] | None = None, edge_features_dict: dict[str, mlx.core.array] | None = None, graph_features: mlx.core.array | None = None, node_labels_dict: dict[str, mlx.core.array] | None = None, edge_labels_dict: dict[str, mlx.core.array] | None = None, edge_labels_index_dict: dict[str, mlx.core.array] | None = None, graph_labels: mlx.core.array | None = None, **kwargs) None [source]#
Represents a graph structure with multiple node and edge types
- Parameters:
edge_index_dict (
dict
[tuple
[str
,str
,str
],array
]) – A dictionary mapping edge types to their corresponding edge indices. The edge indices are represented as a 2D array of shape [2, num_edges], where the first row contains the source node indices and the second row contains the destination node indices.node_features_dict (
Optional
[dict
[str
,array
]]) – A dictionary mapping node types to their corresponding node feature. Each node feature has shape [num_nodes, num_features].edge_features_dict (
Optional
[dict
[str
,array
]]) – A dictionary mapping edge types to their corresponding edge feature matrices. Each edge feature matrix has shape [num_edges, num_features].graph_features (
Optional
[array
]) – A 1D array containing graph-level features.node_labels_dict (
Optional
[dict
[str
,array
]]) – A dictionary mapping node types to their corresponding node label arrays. Each node label array has shape [num_nodes].edge_labels_dict (
Optional
[dict
[str
,array
]]) – A dictionary mapping edge types to their corresponding edge label arrays. Each edge label array has shape [num_edges].edge_labels_index_dict (
Optional
[dict
[str
,array
]]) – A dictionary mapping edge types to their corresponding edge label index arrays. The edge label indices indicate the edges for which labels are available.graph_labels (
Optional
[array
]) – A 1D array containing graph-level labels.**kwargs – Additional keyword arguments to store custom attributes.
Example:
edge_index_dict = { ("user", "rates", "movie"): mx.array([[0, 1], [0, 1]]), ("movie", "rev_rates", "user"): mx.array([[0, 1], [0, 1]]), } node_features_dict = { "user": mx.array([[0.2], [0.8]]), "movie": mx.array([[0.5], [0.3]]), } data = HeteroGraphData(edge_index_dict, node_features_dict)
Methods
has_isolated_nodes
(node_type)Returns a boolean of whether the graph has isolated nodes of the given type (i.e., nodes that don't have a link to any other nodes)
Returns a boolean of whether the graph contains self loops for the given edge type.
Returns a boolean of whether all edges are directed or not.
Returns a boolean indicating whether all edge types in the graph are undirected.
to_dict
()Converts the Data object to a dictionary.
Attributes
Returns a dictionary of the number of edge classes for each edge type in the current graph.
Returns a dictionary of the number of edge features for each edge type.
dictionary of number of edges for each edge type in the graph.
Returns the number of graph features.
Returns a dictionary of the number of node classes for each node type in the current graph.
Returns a dictionary of the number of node features for each node type.
num_nodes
- has_isolated_nodes(node_type: str) bool [source]#
Returns a boolean of whether the graph has isolated nodes of the given type (i.e., nodes that don’t have a link to any other nodes)
- Return type:
- has_self_loops() bool [source]#
Returns a boolean of whether the graph contains self loops for the given edge type.
- Return type:
- is_directed() bool [source]#
Returns a boolean of whether all edges are directed or not.
- Return type:
- is_undirected() bool [source]#
Returns a boolean indicating whether all edge types in the graph are undirected.
- Return type:
- property num_edge_classes: dict[str, int]#
Returns a dictionary of the number of edge classes for each edge type in the current graph.
- property num_edge_features: dict[str, int]#
Returns a dictionary of the number of edge features for each edge type.
- property num_node_classes: dict[str, int] | None#
Returns a dictionary of the number of node classes for each node type in the current graph.