This repository contains the official implementation of our paper "A unified gradient regularization method for heterogeneous graph neural networks", published in Neural Networks (2025).
Short Abstract:
Heterogeneous Graph Neural Networks (HGNNs) often suffer from over-smoothing and unstable training.
We introduce Grug, a unified gradient regularization framework that jointly regularizes node-type and message-level gradients during message passing.
Grug provides theoretical guarantees on stability and diversity, unifies existing dropping and adversarial methods, and achieves state-of-the-art performance on six public datasets with minimal overhead.
📄 Read the paper: https://www.sciencedirect.com/science/article/abs/pii/S0893608025009840
| Model | Method | ACM Macro F1 | ACM Macro F1 | DBLP Micro F1 | DBLP Macro F1 | IMDB Micro F1 | IMDB Macro F1 |
|---|---|---|---|---|---|---|---|
| RGCN | Clean | 90.36±0.60 | 90.44±0.59 | 93.67±0.15 | 92.77±0.16 | 56.05±1.84 | 55.43±1.64 |
| Dropout | 91.60±0.48 | 91.32±0.56 | 94.35±0.14 | 93.53±0.14 | 58.21±0.44 | 61.87±0.33 | |
| DropNode | 91.68±0.75 | 91.74±0.77 | 93.85±0.71 | 93.12±0.71 | 58.38±0.37 | 62.00±0.50 | |
| DropMessage | 92.23±0.43 | 92.11±0.79 | 94.69±0.13 | 93.96±0.13 | 58.54±0.36 | 62.00±0.36 | |
| FLAG | 92.43±0.37 | 92.58±0.34 | 94.22±0.20 | 93.39±0.24 | 58.74±0.33 | 62.10±0.40 | |
| Grug | 93.66±0.20 | 93.79±0.46 | 95.11±0.08 | 94.20±0.12 | 60.60±0.60 | 63.07±0.46 | |
| RGAT | Clean | 89.32±1.03 | 89.34±1.03 | 92.98±0.61 | 92.15±0.63 | 53.12±0.80 | 52.75±0.78 |
| Dropout | 89.38±0.81 | 89.51±0.76 | 93.55±0.48 | 92.69±0.50 | 54.85±0.31 | 53.63±0.26 | |
| DropNode | 89.86±0.67 | 89.98±0.60 | 93.21±0.67 | 92.42±0.72 | 53.86±1.56 | 52.80±1.60 | |
| DropMessage | 90.57±0.62 | 90.57±0.62 | 93.89±0.35 | 93.01±0.39 | 56.48±0.76 | 55.29±0.78 | |
| FLAG | 91.46±0.59 | 91.53±0.58 | 93.38±0.45 | 92.42±0.35 | 56.09±0.68 | 54.47±0.66 | |
| Grug | 92.99±0.35 | 93.47±0.28 | 94.33±0.33 | 93.60±0.32 | 58.80±0.47 | 57.24±0.49 | |
| HGT | Clean | 91.11±0.37 | 91.17±0.38 | 94.08±0.41 | 93.19±0.51 | 61.89±0.78 | 66.05±0.45 |
| Dropout | 91.49±0.33 | 91.57±0.48 | 94.35±0.38 | 93.42±0.49 | 62.66±0.69 | 67.11±0.47 | |
| DropNode | 91.33±0.32 | 91.42±0.40 | 94.44±0.42 | 93.59±0.39 | 63.73±0.85 | 67.03±0.51 | |
| DropMessage | 91.68±0.56 | 91.69±0.57 | 94.69±0.38 | 93.71±0.36 | 63.10±0.49 | 67.32±0.37 | |
| FLAG | 91.17±0.47 | 91.22±0.48 | 94.53±0.38 | 93.73±0.39 | 63.18±0.80 | 67.27±0.64 | |
| Grug | 92.00±0.27 | 92.05±0.24 | 95.00±0.43 | 94.18±0.35 | 63.66±0.87 | 67.97±0.57 | |
| HAN | Clean | 89.41±0.17 | 89.52±0.15 | 84.37±2.29 | 83.33±2.21 | 57.11±0.95 | 63.27±0.70 |
| Dropout | 90.73±0.62 | 90.61±0.59 | 85.02±1.17 | 84.22±1.29 | 57.86±1.21 | 63.88±0.89 | |
| DropNode | 90.65±0.17 | 90.83±0.14 | 85.17±1.32 | 84.60±1.36 | 57.41±1.28 | 63.64±0.81 | |
| DropMessage | 91.56±0.89 | 91.01±0.68 | 87.91±0.97 | 86.88±0.92 | 57.93±1.00 | 63.95±0.59 | |
| FLAG | 91.70±0.83 | 91.39±0.84 | 87.68±0.35 | 86.23±0.50 | 57.23±0.71 | 63.75±0.28 | |
| Grug | 92.00±0.86 | 91.59±0.42 | 89.00±0.43 | 87.15±0.50 | 59.01±0.79 | 64.17±0.23 | |
| SimpleHGN | Clean | 92.79±0.50 | 92.92±0.47 | 94.23±0.45 | 93.50±0.51 | 63.07±2.03 | 66.55±1.99 |
| Dropout | 93.21±0.66 | 93.33±0.52 | 94.38±0.33 | 93.68±0.50 | 63.57±1.76 | 67.24±0.68 | |
| DropNode | 93.17±0.54 | 92.97±0.48 | 94.33±0.37 | 93.59±0.47 | 63.42±0.70 | 67.63±0.75 | |
| DropMessage | 93.34±0.55 | 93.52±0.57 | 94.61±0.28 | 94.13±0.76 | 63.58±1.05 | 67.72±0.68 | |
| FLAG | 93.23±0.74 | 93.51±0.74 | 94.56±0.40 | 93.79±0.44 | 62.91±1.08 | 67.91±0.93 | |
| Grug | 93.78±0.70 | 93.94±0.70 | 94.88±0.30 | 94.49±0.61 | 63.70±0.80 | 68.45±0.73 |
| Model | Method | Amazon AUC-ROC | LastFM AUC-ROC |
|---|---|---|---|
| RGCN | Clean | 70.54±5.00 | 56.66±5.92 |
| Dropout | 72.83±5.17 | 56.95±2.37 | |
| DropNode | 70.81±3.85 | 56.97±3.86 | |
| DropMessage | 81.80±0.57 | 58.01±4.65 | |
| FLAG | 78.87±0.81 | 56.96±4.32 | |
| Grug | 80.99±0.57 | 60.98±2.54 | |
| RGAT | Clean | 81.98±2.36 | 78.09±3.97 |
| Dropout | 81.74±6.72 | 78.94±3.43 | |
| DropNode | 83.36±0.77 | 78.10±2.85 | |
| DropMessage | 85.16±0.35 | 78.50±0.52 | |
| FLAG | 85.59±0.25 | 78.66±1.86 | |
| Grug | 86.74±0.27 | 80.21±2.66 |
torch >= 1.12.1torch-geometric >= 2.2.0dgl >= 0.9.1
You can directly run the source code of Grug:
python main.py -m <model> -t <task> -d <dataset> -a <alpha> -b <beta> -M <iterations>Arguments:
- -m model: Model name (e.g., RGCN, RGAT, HGT).
- -t task: Task name (node_classification or link_prediction).
- -d dataset: Dataset name (ACM, DBLP, IMDB, ogbn-mag, etc.).
- -a alpha: Coefficient for message matrix regularization.
- -b beta: Coefficient for node matrix regularization.
- -M iterations: Number of regularization iterations.
Example:
python main.py -m RGCN -t node_classification -d ACM -a 0.01 -b 0.01 -M 3If you find this work helpful, please consider citing our paper:
@article{yang2025unified,
title = {A Unified Gradient Regularization Method for Heterogeneous Graph Neural Networks},
author = {Yang, Xiao and Zhao, Xuejiao and Shen, Zhiqi},
journal = {Neural Networks},
pages = {108104},
year = {2025},
publisher = {Elsevier}
}