{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Using Customized Dataset\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Dataset for node_classification\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom cogdl import experiment\nfrom cogdl.data import Graph\nfrom cogdl.datasets import NodeDataset, generate_random_graph\n\nclass MyNodeDataset(NodeDataset):\n    def __init__(self, path=\"data.pt\"):\n        self.path = path\n        super(MyNodeDataset, self).__init__(path, scale_feat=False, metric=\"accuracy\")\n\n    def process(self):\n        \"\"\"You need to load your dataset and transform to `Graph`\"\"\"\n        num_nodes, num_edges, feat_dim = 100, 300, 30\n\n        # load or generate your dataset\n        edge_index = torch.randint(0, num_nodes, (2, num_edges))\n        x = torch.randn(num_nodes, feat_dim)\n        y = torch.randint(0, 2, (num_nodes,))\n\n        # set train/val/test mask in node_classification task\n        train_mask = torch.zeros(num_nodes).bool()\n        train_mask[0 : int(0.3 * num_nodes)] = True\n        val_mask = torch.zeros(num_nodes).bool()\n        val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True\n        test_mask = torch.zeros(num_nodes).bool()\n        test_mask[int(0.7 * num_nodes) :] = True\n        data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)\n        return data\n\nif __name__ == \"__main__\":\n    # Train customized dataset via defining a new class\n    dataset = MyNodeDataset()\n    experiment(dataset=dataset, model=\"gcn\")\n\n    # Train customized dataset via feeding the graph data to NodeDataset\n    data = generate_random_graph(num_nodes=100, num_edges=300, num_feats=30)\n    dataset = NodeDataset(data=data)\n    experiment(dataset=dataset, model=\"gcn\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Dataset for graph_classification\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from cogdl.data import Graph\nfrom cogdl.datasets import GraphDataset\n\nclass MyGraphDataset(GraphDataset):\n    def __init__(self, path=\"data.pt\"):\n        self.path = path\n        super(MyGraphDataset, self).__init__(path, metric=\"accuracy\")\n\n    def process(self):\n        # Load and preprocess data\n        # Here we randomly generate several graphs for simplicity as an example\n        graphs = []\n        for i in range(10):\n            edges = torch.randint(0, 20, (2, 30))\n            label = torch.randint(0, 7, (1,))\n            graphs.append(Graph(edge_index=edges, y=label))\n        return graphs\n\nif __name__ == \"__main__\":\n    dataset = MyGraphDataset()\n    experiment(model=\"gin\", dataset=dataset)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}