Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 27, 2025
1 parent 2674c81 commit 5ff4d2d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
27 changes: 10 additions & 17 deletions src/aiida/tools/dumping/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@

from __future__ import annotations

import os
from collections import defaultdict
import itertools as it
import logging
import os
from pathlib import Path

from aiida import orm
Expand All @@ -36,7 +35,7 @@ def __init__(
group: orm.Group | str | None = None,
deduplicate: bool = True,
output_path: str | Path | None = None,
global_log_dict: dict[str, Path] | None = None
global_log_dict: dict[str, Path] | None = None,
):
self.deduplicate = deduplicate

Expand All @@ -60,11 +59,9 @@ def __init__(
self.log_dict = {}

def _should_dump_processes(self) -> bool:

return len([node for node in self.nodes if isinstance(node, orm.ProcessNode)]) > 0

def _get_nodes(self):

# Get all nodes that are in the group
if self.group is not None:
nodes = list(self.group.nodes)
Expand Down Expand Up @@ -97,7 +94,6 @@ def _get_nodes(self):
return nodes

def _get_processes(self):

nodes = self.nodes
workflows = [node for node in nodes if isinstance(node, orm.WorkflowNode)]

Expand All @@ -121,15 +117,14 @@ def _get_processes(self):
self.log_dict = {
'calculations': {},
# dict.fromkeys([c.uuid for c in self.calculations], None),
'workflows': dict.fromkeys([w.uuid for w in workflows], None)
'workflows': dict.fromkeys([w.uuid for w in workflows], None),
}

def _dump_processes(self):

self._get_processes()

if len(self.workflows) + len(self.calculations) == 0:
logger.report("No workflows or calculations to dump in group.")
logger.report('No workflows or calculations to dump in group.')
return

self.output_path.mkdir(exist_ok=True, parents=True)
Expand All @@ -138,14 +133,13 @@ def _dump_processes(self):
self._dump_workflows()

def _dump_calculations(self):

calculations_path = self.output_path / 'calculations'

for calculation in self.calculations:
calculation_dumper = self.process_dumper

calculation_dump_path = (
calculations_path / calculation_dumper._generate_default_dump_path(process_node=calculation, prefix='')
calculation_dump_path = calculations_path / calculation_dumper._generate_default_dump_path(
process_node=calculation, prefix=''
)

if calculation.caller is None:
Expand All @@ -160,16 +154,15 @@ def _dump_workflows(self):
workflow_path.mkdir(exist_ok=True, parents=True)

for workflow in self.workflows:

workflow_dumper = self.process_dumper

workflow_dump_path = (
workflow_path / workflow_dumper._generate_default_dump_path(process_node=workflow, prefix=None)
workflow_dump_path = workflow_path / workflow_dumper._generate_default_dump_path(
process_node=workflow, prefix=None
)

if self.deduplicate and workflow.uuid in self.global_log_dict["workflows"].keys():
if self.deduplicate and workflow.uuid in self.global_log_dict['workflows'].keys():
os.symlink(
src=self.global_log_dict["workflows"][workflow.uuid],
src=self.global_log_dict['workflows'][workflow.uuid],
dst=workflow_dump_path,
)
else:
Expand Down
7 changes: 2 additions & 5 deletions src/aiida/tools/dumping/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
from __future__ import annotations

import logging
import itertools as it

from rich.pretty import pprint
from pathlib import Path

from collections import Counter
from aiida import orm
from aiida.manage.configuration.profile import Profile
from aiida.tools.dumping.base import BaseDumper
Expand Down Expand Up @@ -52,7 +50,7 @@ def __init__(
self.process_dumper: ProcessDumper = process_dumper

# self.log_dict: dict[dict[str, Path]] = {}
self.log_dict= {'calculations': {}, 'workflows': {}}
self.log_dict = {'calculations': {}, 'workflows': {}}

def dump(self):
if not self.groups:
Expand Down Expand Up @@ -111,4 +109,3 @@ def _dump_processes_per_group(self):

pprint(group_dumper.log_dict)
pprint(self.log_dict)

0 comments on commit 5ff4d2d

Please sign in to comment.