mlx_graphs.datasets.Dataset#

class mlx_graphs.datasets.Dataset(name: str, base_dir: str | None = None, pre_transform: Callable | None = None, transform: Callable | None = None)[source]#

Bases: BaseDataset

A dataset class for graph data. download and process methods must be implemented by children classes. The save and load methods save and load only the processed self.graphs attribute by default. You may want to override them to store/load additional processed attributes.

Graph data within the dataset should be stored in self.graphs as a list[GraphData]. The creation and preprocessing of this list of graphs is typically done within the overridden process method.

Parameters:
  • name (str) – name of the dataset

  • base_dir (Optional[str]) – Directory where to store dataset files. Default is in the local directory .mlx_graphs_data/.

  • pre_transform (Optional[Callable]) – A function/transform that takes in a GraphData object and returns a transformed version. The transformation is applied before the first access.

  • transform (Optional[Callable]) – A function/transform that takes in a GraphData object and returns a transformed version. The transformation is applied before every access, i.e., during the __getitem__ call. By default, no transformation is applied.

Methods

Attributes

name

Name of the dataset

num_edge_classes

Returns the number of edge classes to predict.

num_edge_features

Returns the number of edge features.

num_graph_classes

Returns the number of graph classes to predict.

num_graph_features

Returns the number of graph features.

num_graphs

Returns the number of graphs in the dataset.

num_items

Returns the number of items in the dataset.

num_node_classes

Returns the number of node classes to predict.

num_node_features

Returns the number of node features.

processed_path

The path where raw files are stored.

raw_path

The path where raw files are stored.

property num_edge_classes: int | dict[str, int]#

Returns the number of edge classes to predict.

property num_edge_features: int#

Returns the number of edge features.

property num_graph_classes: int | dict[str, int]#

Returns the number of graph classes to predict.

property num_graph_features: int#

Returns the number of graph features.

property num_graphs: int#

Returns the number of graphs in the dataset.

property num_node_classes: int | dict[str, int]#

Returns the number of node classes to predict.

property num_node_features: int#

Returns the number of node features.

property processed_path: str#

The path where raw files are stored. Defaults at <base_dir>/<name>/processed