之前我的实现方式相对而言麻烦且准确率不够好,只能达到65%左右的准确率(Cora上),这里介绍直接用PyG封装好的GAT函数实现:
import torch
import math
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import GATConv
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.datasets import Planetoid
import ssl
import torch.nn.functional as F
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.gat1=GATConv(dataset.num_node_features,8,8,dropout=0.6)
self.gat2=GATConv(64,7,1,dropout=0.6)
def forward(self,data):
x,edge_index=data.x, data.edge_index
x=self.gat1(x,edge_index)
x=self.gat2(x,edge_index)
return F.log_softmax(x,dim=1)
ssl._create_default_https_context = ssl._create_unverified_context
dataset = Planetoid(root='Cora', name='Cora')
x=dataset[0].x
edge_index=dataset[0].edge_index
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(100):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct/int(data.test_mask.sum())
print('Accuracy:{:.4f}'.format(acc))
>>>Accuracy:0.7960
使用封装的GATConv操作就非常简单友好了,这里就不多加分析了。
本文介绍了如何利用PyG库中的GATConv模块,以简化和提升图注意力网络(GAT)的实现效果。在Cora数据集上的实验表明,使用封装的GATConv能获得约65%的准确率。
2229

被折叠的 条评论
为什么被折叠?



