Skip to content

Commit

Permalink
Fix Zeno visualizer on tasks like GSM8k (#2599)
Browse files Browse the repository at this point in the history
* fix(zeno): Generate unique ids in case of multiple filters

* fix(zeno): Report even non-aggregable metrics, just not as metrics

* pre-commit

---------

Co-authored-by: Baber <[email protected]>
  • Loading branch information
pasky and baberabb authored Jan 7, 2025
1 parent 16cfe46 commit 6d62a69
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions scripts/zeno_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ def main():
if model_index == 0: # Only need to assemble data for the first model
metrics = []
for metric in config["metric_list"]:
metrics.append(
ZenoMetric(
name=metric["metric"],
type="mean",
columns=[metric["metric"]],
if metric.get("aggregation") == "mean":
metrics.append(
ZenoMetric(
name=metric["metric"],
type="mean",
columns=[metric["metric"]],
)
)
)
project = client.create_project(
name=args.project_name + (f"_{task}" if len(tasks) > 1 else ""),
view="text-classification",
Expand Down Expand Up @@ -168,7 +169,11 @@ def generate_dataset(
Returns:
pd.Dataframe: A dataframe that is ready to be uploaded to Zeno.
"""
ids = [x["doc_id"] for x in data]
ids = (
[x["doc_id"] for x in data]
if not config.get("filter_list")
else [f"{x['doc_id']}.{x['filter']}" for x in data]
)
labels = [x["target"] for x in data]
instance = [""] * len(ids)

Expand All @@ -190,6 +195,7 @@ def generate_dataset(
return pd.DataFrame(
{
"id": ids,
"doc_id": [x["doc_id"] for x in data],
"data": instance,
"input_len": [len(x) for x in instance],
"labels": labels,
Expand All @@ -208,8 +214,15 @@ def generate_system_df(data, config):
Returns:
pd.Dataframe: A dataframe that is ready to be uploaded to Zeno as a system.
"""
ids = [x["doc_id"] for x in data]
ids = (
[x["doc_id"] for x in data]
if not config.get("filter_list")
else [f"{x['doc_id']}.{x['filter']}" for x in data]
)
system_dict = {"id": ids}
system_dict["doc_id"] = [x["doc_id"] for x in data]
if config.get("filter_list"):
system_dict["filter"] = [x["filter"] for x in data]
system_dict["output"] = [""] * len(ids)

if config["output_type"] == "loglikelihood":
Expand All @@ -228,11 +241,10 @@ def generate_system_df(data, config):
system_dict["output"] = [str(x["filtered_resps"][0]) for x in data]
system_dict["output_length"] = [len(str(x["filtered_resps"][0])) for x in data]

metrics = {}
for metric in config["metric_list"]:
if "aggregation" in metric and metric["aggregation"] == "mean":
metrics[metric["metric"]] = [x[metric["metric"]] for x in data]

metrics = {
metric["metric"]: [x[metric["metric"]] for x in data]
for metric in config["metric_list"]
}
system_dict.update(metrics)
system_df = pd.DataFrame(system_dict)
return system_df
Expand Down

0 comments on commit 6d62a69

Please sign in to comment.