Skip to content

Commit

Permalink
add hash for layer key and skip partitioned nodes in torch conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanhhughes committed Nov 22, 2024
1 parent 5ed3549 commit 61dfc31
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/src/spark_dsg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from spark_dsg._dsg_bindings import *
from spark_dsg._dsg_bindings import (BoundingBoxType, DsgLayers,
DynamicSceneGraph, EdgeAttributes,
LayerView, NodeAttributes,
LayerKey, LayerView, NodeAttributes,
SceneGraphLayer,
compute_ancestor_bounding_box)
from spark_dsg.open3d_visualization import render_to_open3d
Expand Down Expand Up @@ -110,6 +110,12 @@ def _add_metadata_interface(obj):
obj.add_metadata = _add_metadata


def _hash_layerkey(key):
return hash((key.layer, key.partition))


LayerKey.__hash__ = _hash_layerkey

_add_metadata_interface(DynamicSceneGraph)
_add_metadata_interface(NodeAttributes)
_add_metadata_interface(EdgeAttributes)
Expand Down
6 changes: 6 additions & 0 deletions python/src/spark_dsg/torch_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ def scene_graph_to_torch_heterogeneous(
id_map = {}

for node in G.nodes:
if node.layer.partition != 0:
continue

if node.layer not in node_features:
node_features[node.layer] = []
node_positions[node.layer] = []
Expand Down Expand Up @@ -356,6 +359,9 @@ def scene_graph_to_torch_heterogeneous(
for edge in G.edges:
source = G.get_node(edge.source)
target = G.get_node(edge.target)
if source.layer.partition != 0 or target.layer.partition != 0:
continue

edge_type = edge_map[source.layer][target.layer][1]
if edge_type not in edge_indices:
edge_indices[edge_type] = []
Expand Down

0 comments on commit 61dfc31

Please sign in to comment.