Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight kwarg to AlchemiscaleClient.action_tasks method #209

Merged
merged 9 commits into from
Dec 15, 2023
18 changes: 18 additions & 0 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ def action_tasks(
network_scoped_key,
*,
tasks: List[ScopedKey] = Body(embed=True),
weight: Optional[Union[float, List[float]]] = Body(None, embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
) -> List[Union[str, None]]:
Expand All @@ -511,6 +512,23 @@ def action_tasks(
taskhub_sk = n4js.get_taskhub(sk)
actioned_sks = n4js.action_tasks(tasks, taskhub_sk)

try:
if isinstance(weight, float):
n4js.set_task_weights(tasks, taskhub_sk, weight)
elif isinstance(weight, list):
if len(weight) != len(tasks):
detail = "weight (when in a list) must have the same length as tasks"
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=detail,
)

n4js.set_task_weights(
{task: weight for task, weight in zip(tasks, weight)}, taskhub_sk, None
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))

return [str(sk) if sk is not None else None for sk in actioned_sks]


Expand Down
24 changes: 18 additions & 6 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,14 +651,16 @@ def get_transformation_status(
return status_counts

def action_tasks(
self, tasks: List[ScopedKey], network: ScopedKey
self,
tasks: List[ScopedKey],
network: ScopedKey,
weight: Optional[Union[float, List[float]]] = None,
) -> List[Optional[ScopedKey]]:
"""Action Tasks for execution via the given AlchemicalNetwork's
TaskHub.

A Task cannot be actioned:
- to an AlchemicalNetwork in a different Scope.
- if it extends another Task that is not complete.
- to an AlchemicalNetwork in a different Scope
- if it has any status other than 'waiting', 'running', or 'error'

Parameters
Expand All @@ -668,16 +670,26 @@ def action_tasks(
network
The AlchemicalNetwork ScopedKey to action the Tasks for.
The Tasks will be added to the network's associated TaskHub.
weight
Weight to be applied to the actioned Tasks. Only values between 0
and 1 are valid weights. Weights can also be provided as a list of
floats with the same length as `tasks`.

Setting `weight` to ``None`` will apply the default weight of 0.5
to newly actioned Tasks, while leaving the weights of any previously
actioned Tasks unchanged. Setting `weight` to anything other than
``None`` will change the weights of previously actioned Tasks
included in `tasks`.

Returns
-------
List[Optional[ScopedKey]]
ScopedKeys for Tasks actioned, in the same order as given as
`tasks` on input. If a Task couldn't be actioned, then ``None`` will
be returned in its place.
`tasks` on input. If a Task couldn't be actioned, then ``None``
will be returned in its place.

"""
data = dict(tasks=[t.dict() for t in tasks])
data = dict(tasks=[t.dict() for t in tasks], weight=weight)
actioned_sks = self._post_resource(f"/networks/{network}/tasks/action", data)

return [ScopedKey.from_str(i) if i is not None else None for i in actioned_sks]
Expand Down
8 changes: 7 additions & 1 deletion alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def action_tasks(
AND task.status IN ['waiting', 'running', 'error']

// create the connection
CREATE (th)-[ar:ACTIONS {{weight: 1.0}}]->(task)
CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task)

// set the task property to the scoped key of the Task
// this is a convenience for when we have to loop over relationships in Python
Expand Down Expand Up @@ -1186,6 +1186,9 @@ def set_task_weights(
"Cannot set `weight` to a scalar if `tasks` is a dict"
)

if not all([0 <= weight <= 1 for weight in tasks.values()]):
raise ValueError("weights must be between 0 and 1 (inclusive)")

for t, w in tasks.items():
q = f"""
MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}})
Expand All @@ -1200,6 +1203,9 @@ def set_task_weights(
"Must set `weight` to a scalar if `tasks` is a list"
)

if not 0 <= weight <= 1:
raise ValueError("weight must be between 0 and 1 (inclusive)")

for t in tasks:
q = f"""
MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}})
Expand Down
49 changes: 49 additions & 0 deletions alchemiscale/tests/integration/interface/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,55 @@ def test_action_tasks(

assert set(task_sks_e) == set(actioned_sks_e)

@pytest.mark.parametrize(
"weight,shouldfail",
[
(None, False),
(1.0, False),
([1.0], False),
(-1, True),
(1.5, True),
],
)
def test_action_tasks_with_weights(
self,
scope_test,
n4js_preloaded,
user_client: client.AlchemiscaleClient,
network_tyk2,
weight,
shouldfail,
):
n4js = n4js_preloaded

# select the transformation we want to compute
an = network_tyk2
transformation = list(an.edges)[0]

network_sk = user_client.get_scoped_key(an, scope_test)
transformation_sk = user_client.get_scoped_key(transformation, scope_test)

task_sks = user_client.create_tasks(transformation_sk, count=3)

if isinstance(weight, list):
weight = weight * len(task_sks)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will create a list of all the same weights, which isn't bad, but I think it would be more thorough to test if we set a list of different weights that this works as expected. You can use n4js_preloaded.get_task_weights to get at this information.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reworked how this is done


# action these task for this network, in reverse order

if shouldfail:
with pytest.raises(AlchemiscaleClientError):
actioned_sks = user_client.action_tasks(
task_sks,
network_sk,
weight,
)
else:
actioned_sks = user_client.action_tasks(
task_sks,
network_sk,
weight,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test that calls action_tasks with weight None and checks with n4j_preloaded.get_task_weights that the weights are unchanged?

Can you also check that if you call action_tasks on already-actioned Tasks with new weights that these get set appropriately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def test_cancel_tasks(
self,
scope_test,
Expand Down
14 changes: 7 additions & 7 deletions alchemiscale/tests/integration/storage/test_statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,14 +940,14 @@ def test_get_set_weights(self, n4js: Neo4jStore, network_tyk2, scope_test):
task_sks = [n4js.create_task(transformation_sk) for i in range(10)]
n4js.action_tasks(task_sks, taskhub_sk)

# weights should all be the default 1.0
# weights should all be the default 0.5
weights = n4js.get_task_weights(task_sks, taskhub_sk)
assert all([w == 1.0 for w in weights])
assert all([w == 0.5 for w in weights])

# set weights on the tasks to be all 10
n4js.set_task_weights(task_sks, taskhub_sk, weight=10)
# set weights on the tasks to be all 1.0
n4js.set_task_weights(task_sks, taskhub_sk, weight=1.0)
weights = n4js.get_task_weights(task_sks, taskhub_sk)
assert all([w == 10 for w in weights])
assert all([w == 1.0 for w in weights])

def test_cancel_task(self, n4js, network_tyk2, scope_test):
an = network_tyk2
Expand Down Expand Up @@ -1257,8 +1257,8 @@ def test_claim_task_byweight(self, n4js: Neo4jStore, network_tyk2, scope_test):
# set weights on the tasks to be all 0, disabling them
n4js.set_task_weights(task_sks, taskhub_sk, weight=0)

# set the weight of the first task to be 10
weight_dict = {task_sks[0]: 10}
# set the weight of the first task to be 1
weight_dict = {task_sks[0]: 1.0}
n4js.set_task_weights(weight_dict, taskhub_sk)

csid = ComputeServiceID("the best task handler")
Expand Down
Loading