import re
import copy
from contextlib import contextmanager
import scipy.sparse as sp
import networkx as nx
import torch
import numpy as np
from cogdl.utils import (
csr2coo,
coo2csr_index,
add_remaining_self_loops,
symmetric_normalization,
row_normalization,
get_degrees,
)
from cogdl.utils import RandomWalker
from cogdl.operators.sample import sample_adj_c, subgraph_c
subgraph_c = None # noqa: F811
class BaseGraph(object):
def __init__(self):
pass
def eval(self):
pass
def train(self):
pass
def __getitem__(self, key):
r"""Gets the data of the attribute :obj:`key`."""
return getattr(self, key)
def __setitem__(self, key, value):
"""Sets the attribute :obj:`key` to :obj:`value`."""
setattr(self, key, value)
@property
def keys(self):
r"""Returns all names of graph attributes."""
keys = [key for key in self.__dict__.keys() if self[key] is not None]
keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
return keys
def __len__(self):
r"""Returns the number of all present attributes."""
# return len(self.keys)
return 1
def __contains__(self, key):
r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the
data."""
return key in self.keys
def __iter__(self):
r"""Iterates over all present attributes in the data, yielding their
attribute names and content."""
for key in sorted(self.keys):
yield key, self[key]
def __call__(self, *keys):
r"""Iterates over all attributes :obj:`*keys` in the data, yielding
their attribute names and content.
If :obj:`*keys` is not given this method will iterative over all
present attributes."""
for key in sorted(self.keys) if not keys else keys:
if self[key] is not None:
yield key, self[key]
def cat_dim(self, key, value):
r"""Returns the dimension in which the attribute :obj:`key` with
content :obj:`value` gets concatenated when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
# `*index*` and `*face*` should be concatenated in the last dimension,
# everything else in the first dimension.
return -1 if bool(re.search("(index|face)", key)) else 0
def __inc__(self, key, value):
r""" "Returns the incremental count to cumulatively increase the value
of the next attribute of :obj:`key` when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
# Only `*index*` and `*face*` should be cumulatively summed up when
# creating batches.
return self.__num_nodes__ if bool(re.search("(index|face)", key)) else 0
def __cat_dim__(self, key, value=None):
return self.cat_dim(key, value)
def apply(self, func, *keys):
r"""Applies the function :obj:`func` to all attributes :obj:`*keys`.
If :obj:`*keys` is not given, :obj:`func` is applied to all present
attributes.
"""
for key, item in self(*keys):
if isinstance(item, Adjacency):
self[key] = func(item)
if not isinstance(item, torch.Tensor):
continue
self[key] = func(item)
return self
def contiguous(self, *keys):
r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`.
If :obj:`*keys` is not given, all present attributes are ensured to
have a contiguous memory layout."""
return self.apply(lambda x: x.contiguous(), *keys)
def to(self, device, *keys):
r"""Performs tensor dtype and/or device conversion to all attributes
:obj:`*keys`.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return self.apply(lambda x: x.to(device), *keys)
def cuda(self, *keys):
return self.apply(lambda x: x.cuda(), *keys)
[docs]class Adjacency(BaseGraph):
def __init__(self, row=None, col=None, row_ptr=None, weight=None, attr=None, num_nodes=None, types=None, **kwargs):
super(Adjacency, self).__init__()
self.row = row
self.col = col
self.row_ptr = row_ptr
self.weight = weight
self.attr = attr
self.types = types
self.__num_nodes__ = num_nodes
self.__normed__ = None
self.__in_norm__ = self.__out_norm__ = None
self.__symmetric__ = True
for key, item in kwargs.items():
self[key] = item
[docs] def set_weight(self, weight):
self.weight = weight
self.__normed__ = None
self.__in_norm__ = self.__out_norm__ = None
self.__symmetric__ = False
[docs] def get_weight(self, indicator=None):
"""If `indicator` is not None, the normalization will not be implemented"""
if self.weight is None or self.weight.shape[0] != self.col.shape[0]:
self.weight = torch.ones(self.num_edges, device=self.device)
weight = self.weight
if indicator is not None:
return weight
if self.__in_norm__ is not None:
if self.row is None:
num_nodes = self.row_ptr.size(0) - 1
row = torch.arange(num_nodes, device=self.device)
row_count = self.row_ptr[1:] - self.row_ptr[:-1]
self.row = row.repeat_interleave(row_count)
weight = self.__in_norm__[self.row].view(-1)
if self.__out_norm__ is not None:
weight = self.__out_norm__[self.col].view(-1)
return weight
[docs] def add_remaining_self_loops(self):
if self.attr is not None and len(self.attr.shape) == 1:
edge_index, weight_attr = add_remaining_self_loops(
(self.row, self.col), edge_weight=self.attr, fill_value=0, num_nodes=self.num_nodes
)
self.row, self.col = edge_index
self.attr = weight_attr
self.weight = torch.ones_like(self.row).float()
else:
edge_index, self.weight = add_remaining_self_loops(
(self.row, self.col), fill_value=1, num_nodes=self.num_nodes
)
self.row, self.col = edge_index
self.attr = None
self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
self.row = self.row[reindex]
self.col = self.col[reindex]
[docs] def padding_self_loops(self):
device = self.row.device
row, col = torch.arange(self.num_nodes, device=device), torch.arange(self.num_nodes, device=device)
self.row = torch.cat((self.row, row))
self.col = torch.cat((self.col, col))
if self.weight is not None:
values = torch.zeros(self.num_nodes, device=device) + 0.01
self.weight = torch.cat((self.weight, values))
if self.attr is not None:
attr = torch.zeros(self.num_nodes, device=device)
self.attr = torch.cat((self.attr, attr))
self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
self.row = self.row[reindex]
self.col = self.col[reindex]
[docs] def remove_self_loops(self):
mask = self.row == self.col
inv_mask = ~mask
self.row = self.row[inv_mask]
self.col = self.col[inv_mask]
for item in self.__attr_keys__():
if self[item] is not None:
self[item] = self[item][inv_mask]
self.convert_csr()
[docs] def sym_norm(self):
if self.row is None:
self.generate_normalization("sym")
else:
self.normalize_adj("sym")
[docs] def row_norm(self):
if self.row is None:
self.generate_normalization("row")
else:
self.normalize_adj("row")
self.__symmetric__ = False
[docs] def col_norm(self):
if self.row is None:
self.generate_normalization("col")
else:
self.normalize_adj("col")
self.__symmetric__ = False
[docs] def generate_normalization(self, norm="sym"):
if self.__normed__:
return
degrees = (self.row_ptr[1:] - self.row_ptr[:-1]).float()
if norm == "sym":
edge_norm = torch.pow(degrees, -0.5).to(self.device)
edge_norm[torch.isinf(edge_norm)] = 0
self.__out_norm__ = self.__in_norm__ = edge_norm.view(-1, 1)
elif norm == "row":
edge_norm = torch.pow(degrees, -1).to(self.device)
edge_norm[torch.isinf(edge_norm)] = 0
self.__out_norm__ = None
self.__in_norm__ = edge_norm.view(-1, 1)
elif norm == "col":
self.row, _, _ = csr2coo(self.row_ptr, self.col, self.weight)
self.weight = row_normalization(self.num_nodes, self.col, self.row, self.weight)
else:
raise NotImplementedError
self.__normed__ = norm
[docs] def normalize_adj(self, norm="sym"):
if self.__normed__:
return
if self.weight is None or self.weight.shape[0] != self.col.shape[0]:
self.weight = torch.ones(self.num_edges, device=self.device)
if norm == "sym":
self.weight = symmetric_normalization(self.num_nodes, self.row, self.col, self.weight)
elif norm == "row":
self.weight = row_normalization(self.num_nodes, self.row, self.col, self.weight)
elif norm == "col":
self.weight = row_normalization(self.num_nodes, self.col, self.row, self.weight)
else:
raise NotImplementedError
self.__normed__ = norm
[docs] def convert_csr(self):
self._to_csr()
def _to_csr(self):
self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
self.col = self.col[reindex]
self.row = self.row[reindex]
for key in self.__attr_keys__():
if key == "weight" and self[key] is None:
self.weight = torch.ones(self.row.shape[0]).to(self.row.device)
if self[key] is not None:
self[key] = self[key][reindex]
[docs] def is_symmetric(self):
return self.__symmetric__
[docs] def set_symmetric(self, val):
assert val in [True, False]
self.__symmetric__ = val
[docs] def degrees(self, node_idx=None):
if self.row_ptr is not None:
degs = (self.row_ptr[1:] - self.row_ptr[:-1]).float()
if node_idx is not None:
return degs[node_idx]
return degs
else:
return get_degrees(self.row, self.col, num_nodes=self.num_nodes)
@property
def edge_index(self):
if self.row is None:
self.row, _, _ = csr2coo(self.row_ptr, self.col, self.weight)
return self.row, self.col
@edge_index.setter
def edge_index(self, edge_index):
row, col = edge_index
# if self.row is not None and self.row.shape == row.shape:
# return
self.row, self.col = row, col
# self.convert_csr()
self.row_ptr = None
@property
def row_indptr(self):
if self.row_ptr is None:
self._to_csr()
return self.row_ptr
@property
def num_edges(self):
if self.row is not None:
return self.row.shape[0]
elif self.row_ptr is not None:
return self.row_ptr[-1]
else:
return None
@property
def num_nodes(self):
if self.__num_nodes__ is not None:
return self.__num_nodes__
if self.row_ptr is not None:
return self.row_ptr.shape[0] - 1
else:
self.__num_nodes__ = max(self.row.max().item(), self.col.max().item()) + 1
return self.__num_nodes__
@property
def row_ptr_v(self):
return self.row_ptr
@property
def device(self):
return self.row.device if self.row is not None else self.row_ptr.device
@property
def keys(self):
r"""Returns all names of graph attributes."""
keys = [key for key in self.__dict__.keys() if self[key] is not None]
keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
return keys
def __out_repr__(self):
if self.row is not None:
info = ["{}={}".format("edge_index", [2] + list(self.row.size()))]
else:
info = ["{}={}".format(key, list(self[key].size())) for key in ["row", "col"] if self[key] is not None]
attr_key = self.__attr_keys__()
info += ["edge_{}={}".format(key, list(self[key].size())) for key in attr_key if self[key] is not None]
return info
def __getitem__(self, item):
assert type(item) == str, f"{item} must be str"
if item[0] == "_" and item[1] != "_":
# item = re.search("[_]*(.*)", item).group(1)
item = item[1:]
if item.startswith("edge_") and item != "edge_index":
item = item[5:]
if item in self.__dict__:
return self.__dict__[item]
else:
raise KeyError(f"{item} not in Adjacency")
def __copy__(self):
result = self.__class__()
for key in self.keys:
setattr(result, key, copy.copy(self[key]))
result.__num_nodes__ = self.__num_nodes__
return result
def __deepcopy__(self, memodict={}):
result = self.__class__()
memodict[id(self)] = result
for k in self.keys:
v = self[k]
setattr(result, k, copy.deepcopy(v, memodict))
result.__num_nodes__ = self.__num_nodes__
return result
def __repr__(self):
info = [
"{}={}".format(key, list(self[key].size()))
for key in self.keys
if not key.startswith("__") and self[key] is not None
]
return "{}({})".format(self.__class__.__name__, ", ".join(info))
def __attr_keys__(self):
return [x for x in self.keys if "row" not in x and "col" not in x]
[docs] def clone(self):
return Adjacency.from_dict({k: v.clone() for k, v in self})
[docs] def to_scipy_csr(self):
data = self.get_weight().cpu().numpy()
num_nodes = int(self.num_nodes)
if self.row_ptr is None:
row = self.row.cpu().numpy()
col = self.col.cpu().numpy()
mx = sp.csr_matrix((data, (row, col)), shape=(num_nodes, num_nodes))
else:
row_ptr = self.row_ptr.cpu().numpy()
col_ind = self.col.cpu().numpy()
mx = sp.csr_matrix((data, col_ind, row_ptr), shape=(num_nodes, num_nodes))
return mx
[docs] def to_networkx(self, weighted=True):
gnx = nx.Graph()
gnx.add_nodes_from(np.arange(self.num_nodes))
row, col = self.edge_index
row = row.tolist()
col = col.tolist()
if weighted:
weight = self.get_weight().tolist()
gnx.add_weighted_edges_from([(row[i], col[i], weight[i]) for i in range(len(row))])
else:
edges = torch.stack((row, col)).cpu().numpy().transpose()
gnx.add_edges_from(edges)
return gnx
[docs] def random_walk(self, seeds, length=1, restart_p=0.0, parallel=True):
if not hasattr(self, "__walker__"):
scipy_adj = self.to_scipy_csr()
self.__walker__ = RandomWalker(scipy_adj)
return self.__walker__.walk(seeds, length, restart_p=restart_p, parallel=parallel)
[docs] @staticmethod
def from_dict(dictionary):
r"""Creates a data object from a python dictionary."""
data = Adjacency()
for key, item in dictionary.items():
data[key] = item
return data
KEY_MAP = {"edge_weight": "weight", "edge_attr": "attr", "edge_types": "types"}
EDGE_INDEX = "edge_index"
EDGE_WEIGHT = "edge_weight"
EDGE_ATTR = "edge_attr"
ROW_PTR = "row_indptr"
COL_INDICES = "col_indices"
def is_adj_key_train(key):
return key.endswith("_train") and is_read_adj_key(key)
def is_adj_key(key):
return key in ["row", "col", "row_ptr", "attr", "weight", "types"] or key.startswith("edge_")
def is_read_adj_key(key):
return sum([x in key for x in [EDGE_INDEX, EDGE_WEIGHT, EDGE_ATTR]]) > 0 or is_adj_key(key)
[docs]class Graph(BaseGraph):
def __init__(self, x=None, y=None, **kwargs):
super(Graph, self).__init__()
if x is not None:
if not torch.is_tensor(x):
raise ValueError("Node features must be Tensor")
self.x = x
self.y = y
self.grb_adj = None
num_nodes = x.shape[0] if x is not None else None
for key, item in kwargs.items():
if key == "num_nodes":
self.__num_nodes__ = item
num_nodes = item
elif key == "grb_adj":
self.grb_adj = item
elif not is_read_adj_key(key):
self[key] = item
if "edge_index_train" in kwargs:
self._adj_train = Adjacency(num_nodes=num_nodes)
for key, item in kwargs.items():
if is_adj_key_train(key):
_key = re.search(r"(.*)_train", key).group(1)
if _key.startswith("edge_"):
_key = _key.split("edge_")[1]
if _key == "index":
self._adj_train.edge_index = item
else:
self._adj_train[_key] = item
else:
self._adj_train = None
self._adj_full = Adjacency(num_nodes=num_nodes)
for key, item in kwargs.items():
if is_read_adj_key(key) and not is_adj_key_train(key):
if key.startswith("edge_"):
key = key.split("edge_")[-1]
if key == "index":
self._adj_full.edge_index = item
else:
self._adj_full[key] = item
self._adj = self._adj_full
self.__is_train__ = False
self.__temp_adj_stack__ = list()
self.__temp_storage__ = dict()
[docs] def train(self):
self.__is_train__ = True
if self._adj_train is not None:
self._adj = self._adj_train
return self
[docs] def eval(self):
self._adj = self._adj_full
self.__is_train__ = False
return self
[docs] def add_remaining_self_loops(self):
self._adj_full.add_remaining_self_loops()
if self._adj_train is not None:
self._adj_train.add_remaining_self_loops()
return self
[docs] def padding_self_loops(self):
self._adj.padding_self_loops()
return self
[docs] def remove_self_loops(self):
self._adj_full.remove_self_loops()
if self._adj_train is not None:
self._adj_train.remove_self_loops()
return self
[docs] def row_norm(self):
self._adj.row_norm()
[docs] def col_norm(self):
self._adj.col_norm()
[docs] def sym_norm(self):
self._adj.sym_norm()
[docs] def normalize(self, key="sym"):
assert key in ["row", "sym", "col"], "Support row/col/sym normalization"
getattr(self, f"{key}_norm")()
[docs] def is_symmetric(self):
return self._adj.is_symmetric()
[docs] def set_symmetric(self):
self._adj.set_symmetric(True)
[docs] def set_asymmetric(self):
self._adj.set_symmetric(False)
[docs] def is_inductive(self):
return self._adj_train is not None
[docs] def mask2nid(self, split):
mask = getattr(self, f"{split}_mask")
if mask is not None:
if mask.dtype is torch.bool:
return torch.where(mask)[0]
return mask
@property
def train_nid(self):
return self.mask2nid("train")
@property
def val_nid(self):
return self.mask2nid("val")
@property
def test_nid(self):
return self.mask2nid("test")
[docs] @contextmanager
def local_graph(self):
self.__temp_adj_stack__.append(self._adj)
adj = copy.copy(self._adj)
others = [(key, val) for key, val in self.__dict__.items() if not key.startswith("__") and "adj" not in key]
self._adj = adj
yield
del adj
self._adj = self.__temp_adj_stack__.pop()
for key, val in others:
self[key] = val
@property
def edge_index(self):
return self._adj.edge_index
@property
def edge_weight(self):
"""Return actual edge_weight"""
return self._adj.get_weight()
@property
def raw_edge_weight(self):
"""Return edge_weight without __in_norm__ and __out_norm__, only used for SpMM"""
return self._adj.get_weight("raw")
@property
def edge_attr(self):
return self._adj.attr
@property
def edge_types(self):
return self._adj.types
@edge_index.setter
def edge_index(self, edge_index):
if edge_index is None:
self._adj.row = None
self._adj.col = None
self.__num_nodes__ = 0
else:
row, col = edge_index
if self._adj.row is not None and row.shape[0] != self._adj.row.shape[0]:
self._adj.row_ptr = None
self._adj.row = row
self._adj.col = col
if self.x is not None:
self._adj.__num_nodes__ = self.x.shape[0]
self.__num_nodes__ = self.x.shape[0]
else:
self.__num_nodes__ = None
@edge_weight.setter
def edge_weight(self, edge_weight):
self._adj.set_weight(edge_weight)
@edge_attr.setter
def edge_attr(self, edge_attr):
self._adj.attr = edge_attr
@edge_types.setter
def edge_types(self, edge_types):
self._adj.types = edge_types
@property
def row_indptr(self):
return self._adj.row_indptr
@property
def col_indices(self):
if self._adj.row_ptr is None:
self._adj._to_csr()
return self._adj.col
@row_indptr.setter
def row_indptr(self, row_ptr):
self._adj.row_ptr = row_ptr
@col_indices.setter
def col_indices(self, col_indices):
self._adj.col = col_indices
@property
def in_norm(self):
return self._adj.__in_norm__
@property
def out_norm(self):
return self._adj.__out_norm__
@property
def keys(self):
keys = [key for key in self.__dict__.keys() if self[key] is not None]
keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
return keys
@property
def device(self):
return self._adj.device
[docs] def degrees(self):
return self._adj.degrees()
def __keys__(self):
keys = [key for key in self.keys if "adj" not in key]
return keys
def __old_keys__(self):
keys = self.__keys__()
keys += [EDGE_INDEX, EDGE_ATTR]
return keys
def __getitem__(self, key):
r"""Gets the data of the attribute :obj:`key`."""
if is_adj_key(key):
if key[0] == "_" and key[1] != "_":
key = key[1:]
if key.startswith("edge_") and key != "edge_index":
key = key[5:]
return getattr(self._adj, key)
else:
return getattr(self, key)
def __setitem__(self, key, value):
if is_adj_key(key):
if key[0] == "_" and key[1] != "_":
key = key[1:]
if key.startswith("edge_") and key != "edge_index":
key = key[5:]
self._adj[key] = value
else:
setattr(self, key, value)
@property
def num_edges(self):
r"""Returns the number of edges in the graph."""
return self._adj.num_edges
@property
def num_features(self):
r"""Returns the number of features per node in the graph."""
if self.x is None:
return 0
return 1 if self.x.dim() == 1 else self.x.size(1)
@property
def num_nodes(self):
if hasattr(self, "__num_nodes__") and self.__num_nodes__ is not None:
return self.__num_nodes__
elif self.x is not None:
return self.x.shape[0]
else:
return self._adj.num_nodes
@property
def num_classes(self):
if self.y is not None:
return int(torch.max(self.y) + 1) if self.y.dim() == 1 else self.y.shape[-1]
@num_nodes.setter
def num_nodes(self, num_nodes):
self.__num_nodes__ = num_nodes
[docs] @staticmethod
def from_pyg_data(data):
val = {k: v for k, v in data}
return Graph(**val)
[docs] def clone(self):
return Graph.from_dict({k: v.clone() for k, v in self})
[docs] def store(self, key):
if hasattr(self, key) and not callable(getattr(self, key)):
self.__temp_storage__[key] = copy.deepcopy(getattr(self, key))
if hasattr(self._adj, key) and not callable(getattr(self._adj, key)):
self.__temp_storage__[key] = copy.deepcopy(getattr(self._adj, key))
[docs] def restore(self, key):
if key in self.__temp_storage__:
if hasattr(self, key) and not callable(getattr(self, key)):
setattr(self, key, self.__temp_storage__[key])
elif hasattr(self._adj, key) and not callable(getattr(self._adj, key)):
self(self._adj, key, self.__temp_storage__[key])
self.__temp_storage__.pop(key)
def __delitem__(self, key):
if hasattr(self, key):
self[key] = None
def __repr__(self):
info = [
"{}={}".format(key, list(self[key].size()))
for key in self.__keys__()
if not key.startswith("_") and hasattr(self[key], "size")
]
info += self._adj.__out_repr__()
return "{}({})".format(self.__class__.__name__, ", ".join(info))
[docs] def sample_adj(self, batch, size=-1, replace=True):
if sample_adj_c is not None:
if not torch.is_tensor(batch):
batch = torch.tensor(batch, dtype=torch.long)
(row_ptr, col_indices, nodes, edges) = sample_adj_c(
self.row_indptr, self.col_indices, batch, size, replace
)
else:
if torch.is_tensor(batch):
batch = batch.cpu().numpy()
if self.__is_train__ and self._adj_train is not None:
key = "__mx_train__"
else:
key = "__mx__"
if not hasattr(self, key):
row, col = self._adj.row.numpy(), self._adj.col.numpy()
val = self.edge_weight.numpy()
N = self.num_nodes
self[key] = sp.csr_matrix((val, (row, col)), shape=(N, N))
adj = self[key][batch, :]
indptr = adj.indptr
indices = adj.indices
if size != -1:
indptr, indices = self._sample_adj(len(batch), indices, indptr, size)
indptr = indptr.numpy()
indices = indices.numpy()
col_nodes = np.unique(indices)
_node_idx = np.concatenate([batch, np.setdiff1d(col_nodes, batch)])
nodes = torch.tensor(_node_idx, dtype=torch.long)
assoc_dict = {v: i for i, v in enumerate(_node_idx)}
col_indices = torch.tensor([assoc_dict[i] for i in indices], dtype=torch.long)
row_ptr = torch.tensor(indptr, dtype=torch.long)
if row_ptr.shape[0] - 1 < nodes.shape[0]:
padding = torch.full((nodes.shape[0] - row_ptr.shape[0] + 1,), row_ptr[-1].item(), dtype=row_ptr.dtype)
row_ptr = torch.cat([row_ptr, padding])
g = Graph(row_ptr=row_ptr, col=col_indices)
return nodes, g
def _sample_adj(self, batch_size, indices, indptr, size):
if not torch.is_tensor(indices):
indices = torch.from_numpy(indices)
if not torch.is_tensor(indptr):
indptr = torch.from_numpy(indptr)
assert indptr.shape[0] - 1 == batch_size
row_counts = (indptr[1:] - indptr[:-1]).long()
rand = torch.rand(batch_size, size)
rand = rand * row_counts.view(-1, 1)
rand = rand.long()
rand = rand + indptr[:-1].view(-1, 1)
edge_cols = indices[rand].view(-1)
row_ptr = torch.arange(0, batch_size * size + size, size)
return row_ptr, edge_cols
[docs] def csr_subgraph(self, node_idx, keep_order=False):
if self._adj.row_ptr_v is None:
self._adj._to_csr()
if torch.is_tensor(node_idx):
node_idx = node_idx.cpu()
else:
node_idx = torch.as_tensor(node_idx)
if not keep_order:
node_idx = torch.unique(node_idx)
indptr, indices, nodes, edges = subgraph_c(self._adj.row_ptr, self._adj.col, node_idx)
nodes_idx = node_idx.to(self._adj.device)
data = Graph(row_ptr=indptr, col=indices)
for key in self.__keys__():
data[key] = self[key][nodes_idx]
for key in self._adj.keys:
if "row" in key or "col" in key:
continue
if key.startswith("__"):
continue
data._adj[key] = self._adj[key][edges]
data.num_nodes = node_idx.shape[0]
data.edge_weight = None
return data
[docs] def subgraph(self, node_idx, keep_order=False):
if subgraph_c is not None:
if isinstance(node_idx, list):
node_idx = torch.as_tensor(node_idx, dtype=torch.long)
elif isinstance(node_idx, np.ndarray):
node_idx = torch.from_numpy(node_idx)
return self.csr_subgraph(node_idx, keep_order)
else:
if isinstance(node_idx, list):
node_idx = np.array(node_idx, dtype=np.int64)
elif torch.is_tensor(node_idx):
node_idx = node_idx.long().cpu().numpy()
if self.__is_train__ and self._adj_train is not None:
key = "__mx_train__"
else:
key = "__mx__"
if not hasattr(self, key):
row = self._adj.row.numpy()
col = self._adj.col.numpy()
val = self.edge_weight.numpy()
N = self.num_nodes
self[key] = sp.csr_matrix((val, (row, col)), shape=(N, N))
sub_adj = self[key][node_idx, :][:, node_idx].tocoo()
sub_g = Graph()
# sub_g.row_indptr = torch.from_numpy(sub_adj.indptr).long()
# sub_g.col_indices = torch.from_numpy(sub_adj.indices).long()
row = torch.from_numpy(sub_adj.row).long()
col = torch.from_numpy(sub_adj.col).long()
sub_g.edge_index = (row, col)
sub_g.edge_weight = torch.from_numpy(sub_adj.data)
sub_g.num_nodes = len(node_idx)
for key in self.__keys__():
sub_g[key] = self[key][node_idx]
sub_g._adj._to_csr()
return sub_g.to(self._adj.device)
[docs] def edge_subgraph(self, edge_idx, require_idx=True):
row, col = self._adj.edge_index
row = row[edge_idx]
col = col[edge_idx]
edge_index = torch.stack([row, col])
nodes, new_edge_index = torch.unique(edge_index, return_inverse=True)
g = Graph(edge_index=new_edge_index)
for key in self.__keys__():
g[key] = self[key][nodes]
if require_idx:
return g, nodes, edge_idx
else:
return g
[docs] def random_walk(self, seeds, max_nodes_per_seed, restart_p=0.0, parallel=True):
return self._adj.random_walk(seeds, max_nodes_per_seed, restart_p, parallel)
[docs] def random_walk_with_restart(self, seeds, max_nodes_per_seed, restart_p=0.0, parallel=True):
return self._adj.random_walk(seeds, max_nodes_per_seed, restart_p, parallel)
[docs] def to_scipy_csr(self):
return self._adj.to_scipy_csr()
[docs] def to_networkx(self):
return self._adj.to_networkx()
[docs] @staticmethod
def from_dict(dictionary):
r"""Creates a data object from a python dictionary."""
data = Graph()
for key, item in dictionary.items():
data[key] = item
return data
[docs] def nodes(self):
return torch.arange(self.num_nodes)
[docs] def set_grb_adj(self, adj):
self.grb_adj = adj
# @property
# def requires_grad(self):
# return False
#
# @requires_grad.setter
# def requires_grad(self, x):
# print(f"Set `requires_grad` to {x}")