-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GRIT model #777
Add GRIT model #777
Conversation
After some experimenting, I've found that no position encoding works quite well (skipping the RRWP encodings) and drastically reduces the GPU memory requirement. I'll add some options that allow users to do this. |
src/graphnet/models/utils.py
Outdated
|
||
rel_pe = SparseTensor.from_dense(pe, has_value=True) | ||
rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo() | ||
# rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot a comment here
src/graphnet/models/utils.py
Outdated
data: Data, | ||
walk_length: int = 8, | ||
attr_name_abs: str = "rrwp", # name: 'rrwp' | ||
attr_name_rel: str = "rrwp", # name: ('rrwp_idx', 'rrwp_val') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot a comment here
src/graphnet/models/utils.py
Outdated
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) | ||
num_nodes = maybe_num_nodes(edge_index, num_nodes) | ||
source = edge_index[0] | ||
# dest = edge_index[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot a comment here
src/graphnet/models/utils.py
Outdated
|
||
adj = adj.view(size) | ||
_edge_index = adj.nonzero(as_tuple=False).t().contiguous() | ||
# _edge_index, _ = remove_self_loops(_edge_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot a comment here
src/graphnet/models/utils.py
Outdated
add_identity: Add identity matrix to position encoding. | ||
spd: Use shortest path distances. | ||
""" | ||
# device = data.edge_index.device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forgot a comment here
|
||
Args: | ||
in_dim: Dimension of the input tensor. | ||
out_dim: Dimension of theo output tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
theo -> the
if E is not None: | ||
wE = score.flatten(1) | ||
|
||
# Complete attention ccaclculation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ccaclculation -> calculation
out_dim: Dimension of theo output tensor. | ||
num_heads: Number of attention heads. | ||
dropout: Dropout layer probability. | ||
norm: Normalization layer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be informative to mention here that the normalization layer is assumed to be un-instantiated. E.g:
norm: Uninstantiated normalization layer. Must be either BatchNorm1d or BatchNorm1d
|
||
class KNNGraphRWSE(GraphDefinition): | ||
"""KNN Graph with random walk structural encoding.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest expanding the doc string so it's easier to see how this differs from existing representations and, specifically, how the RWSE is accessible. Here's an example:
"""
A KNN graph representation with Random Walk Structural Encoding (RWSE).
Identical to KNNGraph but with an additional field containing RWSE. The encoding can be accessed via
`rwse = graph['rwse']`
"""
src/graphnet/models/graphs/graphs.py
Outdated
|
||
|
||
class KNNGraphRRWP(GraphDefinition): | ||
"""KNN Graph with relative random walk probabilities.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly to the other representation, I would suggest expanding the doc string so it's easier to see how this differs from existing representations and how the new fields are accessible. I think there are five new fields in the RRWP case. So here's an example:
"""
A KNN graph representation with Relative Random Walk Probabilities (RRWP).
Identical to KNNGraph but with five additional fields:
abs_pe = graph['abs_pe']
# Absolute positional encoding
rrwp_index = graph['rrwp_index']
# rrwp index (which is used for slicing the vals?)
rrwp_val = graph['rrwp_val']
# rrwp values
degree= graph['deg']
# Degree of each node (num. of incoming edges)
log_degree = graph['log_deg']
# Equal to torch.log10(graph['deg'] + 1)
"""
src/graphnet/models/graphs/graphs.py
Outdated
return graph | ||
|
||
|
||
class KNNGraphNoPE(GraphDefinition): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering the very high level of similarity between this and KNNGraph
, perhaps it would be beneficial to introduce a new argument to KNNGraph
that would toggle between the vanilla KNNEdges and your new KNNDistanceEdges. For example, one could invent the argument distance_as_edge_features: bool = False
(defaults to KNNEdges
) and use KNNDistanceEdges
if True
.
use_bias: Apply bias the key and value linear layers. | ||
clamp: Clamp the absolute value of the attention scores to a value. | ||
dropout: Dropout layer probability. | ||
activation: Activation function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could mention here that the activation function is assumed to be un-instantiated. E.g.
"""
activation: Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""
rezero: bool = False, | ||
enable_edge_transform: bool = True, | ||
attn_bias: bool = False, | ||
attn_dropout: float = 0.0, # CHECK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot a comment here
if norm_edges | ||
else nn.Identity() | ||
) | ||
else: # TODO: Maybe just set this to nn.Identity. -PW |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think preferred to raise the error if the user passes a non-compatible layer instead of passing warnings and defaulting to identity
"""Forward pass.""" | ||
x = data.x | ||
num_nodes = data.num_nodes | ||
log_deg = get_log_deg(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was there some particular reason why you needed the utility function to grab or calculate this quantity?
Naively I would've thought the degree could've been calculated directly in the forward pass like so:
from torch_geometric.utils import degree
log_deg = torch.log10(degree(data.edge_index[0]) + 1)
Doing it there would save you from needing the utility function and storing the log of the degree in the graph objects during the data loading
norm: nn.Module = nn.BatchNorm1d, | ||
residual: bool = True, | ||
deg_scaler: bool = True, | ||
activation: nn.Module = nn.ReLU, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could mention here that the activation function is assumed to be un-instantiated. E.g.
"""
activation: Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""
norm: Normalization layer. | ||
residual: Apply residual connections. | ||
deg_scaler: Apply degree scaling after MHA. | ||
activation: Activation function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could mention here that the activation function is assumed to be un-instantiated. E.g.
"""
activation: Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""
edge_enhance: bool = True, | ||
update_edges: bool = True, | ||
attn_clamp: float = 5.0, | ||
activation: nn.Module = nn.ReLU, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could mention here that the activation function is assumed to be un-instantiated. E.g.
"""
activation: Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""
add_node_attr_as_self_loop: bool = False, | ||
dropout: float = 0.0, | ||
fill_value: float = 0.0, | ||
norm: nn.Module = nn.BatchNorm1d, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be informative to mention here that the normalization layer is assumed to be un-instantiated. E.g:
norm: Uninstantiated normalization layer. Must be either BatchNorm1d or BatchNorm1d
dim_in: Input dimension. | ||
dim_out: Output dimension. | ||
L: Number of hidden layers. | ||
activation: Activation function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could mention here that the activation function is assumed to be un-instantiated. E.g.
"""
activation: Reference to uninstantiated activation function. E.g. `torch.nn.LeakyReLU`
"""
dim_out: Output dimension. | ||
L: Number of hidden layers. | ||
activation: Activation function. | ||
pooling: Pooling method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could point out the supported methods in the doc string. Perhaps like this:
"""
pooling: Node-wise pooling operation. Either "mean" or "add".
"""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @pweigel, thank you for this very clean contribution! 🎸
I have added superficial comments, mostly about commented-out code and doc strings.
Have you, by chance, tested these different initial representations and their impact on the performance of the GRIT model in a neutrino telescope setting? The original authors appear to favor RRWP over RWSE (I presume they conclude RRPW > RWSE > NoPE), but that may be very problem-dependent.
Hey @RasmusOrsoe, thanks for taking a look. I think I've made all of the requested changes (and fixed a few other minor things). I haven't had the chance to benchmark the different encodings yet, but I plan to. The RRWP encodings are a bit memory-hungry, so I haven't had the chance to fully train a model beyond some tests to show that it works. Without the encodings, I've trained some models that look very good. At some point in the near future, we should consider a better method of adding the different attributes ( |
Thanks! I think your idea of adding a positional encoding as a separate module, and storing its values in a dedicated field in the graph structures, is good! I also think this is the intended usage in PyG (see here). We've not utilized this in the past because the distinction between having the position as a node feature or as a separately accessible graph feature didn't matter much for the GNN applications we've had so far. I think your use-case is a good example of how it can be beneficial! It looks like the GitHub rollout to the new ubuntu version has affected your PR checks. @Aske-Rosted has solved this in the main branch (see #779). Could you merge the main branch into yours, so the checks pass? The checks should pass after that 💪 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
GRIT: "Graph Inductive Biases in Transformers without Message Passing"
This PR includes a new model based on the GRIT transformer. It uses novel methods for encoding graph information for use in sparse multi-head attention blocks. It uses a learned position encoding based on random walk probabilities, which enhances the model's expressivity.
PMLR: https://proceedings.mlr.press/v202/ma23c.html
Paper pre-print: https://arxiv.org/abs/2305.17589
Many layers/functions are adapted from the original repository: https://github.com/LiamMa/GRIT/tree/main. The original code uses
graphgym
to set up most of its modules, so I refactored some things to fit into graphnet. Many of the arguments have been relabeled to be more self-explanatory. In principle, other graph attention mechanisms could be used by replacing the GRIT MHA block.Since there are a lot of changes, I will quickly summarize the significant new additions and modifications to existing files:
KNNEdge
(KNNDistanceEdges
) to include edge values corresponding to the distance between the node pair.KNNGraph
to use theKNNDistanceEdges
and compute the RRWP values.This model has many hyperparameters, but the defaults should provide a good starting point. It should be noted that the GPU memory required to train this model is quite high due to the use of global attention.