Source code for rrgcn.random_rgcn_conv

# Based on PyG's RGCNConv implementation
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/rgcn_conv.py
from typing import Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.rgcn_conv import masked_edge_index
from torch_geometric.typing import Adj, OptTensor
from torch_sparse import SparseTensor, matmul

from .util import glorot_seed


[docs]class RandomRGCNConv(MessagePassing): def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, num_relations: int, seed: int = None, **kwargs, ): r"""Random graph convolution operation, characterized by a single seed. Args: in_channels (int or tuple): Size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities. In case no input features are given, this argument should correspond to the number of nodes in your graph. out_channels (int): Size of each output sample. num_relations (int): Number of relations. seed (int): Random seed (fully characterizes the layer). **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ super().__init__(aggr="mean", node_dim=0, **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.num_relations = num_relations if isinstance(in_channels, int): in_channels = (in_channels, in_channels) self.in_channels_l = in_channels[0] self.in_channels_r = in_channels[0] if seed is None: self.seed = np.random.randint(1000000) else: self.seed = seed rng = np.random.default_rng(self.seed) self.seeds = rng.integers(1e6, size=num_relations + 1) self.root = self.seeds[-1]
[docs] def forward( self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], edge_index: Adj, edge_type: OptTensor = None, ): r""" Args: x: The input node features. Can be either a :obj:`[num_nodes, in_channels]` node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings). Furthermore, :obj:`x` can be of type :obj:`tuple` denoting source and destination node features. edge_index (LongTensor or SparseTensor): The edge indices. edge_type: The one-dimensional relation type/index for each edge in :obj:`edge_index`. Should be only :obj:`None` in case :obj:`edge_index` is of type :class:`torch_sparse.tensor.SparseTensor`. (default: :obj:`None`) """ x_l: OptTensor = None x_l = x x_r: Tensor = x_l if isinstance(x, tuple): x_r = x[1] size = (x_l.size(0), x_r.size(0)) if isinstance(edge_index, SparseTensor): edge_type = edge_index.storage.value() assert edge_type is not None out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) for i in range(self.num_relations): weight = glorot_seed( (self.in_channels_l, self.out_channels), seed=self.seeds[i], device=x_r.device, ) tmp = masked_edge_index(edge_index, edge_type == i) if x_l.dtype == torch.long: out += self.propagate(tmp, x=weight[x_l], size=size) else: out += self.propagate(tmp, x=(x_l @ weight), size=size) del weight if self.root is not None: root = glorot_seed( (self.in_channels_l, self.out_channels), seed=self.root, device=x_r.device, ) out += root[x_r] if x_r.dtype == torch.long else x_r @ root return out
[docs] def message(self, x_j: Tensor) -> Tensor: return x_j
[docs] def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: adj_t = adj_t.set_value(None, layout=None) h = x return matmul(adj_t, h, reduce=self.aggr)
def __repr__(self) -> str: return ( f"{self.__class__.__name__}({self.in_channels}, " f"{self.out_channels}, num_relations={self.num_relations})" )