diff --git a/alchemiscale/interface/api.py b/alchemiscale/interface/api.py index 8d278531..10c91961 100644 --- a/alchemiscale/interface/api.py +++ b/alchemiscale/interface/api.py @@ -502,6 +502,7 @@ def action_tasks( network_scoped_key, *, tasks: List[ScopedKey] = Body(embed=True), + weight: Optional[Union[float, List[float]]] = Body(None, embed=True), n4js: Neo4jStore = Depends(get_n4js_depends), token: TokenData = Depends(get_token_data_depends), ) -> List[Union[str, None]]: @@ -511,6 +512,23 @@ def action_tasks( taskhub_sk = n4js.get_taskhub(sk) actioned_sks = n4js.action_tasks(tasks, taskhub_sk) + try: + if isinstance(weight, float): + n4js.set_task_weights(tasks, taskhub_sk, weight) + elif isinstance(weight, list): + if len(weight) != len(tasks): + detail = "weight (when in a list) must have the same length as tasks" + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=detail, + ) + + n4js.set_task_weights( + {task: weight for task, weight in zip(tasks, weight)}, taskhub_sk, None + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + return [str(sk) if sk is not None else None for sk in actioned_sks] diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index 98a1c1cd..09b10500 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -651,14 +651,16 @@ def get_transformation_status( return status_counts def action_tasks( - self, tasks: List[ScopedKey], network: ScopedKey + self, + tasks: List[ScopedKey], + network: ScopedKey, + weight: Optional[Union[float, List[float]]] = None, ) -> List[Optional[ScopedKey]]: """Action Tasks for execution via the given AlchemicalNetwork's TaskHub. A Task cannot be actioned: - - to an AlchemicalNetwork in a different Scope. - - if it extends another Task that is not complete. + - to an AlchemicalNetwork in a different Scope - if it has any status other than 'waiting', 'running', or 'error' Parameters @@ -668,16 +670,26 @@ def action_tasks( network The AlchemicalNetwork ScopedKey to action the Tasks for. The Tasks will be added to the network's associated TaskHub. + weight + Weight to be applied to the actioned Tasks. Only values between 0 + and 1 are valid weights. Weights can also be provided as a list of + floats with the same length as `tasks`. + + Setting `weight` to ``None`` will apply the default weight of 0.5 + to newly actioned Tasks, while leaving the weights of any previously + actioned Tasks unchanged. Setting `weight` to anything other than + ``None`` will change the weights of previously actioned Tasks + included in `tasks`. Returns ------- List[Optional[ScopedKey]] ScopedKeys for Tasks actioned, in the same order as given as - `tasks` on input. If a Task couldn't be actioned, then ``None`` will - be returned in its place. + `tasks` on input. If a Task couldn't be actioned, then ``None`` + will be returned in its place. """ - data = dict(tasks=[t.dict() for t in tasks]) + data = dict(tasks=[t.dict() for t in tasks], weight=weight) actioned_sks = self._post_resource(f"/networks/{network}/tasks/action", data) return [ScopedKey.from_str(i) if i is not None else None for i in actioned_sks] diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 9eeb50b2..7c9de471 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -1120,7 +1120,7 @@ def action_tasks( AND task.status IN ['waiting', 'running', 'error'] // create the connection - CREATE (th)-[ar:ACTIONS {{weight: 1.0}}]->(task) + CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task) // set the task property to the scoped key of the Task // this is a convenience for when we have to loop over relationships in Python @@ -1186,6 +1186,9 @@ def set_task_weights( "Cannot set `weight` to a scalar if `tasks` is a dict" ) + if not all([0 <= weight <= 1 for weight in tasks.values()]): + raise ValueError("weights must be between 0 and 1 (inclusive)") + for t, w in tasks.items(): q = f""" MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) @@ -1200,6 +1203,9 @@ def set_task_weights( "Must set `weight` to a scalar if `tasks` is a list" ) + if not 0 <= weight <= 1: + raise ValueError("weight must be between 0 and 1 (inclusive)") + for t in tasks: q = f""" MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index dcc333cb..1434d367 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -653,6 +653,105 @@ def test_action_tasks( assert set(task_sks_e) == set(actioned_sks_e) + @pytest.mark.parametrize( + "weight,shouldfail", + [ + (None, False), + (1.0, False), + ([0.25, 0.5, 0.75], False), + (-1, True), + (1.5, True), + ], + ) + def test_action_tasks_with_weights( + self, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + network_tyk2, + weight, + shouldfail, + ): + n4js = n4js_preloaded + + # select the transformation we want to compute + an = network_tyk2 + transformation = list(an.edges)[0] + + network_sk = user_client.get_scoped_key(an, scope_test) + transformation_sk = user_client.get_scoped_key(transformation, scope_test) + + task_sks = user_client.create_tasks(transformation_sk, count=3) + + if shouldfail: + with pytest.raises(AlchemiscaleClientError): + actioned_sks = user_client.action_tasks( + task_sks, + network_sk, + weight, + ) + else: + actioned_sks = user_client.action_tasks( + task_sks, + network_sk, + weight, + ) + + th_sk = n4js.get_taskhub(network_sk) + task_weights = n4js.get_task_weights(task_sks, th_sk) + + _weight = weight + + if weight is None: + _weight = [0.5] * len(task_sks) + elif not isinstance(weight, list): + _weight = [weight] * len(task_sks) + + assert task_weights == _weight + + # actioning tasks again with None should preserve + # task weights + user_client.action_tasks(task_sks, network_sk, weight=None) + + task_weights = n4js.get_task_weights(task_sks, th_sk) + assert task_weights == _weight + + def test_action_tasks_update_weights( + self, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + network_tyk2, + ): + n4js = n4js_preloaded + + # select the transformation we want to compute + an = network_tyk2 + transformation = list(an.edges)[0] + + network_sk = user_client.get_scoped_key(an, scope_test) + th_sk = n4js.get_taskhub(network_sk) + transformation_sk = user_client.get_scoped_key(transformation, scope_test) + + task_sks = user_client.create_tasks(transformation_sk, count=3) + user_client.action_tasks(task_sks, network_sk) + + # base case + assert [0.5, 0.5, 0.5] == n4js.get_task_weights(task_sks, th_sk) + + new_weights = [1.0, 0.7, 0.4] + user_client.action_tasks(task_sks, network_sk, new_weights) + + assert new_weights == n4js.get_task_weights(task_sks, th_sk) + + # action a couple more tasks along with existing ones, then check weights as expected + new_task_sks = user_client.create_tasks(transformation_sk, count=2) + user_client.action_tasks(task_sks + new_task_sks, network_sk) + + assert new_weights + [0.5] * 2 == n4js.get_task_weights( + task_sks + new_task_sks, th_sk + ) + def test_cancel_tasks( self, scope_test, diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 15561a25..45b307a3 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -940,14 +940,14 @@ def test_get_set_weights(self, n4js: Neo4jStore, network_tyk2, scope_test): task_sks = [n4js.create_task(transformation_sk) for i in range(10)] n4js.action_tasks(task_sks, taskhub_sk) - # weights should all be the default 1.0 + # weights should all be the default 0.5 weights = n4js.get_task_weights(task_sks, taskhub_sk) - assert all([w == 1.0 for w in weights]) + assert all([w == 0.5 for w in weights]) - # set weights on the tasks to be all 10 - n4js.set_task_weights(task_sks, taskhub_sk, weight=10) + # set weights on the tasks to be all 1.0 + n4js.set_task_weights(task_sks, taskhub_sk, weight=1.0) weights = n4js.get_task_weights(task_sks, taskhub_sk) - assert all([w == 10 for w in weights]) + assert all([w == 1.0 for w in weights]) def test_cancel_task(self, n4js, network_tyk2, scope_test): an = network_tyk2 @@ -1257,8 +1257,8 @@ def test_claim_task_byweight(self, n4js: Neo4jStore, network_tyk2, scope_test): # set weights on the tasks to be all 0, disabling them n4js.set_task_weights(task_sks, taskhub_sk, weight=0) - # set the weight of the first task to be 10 - weight_dict = {task_sks[0]: 10} + # set the weight of the first task to be 1 + weight_dict = {task_sks[0]: 1.0} n4js.set_task_weights(weight_dict, taskhub_sk) csid = ComputeServiceID("the best task handler")