{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Using Customized GNN\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## GNN layers in CogDL to Define model\n CogDL has implemented popular GNN layers in cogdl.layers, and they can serve as modules to help design new GNNs. Here is how we implement Jumping Knowledge Network (JKNet) with GCNLayer in CogDL.\n JKNet collects the output of all layers and concatenate them together to get the result:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom cogdl.layers import GCNLayer\nfrom cogdl.models import BaseModel\n\nclass JKNet(BaseModel):\n    def __init__(self, in_feats, out_feats, hidden_size, num_layers):\n        super(JKNet, self).__init__()\n        shapes = [in_feats] + [hidden_size] * num_layers\n        self.layers = nn.ModuleList([\n            GCNLayer(shapes[i], shapes[i+1])\n            for i in range(num_layers)\n        ])\n        self.fc = nn.Linear(hidden_size * num_layers, out_feats)\n\n    def forward(self, graph):\n        # symmetric normalization of adjacency matrix\n        graph.sym_norm()\n        h = graph.x\n        out = []\n        for layer in self.layers:\n            h = layer(graph,h)\n            out.append(h)\n        out = torch.cat(out, dim=1)\n        return self.fc(out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Define your GNN Module\nIn most cases, you may build a layer module with new message propagation and aggragation scheme. Here the code snippet shows how to implement a GCNLayer using Graph and efficient sparse matrix operators in CogDL.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom cogdl.utils import spmm\n\nclass GCNLayer(torch.nn.Module):\n    \"\"\"\n    Args:\n        in_feats: int\n            Input feature size\n        out_feats: int\n            Output feature size\n    \"\"\"\n    def __init__(self, in_feats, out_feats):\n        super(GCNLayer, self).__init__()\n        self.fc = torch.nn.Linear(in_feats, out_feats)\n\n    def forward(self, graph, x):\n        h = self.fc(x)\n        h = spmm(graph, h)\n        return h"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Use Custom models with CogDL\nNow that you have defined your own GNN, you can use dataset/task in CogDL to immediately train and evaluate the performance of your model.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data = build_dataset_from_name(\"cora\")[0]\n# Use the JKNet model as defined above\nmodel = JKNet(data.num_features, data.num_classes, 32, 4)\nexperiment(model=model, dataset=\"cora\", mw=\"node_classification_mw\", dw=\"node_classification_dw\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}