Skip to content

Commit

Permalink
Merge pull request #6 from UnravelSports/feat/kloppy-polars
Browse files Browse the repository at this point in the history
⚽ Polars implementation
  • Loading branch information
UnravelSports authored Jan 27, 2025
2 parents 663a024 + 77ab8c2 commit 7c8fc7e
Show file tree
Hide file tree
Showing 31 changed files with 2,934 additions and 323 deletions.
142 changes: 94 additions & 48 deletions examples/1_kloppy_gnn_train.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/2_big_data_bowl_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.11.11"
}
},
"nbformat": 4,
Expand Down
794 changes: 794 additions & 0 deletions examples/deprecated/1_kloppy_gnn_train.ipynb

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions examples/graphs_faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ In section 6.1 we can see what this looks like in Python.
| `max_ball_acceleration` | float | The maximum speed of the ball in yards per second squared. Used for normalizing node features. | 10.0 | 🏈 |
| `attacking_non_qb_node_value` | float | Value for the node feature when player is NOT the QB, but is on the attacking team | 0.1 | 🏈 |
| `chunk_size` | int | Set to determine size of conversions from Polars to Graphs. Preferred setting depends on available computing power | 2_000 | 🏈 |
| `ball_carrier_threshold` | float | The distance threshold to determine the ball carrier in meters. If no ball carrier within ball_carrier_threshold, we skip the frame. | 25.0 ||
| `boundary_correction` | float | A correction factor for boundary calculations, used to correct out of bounds as a percentage (Used as 1+boundary_correction, i.e., 0.05). Not setting this might lead to players outside the pitch markings to have values that fall slightly outside of our normalization range. When we set boundary_correction, any players outside the pitch will be moved to be on the closest line. | None ||
| `infer_ball_ownership` | bool | Infers 'attacking_team' if no 'ball_owning_team' exist (in Kloppy TrackingDataset) by finding the player closest to the ball using ball xyz, uses 'ball_carrier_threshold' as a cut-off. | True ||
| `infer_goalkeepers` | bool | Set True if no GK label is provided, set False for incomplete (broadcast tracking) data that might not have a GK in every frame. | True ||
| `non_potential_receiver_node_value` | float | Value for the node feature when player is NOT a potential receiver of a pass (when on opposing team or in possession of the ball). Should be between 0 and 1 including. | 0.1 ||


Expand All @@ -64,7 +60,7 @@ In section 6.1 we can see what this looks like in Python.
#### C. What features does each Graph have?

<details>
<summary> <b><i> 🌀 ⚽ Expand for a full list of Soccer features </b></i></summary>
<summary> <b><i> 🌀 ⚽ Expand for a full list of Soccer features (note: `SoccerGraphConverter`, `SoccerGraphConverterPolars` has slightly different features) </b></i></summary>

| Variable | Datatype | Index | Features |
|----------|-----------------------------------|-------|---------------------------------------------------------------------------------------------------------------------------------|
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy==1.26.4
spektral==1.2.0
kloppy==3.15.0
kloppy==3.16.0
tensorflow>=2.14.0; platform_machine != 'arm64' or platform_system != 'Darwin'
tensorflow-macos>=2.14.0; platform_machine == 'arm64' and platform_system == 'Darwin'
keras==2.14.0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def read_version():
python_requires="~=3.11",
install_requires=[
"spektral==1.2.0",
"kloppy==3.15.0",
"kloppy==3.16.0",
"tensorflow>=2.14.0;platform_machine != 'arm64' or platform_system != 'Darwin'",
"tensorflow-macos>=2.14.0;platform_machine == 'arm64' and platform_system == 'Darwin'",
"keras==2.14.0",
Expand Down
17 changes: 5 additions & 12 deletions tests/test_bigdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
AmericanFootballGraphConverter,
AmericanFootballPitchDimensions,
)
from unravel.american_football.graphs.dataset import Constant
from unravel.utils import (
add_graph_id_column,
add_dummy_label_column,
flatten_to_reshaped_array,
make_sparse,
CustomSpektralDataset,
Expand Down Expand Up @@ -53,10 +52,8 @@ def dataset(self, coordinates: str, players: str, plays: str):
plays_file_path=plays,
)
bdb_dataset.load()
bdb_dataset.add_graph_ids(by=["gameId", "playId"], column_name="graph_id")
bdb_dataset.add_dummy_labels(
by=["gameId", "playId", "frameId"], column_name="label"
)
bdb_dataset.add_graph_ids(by=["gameId", "playId"])
bdb_dataset.add_dummy_labels(by=["gameId", "playId", "frameId"])
return bdb_dataset

@pytest.fixture
Expand Down Expand Up @@ -141,8 +138,6 @@ def node_feature_values(self):
@pytest.fixture
def arguments(self):
return dict(
label_col="label",
graph_id_col="graph_id",
max_player_speed=8.0,
max_ball_speed=28.0,
max_player_acceleration=10.0,
Expand All @@ -161,8 +156,6 @@ def arguments(self):
@pytest.fixture
def non_default_arguments(self):
return dict(
label_col="label",
graph_id_col="graph_id",
max_player_speed=12.0,
max_ball_speed=24.0,
max_player_acceleration=11.0,
Expand Down Expand Up @@ -199,8 +192,8 @@ def test_settings(self, gnnc_non_default, non_default_arguments):
assert settings.pitch_dimensions.y_dim.min == -26.65
assert settings.pitch_dimensions.end_zone == 50.0

assert settings.ball_id == "football"
assert settings.qb_id == "QB"
assert Constant.BALL == "football"
assert Constant.QB == "QB"
assert settings.max_height == 225.0
assert settings.min_height == 150.0
assert settings.max_weight == 200.0
Expand Down
188 changes: 188 additions & 0 deletions tests/test_kloppy_polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from pathlib import Path
from unravel.soccer import SoccerGraphConverterPolars, KloppyPolarsDataset
from unravel.utils import (
dummy_labels,
dummy_graph_ids,
CustomSpektralDataset,
)

from kloppy import skillcorner
from kloppy.domain import Ground, TrackingDataset, Orientation
from typing import List, Dict

from spektral.data import Graph

import pytest

import numpy as np


class TestKloppyPolarsData:
@pytest.fixture
def match_data(self, base_dir: Path) -> str:
return base_dir / "files" / "skillcorner_match_data.json"

@pytest.fixture
def structured_data(self, base_dir: Path) -> str:
return base_dir / "files" / "skillcorner_structured_data.json.gz"

@pytest.fixture()
def kloppy_dataset(self, match_data: str, structured_data: str) -> TrackingDataset:
return skillcorner.load(
raw_data=structured_data,
meta_data=match_data,
coordinates="tracab",
include_empty_frames=False,
limit=500,
)

@pytest.fixture()
def kloppy_polars_dataset(
self, kloppy_dataset: TrackingDataset
) -> KloppyPolarsDataset:
dataset = KloppyPolarsDataset(
kloppy_dataset=kloppy_dataset,
ball_carrier_threshold=25.0,
)
dataset.load()
dataset.add_dummy_labels(by=["game_id", "frame_id"])
dataset.add_graph_ids(by=["game_id", "frame_id"])
return dataset

@pytest.fixture()
def spc_padding(
self, kloppy_polars_dataset: KloppyPolarsDataset
) -> SoccerGraphConverterPolars:
return SoccerGraphConverterPolars(
dataset=kloppy_polars_dataset,
chunk_size=2_0000,
non_potential_receiver_node_value=0.1,
max_player_speed=12.0,
max_player_acceleration=12.0,
max_ball_speed=13.5,
max_ball_acceleration=100,
self_loop_ball=True,
adjacency_matrix_connect_type="ball",
adjacency_matrix_type="split_by_team",
label_type="binary",
defending_team_node_value=0.0,
random_seed=False,
pad=True,
verbose=False,
)

@pytest.fixture()
def soccer_polars_converter(
self, kloppy_polars_dataset: KloppyPolarsDataset
) -> SoccerGraphConverterPolars:

return SoccerGraphConverterPolars(
dataset=kloppy_polars_dataset,
chunk_size=2_0000,
non_potential_receiver_node_value=0.1,
max_player_speed=12.0,
max_player_acceleration=12.0,
max_ball_speed=13.5,
max_ball_acceleration=100,
self_loop_ball=True,
adjacency_matrix_connect_type="ball",
adjacency_matrix_type="split_by_team",
label_type="binary",
defending_team_node_value=0.0,
random_seed=False,
pad=False,
verbose=False,
)

def test_padding(self, spc_padding: SoccerGraphConverterPolars):
spektral_graphs = spc_padding.to_spektral_graphs()

assert 1 == 1

data = spektral_graphs
assert len(data) == 384
assert isinstance(data[0], Graph)

def test_to_spektral_graph(
self, soccer_polars_converter: SoccerGraphConverterPolars
):
"""
Test navigating (next/prev) through events
"""
spektral_graphs = soccer_polars_converter.to_spektral_graphs()

assert 1 == 1

data = spektral_graphs
assert data[0].id == "2417-1529"
assert len(data) == 489
assert isinstance(data[0], Graph)

x = data[0].x
n_players = x.shape[0]
assert x.shape == (n_players, 15)
assert 0.4524340998288571 == pytest.approx(x[0, 0], abs=1e-5)
assert 0.9948105277764999 == pytest.approx(x[0, 4], abs=1e-5)
assert 0.2941671698429814 == pytest.approx(x[8, 2], abs=1e-5)

e = data[0].e
assert e.shape == (129, 6)
assert 0.0 == pytest.approx(e[0, 0], abs=1e-5)
assert 0.5 == pytest.approx(e[0, 4], abs=1e-5)
assert 0.7140882876637022 == pytest.approx(e[8, 2], abs=1e-5)

a = data[0].a
assert a.shape == (n_players, n_players)
assert 1.0 == pytest.approx(a[0, 0], abs=1e-5)
assert 1.0 == pytest.approx(a[0, 4], abs=1e-5)
assert 0.0 == pytest.approx(a[8, 2], abs=1e-5)

dataset = CustomSpektralDataset(graphs=spektral_graphs)
N, F, S, n_out, n = dataset.dimensions()
assert N == 20
assert F == 15
assert S == 6
assert n_out == 1
assert n == 489

train, test, val = dataset.split_test_train_validation(
split_train=4,
split_test=1,
split_validation=1,
by_graph_id=True,
random_seed=42,
)
assert train.n_graphs == 326
assert test.n_graphs == 81
assert val.n_graphs == 82

train, test, val = dataset.split_test_train_validation(
split_train=4,
split_test=1,
split_validation=1,
by_graph_id=False,
random_seed=42,
)
assert train.n_graphs == 326
assert test.n_graphs == 81
assert val.n_graphs == 82

train, test = dataset.split_test_train(
split_train=4, split_test=1, by_graph_id=False, random_seed=42
)
assert train.n_graphs == 391
assert test.n_graphs == 98

train, test = dataset.split_test_train(
split_train=4, split_test=5, by_graph_id=False, random_seed=42
)
assert train.n_graphs == 217
assert test.n_graphs == 272

with pytest.raises(
NotImplementedError,
match="Make sure split_train > split_test >= split_validation, other behaviour is not supported when by_graph_id is True...",
):
dataset.split_test_train(
split_train=4, split_test=5, by_graph_id=True, random_seed=42
)
8 changes: 2 additions & 6 deletions tests/test_spektral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,8 @@ def bdb_dataset(self, coordinates: str, players: str, plays: str):
plays_file_path=plays,
)
bdb_dataset.load()
bdb_dataset.add_graph_ids(by=["gameId", "playId"], column_name="graph_id")
bdb_dataset.add_dummy_labels(
by=["gameId", "playId", "frameId"], column_name="label"
)
bdb_dataset.add_graph_ids(by=["gameId", "playId"])
bdb_dataset.add_dummy_labels(by=["gameId", "playId", "frameId"])
return bdb_dataset

@pytest.fixture
Expand Down Expand Up @@ -122,8 +120,6 @@ def bdb_converter(
) -> AmericanFootballGraphConverter:
return AmericanFootballGraphConverter(
dataset=bdb_dataset,
label_col="label",
graph_id_col="graph_id",
max_player_speed=8.0,
max_ball_speed=28.0,
max_player_acceleration=10.0,
Expand Down
2 changes: 1 addition & 1 deletion unravel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.0"
__version__ = "0.3.0"

from .soccer import *
from .american_football import *
Expand Down
Loading

0 comments on commit 7c8fc7e

Please sign in to comment.