Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GRIT model #777

Merged
merged 23 commits into from
Jan 21, 2025
Merged

Add GRIT model #777

merged 23 commits into from
Jan 21, 2025

Conversation

pweigel
Copy link
Collaborator

@pweigel pweigel commented Dec 17, 2024

GRIT: "Graph Inductive Biases in Transformers without Message Passing"

This PR includes a new model based on the GRIT transformer. It uses novel methods for encoding graph information for use in sparse multi-head attention blocks. It uses a learned position encoding based on random walk probabilities, which enhances the model's expressivity.

PMLR: https://proceedings.mlr.press/v202/ma23c.html
Paper pre-print: https://arxiv.org/abs/2305.17589

image

Many layers/functions are adapted from the original repository: https://github.com/LiamMa/GRIT/tree/main. The original code uses graphgym to set up most of its modules, so I refactored some things to fit into graphnet. Many of the arguments have been relabeled to be more self-explanatory. In principle, other graph attention mechanisms could be used by replacing the GRIT MHA block.

Since there are a lot of changes, I will quickly summarize the significant new additions and modifications to existing files:

This model has many hyperparameters, but the defaults should provide a good starting point. It should be noted that the GPU memory required to train this model is quite high due to the use of global attention.

@pweigel
Copy link
Collaborator Author

pweigel commented Dec 24, 2024

After some experimenting, I've found that no position encoding works quite well (skipping the RRWP encodings) and drastically reduces the GPU memory requirement. I'll add some options that allow users to do this.


rel_pe = SparseTensor.from_dense(pe, has_value=True)
rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
# rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot a comment here

data: Data,
walk_length: int = 8,
attr_name_abs: str = "rrwp", # name: 'rrwp'
attr_name_rel: str = "rrwp", # name: ('rrwp_idx', 'rrwp_val')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot a comment here

edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
num_nodes = maybe_num_nodes(edge_index, num_nodes)
source = edge_index[0]
# dest = edge_index[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot a comment here


adj = adj.view(size)
_edge_index = adj.nonzero(as_tuple=False).t().contiguous()
# _edge_index, _ = remove_self_loops(_edge_index)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot a comment here

add_identity: Add identity matrix to position encoding.
spd: Use shortest path distances.
"""
# device = data.edge_index.device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot a comment here


Args:
in_dim: Dimension of the input tensor.
out_dim: Dimension of theo output tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theo -> the

if E is not None:
wE = score.flatten(1)

# Complete attention ccaclculation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ccaclculation -> calculation

out_dim: Dimension of theo output tensor.
num_heads: Number of attention heads.
dropout: Dropout layer probability.
norm: Normalization layer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be informative to mention here that the normalization layer is assumed to be un-instantiated. E.g:

norm: Uninstantiated normalization layer. Must be either BatchNorm1d or BatchNorm1d


class KNNGraphRWSE(GraphDefinition):
"""KNN Graph with random walk structural encoding."""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest expanding the doc string so it's easier to see how this differs from existing representations and, specifically, how the RWSE is accessible. Here's an example:

"""
A KNN graph representation with Random Walk Structural Encoding (RWSE).

Identical to KNNGraph but with an additional field containing RWSE. The encoding can be accessed via 

`rwse = graph['rwse']`
"""



class KNNGraphRRWP(GraphDefinition):
"""KNN Graph with relative random walk probabilities."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly to the other representation, I would suggest expanding the doc string so it's easier to see how this differs from existing representations and how the new fields are accessible. I think there are five new fields in the RRWP case. So here's an example:

"""
A KNN graph representation with Relative Random Walk Probabilities (RRWP).

Identical to KNNGraph but with five additional fields:

abs_pe = graph['abs_pe'] # Absolute positional encoding
rrwp_index = graph['rrwp_index'] # rrwp index (which is used for slicing the vals?)
rrwp_val = graph['rrwp_val'] # rrwp values
degree= graph['deg'] # Degree of each node (num. of incoming edges)
log_degree = graph['log_deg'] # Equal to torch.log10(graph['deg'] + 1)
"""

return graph


class KNNGraphNoPE(GraphDefinition):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering the very high level of similarity between this and KNNGraph, perhaps it would be beneficial to introduce a new argument to KNNGraph that would toggle between the vanilla KNNEdges and your new KNNDistanceEdges. For example, one could invent the argument distance_as_edge_features: bool = False (defaults to KNNEdges) and use KNNDistanceEdges if True.

use_bias: Apply bias the key and value linear layers.
clamp: Clamp the absolute value of the attention scores to a value.
dropout: Dropout layer probability.
activation: Activation function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could mention here that the activation function is assumed to be un-instantiated. E.g.

"""
activation:  Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""

rezero: bool = False,
enable_edge_transform: bool = True,
attn_bias: bool = False,
attn_dropout: float = 0.0, # CHECK
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot a comment here

if norm_edges
else nn.Identity()
)
else: # TODO: Maybe just set this to nn.Identity. -PW
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think preferred to raise the error if the user passes a non-compatible layer instead of passing warnings and defaulting to identity

"""Forward pass."""
x = data.x
num_nodes = data.num_nodes
log_deg = get_log_deg(data)
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there some particular reason why you needed the utility function to grab or calculate this quantity?

Naively I would've thought the degree could've been calculated directly in the forward pass like so:

from torch_geometric.utils import degree
log_deg = torch.log10(degree(data.edge_index[0]) + 1)

Doing it there would save you from needing the utility function and storing the log of the degree in the graph objects during the data loading

norm: nn.Module = nn.BatchNorm1d,
residual: bool = True,
deg_scaler: bool = True,
activation: nn.Module = nn.ReLU,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could mention here that the activation function is assumed to be un-instantiated. E.g.

"""
activation:  Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""

norm: Normalization layer.
residual: Apply residual connections.
deg_scaler: Apply degree scaling after MHA.
activation: Activation function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could mention here that the activation function is assumed to be un-instantiated. E.g.

"""
activation:  Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""

edge_enhance: bool = True,
update_edges: bool = True,
attn_clamp: float = 5.0,
activation: nn.Module = nn.ReLU,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could mention here that the activation function is assumed to be un-instantiated. E.g.

"""
activation:  Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""

add_node_attr_as_self_loop: bool = False,
dropout: float = 0.0,
fill_value: float = 0.0,
norm: nn.Module = nn.BatchNorm1d,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be informative to mention here that the normalization layer is assumed to be un-instantiated. E.g:

norm: Uninstantiated normalization layer. Must be either BatchNorm1d or BatchNorm1d

dim_in: Input dimension.
dim_out: Output dimension.
L: Number of hidden layers.
activation: Activation function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could mention here that the activation function is assumed to be un-instantiated. E.g.

"""
activation:  Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""

dim_out: Output dimension.
L: Number of hidden layers.
activation: Activation function.
pooling: Pooling method.
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could point out the supported methods in the doc string. Perhaps like this:

"""
pooling: Node-wise pooling operation. Either "mean" or "add".
"""

Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pweigel, thank you for this very clean contribution! 🎸

I have added superficial comments, mostly about commented-out code and doc strings.

Have you, by chance, tested these different initial representations and their impact on the performance of the GRIT model in a neutrino telescope setting? The original authors appear to favor RRWP over RWSE (I presume they conclude RRPW > RWSE > NoPE), but that may be very problem-dependent.

@pweigel
Copy link
Collaborator Author

pweigel commented Jan 10, 2025

Hey @RasmusOrsoe, thanks for taking a look. I think I've made all of the requested changes (and fixed a few other minor things). I haven't had the chance to benchmark the different encodings yet, but I plan to. The RRWP encodings are a bit memory-hungry, so I haven't had the chance to fully train a model beyond some tests to show that it works. Without the encodings, I've trained some models that look very good.

At some point in the near future, we should consider a better method of adding the different attributes (graph.encoding) in a modular fashion that doesn't require a new graph object. It could even be added as a part of GraphDefinition, where you pass some pos_encoding=MyGraphPosEncoding() and it uses the values/indices from the edge and node definitions to add the new attributes. It might be beyond the scope of this PR, but it would definitely be an enhancement.

@RasmusOrsoe
Copy link
Collaborator

Hey @RasmusOrsoe, thanks for taking a look. I think I've made all of the requested changes (and fixed a few other minor things). I haven't had the chance to benchmark the different encodings yet, but I plan to. The RRWP encodings are a bit memory-hungry, so I haven't had the chance to fully train a model beyond some tests to show that it works. Without the encodings, I've trained some models that look very good.

At some point in the near future, we should consider a better method of adding the different attributes (graph.encoding) in a modular fashion that doesn't require a new graph object. It could even be added as a part of GraphDefinition, where you pass some pos_encoding=MyGraphPosEncoding() and it uses the values/indices from the edge and node definitions to add the new attributes. It might be beyond the scope of this PR, but it would definitely be an enhancement.

Thanks! I think your idea of adding a positional encoding as a separate module, and storing its values in a dedicated field in the graph structures, is good! I also think this is the intended usage in PyG (see here). We've not utilized this in the past because the distinction between having the position as a node feature or as a separately accessible graph feature didn't matter much for the GNN applications we've had so far. I think your use-case is a good example of how it can be beneficial!

It looks like the GitHub rollout to the new ubuntu version has affected your PR checks. @Aske-Rosted has solved this in the main branch (see #779).

Could you merge the main branch into yours, so the checks pass? The checks should pass after that 💪

@RasmusOrsoe RasmusOrsoe self-requested a review January 21, 2025 11:29
Copy link
Collaborator

@RasmusOrsoe RasmusOrsoe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

@pweigel pweigel merged commit 79d7baf into graphnet-team:main Jan 21, 2025
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants