Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make consistent behavior when denom is 0 and add docstrings #206

Draft
wants to merge 2 commits into
base: lnk-metric
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 51 additions & 7 deletions src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,52 @@ def _compute(self, data: Matched):

return errors

def _get_tra(self, errors, n_nodes, n_edges):
def _get_tra(self, errors: dict[str, int], n_nodes: int, n_edges: int) -> float:
"""Get the TRA score from the error counts and total number of gt nodes and edges

Args:
errors (dict[str, int]): A dictionary containing the AOGM
n_nodes (int): The number of nodes in the ground truth graph
n_edges (int): The number of edges in the ground truth graph

Returns:
float: the TRA score, computed with the CTC metric weights, or np.nan if
the AOGM_0 is 0
"""
aogm_0 = n_nodes * self.v_weights["fn"] + n_edges * self.e_weights["fn"]
if aogm_0 == 0:
raise RuntimeError(
f"AOGM0 is 0 - cannot compute TRA from GT graph with {n_nodes} nodes and"
+ f" {n_edges} edges with {self.v_weights['fn']} vertex FN weight and"
+ f" {self.e_weights['fn']} edge FN weight"
warnings.warn(
UserWarning(
f"AOGM0 is 0 - cannot compute TRA from GT graph with {n_nodes} nodes and"
+ f" {n_edges} edges with {self.v_weights['fn']} vertex FN weight and"
+ f" {self.e_weights['fn']} edge FN weight"
),
stacklevel=1,
)
return np.nan
aogm = errors["AOGM"]
tra = 1 - min(aogm, aogm_0) / aogm_0
return tra

def _get_det(self, errors, n_nodes):
def _get_det(self, errors: dict[str, int], n_nodes: int) -> float:
"""Get the DET score from the error counts and total number of gt nodes

Args:
errors (dict[str, int]): A dictionary containing the counts
of each type of node error (fp_nodes, fn_nodes, ns_nodes)
n_nodes (int): The number of nodes in the ground truth graph

Returns:
float: the DET score, computed with the CTC metric weights, or np.nan
if there are no nodes in the gt graph
"""
if n_nodes == 0:
warnings.warn(
UserWarning("No nodes in the GT graph, cannot compute DET."),
stacklevel=1,
)
return np.nan

aogmd_0 = n_nodes * self.v_weights["fn"]
aogmd = get_weighted_vertex_error_sum(
{
Expand All @@ -136,7 +169,18 @@ def _get_det(self, errors, n_nodes):
det = 1 - min(aogmd, aogmd_0) / aogmd_0
return det

def _get_lnk(self, errors, n_edges):
def _get_lnk(self, errors: dict[str, int], n_edges: int):
"""Get the DET score from the error counts and total number of gt edges

Args:
errors (dict[str, int]): A dictionary containing the counts
of each type of edge error (fp_edges, fn_edges, ws_edges)
n_edges (int): The number of edges in the ground truth graph

Returns:
float: the TRA score, computed with the CTC metric weights, or np.nan if
there are no edges in the GT graph
"""
if n_edges == 0:
warnings.warn(
UserWarning("No edges in the GT graph, cannot compute LNK."),
Expand Down
Loading