import random
import networkx as nx
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from cogdl.utils.transform import dropout_adj
from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_logreg
from tqdm import tqdm
from .. import UnsupervisedModelWrapper
[docs]class SelfAuxiliaryModelWrapper(UnsupervisedModelWrapper):
[docs] @staticmethod
def add_args(parser):
# fmt: off
parser.add_argument("--auxiliary-task", type=str, default="edge_mask",
help="Option: edge_mask, attribute_mask, distance2clusters,"
" pairwise_distance, pairwise_attr_sim")
parser.add_argument("--dropedge-rate", type=float, default=0.0)
parser.add_argument("--mask-ratio", type=float, default=0.1)
parser.add_argument("--sampling", action="store_true")
# fmt: on
def __init__(self, model, optimizer_cfg, auxiliary_task, dropedge_rate, mask_ratio, sampling):
super(SelfAuxiliaryModelWrapper, self).__init__()
self.auxiliary_task = auxiliary_task
self.optimizer_cfg = optimizer_cfg
self.hidden_size = optimizer_cfg["hidden_size"]
self.dropedge_rate = dropedge_rate
self.mask_ratio = mask_ratio
self.sampling = sampling
self.model = model
self.agent = None
[docs] def train_step(self, subgraph):
graph = subgraph
with graph.local_graph():
graph = self.agent.transform_data(graph)
pred = self.model(graph)
sup_loss = self.default_loss_fn(pred, graph.y)
pred = self.model.embed(graph)
ssl_loss = self.agent.make_loss(pred)
return sup_loss + ssl_loss
[docs] def test_step(self, graph):
self.model.eval()
with torch.no_grad():
pred = self.model.embed(graph)
y = graph.y
result = evaluate_node_embeddings_using_logreg(pred, y, graph.train_mask, graph.test_mask)
self.note("test_acc", result)
[docs] def pre_stage(self, stage, data_w):
if stage == 0:
data = data_w.get_dataset().data
self.generate_virtual_labels(data)
[docs] def generate_virtual_labels(self, data):
if self.auxiliary_task == "edge_mask":
self.agent = EdgeMask(self.hidden_size, self.mask_ratio, self.device)
elif self.auxiliary_task == "attribute_mask":
self.agent = AttributeMask(data, self.hidden_size, data.train_mask, self.mask_ratio, self.device)
elif self.auxiliary_task == "pairwise_distance":
self.agent = PairwiseDistance(
self.hidden_size, [(1, 2), (2, 3), (3, 5)], self.sampling, self.dropedge_rate, 256, self.device,
)
elif self.auxiliary_task == "distance2clusters":
self.agent = Distance2Clusters(self.hidden_size, 30, self.device)
elif self.auxiliary_task == "pairwise_attr_sim":
self.agent = PairwiseAttrSim(self.hidden_size, 5, self.device)
else:
raise Exception(
"auxiliary task must be edge_mask, attribute_mask, pairwise_distance, distance2clusters,"
"or pairwise_attr_sim"
)
[docs] def setup_optimizer(self):
lr, wd = self.optimizer_cfg["lr"], self.optimizer_cfg["weight_decay"]
return torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
class SSLTask:
def __init__(self, device):
self.device = device
self.cached_edges = None
def transform_data(self, graph):
raise NotImplementedError
def make_loss(self, embeddings):
raise NotImplementedError
class EdgeMask(SSLTask):
def __init__(self, hidden_size, mask_ratio, device):
super().__init__(device)
self.linear = nn.Linear(hidden_size, 2).to(device)
self.mask_ratio = mask_ratio
def transform_data(self, graph):
device = graph.x.device
num_edges = graph.num_edges
# if self.cached_edges is None:
row, col = graph.edge_index
edges = torch.stack([row, col])
perm = np.random.permutation(num_edges)
preserve_nnz = int(num_edges * (1 - self.mask_ratio))
masked = perm[preserve_nnz:]
preserved = perm[:preserve_nnz]
self.masked_edges = edges[:, masked]
self.cached_edges = edges[:, preserved]
mask_num = len(masked)
self.neg_edges = self.neg_sample(mask_num, graph).to(self.masked_edges.device)
self.pseudo_labels = torch.cat([torch.ones(mask_num), torch.zeros(mask_num)]).long().to(device)
self.node_pairs = torch.cat([self.masked_edges, self.neg_edges], 1).to(device)
graph.edge_index = self.cached_edges
return graph
def make_loss(self, embeddings):
embeddings = self.linear(torch.abs(embeddings[self.node_pairs[0]] - embeddings[self.node_pairs[1]]))
output = F.log_softmax(embeddings, dim=1)
return F.nll_loss(output, self.pseudo_labels)
def neg_sample(self, edge_num, graph):
edge_index = graph.edge_index
num_nodes = graph.num_nodes
edges = torch.stack(edge_index).t().cpu().numpy()
exclude = set([(_[0], _[1]) for _ in list(edges)])
itr = self.sample(exclude, num_nodes)
sampled = [next(itr) for _ in range(edge_num)]
return torch.tensor(sampled).t()
def sample(self, exclude, num_nodes):
while True:
t = tuple(np.random.randint(0, num_nodes, 2))
if t[0] != t[1] and t not in exclude:
exclude.add(t)
exclude.add((t[1], t[0]))
yield t
class AttributeMask(SSLTask):
def __init__(self, graph, hidden_size, train_mask, mask_ratio, device):
super().__init__(device)
self.linear = nn.Linear(hidden_size, graph.x.shape[1]).to(device)
self.cached_features = None
self.mask_ratio = mask_ratio
def transform_data(self, graph):
# if self.cached_features is None:
device = graph.x.device
x_feat = graph.x
num_nodes = graph.num_nodes
unlabelled = torch.where(~graph.train_mask)[0]
perm = np.random.permutation(unlabelled.cpu().numpy())
mask_nnz = int(num_nodes * self.mask_ratio)
self.masked_nodes = perm[:mask_nnz]
x_feat[self.masked_nodes] = 0
self.pseudo_labels = x_feat[self.masked_nodes].to(device)
graph.x = x_feat
return graph
def make_loss(self, embeddings):
embeddings = self.linear(embeddings[self.masked_nodes])
loss = F.mse_loss(embeddings, self.pseudo_labels, reduction="mean")
return loss
class PairwiseDistance(SSLTask):
def __init__(self, hidden_size, class_split, sampling, dropedge_rate, num_centers, device):
super().__init__(device)
self.nclass = len(class_split) + 1
self.class_split = class_split
self.max_distance = self.class_split[self.nclass - 2][1]
self.sampling = sampling
self.dropedge_rate = dropedge_rate
self.num_centers = num_centers
self.linear = nn.Linear(hidden_size, self.nclass).to(device)
self.get_distance_cache = False
def get_distance(self, graph):
num_nodes = graph.num_nodes
num_edges = graph.num_edges
edge_index = graph.edge_index
if self.sampling:
self.dis_node_pairs = [[] for i in range(self.nclass)]
node_idx = random.sample(range(num_nodes), self.num_centers)
adj = sp.coo_matrix(
(np.ones(num_edges), (edge_index[0].cpu().numpy(), edge_index[1].cpu().numpy())),
shape=(num_nodes, num_nodes),
).tocsr()
num_samples = tqdm(range(self.num_centers))
for i in num_samples:
num_samples.set_description(f"Generating node pairs {i:03d}")
idx = node_idx[i]
queue = [idx]
dis = -np.ones(num_nodes)
dis[idx] = 0
head = 0
tail = 0
cur_class = 0
stack = []
# bfs algorithm
while head <= tail:
u = queue[head]
if cur_class != self.nclass - 1 and dis[u] >= self.class_split[cur_class][1]:
sampled = random.sample(stack, 1024) if len(stack) > 1024 else stack
if self.dis_node_pairs[cur_class] == []:
self.dis_node_pairs[cur_class] = np.array([[idx] * len(sampled), sampled]).transpose()
else:
self.dis_node_pairs[cur_class] = np.concatenate(
(self.dis_node_pairs[cur_class], np.array([[idx] * len(sampled), sampled]).transpose()),
axis=0,
)
cur_class += 1
if cur_class == self.nclass - 1:
break
stack = []
if u != idx:
stack.append(u)
head += 1
i_s = adj.indptr[u]
i_e = adj.indptr[u + 1]
for i in range(i_s, i_e):
v = adj.indices[i]
if dis[v] == -1:
dis[v] = dis[u] + 1
tail += 1
queue.append(v)
remain = list(np.where(dis == -1)[0])
sampled = random.sample(remain, 1024) if len(remain) > 1024 else remain
if self.dis_node_pairs[cur_class] == []:
self.dis_node_pairs[cur_class] = np.array([[idx] * len(sampled), sampled]).transpose()
else:
self.dis_node_pairs[cur_class] = np.concatenate(
(self.dis_node_pairs[cur_class], np.array([[idx] * len(sampled), sampled]).transpose()), axis=0
)
if self.class_split[0][1] == 2:
self.dis_node_pairs[0] = torch.stack(edge_index).cpu().numpy().transpose()
num_per_class = np.min(np.array([len(dis) for dis in self.dis_node_pairs]))
for i in range(self.nclass):
sampled = np.random.choice(np.arange(len(self.dis_node_pairs[i])), num_per_class, replace=False)
self.dis_node_pairs[i] = self.dis_node_pairs[i][sampled]
else:
G = nx.Graph()
G.add_edges_from(torch.stack(edge_index).cpu().numpy().transpose())
path_length = dict(nx.all_pairs_shortest_path_length(G, cutoff=self.max_distance))
distance = -np.ones((num_nodes, num_nodes), dtype=np.int)
for u, p in path_length.items():
for v, d in p.items():
distance[u][v] = d - 1
self.distance = distance
self.dis_node_pairs = []
for i in range(self.nclass - 1):
tmp = np.array(
np.where((distance >= self.class_split[i][0]) * (distance < self.class_split[i][1]))
).transpose()
np.random.shuffle(tmp)
self.dis_node_pairs.append(tmp)
tmp = np.array(np.where(distance == -1)).transpose()
np.random.shuffle(tmp)
self.dis_node_pairs.append(tmp)
def transform_data(self, graph):
if not self.get_distance_cache:
self.get_distance(graph)
self.get_distance_cache = True
graph.edge_index, _ = dropout_adj(edge_index=graph.edge_index, drop_rate=self.dropedge_rate)
return graph
def make_loss(self, embeddings, sample=True, k=4000):
node_pairs, pseudo_labels = self.sample(sample, k)
embeddings = self.linear(torch.abs(embeddings[node_pairs[0]] - embeddings[node_pairs[1]]))
output = F.log_softmax(embeddings, dim=1)
return F.nll_loss(output, pseudo_labels)
def sample(self, sample, k):
sampled = torch.tensor([]).long()
pseudo_labels = torch.tensor([]).long()
for i in range(self.nclass):
tmp = self.dis_node_pairs[i]
if sample:
x = int(random.random() * (len(tmp) - k))
sampled = torch.cat([sampled, torch.tensor(tmp[x : x + k]).long().t()], 1)
"""
indices = np.random.choice(np.arange(len(tmp)), k, replace=False)
sampled = torch.cat([sampled, torch.tensor(tmp[indices]).long().t()], 1)
"""
pseudo_labels = torch.cat([pseudo_labels, torch.ones(k).long() * i])
else:
sampled = torch.cat([sampled, torch.tensor(tmp).long().t()], 1)
pseudo_labels = torch.cat([pseudo_labels, torch.ones(len(tmp)).long() * i])
return sampled.to(self.device), pseudo_labels.to(self.device)
class Distance2Clusters(SSLTask):
def __init__(self, hidden_size, num_clusters, device):
super().__init__(device)
self.num_clusters = num_clusters
self.linear = nn.Linear(hidden_size, num_clusters).to(device)
self.gen_cluster_info_cache = False
def transform_data(self, graph):
if not self.gen_cluster_info_cache:
self.gen_cluster_info(graph)
self.gen_cluster_info_cache = True
return graph
def gen_cluster_info(self, graph, use_metis=False):
edge_index = graph.edge_index
num_nodes = graph.num_nodes
x = graph.x
G = nx.Graph()
G.add_edges_from(torch.stack(edge_index).cpu().numpy().transpose())
if use_metis:
import metis
_, parts = metis.part_graph(G, self.num_clusters)
else:
from sklearn.cluster import KMeans
clustering = KMeans(n_clusters=self.num_clusters, random_state=0).fit(x.cpu())
parts = clustering.labels_
node_clusters = [[] for i in range(self.num_clusters)]
for i, p in enumerate(parts):
node_clusters[p].append(i)
self.central_nodes = np.array([])
self.distance_vec = np.zeros((num_nodes, self.num_clusters))
for i in range(self.num_clusters):
subgraph = G.subgraph(node_clusters[i])
center = None
for node in subgraph.nodes:
if center is None or subgraph.degree[node] > subgraph.degree[center]:
center = node
np.append(self.central_nodes, center)
distance = dict(nx.shortest_path_length(G, source=center))
for node in distance:
self.distance_vec[node][i] = distance[node]
self.distance_vec = torch.tensor(self.distance_vec).float().to(self.device)
def make_loss(self, embeddings):
output = self.linear(embeddings)
return F.mse_loss(output, self.distance_vec, reduction="mean")
class PairwiseAttrSim(SSLTask):
def __init__(self, hidden_size, k, device):
super().__init__(device)
self.k = k
self.linear = nn.Linear(hidden_size, 1).to(self.device)
self.get_attr_sim_cache = False
def get_avg_distance(self, graph, idx_sorted, k, sampled):
edge_index = graph.edge_index
num_nodes = graph.num_nodes
self.G = nx.Graph()
self.G.add_edges_from(torch.stack(edge_index).cpu().numpy().transpose())
avg_min = 0
avg_max = 0
avg_sampled = 0
for i in range(num_nodes):
distance = dict(nx.shortest_path_length(self.G, source=i))
sum = 0
num = 0
for node in idx_sorted[i, :k]:
if node in distance:
sum += distance[node]
num += 1
if num:
avg_min += sum / num / num_nodes
sum = 0
num = 0
for node in idx_sorted[i, -k - 1 :]:
if node in distance:
sum += distance[node]
num += 1
if num:
avg_max += sum / num / num_nodes
sum = 0
num = 0
for node in idx_sorted[i, sampled]:
if node in distance:
sum += distance[node]
num += 1
if num:
avg_sampled += sum / num / num_nodes
return avg_min, avg_max, avg_sampled
def get_attr_sim(self, graph):
x = graph.x
num_nodes = graph.num_nodes
from sklearn.metrics.pairwise import cosine_similarity
sims = cosine_similarity(x.cpu().numpy())
idx_sorted = sims.argsort(1)
self.node_pairs = None
self.pseudo_labels = None
sampled = self.sample(self.k, num_nodes)
for i in range(num_nodes):
for node in np.hstack((idx_sorted[i, : self.k], idx_sorted[i, -self.k - 1 :], idx_sorted[i, sampled])):
pair = torch.tensor([[i, node]])
sim = torch.tensor([sims[i][node]])
self.node_pairs = pair if self.node_pairs is None else torch.cat([self.node_pairs, pair], 0)
self.pseudo_labels = sim if self.pseudo_labels is None else torch.cat([self.pseudo_labels, sim])
print(
"max k avg distance: {%.4f}, min k avg distance: {%.4f}, sampled k avg distance: {%.4f}"
% (self.get_avg_distance(graph, idx_sorted, self.k, sampled))
)
self.node_pairs = self.node_pairs.long().to(self.device)
self.pseudo_labels = self.pseudo_labels.float().to(self.device)
def sample(self, k, num_nodes):
sampled = []
for i in range(k):
sampled.append(int(random.random() * (num_nodes - self.k * 2)) + self.k)
return np.array(sampled)
def transform_data(self, graph):
if not self.get_attr_sim_cache:
self.get_attr_sim(graph)
return graph
def make_loss(self, embeddings):
node_pairs = self.node_pairs
output = self.linear(torch.abs(embeddings[node_pairs[0]] - embeddings[node_pairs[1]]))
return F.mse_loss(output, self.pseudo_labels, reduction="mean")