diff --git a/bazel/ray.bzl b/bazel/ray.bzl index 107e5f46f750d..5c9712f0a4350 100644 --- a/bazel/ray.bzl +++ b/bazel/ray.bzl @@ -1,6 +1,6 @@ -load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public") -load("@bazel_skylib//rules:copy_file.bzl", "copy_file") load("@bazel_common//tools/maven:pom_file.bzl", "pom_file") +load("@bazel_skylib//rules:copy_file.bzl", "copy_file") +load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_library_public") load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") COPTS_WITHOUT_LOG = select({ @@ -14,6 +14,7 @@ COPTS_WITHOUT_LOG = select({ "//conditions:default": [ "-Wunused-result", "-Wconversion-null", + "-Wmisleading-indentation", ], }) + select({ "//:clang-cl": [ diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 11c0969ad1f27..389e236614fa8 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -243,6 +243,7 @@ py_test_module_list( "test_basic_2.py", "test_basic_4.py", "test_basic_5.py", + "test_wait.py", ], size = "medium", tags = ["exclusive", "minimal", "basic_test", "team:core"], @@ -732,6 +733,7 @@ py_test_module_list( "test_basic_3.py", "test_basic_4.py", "test_basic_5.py", + "test_wait.py", "test_multiprocessing.py", "test_list_actors.py", "test_list_actors_2.py", diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index b146c7043008a..fe4a0da9a23ae 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -657,22 +657,6 @@ def test_put_get(shutdown_only): assert value_before == value_after -def test_wait_timing(shutdown_only): - ray.init(num_cpus=2) - - @ray.remote - def f(): - time.sleep(1) - - future = f.remote() - - start = time.time() - ready, not_ready = ray.wait([future], timeout=0.2) - assert 0.2 < time.time() - start < 0.3 - assert len(ready) == 0 - assert len(not_ready) == 1 - - @pytest.mark.skipif(client_test_enabled(), reason="internal _raylet") def test_function_descriptor(): python_descriptor = ray._raylet.PythonFunctionDescriptor( diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py index 2f3eaad85fc34..31afaab64e1cc 100644 --- a/python/ray/tests/test_basic_2.py +++ b/python/ray/tests/test_basic_2.py @@ -592,52 +592,6 @@ def call(actor): assert ray.get(actor.get_num_threads.remote()) <= CONCURRENCY -def test_wait(ray_start_regular_shared): - @ray.remote - def f(delay): - time.sleep(delay) - return - - object_refs = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)] - ready_ids, remaining_ids = ray.wait(object_refs) - assert len(ready_ids) == 1 - assert len(remaining_ids) == 3 - ready_ids, remaining_ids = ray.wait(object_refs, num_returns=4) - assert set(ready_ids) == set(object_refs) - assert remaining_ids == [] - - object_refs = [f.remote(0), f.remote(5)] - ready_ids, remaining_ids = ray.wait(object_refs, timeout=0.5, num_returns=2) - assert len(ready_ids) == 1 - assert len(remaining_ids) == 1 - - # Verify that calling wait with duplicate object refs throws an - # exception. - x = ray.put(1) - with pytest.raises(Exception): - ray.wait([x, x]) - - # Make sure it is possible to call wait with an empty list. - ready_ids, remaining_ids = ray.wait([]) - assert ready_ids == [] - assert remaining_ids == [] - - # Test semantics of num_returns with no timeout. - obj_refs = [ray.put(i) for i in range(10)] - (found, rest) = ray.wait(obj_refs, num_returns=2) - assert len(found) == 2 - assert len(rest) == 8 - - # Verify that incorrect usage raises a TypeError. - x = ray.put(1) - with pytest.raises(TypeError): - ray.wait(x) - with pytest.raises(TypeError): - ray.wait(1) - with pytest.raises(TypeError): - ray.wait([1]) - - def test_duplicate_args(ray_start_regular_shared): @ray.remote def f(arg1, arg2, arg1_duplicate, kwarg1=None, kwarg2=None, kwarg1_duplicate=None): diff --git a/python/ray/tests/test_wait.py b/python/ray/tests/test_wait.py new file mode 100644 index 0000000000000..659f7b29c69c2 --- /dev/null +++ b/python/ray/tests/test_wait.py @@ -0,0 +1,119 @@ +# coding: utf-8 + +import pytest +import numpy as np +import time +import logging +import sys +import os + +from ray._private.test_utils import client_test_enabled + + +if client_test_enabled(): + from ray.util.client import ray +else: + import ray + +logger = logging.getLogger(__name__) + + +def test_wait(ray_start_regular): + @ray.remote + def f(delay): + time.sleep(delay) + return + + object_refs = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)] + ready_ids, remaining_ids = ray.wait(object_refs) + assert len(ready_ids) == 1 + assert len(remaining_ids) == 3 + ready_ids, remaining_ids = ray.wait(object_refs, num_returns=4) + assert set(ready_ids) == set(object_refs) + assert remaining_ids == [] + + object_refs = [f.remote(0), f.remote(5)] + ready_ids, remaining_ids = ray.wait(object_refs, timeout=0.5, num_returns=2) + assert len(ready_ids) == 1 + assert len(remaining_ids) == 1 + + # Verify that calling wait with duplicate object refs throws an + # exception. + x = ray.put(1) + with pytest.raises(Exception): + ray.wait([x, x]) + + # Make sure it is possible to call wait with an empty list. + ready_ids, remaining_ids = ray.wait([]) + assert ready_ids == [] + assert remaining_ids == [] + + # Test semantics of num_returns with no timeout. + obj_refs = [ray.put(i) for i in range(10)] + (found, rest) = ray.wait(obj_refs, num_returns=2) + assert len(found) == 2 + assert len(rest) == 8 + + # Verify that incorrect usage raises a TypeError. + x = ray.put(1) + with pytest.raises(TypeError): + ray.wait(x) + with pytest.raises(TypeError): + ray.wait(1) + with pytest.raises(TypeError): + ray.wait([1]) + + +def test_wait_timing(ray_start_2_cpus): + @ray.remote + def f(): + time.sleep(1) + + future = f.remote() + + start = time.time() + ready, not_ready = ray.wait([future], timeout=0.2) + assert 0.2 < time.time() - start < 0.3 + assert len(ready) == 0 + assert len(not_ready) == 1 + + +def test_wait_always_fetch_local(ray_start_cluster): + cluster = ray_start_cluster + cluster.add_node(num_cpus=0, object_store_memory=500e6) # head node + ray.init(address=cluster.address) + worker_node = cluster.add_node(num_cpus=1, object_store_memory=80e6) + + @ray.remote(num_cpus=1) + def return_large_object(): + # 100mb so will spill on worker, but not once on head + return np.zeros(100 * 1024 * 1024, dtype=np.uint8) + + @ray.remote(num_cpus=0) + def small_local_task(): + return 1 + + put_on_worker = ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + worker_node.node_id, soft=False + ) + x = small_local_task.remote() + y = return_large_object.options(scheduling_strategy=put_on_worker).remote() + z = return_large_object.options(scheduling_strategy=put_on_worker).remote() + # even though x will be found in local, requests should be made + # to start pulling y and z + ray.wait([x, y, z], num_returns=1, fetch_local=True) + time.sleep(3) + + start_time = time.perf_counter() + ray.get([y, z]) + # y and z should be immediately available as pull requests should've + # been made immediately on the ray.wait call + time_to_get = time.perf_counter() - start_time + assert time_to_get < 0.2 + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 1a9142efdf706..909851115d6d2 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2095,7 +2095,10 @@ Status CoreWorker::Wait(const std::vector &ids, if (fetch_local) { RetryObjectInPlasmaErrors( memory_store_, worker_context_, memory_object_ids, plasma_object_ids, ready); - if (static_cast(ready.size()) < num_objects && !plasma_object_ids.empty()) { + // We make the request to the plasma store even if we have num_objects ready since we + // want to at least make the request to pull these objects if the user specified + // fetch_local so the pulling can start. + if (!plasma_object_ids.empty()) { RAY_RETURN_NOT_OK(plasma_store_provider_->Wait( plasma_object_ids, std::min(static_cast(plasma_object_ids.size()), diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index db1da21f1fcc5..e3a6f108a6be7 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1759,6 +1759,31 @@ void NodeManager::ProcessWaitRequestMessage( current_task_id, /*ray_get=*/false); } + if (message->num_ready_objects() == 0) { + // If we don't need to wait for any, return immediately after making the pull + // requests through AsyncResolveObjects above. + flatbuffers::FlatBufferBuilder fbb; + auto wait_reply = protocol::CreateWaitReply(fbb, + to_flatbuf(fbb, std::vector{}), + to_flatbuf(fbb, std::vector{})); + fbb.Finish(wait_reply); + const auto status = + client->WriteMessage(static_cast(protocol::MessageType::WaitReply), + fbb.GetSize(), + fbb.GetBufferPointer()); + if (status.ok()) { + if (resolve_objects) { + AsyncResolveObjectsFinish(client, current_task_id); + } + } else { + // We failed to write to the client, so disconnect the client. + std::ostringstream stream; + stream << "Failed to write WaitReply to the client. Status " << status + << ", message: " << status.message(); + DisconnectClient(client, rpc::WorkerExitType::SYSTEM_ERROR, stream.str()); + } + return; + } uint64_t num_required_objects = static_cast(message->num_ready_objects()); wait_manager_.Wait( object_ids, diff --git a/src/ray/raylet/wait_manager.cc b/src/ray/raylet/wait_manager.cc index a618fdd9c17a4..8512c75d4c6ad 100644 --- a/src/ray/raylet/wait_manager.cc +++ b/src/ray/raylet/wait_manager.cc @@ -28,7 +28,6 @@ void WaitManager::Wait(const std::vector &object_ids, << "Waiting duplicate objects is not allowed. Please make sure all object IDs are " "unique before calling `WaitManager::Wait`."; RAY_CHECK(timeout_ms >= 0 || timeout_ms == -1); - RAY_CHECK_NE(num_required_objects, 0u); RAY_CHECK_LE(num_required_objects, object_ids.size()); const uint64_t wait_id = next_wait_id_++;