diff --git a/config/transductive/inference.yaml b/config/transductive/inference.yaml index 03f1b6a..4aa050d 100644 --- a/config/transductive/inference.yaml +++ b/config/transductive/inference.yaml @@ -36,7 +36,7 @@ optimizer: train: gpus: {{ gpus }} - batch_size: 8 + batch_size: {{ bs }} num_epoch: {{ epochs }} log_interval: 100 batch_per_epoch: {{ bpe }} diff --git a/script/run.py b/script/run.py index a7ae350..d4f8e06 100644 --- a/script/run.py +++ b/script/run.py @@ -3,6 +3,7 @@ import math import pprint from itertools import islice +from tqdm import tqdm import torch import torch_geometric as pyg @@ -131,7 +132,7 @@ def test(cfg, model, test_data, device, logger, filtered_data=None, return_metri rankings = [] num_negatives = [] tail_rankings, num_tail_negs = [], [] # for explicit tail-only evaluation needed for 5 datasets - for batch in test_loader: + for batch in tqdm(test_loader): t_batch, h_batch = tasks.all_negative(test_data, batch) t_pred = model(test_data, t_batch) h_pred = model(test_data, h_batch) diff --git a/ultra/layers.py b/ultra/layers.py index 32dca7d..9976605 100644 --- a/ultra/layers.py +++ b/ultra/layers.py @@ -104,7 +104,9 @@ def propagate(self, edge_index, size=None, **kwargs): size = self._check_input(edge_index, size) coll_dict = self._collect(self._fused_user_args, edge_index, size, kwargs) - msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict) + # PyG 2.5+ renaming + # msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict) + msg_aggr_kwargs = self.inspector.collect_param_data("message_and_aggregate", coll_dict) for hook in self._message_and_aggregate_forward_pre_hooks.values(): res = hook(self, (edge_index, msg_aggr_kwargs)) if res is not None: @@ -115,7 +117,8 @@ def propagate(self, edge_index, size=None, **kwargs): if res is not None: out = res - update_kwargs = self.inspector.distribute("update", coll_dict) + #update_kwargs = self.inspector.distribute("update", coll_dict) + update_kwargs = self.inspector.collect_param_data("update", coll_dict) out = self.update(out, **update_kwargs) for hook in self._propagate_forward_hooks.values(): diff --git a/ultra/rspmm/rspmm.py b/ultra/rspmm/rspmm.py index 52273f0..dbba6ad 100644 --- a/ultra/rspmm/rspmm.py +++ b/ultra/rspmm/rspmm.py @@ -181,7 +181,9 @@ def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum=" def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs): if extra_cflags is None: extra_cflags = ["-Ofast"] - if torch.backends.openmp.is_available(): + # Torch 2.2.1 for MacOS is compiled with OpenMP and compiling kernels with OpenMP + # requires bringing llvm and libomp, so skip that for MacOS and resort to a standard CPU version + if torch.backends.openmp.is_available() and not sys.platform.startswith('darwin'): extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"] else: extra_cflags.append("-DAT_PARALLEL_NATIVE")