{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# \u81ea\u5b9a\u4e49GNN\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## \u7528CogDL \u4e2d\u7684 GNN layers\u5b9a\u4e49\u6a21\u578b\uf0c1\n CogDL \u5728 cogdl.layers \u4e2d\u5b9e\u73b0\u4e86\u6d41\u884c\u7684 GNN \u5c42\uff0c\u5b83\u4eec\u53ef\u4ee5\u4f5c\u4e3a\u6a21\u5757\u6765\u5e2e\u52a9\u60a8\u8bbe\u8ba1\u65b0\u7684 GNN\u3002\u4ee5\u4e0b\u662f\u6211\u4eec\u5728 CogDL \u4e2d\u5b9e\u73b0 Jumping Knowledge Network (JKNet) \u7684 GCNLayer \u65b9\u6cd5\u793a\u4f8b\u3002 JKNet \u6536\u96c6\u6240\u6709\u5c42\u7684\u8f93\u51fa\u5e76\u5c06\u5b83\u4eec\u8fde\u63a5\u5728\u4e00\u8d77\u6765\u83b7\u5f97\u7ed3\u679c\uff1a\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\n\n\nfrom cogdl.layers import GCNLayer\nfrom cogdl.models import BaseModel\n\nclass JKNet(BaseModel):\n    def __init__(self, in_feats, out_feats, hidden_size, num_layers):\n        super(JKNet, self).__init__()\n        shapes = [in_feats] + [hidden_size] * num_layers\n        self.layers = nn.ModuleList([\n            GCNLayer(shapes[i], shapes[i+1])\n            for i in range(num_layers)\n        ])\n        self.fc = nn.Linear(hidden_size * num_layers, out_feats)\n\n    def forward(self, graph):\n        # symmetric normalization of adjacency matrix\n        graph.sym_norm()\n        h = graph.x\n        out = []\n        for layer in self.layers:\n            h = layer(graph,h)\n            out.append(h)\n        out = torch.cat(out, dim=1)\n        return self.fc(out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## \u5b9a\u4e49\u4f60\u7684 GNN \u6a21\u5757\n\u5728\u5927\u591a\u6570\u60c5\u51b5\u4e0b\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528\u65b0\u7684\u6d88\u606f\u4f20\u64ad\u548c\u805a\u5408\u65b9\u6848\u6784\u5efa\u5c42\u6a21\u5757\u3002\u8fd9\u91cc\u7684\u4ee3\u7801\u7247\u6bb5\u5c55\u793a\u4e86\u5982\u4f55\u5728 CogDL \u4e2d\u4f7f\u7528 Graph \u548c\u9ad8\u6548\u7684\u7a00\u758f\u77e9\u9635\u7b97\u5b50\u6765\u5b9e\u73b0 GCNLayer\u3002\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom cogdl.utils import spmm\n\nclass GCNLayer(torch.nn.Module):\n    \"\"\"\n    Args:\n        in_feats: int\n            Input feature size\n        out_feats: int\n            Output feature size\n    \"\"\"\n    def __init__(self, in_feats, out_feats):\n        super(GCNLayer, self).__init__()\n        self.fc = torch.nn.Linear(in_feats, out_feats)\n\n    def forward(self, graph, x):\n        h = self.fc(x)\n        h = spmm(graph, h)\n        return h"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## \u5c06\u81ea\u5b9a\u4e49\u7684GNN\u6a21\u578b\u4e0eCogdl\u4e00\u8d77\u4f7f\u7528\n\u73b0\u5728\u60a8\u5df2\u7ecf\u5b9a\u4e49\u4e86\u81ea\u5df1\u7684 GNN\uff0c\u60a8\u53ef\u4ee5\u4f7f\u7528 CogDL \u4e2d\u7684\u6570\u636e\u96c6/\u4efb\u52a1\u6765\u7acb\u5373\u8bad\u7ec3\u548c\u8bc4\u4f30\u6a21\u578b\u7684\u6027\u80fd\u3002\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from cogdl import experiment\nfrom cogdl.datasets import build_dataset_from_name\ndata = build_dataset_from_name(\"cora\")[0]\n# Use the JKNet model as defined above\nmodel = JKNet(data.num_features, data.num_classes, 32, 4)\nexperiment(model=model, dataset=\"cora\", mw=\"node_classification_mw\", dw=\"node_classification_dw\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}