mlx_graphs.datasets.HeteroDataset#

class mlx_graphs.datasets.HeteroDataset(name: str, base_dir: str | None = None, pre_transform: Callable | None = None, transform: Callable | None = None)[source]#

Bases: BaseDataset

A dataset class for handling heterogeneous graph data.

Parameters:
  • name (str) – name of the dataset

  • base_dir (Optional[str]) – Directory where to store dataset files. Default is in the local directory .mlx_graphs_data/.

  • pre_transform (Optional[Callable]) – A function/transform that takes in a HeteroGraphData object and returns a transformed version. The transformation is applied before the first access.

  • transform (Optional[Callable]) – A function/transform that takes in a HeteroGraphData object and returns a transformed version. The transformation is applied before every access, i.e., during the __getitem__ call. By default, no transformation is applied.

Methods

Attributes

name

Name of the dataset

num_edge_classes

Returns a dictionary of the number of edge classes for each edge type.

num_edge_features

Returns a dictionary of the number of edge features for each edge type.

num_edges

Returns a dictionary of the number of edges for each edge type.

num_graph_features

Returns the number of graph features.

num_items

Returns the number of items in the dataset.

num_node_classes

Returns a dictionary of the number of node classes for each node type.

num_node_features

Returns a dictionary of the number of node features for each node type.

num_nodes

Returns a dictionary of the number of nodes for each node type.

processed_path

The path where processed files are stored.

raw_path

The path where raw files are stored.

property num_edge_classes: dict[Any, int]#

Returns a dictionary of the number of edge classes for each edge type.

property num_edge_features: dict[str, int]#

Returns a dictionary of the number of edge features for each edge type.

property num_edges: dict[Any, int]#

Returns a dictionary of the number of edges for each edge type.

property num_graph_features: int#

Returns the number of graph features.

property num_node_classes: dict[str, int] | None#

Returns a dictionary of the number of node classes for each node type.

property num_node_features: dict[str, int]#

Returns a dictionary of the number of node features for each node type.

property num_nodes: dict[str, int]#

Returns a dictionary of the number of nodes for each node type.