mlx_graphs.datasets.Dataset#

class mlx_graphs.datasets.Dataset(name: str, base_dir: str | None = None, pre_transform: Callable | None = None, transform: Callable | None = None)[source]#

Bases: ABC

Base dataset class. download and process methods must be implemented by children classes. The save and load methods save and load only the processed self.graphs attribute by default. You may want to override them to store/load additional processed attributes.

Graph data within the dataset should be stored in self.graphs as a List[GraphData]. The creation and preprocessing of this list of graphs is typically done within the overridden process method.

Parameters:
  • name (str) – name of the dataset

  • base_dir (Optional[str]) – Directory where to store dataset files. Default is in the local directory .mlx_graphs_data/.

  • transform (Optional[Callable]) – A function/transform that takes in a GraphData object and returns a transformed version. The transformation is applied before every access, i.e., during the __getitem__ call. By default, no transformation is applied.

Methods

download()

Download the dataset at self.raw_path.

load()

Load the processed dataset

process()

Process the dataset and store graphs in self.graphs

save()

Save the processed dataset

Attributes

name

Name of the dataset

num_edge_classes

Returns the number of edge classes to predict.

num_edge_features

Returns the number of edge features.

num_graph_classes

Returns the number of graph classes to predict.

num_graph_features

Returns the number of graph features.

num_graphs

Returns the number of graphs in the dataset.

num_node_classes

Returns the number of node classes to predict.

num_node_features

Returns the number of node features.

processed_path

The path where raw files are stored.

raw_path

The path where raw files are stored.

abstract download()[source]#

Download the dataset at self.raw_path.

load()[source]#

Load the processed dataset

property name: str#

Name of the dataset

property num_edge_classes: int | dict[str, int]#

Returns the number of edge classes to predict.

property num_edge_features: int#

Returns the number of edge features.

property num_graph_classes: int | dict[str, int]#

Returns the number of graph classes to predict.

property num_graph_features: int#

Returns the number of graph features.

property num_graphs: int#

Returns the number of graphs in the dataset.

property num_node_classes: int | dict[str, int]#

Returns the number of node classes to predict.

property num_node_features: int#

Returns the number of node features.

abstract process()[source]#

Process the dataset and store graphs in self.graphs

property processed_path: str#

The path where raw files are stored. Defaults at <base_dir>/<name>/processed

property raw_path: str#

The path where raw files are stored. Defaults at <base_dir>/<name>/raw

save()[source]#

Save the processed dataset