|  | 
|  | 1 | +import torch | 
|  | 2 | +import torch.nn as nn | 
|  | 3 | +import torch.nn.functional as F | 
|  | 4 | + | 
|  | 5 | + | 
|  | 6 | +class GraphAttentionLayer(nn.Module): | 
|  | 7 | +    """ | 
|  | 8 | +    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 | 
|  | 9 | +    """ | 
|  | 10 | + | 
|  | 11 | +    def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2, concat=True): | 
|  | 12 | +        super(GraphAttentionLayer, self).__init__() | 
|  | 13 | +        self.dropout = dropout | 
|  | 14 | +        self.in_features = in_features | 
|  | 15 | +        self.out_features = out_features | 
|  | 16 | +        self.alpha = alpha | 
|  | 17 | +        self.concat = concat | 
|  | 18 | + | 
|  | 19 | +        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) | 
|  | 20 | +        nn.init.xavier_uniform_(self.W.data, gain=1.414) | 
|  | 21 | +        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) | 
|  | 22 | +        nn.init.xavier_uniform_(self.a.data, gain=1.414) | 
|  | 23 | + | 
|  | 24 | +        self.leakyrelu = nn.LeakyReLU(self.alpha) | 
|  | 25 | + | 
|  | 26 | +    def forward(self, input, adj): | 
|  | 27 | +        h = torch.mm(input, self.W) | 
|  | 28 | +        N = h.size(0) | 
|  | 29 | + | 
|  | 30 | +        a_input = torch.cat((h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)), dim=1).view(N, -1, | 
|  | 31 | +        2 * self.out_features) | 
|  | 32 | +        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) | 
|  | 33 | + | 
|  | 34 | +        zero_vec = -9e15*torch.ones_like(e) | 
|  | 35 | +        attention = torch.where(adj > 0, e, zero_vec) | 
|  | 36 | +        attention = F.softmax(attention, dim=1) | 
|  | 37 | +        attention = F.dropout(attention, self.dropout, training=self.training) | 
|  | 38 | +        h_prime = torch.matmul(attention, h) | 
|  | 39 | + | 
|  | 40 | +        if self.concat: | 
|  | 41 | +            return F.elu(h_prime) | 
|  | 42 | +        else: | 
|  | 43 | +            return h_prime | 
0 commit comments