Skip to content

Commit

Permalink
Merge branch 'main' into lnk-metric
Browse files Browse the repository at this point in the history
  • Loading branch information
DragaDoncila authored Feb 12, 2025
2 parents b0ed0c1 + 31be450 commit 8759618
Show file tree
Hide file tree
Showing 11 changed files with 211 additions and 45 deletions.
39 changes: 38 additions & 1 deletion src/traccuracy/matchers/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Hashable
Expand All @@ -18,6 +19,9 @@ class Matcher(ABC):
on a particular dataset
"""

# Set explicitly only if the matching type is guaranteed by the matcher
_matching_type = None

def compute_mapping(
self, gt_graph: TrackingGraph, pred_graph: TrackingGraph
) -> Matched:
Expand Down Expand Up @@ -71,7 +75,10 @@ def _compute_mapping(
@property
def info(self):
"""Dictionary of Matcher name and any parameters"""
return {"name": self.__class__.__name__, **self.__dict__}
info = {"name": self.__class__.__name__, **self.__dict__}
if self._matching_type:
info["matching type"] = self._matching_type
return info


class Matched:
Expand Down Expand Up @@ -111,6 +118,36 @@ def __init__(
self.gt_pred_map = dict(gt_pred_map)
self.pred_gt_map = dict(pred_gt_map)

self._matching_type = self.matcher_info.get("matching type")

@property
def matching_type(self):
"""Determines the matching type from gt to pred:
one-to-one, one-to-many, many-to-one, many-to-many"""
if len(self.mapping) == 0:
warnings.warn(
"Mapping is empty. Defaulting to type of one-to-one", stacklevel=2
)

if self._matching_type is not None:
return self._matching_type

pred_type = "one"
for matches in self.gt_pred_map.values():
if len(matches) > 1:
pred_type = "many"
break

gt_type = "one"
for matches in self.pred_gt_map.values():
if len(matches) > 1:
gt_type = "many"
break

self._matching_type = f"{gt_type}-to-{pred_type}"
self.matcher_info["matching type"] = self._matching_type
return self._matching_type

def _get_match(self, node: Hashable, map: dict[Hashable, list]):
if node in map:
match = map[node]
Expand Down
3 changes: 3 additions & 0 deletions src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class CTCMatcher(Matcher):
for complete details.
"""

# CTC can return many-to-one or one-to-one
_matching_type = None

def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
"""Run ctc matching
Expand Down
4 changes: 4 additions & 0 deletions src/traccuracy/matchers/_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def __init__(self, iou_threshold=0.6, one_to_one=False):
self.iou_threshold = iou_threshold
self.one_to_one = one_to_one

# If either condition is met, matching must be one to one
if one_to_one or iou_threshold > 0.5:
self._matching_type = "one-to-one"

def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
"""Computes IOU mapping for a set of grpahs
Expand Down
25 changes: 21 additions & 4 deletions src/traccuracy/metrics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,36 @@
from traccuracy.matchers._base import Matched


MATCHING_TYPES = ["one-to-one", "one-to-many", "many-to-one", "many-to-many"]


class Metric(ABC):
"""The base class for Metrics
Data should be passed directly into the compute method
Kwargs should be specified in the constructor
"""

@abstractmethod
def __init__(self, valid_matches: list):
# Check that we have gotten a list of valid match types
if len(valid_matches) == 0:
raise TypeError("New metrics must provide a list of valid matching types")

for mtype in valid_matches:
if mtype not in MATCHING_TYPES:
raise ValueError(
f"Matching type {mtype} is not supported. "
"Choose from {MATCHING_TYPES}."
)

self.valid_match_types = valid_matches

def _validate_matcher(self, matched: Matched) -> bool:
"""Verifies that the matched meets the assumptions of the metric
Returns True if matcher is valid and False if matcher is not valid"""

raise NotImplementedError
if not hasattr(self, "valid_match_types"):
raise AttributeError("Metric subclass does not define valid_match_types")
return matched.matching_type in self.valid_match_types

@abstractmethod
def _compute(self, matched: Matched) -> dict:
Expand Down Expand Up @@ -75,7 +92,7 @@ def compute(self, matched: Matched, override_matcher: bool = False) -> Results:

@property
def info(self):
"""Dictionary with Matcher name and any parameters"""
"""Dictionary with Metric name and any parameters"""
return {"name": self.__class__.__name__, **self.__dict__}


Expand Down
9 changes: 3 additions & 6 deletions src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def __init__(
edge_fn_weight=1,
edge_ws_weight=1,
):
valid_matching_types = ["one-to-one", "many-to-one"]
super().__init__(valid_matching_types)

self.v_weights = {
"ns": vertex_ns_weight,
"fp": vertex_fp_weight,
Expand All @@ -37,12 +40,6 @@ def __init__(
"ws": edge_ws_weight,
}

def _validate_matcher(self, matched: Matched) -> bool:
valid_matchers = {"IOUMatcher", "CTCMatcher"}
name = matched.matcher_info["name"]

return name in valid_matchers

def _compute(self, data: Matched):
evaluate_ctc_events(data)

Expand Down
17 changes: 3 additions & 14 deletions src/traccuracy/metrics/_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,10 @@ class DivisionMetrics(Metric):
"""

def __init__(self, max_frame_buffer=0):
self.frame_buffer = max_frame_buffer

def _validate_matcher(self, matched: Matched) -> bool:
"Matcher must be one to one"
name = matched.matcher_info["name"]
valid = False
valid_matching_types = ["one-to-one"]
super().__init__(valid_matching_types)

if name == "IOUMatcher":
if matched.matcher_info["one_to_one"]:
valid = True
# Threshold of greater than 0.5 ensures one to one
if matched.matcher_info["iou_threshold"] > 0.5:
valid = True

return valid
self.frame_buffer = max_frame_buffer

def _compute(self, data: Matched):
"""Runs `_evaluate_division_events` and calculates summary metrics for each frame buffer
Expand Down
9 changes: 2 additions & 7 deletions src/traccuracy/metrics/_track_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,10 @@ class TrackOverlapMetrics(Metric):
"""

def __init__(self, include_division_edges: bool = True):
valid_match_types = ["many-to-one", "one-to-one"]
super().__init__(valid_match_types)
self.include_division_edges = include_division_edges

def _validate_matcher(self, matched: Matched) -> bool:
"""Supports many to one matching"""
valid_matchers = {"IOUMatcher", "CTCMatcher"}
name = matched.matcher_info["name"]

return name in valid_matchers

def _compute(self, matched: Matched) -> dict:
gt_tracklets = matched.gt_graph.get_tracklets(
include_division_edges=self.include_division_edges
Expand Down
46 changes: 46 additions & 0 deletions tests/matchers/test_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import networkx as nx
import pytest

import tests.examples.graphs as ex_graphs
Expand All @@ -18,6 +19,10 @@ def __init__(self, param="default"):
self.param = param


class DummyMatcherWithType(DummyMatcher):
_matching_type = "one-to-one"


def test_matched_info():
matcher = DummyMatcher()
# Check that matcher info is correctly generated
Expand Down Expand Up @@ -67,3 +72,44 @@ def test_get_matches(self):
assert matched.get_pred_gt_matches(pred) == gt
assert matched.get_gt_pred_matches(gt[0]) == [pred]
assert matched.get_gt_pred_matches(gt[1]) == [pred]

def test_matching_type_cache(self):
# Test caching
matched = ex_graphs.good_matched()
assert matched._matching_type is None
assert matched.matching_type == "one-to-one"
assert matched._matching_type == "one-to-one"

def test_matching_type_from_info(self):
graph = TrackingGraph(nx.DiGraph())

# Matching type can be set from the matcher info
matcher = DummyMatcherWithType()
matched = matcher.compute_mapping(graph, graph)
assert matched._matching_type == "one-to-one"

def test_matching_type(self):
graph = TrackingGraph(nx.DiGraph())

# Test empty mapping
matched = Matched(graph, graph, [], {})
with pytest.raises(
UserWarning, match="Mapping is empty. Defaulting to type of one-to-one"
):
assert matched.matching_type == "one-to-one"

# One to one
matched = ex_graphs.good_matched()
assert matched.matching_type == "one-to-one"

# One to many (with more than 2)
matched = Matched(graph, graph, [(1, 2), (1, 3), (1, 4), (5, 6)], {})
assert matched.matching_type == "one-to-many"

# Many to one
matched = Matched(graph, graph, [(2, 1), (3, 1), (4, 5)], {})
assert matched.matching_type == "many-to-one"

# Many to many
matched = Matched(graph, graph, [(1, 2), (1, 3), (4, 5), (6, 5)], {})
assert matched.matching_type == "many-to-many"
84 changes: 84 additions & 0 deletions tests/metrics/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import networkx as nx
import pytest

from traccuracy import TrackingGraph
from traccuracy.matchers import Matched
from traccuracy.metrics._base import Metric


class TestMetric:
matched = Matched(TrackingGraph(nx.DiGraph()), TrackingGraph(nx.DiGraph()), [], {})

def test_missing_attribute(self):
# Should fail if super init isn't called and subclass init isn't used
# Error doesn't occur until metric._validate_matcher is called
class DummyMetric(Metric):
def __init__(self):
pass

def _compute(self):
pass

metric = DummyMetric()
with pytest.raises(
AttributeError, match="Metric subclass does not define valid_match_types"
):
metric._validate_matcher(self.matched)

def test_empty_list(self):
class DummyMetric(Metric):
def __init__(self):
super().__init__(valid_matches=[])

def _compute(self):
pass

with pytest.raises(
TypeError, match="New metrics must provide a list of valid matching types"
):
DummyMetric()

def test_invalid_option(self):
bad_option = "not-valid"

class DummyMetric(Metric):
def __init__(self):
super().__init__(valid_matches=[bad_option])

def _compute(self):
pass

with pytest.raises(ValueError, match=r"Matching type .* is not supported."):
DummyMetric()

def test_matcher_override(self):
class DummyMetric(Metric):
def __init__(self):
super().__init__(valid_matches=["one-to-one"])

def _compute(self):
return {"success": True}

graph = TrackingGraph(nx.DiGraph())
matched = Matched(
graph, graph, [(1, 2), (1, 3)], {"matching type": "many-to-many"}
)

metric = DummyMetric()

# Fail without override
message = (
"The matched data uses a matcher that does not meet the requirements "
"of the metric. Check the documentation for the metric for more information."
)
with pytest.raises(TypeError, match=message):
metric.compute(matched)

# Override triggers warning
message = (
"Overriding matcher/metric validation may result in "
"unpredictable/incorrect metric results"
)
with pytest.raises(UserWarning, match=message):
results = metric.compute(matched, override_matcher=True)
assert "success" in results.results
5 changes: 0 additions & 5 deletions tests/metrics/test_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ def pred_hela():


def test_iou_div_metrics(gt_hela, pred_hela):
# Fail validation if one-to-one not enabled
iou_matched = IOUMatcher(iou_threshold=0.1).compute_mapping(gt_hela, pred_hela)
with pytest.raises(TypeError):
div_results = DivisionMetrics().compute(iou_matched)

iou_matched = IOUMatcher(iou_threshold=0.1, one_to_one=True).compute_mapping(
gt_hela, pred_hela
)
Expand Down
Loading

1 comment on commit 8759618

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE 31be450 Mean (s) HEAD 8759618 Percent Change
test_load_gt_ctc_data[2d] 5.60525 5.56292 -0.76
test_load_gt_ctc_data[3d] 15.7445 15.3953 -2.22
test_load_pred_ctc_data[2d] 1.06099 1.0541 -0.65
test_ctc_checks[2d] 0.75511 0.74045 -1.94
test_ctc_checks[3d] 9.89331 10.0416 1.5
test_ctc_matcher[2d] 1.47226 1.58136 7.41
test_ctc_matcher[3d] 16.8566 17.25 2.33
test_ctc_metrics[2d] 0.26094 0.27179 4.16
test_ctc_metrics[3d] 4.16021 2.02662 -51.29
test_iou_matcher[2d] 1.69091 1.74861 3.41
test_iou_matcher[3d] 18.0343 18.3222 1.6
test_iou_div_metrics[2d] 0.0054 0.00584 8.11
test_iou_div_metrics[3d] 0.01603 0.01592 -0.71

Please sign in to comment.