Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Feb 10, 2025
1 parent f8cbf39 commit 34e6667
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/traccuracy/track_errors/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _classify_nodes(matched: Matched):
gt_graph = matched.gt_graph

if pred_graph.node_errors and gt_graph.node_errors:
logger.info("Node errors already calculated. Skipping graph annotation")
logger.warning("Node errors already calculated. Skipping graph annotation")
return

# Label as TP if the node is matched
Expand Down Expand Up @@ -72,7 +72,7 @@ def _classify_edges(matched: Matched):
gt_graph = matched.gt_graph

if pred_graph.edge_errors and gt_graph.edge_errors:
logger.info("Edge errors already calculated. Skipping graph annotation")
logger.warning("Edge errors already calculated. Skipping graph annotation")
return

# Node errors are needed for edge annotation
Expand Down
16 changes: 14 additions & 2 deletions tests/track_errors/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_empty_pred(self):
for attrs in matched.gt_graph.nodes.values():
assert NodeFlag.FALSE_NEG in attrs

def test_good_match(self):
def test_good_match(self, caplog):
matched = ex_graphs.good_matched()
_classify_nodes(matched)

Expand All @@ -32,6 +32,12 @@ def test_good_match(self):
for attrs in graph.nodes.values():
assert NodeFlag.TRUE_POS in attrs

# Check that it doesn't run a second time
_classify_nodes(matched)
assert (
"Node errors already calculated. Skipping graph annotation" in caplog.text
)

@pytest.mark.parametrize("t", [0, 1, 2])
def test_fn_node(self, t):
wrong_node = [1, 2, 3][t]
Expand Down Expand Up @@ -130,14 +136,20 @@ def test_empty_pred(self):
for attrs in matched.gt_graph.edges.values():
assert EdgeFlag.FALSE_NEG in attrs

def test_good_match(self):
def test_good_match(self, caplog):
matched = ex_graphs.good_matched()
_classify_edges(matched)

for graph in [matched.gt_graph, matched.pred_graph]:
for attrs in graph.edges.values():
assert EdgeFlag.TRUE_POS in attrs

# Check that it doesn't run a second time
_classify_edges(matched)
assert (
"Edge errors already calculated. Skipping graph annotation" in caplog.text
)

def test_fn_node_end(self):
matched = ex_graphs.fn_node_matched(0)
_classify_edges(matched)
Expand Down

1 comment on commit 34e6667

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE 31be450 Mean (s) HEAD 34e6667 Percent Change
test_load_gt_ctc_data[2d] 5.38999 5.465 1.39
test_load_gt_ctc_data[3d] 15.1941 16.3014 7.29
test_load_pred_ctc_data[2d] 1.05216 1.02471 -2.61
test_ctc_checks[2d] 0.73035 0.73815 1.07
test_ctc_checks[3d] 9.63441 9.80326 1.75
test_ctc_matcher[2d] 1.47277 1.48982 1.16
test_ctc_matcher[3d] 16.8753 17.0508 1.04
test_ctc_metrics[2d] 0.25438 0.26742 5.12
test_ctc_metrics[3d] 4.76072 4.18882 -12.01
test_iou_matcher[2d] 1.66632 1.68606 1.18
test_iou_matcher[3d] 17.8547 18.1741 1.79
test_iou_div_metrics[2d] 0.00529 0.00545 3.11
test_iou_div_metrics[3d] 0.01549 0.01556 0.41

Please sign in to comment.