Skip to content

Commit b13a909

Browse files
committed
PyG 2.5 compatibility
1 parent 871add0 commit b13a909

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

config/transductive/inference.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ optimizer:
3636

3737
train:
3838
gpus: {{ gpus }}
39-
batch_size: 8
39+
batch_size: {{ bs }}
4040
num_epoch: {{ epochs }}
4141
log_interval: 100
4242
batch_per_epoch: {{ bpe }}

script/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import pprint
55
from itertools import islice
6+
from tqdm import tqdm
67

78
import torch
89
import torch_geometric as pyg
@@ -131,7 +132,7 @@ def test(cfg, model, test_data, device, logger, filtered_data=None, return_metri
131132
rankings = []
132133
num_negatives = []
133134
tail_rankings, num_tail_negs = [], [] # for explicit tail-only evaluation needed for 5 datasets
134-
for batch in test_loader:
135+
for batch in tqdm(test_loader):
135136
t_batch, h_batch = tasks.all_negative(test_data, batch)
136137
t_pred = model(test_data, t_batch)
137138
h_pred = model(test_data, h_batch)

ultra/layers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def propagate(self, edge_index, size=None, **kwargs):
104104
size = self._check_input(edge_index, size)
105105
coll_dict = self._collect(self._fused_user_args, edge_index, size, kwargs)
106106

107-
msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict)
107+
# PyG 2.5+ renaming
108+
# msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict)
109+
msg_aggr_kwargs = self.inspector.collect_param_data("message_and_aggregate", coll_dict)
108110
for hook in self._message_and_aggregate_forward_pre_hooks.values():
109111
res = hook(self, (edge_index, msg_aggr_kwargs))
110112
if res is not None:
@@ -115,7 +117,8 @@ def propagate(self, edge_index, size=None, **kwargs):
115117
if res is not None:
116118
out = res
117119

118-
update_kwargs = self.inspector.distribute("update", coll_dict)
120+
#update_kwargs = self.inspector.distribute("update", coll_dict)
121+
update_kwargs = self.inspector.collect_param_data("update", coll_dict)
119122
out = self.update(out, **update_kwargs)
120123

121124
for hook in self._propagate_forward_hooks.values():

ultra/rspmm/rspmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="
181181
def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs):
182182
if extra_cflags is None:
183183
extra_cflags = ["-Ofast"]
184-
if torch.backends.openmp.is_available():
184+
# Torch 2.2.1 for MacOS is compiled with OpenMP and compiling kernels with OpenMP
185+
# requires bringing llvm and libomp, so skip that for MacOS and resort to a standard CPU version
186+
if torch.backends.openmp.is_available() and not sys.platform.startswith('darwin'):
185187
extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"]
186188
else:
187189
extra_cflags.append("-DAT_PARALLEL_NATIVE")

0 commit comments

Comments
 (0)