Skip to content

Commit

Permalink
Sorting the collection processing order to always handled the mosted …
Browse files Browse the repository at this point in the history
…nested collections first, this allows for more deeply nested objects to exists
  • Loading branch information
yimuchen committed Feb 6, 2025
1 parent 7c551ea commit 1fd75d2
Showing 1 changed file with 24 additions and 47 deletions.
71 changes: 24 additions & 47 deletions src/coffea/nanoevents/schemas/treemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,14 @@ class TreeMakerSchema(BaseSchema):

def __init__(self, base_form, *args, **kwargs):
super().__init__(base_form, *args, **kwargs)
old_style_form = {
k: v for k, v in zip(self._form["fields"], self._form["contents"])
}
old_style_form = {k: v for k, v in zip(self._form["fields"], self._form["contents"])}
output = self._build_collections(old_style_form)
self._form["fields"] = [k for k in output.keys()]
self._form["contents"] = [v for v in output.values()]

def _build_collections(self, branch_forms):
# Turn any special classes into the appropriate awkward form
composite_objects = list(
{k.split("/")[0].rstrip("_") for k in branch_forms if "/" in k}
)
composite_objects = list({k.split("/")[0].rstrip("_") for k in branch_forms if "/" in k})

composite_behavior = { # Dictionary for overriding the default behavior
"Tracks": "LorentzVector"
Expand All @@ -60,7 +56,8 @@ def _build_collections(self, branch_forms):
components = { # Extracting the various composite object names
k.split(".")[-1]: k
for k in branch_forms
if k.startswith(objname + "/") or
if k.startswith(objname + "/")
or
# Second case for skimming
k.startswith(objname + "_/")
}
Expand Down Expand Up @@ -98,18 +95,20 @@ def _build_collections(self, branch_forms):
)
branch_forms[objname] = form
else:
raise ValueError(
f"Unrecognized class with split branches of object {objname}: {components.values()}"
)
raise ValueError(f"Unrecognized class with split branches of object {objname}: {components.values()}")

# Generating collection from branch name
collections = [k for k in branch_forms if "_" in k and not k.startswith("n")]
collections = {
"_".join(k.split("_")[:-1])
for k in collections
if k.split("_")[-1] != "AK8"
# Excluding per-event variables with AK8 variants like Mjj and MT
}
collections = sorted(
{
"_".join(k.split("_")[:-1])
for k in collections
if k.split("_")[-1] != "AK8"
# Excluding per-event variables with AK8 variants like Mjj and MT
},
key=lambda colname: colname.count("_"),
reverse=True,
)

subcollections = []

Expand All @@ -122,9 +121,7 @@ def _build_collections(self, branch_forms):
countitems = [x for x in items if x.endswith("Counts")]
subcols = {x[:-6] for x in countitems} # List of subcollection names
for subcol in subcols:
items = [
k for k in items if not k.startswith(subcol) or k.endswith("Counts")
]
items = [k for k in items if not k.startswith(subcol) or k.endswith("Counts")]
subname = subcol[len(cname) + 1 :]
subcollections.append(
{
Expand All @@ -136,23 +133,17 @@ def _build_collections(self, branch_forms):
)

if cname not in branch_forms:
collection = zip_forms(
{k[len(cname) + 1]: branch_forms.pop(k) for k in items}, cname
)
collection = zip_forms({k[len(cname) + 1]: branch_forms.pop(k) for k in items}, cname)
branch_forms[cname] = collection
else:
collection = branch_forms[cname]
if not collection["class"].startswith("ListOffsetArray"):
print(collection["class"])
raise NotImplementedError(
f"{cname} isn't a jagged array, not sure what to do"
)
raise NotImplementedError(f"{cname} isn't a jagged array, not sure what to do")
for item in items:
Itemname = item[len(cname) + 1 :]
collection["content"]["fields"].append(Itemname)
collection["content"]["contents"].append(
branch_forms.pop(item)["content"]
)
collection["content"]["contents"].append(branch_forms.pop(item)["content"])

for sub in subcollections:
nest_jagged_forms(
Expand Down Expand Up @@ -201,9 +192,7 @@ def _is_compat(a):
if isinstance(t, ak.types.ArrayType):
if isinstance(t._content, ak.types.NumpyType):
return True
if isinstance(t._content, ak.types.ListType) and isinstance(
t._content._content, ak.types.NumpyType
):
if isinstance(t._content, ak.types.ListType) and isinstance(t._content._content, ak.types.NumpyType):
return True
return False

Expand All @@ -221,33 +210,21 @@ def zip_composite(array):
"y": "/.fY",
"z": "/.fZ",
}
return ak.zip(
{
_rename_lookup.get(n, n): _make_packed(array[n])
for n in array.fields
if _is_compat(array[n])
}
)
return ak.zip({_rename_lookup.get(n, n): _make_packed(array[n]) for n in array.fields if _is_compat(array[n])})

# Looping over events structure
out = {}
for bname in events.fields:
if events[bname].fields:
sub_collection = [ # Handing nested structures first
x.replace("Counts", "")
for x in events[bname].fields
if x.endswith("Counts")
x.replace("Counts", "") for x in events[bname].fields if x.endswith("Counts")
]
if sub_collection:
for subname in sub_collection:
if events[bname][subname].fields:
out[f"{bname}_{subname}"] = zip_composite(
ak.flatten(events[bname][subname], axis=-1)
)
out[f"{bname}_{subname}"] = zip_composite(ak.flatten(events[bname][subname], axis=-1))
else:
out[f"{bname}_{subname}"] = _make_packed(
ak.flatten(events[bname][subname], axis=-1)
)
out[f"{bname}_{subname}"] = _make_packed(ak.flatten(events[bname][subname], axis=-1))
out[bname] = zip_composite(events[bname])
else:
out[bname] = _make_packed(events[bname])
Expand Down

0 comments on commit 1fd75d2

Please sign in to comment.