Skip to content

Commit

Permalink
removed spektralgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
UnravelSports [JB] committed Jul 22, 2024
1 parent a82025d commit 65281fe
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 70 deletions.
7 changes: 4 additions & 3 deletions tests/test_kloppy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
dummy_labels,
dummy_graph_ids,
CustomSpektralDataset,
SpektralGraph,
GraphFrame,
)

from kloppy import skillcorner
from kloppy.domain import Ground, TrackingDataset, Orientation
from typing import List, Dict

from spektral.data import Graph

import pytest

import numpy as np
Expand Down Expand Up @@ -176,7 +177,7 @@ def test_to_spektral_graph(self, gnnc: GraphConverter):

data = spektral_graphs
assert len(data) == 387
assert isinstance(data[0], SpektralGraph)
assert isinstance(data[0], Graph)
# note: these shape tests fail if we add more features (ie. acceleration)

x = data[0].x
Expand Down Expand Up @@ -265,7 +266,7 @@ def test_to_spektral_graph_padding_random(

data = spektral_graphs
assert len(data) == 387
assert isinstance(data[0], SpektralGraph)
assert isinstance(data[0], Graph)
# note: these shape tests fail if we add more features (ie. acceleration)

x = data[0].x
Expand Down
4 changes: 3 additions & 1 deletion unravel/soccer/graphs/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
SecondSpectrumCoordinateSystem,
)

from spektral.data import Graph

from .exceptions import (
MissingLabelsError,
MissingDatasetError,
Expand Down Expand Up @@ -256,7 +258,7 @@ def to_graph_frames(self) -> dict:

return self.graph_frames

def to_spektral_graphs(self) -> List[SpektralGraph]:
def to_spektral_graphs(self) -> List[Graph]:
if not self.graph_frames:
self.to_graph_frames()

Expand Down
1 change: 0 additions & 1 deletion unravel/utils/objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@
from .default_ball import DefaultBall
from .default_tracking import DefaultTrackingModel
from .custom_spektral_dataset import CustomSpektralDataset
from .spektral_graph import SpektralGraph
from .graph_frame import GraphFrame
from .graph_settings import GraphSettings
10 changes: 5 additions & 5 deletions unravel/utils/objects/custom_spektral_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

from collections.abc import Sequence

from spektral.data import Dataset
from spektral.data import Dataset, Graph

from .graph_frame import GraphFrame, SpektralGraph
from .graph_frame import GraphFrame

from ..exceptions import NoGraphIdsWarning

Expand Down Expand Up @@ -45,11 +45,11 @@ def __init__(self, **kwargs):

super().__init__(**kwargs)

def __convert(self, data) -> List[SpektralGraph]:
def __convert(self, data) -> List[Graph]:
"""
Convert incoming data to correct List[Graph] format
"""
if isinstance(data[0], SpektralGraph):
if isinstance(data[0], Graph):
return data
elif isinstance(data[0], GraphFrame):
return [g.to_spektral_graph() for g in self.data]
Expand All @@ -61,7 +61,7 @@ def __convert(self, data) -> List[SpektralGraph]:
else:
raise NotImplementedError()

def read(self) -> List[SpektralGraph]:
def read(self) -> List[Graph]:
"""
Overriding the read function - to return a list of Graph objects
"""
Expand Down
8 changes: 5 additions & 3 deletions unravel/utils/objects/graph_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from dataclasses import dataclass, field

from warnings import *

from spektral.data import Graph

from ..exceptions import AdjcacenyMatrixTypeNotSet
from ..features import (
AdjacencyMatrixType,
Expand All @@ -18,7 +21,6 @@
make_sparse,
)
from .graph_settings import GraphSettings
from .spektral_graph import SpektralGraph
from .default_tracking import DefaultTrackingModel


Expand Down Expand Up @@ -52,9 +54,9 @@ def __post_init__(self):
if self._quality_check(X, E):
self.graph_data = dict(x=X, a=sparse_A, e=E, y=Y, id=self.graph_id)

def to_spektral_graph(self) -> SpektralGraph:
def to_spektral_graph(self) -> Graph:
if self.graph_data:
return SpektralGraph(
return Graph(
x=self.graph_data["x"],
a=self.graph_data["a"],
e=self.graph_data["e"],
Expand Down
57 changes: 0 additions & 57 deletions unravel/utils/objects/spektral_graph.py

This file was deleted.

0 comments on commit 65281fe

Please sign in to comment.