Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
UnravelSports [JB] committed Jul 19, 2024
1 parent 9a52d63 commit a282086
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
13 changes: 3 additions & 10 deletions examples/0_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,10 @@
")\n",
"\n",
"# Initialize the Graph Converter, with dataset, labels and settings\n",
"converter = GraphConverter(\n",
" dataset=kloppy_dataset,\n",
" labels=dummy_labels(kloppy_dataset)\n",
")\n",
"converter = GraphConverter(dataset=kloppy_dataset, labels=dummy_labels(kloppy_dataset))\n",
"\n",
"# Compute the graphs and add them to the CustomSpektralDataset\n",
"dataset = CustomSpektralDataset(\n",
" data=converter.to_spektral_graphs()\n",
")"
"dataset = CustomSpektralDataset(data=converter.to_spektral_graphs())"
]
},
{
Expand All @@ -91,9 +86,7 @@
"\n",
"N, F, S, n_out, n = dataset.dimensions()\n",
"\n",
"train, test = dataset.split_test_train(\n",
" split_train=4, split_test=1, random_seed=42\n",
")\n",
"train, test = dataset.split_test_train(split_train=4, split_test=1, random_seed=42)\n",
"\n",
"loader_tr = DisjointLoader(train, batch_size=16, epochs=150)\n",
"loader_te = DisjointLoader(test, batch_size=16, epochs=1, shuffle=False)"
Expand Down
1 change: 1 addition & 0 deletions examples/1_tutorial_graph_converter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@
"source": [
"import pickle\n",
"\n",
"\n",
"def load_pickle(file_path):\n",
" with open(file_path, \"rb\") as file:\n",
" # Deserialize the object from the file\n",
Expand Down
4 changes: 3 additions & 1 deletion unravel/soccer/graphs/objects/custom_spektral_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def split_test_train_validation(
num_test = dataset_length - num_train
num_validation = 0

unique_graph_ids = set([g.get("id") if hasattr(g, "id") else None for g in self])
unique_graph_ids = set(
[g.get("id") if hasattr(g, "id") else None for g in self]
)
if unique_graph_ids == {None}:
by_graph_id = False

Expand Down

0 comments on commit a282086

Please sign in to comment.