Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GRIT model #777

Merged
merged 23 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b2a8b25
Add GRIT layers and model
pweigel Dec 14, 2024
c7b8dba
Move linear encoders to embedding.py
pweigel Dec 14, 2024
4d6ad04
Add GRIT to __init__.py
pweigel Dec 14, 2024
8a8c9fb
Fix imports, add RRWP utils and graph
pweigel Dec 14, 2024
3d52f70
Cleaning up grit model/layers
pweigel Dec 14, 2024
4fecd5b
Cleaning up GRIT layers
pweigel Dec 14, 2024
02d4a0d
Remove duplicate pyg_softmax function
pweigel Dec 14, 2024
924d165
Merge the attention calc into forward, reduce amount of saved data
pweigel Dec 14, 2024
7ccb6f0
Significant improvements to naming, docstrings
pweigel Dec 17, 2024
bb232e9
Remove TODO
pweigel Dec 17, 2024
1a2d80c
Fix normalization layers
pweigel Dec 17, 2024
5c4b0aa
Updating new graph/edge definitions
pweigel Dec 17, 2024
c861d9d
Added example training script, bug fixes
pweigel Dec 17, 2024
e540b0e
Improving SANGraphHead and simplify dims
pweigel Dec 17, 2024
1e46b75
Merge branch 'graphnet-team:main' into grit
pweigel Dec 17, 2024
48d5e85
Add newline to make flake8 happy
pweigel Dec 17, 2024
68634cc
Improvements to GRIT arguments, added new position encodings, fixed p…
pweigel Dec 29, 2024
9673f50
Updating docstrings, removing old comments
pweigel Jan 6, 2025
aa70dc0
More docstring fixes, simplifying graph calculations
pweigel Jan 6, 2025
417ad3b
Remove KNNGraphNoPE, add distance setting for KNNGraph
pweigel Jan 6, 2025
422c182
Remove NoPE from __init__.py
pweigel Jan 6, 2025
f896062
Fix log_deg shape
pweigel Jan 7, 2025
acab6e5
Merge branch 'graphnet-team:main' into grit
pweigel Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 245 additions & 0 deletions examples/04_training/08_train_grit_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
"""Example of training Model."""

import os
from typing import Any, Dict, List, Optional

from pytorch_lightning.loggers import WandbLogger
import torch
from torch.optim.adam import Adam

from graphnet.constants import EXAMPLE_DATA_DIR, EXAMPLE_OUTPUT_DIR
from graphnet.data.constants import FEATURES, TRUTH
from graphnet.models import StandardModel
from graphnet.models.detector.prometheus import Prometheus
from graphnet.models.gnn import GRIT
from graphnet.models.graphs import KNNGraphRRWP
from graphnet.models.task.reconstruction import EnergyReconstruction
from graphnet.training.callbacks import PiecewiseLinearLR
from graphnet.training.loss_functions import LogCoshLoss
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger
from graphnet.data import GraphNeTDataModule
from graphnet.data.dataset import SQLiteDataset
from graphnet.data.dataset import ParquetDataset

# Constants
features = FEATURES.PROMETHEUS
truth = TRUTH.PROMETHEUS


def main(
path: str,
pulsemap: str,
target: str,
truth_table: str,
gpus: Optional[List[int]],
max_epochs: int,
early_stopping_patience: int,
batch_size: int,
num_workers: int,
wandb: bool = False,
) -> None:
"""Run example."""
# Construct Logger
logger = Logger()

# Initialise Weights & Biases (W&B) run
if wandb:
# Make sure W&B output directory exists
wandb_dir = "./wandb/"
os.makedirs(wandb_dir, exist_ok=True)
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=wandb_dir,
log_model=True,
)

logger.info(f"features: {features}")
logger.info(f"truth: {truth}")

# Configuration
config: Dict[str, Any] = {
"path": path,
"pulsemap": pulsemap,
"batch_size": batch_size,
"num_workers": num_workers,
"target": target,
"early_stopping_patience": early_stopping_patience,
"fit": {
"gpus": gpus,
"max_epochs": max_epochs,
"distribution_strategy": "ddp_find_unused_parameters_true",
},
"dataset_reference": (
SQLiteDataset if path.endswith(".db") else ParquetDataset
),
}

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
run_name = "grit_{}_example".format(config["target"])
if wandb:
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

walk_length = 6
graph_definition = KNNGraphRRWP(
detector=Prometheus(),
input_feature_names=features,
nb_nearest_neighbours=5,
walk_length=walk_length,
)
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
"truth": truth,
"truth_table": truth_table,
"features": features,
"graph_definition": graph_definition,
"pulsemaps": [config["pulsemap"]],
"path": config["path"],
},
train_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
test_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
)

training_dataloader = dm.train_dataloader
validation_dataloader = dm.val_dataloader

# Building model
backbone = GRIT(
nb_inputs=graph_definition.nb_outputs,
hidden_dim=32,
ksteps=walk_length,
)

task = EnergyReconstruction(
hidden_size=backbone.nb_outputs,
target_labels=config["target"],
loss_function=LogCoshLoss(),
transform_prediction_and_target=lambda x: torch.log10(x),
transform_inference=lambda x: torch.pow(10, x),
)

model = StandardModel(
graph_definition=graph_definition,
backbone=backbone,
tasks=[task],
optimizer_class=Adam,
optimizer_kwargs={"lr": 1e-03, "eps": 1e-03},
scheduler_class=PiecewiseLinearLR,
scheduler_kwargs={
"milestones": [
0,
len(training_dataloader) / 2,
len(training_dataloader) * config["fit"]["max_epochs"],
],
"factors": [1e-2, 1, 1e-02],
},
scheduler_config={
"interval": "step",
},
)

# Training model
model.fit(
training_dataloader,
validation_dataloader,
early_stopping_patience=config["early_stopping_patience"],
logger=wandb_logger if wandb else None,
**config["fit"],
)

# Get predictions
additional_attributes = model.target_labels
assert isinstance(additional_attributes, list) # mypy

results = model.predict_as_dataframe(
validation_dataloader,
additional_attributes=additional_attributes + ["event_no"],
gpus=config["fit"]["gpus"],
)

# Save predictions and model to file
db_name = path.split("/")[-1].split(".")[0]
path = os.path.join(archive, db_name, run_name)
logger.info(f"Writing results to {path}")
os.makedirs(path, exist_ok=True)

results.to_csv(f"{path}/results.csv")

model.save(f"{path}/model.pth")
model.save_state_dict(f"{path}/state_dict.pth")
model.save_config(f"{path}/model_config.yml")


if __name__ == "__main__":

# Parse command-line arguments
parser = ArgumentParser(
description="""
Train GNN model without the use of config files.
"""
)

parser.add_argument(
"--path",
help="Path to dataset file (default: %(default)s)",
default=f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db",
)

parser.add_argument(
"--pulsemap",
help="Name of pulsemap to use (default: %(default)s)",
default="total",
)

parser.add_argument(
"--target",
help=(
"Name of feature to use as regression target (default: "
"%(default)s)"
),
default="total_energy",
)

parser.add_argument(
"--truth-table",
help="Name of truth table to be used (default: %(default)s)",
default="mc_truth",
)

parser.with_standard_arguments(
"gpus",
("max-epochs", 1),
"early-stopping-patience",
("batch-size", 16),
"num-workers",
)

parser.add_argument(
"--wandb",
action="store_true",
help="If True, Weights & Biases are used to track the experiment.",
)

args, unknown = parser.parse_known_args()

main(
args.path,
args.pulsemap,
args.target,
args.truth_table,
args.gpus,
args.max_epochs,
args.early_stopping_patience,
args.batch_size,
args.num_workers,
args.wandb,
)
Loading