diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile.dgl_jupyter b/DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile.dgl_jupyter new file mode 100644 index 000000000..b85918990 --- /dev/null +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/Dockerfile.dgl_jupyter @@ -0,0 +1,22 @@ +# "Copyright Advanced Micro Devices, Inc. +# Licensed under the Apache License Version 2.0" + +############################################################################# +ARG BASE_IMAGE=rocm/dgl:latest +FROM ${BASE_IMAGE} AS dgl + +SHELL ["/bin/bash", "--login", "-c"] +ENV DGL_SRC_DIR="/src/dgl" +ENV DEEP_LEARNING_EXAMPLES="/src/DeepLearningExamples" + +# Set the default command to an interactive bash shell. +COPY . ${DEEP_LEARNING_EXAMPLES} + +WORKDIR ${DEEP_LEARNING_EXAMPLES} +RUN pip install -r requirements.txt && pip install -e . +ENV LD_LIBRARY_PATH="/src/dgl/build/:${LD_LIBRARY_PATH}" + + +RUN pip install jupyter jupyterlab plotly torchinfo rdkit py3Dmol +# Run JupyterLab +CMD ["jupyter", "lab", "--ip=0.0.0.0", "--port=8888", "--no-browser", "--allow-root"] diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3transformer.ipynb b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3transformer.ipynb new file mode 100644 index 000000000..b515d360e --- /dev/null +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3transformer.ipynb @@ -0,0 +1,499 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3a3efc98-003c-4925-9305-8a19e613b8ab", + "metadata": {}, + "source": [ + "# SE(3)-Transformer Overview\n", + "The SE(3)-Transformer is a Graph Neural Network using a variant of self-attention for 3D points and graphs processing.\n", + "This model is equivariant under continuous 3D roto-translations, meaning that when the inputs (graphs or sets of points) rotate in 3D space\n", + "(or more generally experience a proper rigid transformation), the model outputs either stay invariant or transform with the input." + ] + }, + { + "cell_type": "markdown", + "id": "5422b34f", + "metadata": {}, + "source": [ + "These imports set up the full SE(3)-Transformer training and evaluation pipeline on the QM9 molecular dataset — covering data loading, distributed training, optimization, logging, and inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2857ce49-1b69-4fab-8006-5cc8e79b0da6", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "import torch.nn as nn\n", + "import dgl\n", + "\n", + "from se3_transformer.data_loading import QM9DataModule\n", + "from se3_transformer.model import SE3TransformerPooled\n", + "from se3_transformer.model.fiber import Fiber\n", + "from se3_transformer.runtime.arguments import PARSER\n", + "from se3_transformer.runtime.callbacks import (\n", + " QM9MetricCallback,\n", + " QM9LRSchedulerCallback,\n", + ")\n", + "from se3_transformer.runtime.loggers import (\n", + " LoggerCollection,\n", + " DLLogger,\n", + ")\n", + "from se3_transformer.runtime.utils import (\n", + " seed_everything,\n", + " using_tensor_cores,\n", + ")\n", + "from se3_transformer.runtime.training import train" + ] + }, + { + "cell_type": "markdown", + "id": "7e2e60dc-0d33-4bc7-94f3-d818fb780b0b", + "metadata": {}, + "source": [ + "## Using the CLI's args to setup training\n", + "Using CLI Arguments to Set Up Training\n", + "\n", + "The SE(3)-Transformer example from the Deep Learning Examples repository was originally designed to run as a command-line program — but we can easily adapt it for use in Jupyter notebooks! The training configuration, including model, optimizer, and runtime settings, is managed through an argparse parser, which we can leverage directly within the notebook for flexible experimentation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73d1fe05-7c2f-441e-874e-0ed5f96daa0b", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment the line below to see all available training and runtime arguments\n", + "PARSER.print_help()\n", + "\n", + "# Adjust the following parameters as needed for your system configuration.\n", + "args = PARSER.parse_args(\n", + " [\n", + " \"--epochs\",\n", + " \"5\",\n", + " \"--eval_interval\",\n", + " \"1\",\n", + " \"--batch_size\",\n", + " \"240\",\n", + " \"--num_workers\",\n", + " \"16\",\n", + " \"--precompute_bases\",\n", + " \"--use_layer_norm\",\n", + " \"--norm\",\n", + " \"--save_ckpt_pat\",\n", + " \"model_qm9.pth\",\n", + " # If you want to load a model trained for 100 epochs, uncomment the line below\n", + " # \"--load_ckpt_path\",\n", + " # \"model_qm9_100.pth\",\n", + " ]\n", + ")\n", + "# Uncomment to verify that the args have been set properly\n", + "print(args)" + ] + }, + { + "cell_type": "markdown", + "id": "9e537654-b905-4248-b2bc-d8fd01e52e96", + "metadata": {}, + "source": [ + "## Dataset and Model Setup\n", + "We start by loading the QM9 molecular dataset using the QM9DataModule, which handles data preprocessing, batching, and splitting for training and evaluation.\n", + "Next, we initialize the SE(3)-Transformer model (SE3TransformerPooled) with input, edge, and output fibers that define how geometric and feature information flow through the network.\n", + "Finally, we define the L1 loss (nn.L1Loss) — a simple yet effective choice for molecular property regression tasks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ddc3d1de-70f5-45c5-8dd8-0ec346536f2f", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = QM9DataModule(**vars(args))\n", + "model = SE3TransformerPooled(\n", + " fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),\n", + " fiber_out=Fiber({0: args.num_degrees * args.num_channels}),\n", + " fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),\n", + " output_dim=1,\n", + " tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively\n", + " **vars(args),\n", + ")\n", + "loss_fn = nn.L1Loss()" + ] + }, + { + "cell_type": "markdown", + "id": "1ce1af6e-1c82-423e-b14c-ed837ab3edeb", + "metadata": {}, + "source": [ + "# ⌬ Inspecting the Molecules\n", + "Before diving into training, it’s helpful to visually inspect the molecules from the QM9 dataset.\n", + "We use RDKit to reconstruct 3D molecular structures from the graph data (node positions and atomic numbers) and py3Dmol for interactive visualization right inside the notebook.\n", + "\n", + "The convert_to_mol() function converts DGL graphs into RDKit Mol objects by building a representation and determining the bonds. Then, using an interactive widget, we can scroll through the validation set and view each molecule in 3D — a great way to sanity-check data preprocessing and bonding structure.\n", + "\n", + "Note : You might occasionally see RDKit throw a ValueError like\n", + "“Valence of atom X is larger than the allowed maximum”.\n", + "This happens when bond inference from raw coordinates produces chemically invalid structures.\n", + "It’s expected for a few molecules in QM9, since not all atomic configurations map perfectly back to valid 3D molecules. You can safely ignore these errors or skip those samples — they don’t affect the rest of the dataset or training. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c763b1a-1d3a-40e4-bd24-9a7d99e7efa3", + "metadata": {}, + "outputs": [], + "source": [ + "# Note: RDKit may raise a \"Valence of atom ... is larger than allowed\" error.\n", + "# This occurs when the inferred bonds don't form a valid molecule — it's expected for a few QM9 samples.\n", + "\n", + "from rdkit import Chem\n", + "from rdkit.Chem import rdDetermineBonds\n", + "import py3Dmol\n", + "from ipywidgets import interact, IntSlider\n", + "\n", + "\n", + "def convert_to_mol(graph: dgl.graph) -> Chem.Mol:\n", + " ptable = Chem.GetPeriodicTable()\n", + "\n", + " # extract positions and atomic numbers\n", + " raw_coords = graph.ndata[\"pos\"]\n", + " raw_atomic_numbers = graph.ndata[\"attr\"][:, 5]\n", + " n_atoms = raw_atomic_numbers.shape[0]\n", + "\n", + " # construct xyz\n", + " xyz_str = f\"{n_atoms}\\n\\n\"\n", + " for an, coords in zip(raw_atomic_numbers, raw_coords):\n", + " symb = ptable.GetElementSymbol(int(an))\n", + " xyz_str += f\"{symb} {coords[0]} {coords[1]} {coords[2]}\\n\"\n", + " mol = Chem.MolFromXYZBlock(xyz_str)\n", + "\n", + " # get bonds, and go from 2D ->\n", + " rdDetermineBonds.DetermineBonds(mol)\n", + " return mol\n", + "\n", + "\n", + "# Get the dataloader for the validation dataset\n", + "val_loader = datamodule.val_dataloader()\n", + "\n", + "\n", + "# Define a function that takes the index as a parameter\n", + "def visualize_molecule(index):\n", + " try:\n", + " mol = convert_to_mol(val_loader.dataset[index][0])\n", + " mb = Chem.MolToMolBlock(mol)\n", + " # Add your visualization code here\n", + " # For example, if you're using py3Dmol:\n", + " view = py3Dmol.view(width=400, height=400)\n", + " view.addModel(mb, \"sdf\")\n", + " view.setStyle({\"stick\": {}})\n", + " view.zoomTo()\n", + " return view.show()\n", + " except ValueError as e:\n", + " print(f\"We cannot visualize this molecule: {e}\")\n", + "\n", + "\n", + "# Create the interactive widget\n", + "interact(\n", + " visualize_molecule,\n", + " index=IntSlider(\n", + " min=0,\n", + " max=len(val_loader.dataset) - 1,\n", + " step=1,\n", + " value=18,\n", + " description=\"Validation Index:\",\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2f8a401a-9fa5-49ce-889a-da43ced4fc31", + "metadata": {}, + "source": [ + "# Model Summary\n", + "We can quickly inspect the SE(3)-Transformer architecture using torchinfo.summary, which prints a detailed overview of each layer, its input/output shapes, and the number of parameters. This helps us verify that the model is built correctly before training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a95615d1-e5a7-4a5f-9861-7712a87ae5bc", + "metadata": {}, + "outputs": [], + "source": [ + "from torchinfo import summary\n", + "\n", + "summary(model)" + ] + }, + { + "cell_type": "markdown", + "id": "fe5fd79d-b7a9-4061-bfd0-64b12e195786", + "metadata": {}, + "source": [ + "# Logging and callbacks\n", + "Before training, we set up logging, seeding, and callbacks to keep the experiment organized and reproducible. The logging level is set to INFO so key messages about configuration and progress are visible. If a random seed is provided, it is initialized to ensure reproducibility across runs. We create a DLLogger (wrapped in a LoggerCollection) to save logs, and configure callbacks like QM9MetricCallback for validation metrics and QM9LRSchedulerCallback for learning rate scheduling. Finally, all hyperparameters from args are recorded in the logger to track and reproduce experiments consistently." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf0496b0-b74d-4b4d-a460-b50c5f84b608", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize logging, set seed, configure loggers and training callbacks\n", + "logging.getLogger().setLevel(logging.INFO)\n", + "\n", + "if args.seed is not None:\n", + " logging.info(f\"Using seed {args.seed}\")\n", + " seed_everything(args.seed)\n", + "\n", + "logging.info(f\"Saving info to {args.log_dir}/{args.dllogger_name}\")\n", + "loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]\n", + "logger = LoggerCollection(loggers)\n", + "callbacks = [\n", + " QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix=\"validation\"),\n", + " QM9LRSchedulerCallback(logger, epochs=args.epochs),\n", + "]\n", + "logger.log_hyperparams(vars(args))" + ] + }, + { + "cell_type": "markdown", + "id": "bbe97844-8a8f-423f-b2c3-6ab44be451a8", + "metadata": {}, + "source": [ + "# Train\n", + "With everything configured, we’re ready to kick off training. The train() function orchestrates the entire training loop — running forward and backward passes, computing losses, updating parameters, and periodically evaluating on the validation set. It uses the dataloaders, callbacks, and logger we set up earlier to track progress, log metrics, and manage learning rate schedules throughout the training process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47df1b32-d214-4bee-8868-3e390068bc61", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "train(\n", + " model,\n", + " loss_fn,\n", + " datamodule.train_dataloader(),\n", + " datamodule.val_dataloader(),\n", + " callbacks,\n", + " logger,\n", + " args,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "87f45770", + "metadata": {}, + "source": [ + "# Visualizing Training Progress\n", + "After training, we can visualize and analyze the logged results. We import Plotly for interactive plotting and dllogger to access the saved training logs. Flushing the logger ensures all metrics have been written to disk before loading them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fc20202-5791-4a65-b4d6-edfcb008e847", + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go\n", + "import plotly.io as pio\n", + "import pandas as pd\n", + "from plotly.subplots import make_subplots\n", + "import json\n", + "import dllogger\n", + "import os\n", + "\n", + "# If we're loading a checkpoint, we need to use the saved log file\n", + "# otherwise, we'll use the current log file\n", + "if args.load_ckpt_path is not None:\n", + " LOG_FILE = os.path.join(\"results\", \"dllogger_results_100.json\")\n", + " if not os.path.exists(LOG_FILE):\n", + " raise FileNotFoundError(f\"Log file {LOG_FILE} does not exist, please copy the log file to the results directory or turn off checkpoint loading\")\n", + "else:\n", + " LOG_FILE = os.path.join(\"results\", args.log_dir, args.dllogger_name)\n", + " dllogger.flush()\n", + "\n", + "print(f\"Using log file: {LOG_FILE}\")\n", + "pio.renderers.default = \"notebook\"" + ] + }, + { + "cell_type": "markdown", + "id": "7e19d697", + "metadata": {}, + "source": [ + "This step parses and organizes the logged training data from dllogger_results.json. We read the file line by line, clean up any malformed entries, and filter out records without valid steps. Each log entry is then grouped by its training step, extracting key metrics such as training loss, learning rate, and validation MAE. The results are compiled into a tidy Pandas DataFrame, making it easier to visualize and analyze how model performance and learning dynamics evolved throughout training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6fa875d8-0a1d-4c27-978d-d5817d4cb9b1", + "metadata": {}, + "outputs": [], + "source": [ + "# Read and parse the data\n", + "with open(LOG_FILE, \"r\") as f:\n", + " logs = [json.loads(line.replace(\"DLLL\", \"\")) for line in f.readlines()]\n", + "\n", + "# Filter out entries where step is an empty list\n", + "logs = [log for log in logs if log.get(\"step\") != []]\n", + "\n", + "# Create a dictionary to aggregate metrics by step\n", + "metrics_by_step = {}\n", + "\n", + "for log in logs:\n", + " if log.get(\"type\") == \"LOG\":\n", + " step = log.get(\"step\")\n", + "\n", + " # Skip if step is not an integer or if it's the PARAMETER step\n", + " if not isinstance(step, int):\n", + " continue\n", + "\n", + " # Initialize the step if not exists\n", + " if step not in metrics_by_step:\n", + " metrics_by_step[step] = {\n", + " \"step\": step,\n", + " \"train loss\": None,\n", + " \"learning rate\": None,\n", + " \"validation MAE\": None,\n", + " }\n", + "\n", + " # Update metrics for this step\n", + " data = log.get(\"data\", {})\n", + " if \"train loss\" in data:\n", + " metrics_by_step[step][\"train loss\"] = data[\"train loss\"]\n", + " if \"learning rate\" in data:\n", + " metrics_by_step[step][\"learning rate\"] = data[\"learning rate\"]\n", + " if \"validation MAE\" in data:\n", + " metrics_by_step[step][\"validation MAE\"] = data[\"validation MAE\"]\n", + "\n", + "# Convert to DataFrame\n", + "df = pd.DataFrame(list(metrics_by_step.values()))\n", + "df = df.sort_values(\"step\").reset_index(drop=True)\n", + "\n", + "print(df)" + ] + }, + { + "cell_type": "markdown", + "id": "ac481624", + "metadata": {}, + "source": [ + "To get a clear picture of how training evolved, we plot the key metrics over epochs using Plotly. The figure below displays training loss, validation MAE, and learning rate in separate subplots, making it easy to observe the model’s convergence and learning dynamics. Ideally, you should see the training loss and validation MAE steadily decreasing as the learning rate adjusts — giving a quick visual confirmation that training progressed smoothly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fc27119-2ced-439f-9485-beb6ee2074bf", + "metadata": {}, + "outputs": [], + "source": [ + "# Create subplots\n", + "fig = make_subplots(\n", + " rows=3,\n", + " cols=1,\n", + " subplot_titles=(\"Train Loss\", \"Validation MAE\", \"Learning Rate\"),\n", + " vertical_spacing=0.08,\n", + ")\n", + "\n", + "# Train Loss\n", + "fig.add_trace(\n", + " go.Scatter(\n", + " x=df[\"step\"],\n", + " y=df[\"train loss\"],\n", + " mode=\"lines+markers\",\n", + " name=\"Train Loss\",\n", + " line=dict(color=\"blue\"),\n", + " ),\n", + " row=1,\n", + " col=1,\n", + ")\n", + "\n", + "# Validation MAE\n", + "fig.add_trace(\n", + " go.Scatter(\n", + " x=df[\"step\"],\n", + " y=df[\"validation MAE\"],\n", + " mode=\"lines+markers\",\n", + " name=\"Validation MAE\",\n", + " line=dict(color=\"red\"),\n", + " ),\n", + " row=2,\n", + " col=1,\n", + ")\n", + "\n", + "# Learning Rate\n", + "fig.add_trace(\n", + " go.Scatter(\n", + " x=df[\"step\"],\n", + " y=df[\"learning rate\"],\n", + " mode=\"lines+markers\",\n", + " name=\"Learning Rate\",\n", + " line=dict(color=\"green\"),\n", + " ),\n", + " row=3,\n", + " col=1,\n", + ")\n", + "\n", + "fig.update_xaxes(title_text=\"Epoch\", row=3, col=1)\n", + "fig.update_layout(height=1000, showlegend=False, title_text=\"SE(3) Training\")\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "107dfd25-5d7c-42a6-ac17-0f6ed1b55099", + "metadata": {}, + "source": [ + "# Conclusion\n", + "\n", + "In this notebook, we walked through the end-to-end workflow for training and evaluating an SE(3)-Transformer model on the QM9 molecular dataset. We explored how to set up training configurations originally designed for CLI use, adapted them for an interactive Jupyter workflow, and visualized molecules directly from graph data to validate preprocessing. We then built and trained the SE(3)-Transformer, logged its performance, and used interactive plots to analyze key metrics like loss, MAE, and learning rate over time.\n", + "\n", + "With the workflow now validated, this setup provides a strong foundation for scaling up experiments, benchmarking performance, and adapting the SE(3)-Transformer to more complex or domain-specific datasets." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/DGLPyTorch/DrugDiscovery/SE3Transformer/se3transformer_jupyter_setup.md b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3transformer_jupyter_setup.md new file mode 100644 index 000000000..d203067ee --- /dev/null +++ b/DGLPyTorch/DrugDiscovery/SE3Transformer/se3transformer_jupyter_setup.md @@ -0,0 +1,27 @@ +# Setup for Jupyter Notebook + +There are several steps to use the jupyter notebook: +1. Build the jupyter notebook docker image +2. Run the jupyter notebook docker image (which will start jupyter hub and you can access the notebook at `http://:8888`) + +```bash +# build the jupyter notebook image +docker build -t dgl:jupyter \ + --build-arg BASE_IMAGE= \ + -f Dockerfile.dgl_jupyter \ + . + +# run the jupyter notebook image (this will start jupyter hub +# and you can access the notebook at http://:8888) +docker run -it --rm --privileged \ + --cap-add=SYS_PTRACE \ + --ipc=host \ + --privileged=true \ + --network=host \ + --device=/dev/kfd \ + --device=/dev/dri \ + --group-add video \ + --security-opt seccomp=unconfined \ + dgl:jupyter +``` +