Skip to content
/ grug Public
forked from YXNTU/Grug

Code for "A Unified Gradient Regularization Method for Heterogeneous Graph Neural Networks".

Notifications You must be signed in to change notification settings

RaulRX/grug

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

A Unified Gradient Regularization Method for Heterogeneous Graph Neural Networks

This repository contains the official implementation of our paper "A unified gradient regularization method for heterogeneous graph neural networks", published in Neural Networks (2025).

Short Abstract:
Heterogeneous Graph Neural Networks (HGNNs) often suffer from over-smoothing and unstable training.
We introduce Grug, a unified gradient regularization framework that jointly regularizes node-type and message-level gradients during message passing.
Grug provides theoretical guarantees on stability and diversity, unifies existing dropping and adversarial methods, and achieves state-of-the-art performance on six public datasets with minimal overhead.

📄 Read the paper: https://www.sciencedirect.com/science/article/abs/pii/S0893608025009840


Node Classification Results (Macro/Micro F1)

Model Method ACM Macro F1 ACM Macro F1 DBLP Micro F1 DBLP Macro F1 IMDB Micro F1 IMDB Macro F1
RGCN Clean 90.36±0.60 90.44±0.59 93.67±0.15 92.77±0.16 56.05±1.84 55.43±1.64
Dropout 91.60±0.48 91.32±0.56 94.35±0.14 93.53±0.14 58.21±0.44 61.87±0.33
DropNode 91.68±0.75 91.74±0.77 93.85±0.71 93.12±0.71 58.38±0.37 62.00±0.50
DropMessage 92.23±0.43 92.11±0.79 94.69±0.13 93.96±0.13 58.54±0.36 62.00±0.36
FLAG 92.43±0.37 92.58±0.34 94.22±0.20 93.39±0.24 58.74±0.33 62.10±0.40
Grug 93.66±0.20 93.79±0.46 95.11±0.08 94.20±0.12 60.60±0.60 63.07±0.46
RGAT Clean 89.32±1.03 89.34±1.03 92.98±0.61 92.15±0.63 53.12±0.80 52.75±0.78
Dropout 89.38±0.81 89.51±0.76 93.55±0.48 92.69±0.50 54.85±0.31 53.63±0.26
DropNode 89.86±0.67 89.98±0.60 93.21±0.67 92.42±0.72 53.86±1.56 52.80±1.60
DropMessage 90.57±0.62 90.57±0.62 93.89±0.35 93.01±0.39 56.48±0.76 55.29±0.78
FLAG 91.46±0.59 91.53±0.58 93.38±0.45 92.42±0.35 56.09±0.68 54.47±0.66
Grug 92.99±0.35 93.47±0.28 94.33±0.33 93.60±0.32 58.80±0.47 57.24±0.49
HGT Clean 91.11±0.37 91.17±0.38 94.08±0.41 93.19±0.51 61.89±0.78 66.05±0.45
Dropout 91.49±0.33 91.57±0.48 94.35±0.38 93.42±0.49 62.66±0.69 67.11±0.47
DropNode 91.33±0.32 91.42±0.40 94.44±0.42 93.59±0.39 63.73±0.85 67.03±0.51
DropMessage 91.68±0.56 91.69±0.57 94.69±0.38 93.71±0.36 63.10±0.49 67.32±0.37
FLAG 91.17±0.47 91.22±0.48 94.53±0.38 93.73±0.39 63.18±0.80 67.27±0.64
Grug 92.00±0.27 92.05±0.24 95.00±0.43 94.18±0.35 63.66±0.87 67.97±0.57
HAN Clean 89.41±0.17 89.52±0.15 84.37±2.29 83.33±2.21 57.11±0.95 63.27±0.70
Dropout 90.73±0.62 90.61±0.59 85.02±1.17 84.22±1.29 57.86±1.21 63.88±0.89
DropNode 90.65±0.17 90.83±0.14 85.17±1.32 84.60±1.36 57.41±1.28 63.64±0.81
DropMessage 91.56±0.89 91.01±0.68 87.91±0.97 86.88±0.92 57.93±1.00 63.95±0.59
FLAG 91.70±0.83 91.39±0.84 87.68±0.35 86.23±0.50 57.23±0.71 63.75±0.28
Grug 92.00±0.86 91.59±0.42 89.00±0.43 87.15±0.50 59.01±0.79 64.17±0.23
SimpleHGN Clean 92.79±0.50 92.92±0.47 94.23±0.45 93.50±0.51 63.07±2.03 66.55±1.99
Dropout 93.21±0.66 93.33±0.52 94.38±0.33 93.68±0.50 63.57±1.76 67.24±0.68
DropNode 93.17±0.54 92.97±0.48 94.33±0.37 93.59±0.47 63.42±0.70 67.63±0.75
DropMessage 93.34±0.55 93.52±0.57 94.61±0.28 94.13±0.76 63.58±1.05 67.72±0.68
FLAG 93.23±0.74 93.51±0.74 94.56±0.40 93.79±0.44 62.91±1.08 67.91±0.93
Grug 93.78±0.70 93.94±0.70 94.88±0.30 94.49±0.61 63.70±0.80 68.45±0.73

Link Prediction Results (AUC-ROC)

Model Method Amazon AUC-ROC LastFM AUC-ROC
RGCN Clean 70.54±5.00 56.66±5.92
Dropout 72.83±5.17 56.95±2.37
DropNode 70.81±3.85 56.97±3.86
DropMessage 81.80±0.57 58.01±4.65
FLAG 78.87±0.81 56.96±4.32
Grug 80.99±0.57 60.98±2.54
RGAT Clean 81.98±2.36 78.09±3.97
Dropout 81.74±6.72 78.94±3.43
DropNode 83.36±0.77 78.10±2.85
DropMessage 85.16±0.35 78.50±0.52
FLAG 85.59±0.25 78.66±1.86
Grug 86.74±0.27 80.21±2.66

Requirements

  • torch >= 1.12.1
  • torch-geometric >= 2.2.0
  • dgl >= 0.9.1

Experiments

Datasets

Baseline Models

Gradient Regularization Backbones

Quick Start

You can directly run the source code of Grug:

python main.py -m <model> -t <task> -d <dataset> -a <alpha> -b <beta> -M <iterations>

Arguments:

  • -m model: Model name (e.g., RGCN, RGAT, HGT).
  • -t task: Task name (node_classification or link_prediction).
  • -d dataset: Dataset name (ACM, DBLP, IMDB, ogbn-mag, etc.).
  • -a alpha: Coefficient for message matrix regularization.
  • -b beta: Coefficient for node matrix regularization.
  • -M iterations: Number of regularization iterations.

Example:

python main.py -m RGCN -t node_classification -d ACM -a 0.01 -b 0.01 -M 3

Citation

If you find this work helpful, please consider citing our paper:

@article{yang2025unified,
  title = {A Unified Gradient Regularization Method for Heterogeneous Graph Neural Networks},
  author = {Yang, Xiao and Zhao, Xuejiao and Shen, Zhiqi},
  journal = {Neural Networks},
  pages = {108104},
  year = {2025},
  publisher = {Elsevier}
}

About

Code for "A Unified Gradient Regularization Method for Heterogeneous Graph Neural Networks".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%