Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:cpu] Add OneDnnThreadPool based on parallel loop runner #22125

Merged
merged 1 commit into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions xla/backends/cpu/runtime/onednn/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# copybara:uncomment_begin(google-only)
# load("//xla:xla.bzl", "xla_cc_test")
# load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
#
# package(
# # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
# default_visibility = [":friends"],
# licenses = ["notice"],
# )
#
# package_group(
# name = "friends",
# includes = [
# "//xla:friends",
# ],
# )
#
# cc_library(
# name = "onednn_interop",
# hdrs = ["onednn_interop.h"],
# deps = [
# "@com_google_absl//absl/base:core_headers",
# "@com_google_absl//absl/status",
# "@onednn//:mkl_dnn",
# "//xla:util",
# "//xla/tsl/platform:logging",
# ],
# )
#
# cc_library(
# name = "onednn_threadpool",
# hdrs = ["onednn_threadpool.h"],
# deps = [
# ":onednn_interop",
# "@onednn//:mkl_dnn",
# "//xla/backends/cpu/runtime:parallel_loop_runner",
# ],
# )
#
# xla_cc_test(
# name = "onednn_threadpool_test",
# srcs = ["onednn_threadpool_test.cc"],
# deps = [
# ":onednn_interop",
# ":onednn_threadpool",
# "@com_google_googletest//:gtest_main",
# "@com_google_absl//absl/algorithm:container",
# "@com_google_absl//absl/status",
# "@com_google_absl//absl/status:statusor",
# "@com_google_absl//absl/synchronization",
# "@eigen_archive//:eigen3",
# "@onednn//:mkl_dnn",
# "@pthreadpool",
# "//xla/backends/cpu/runtime:parallel_loop_runner",
# "//xla/tsl/concurrency:async_value",
# "//xla/tsl/lib/core:status_test_util",
# "//xla/tsl/platform:env",
# "//xla/tsl/platform:statusor",
# "//xla/tsl/platform:test",
# "//xla/tsl/platform:test_benchmark",
# "//xla/tsl/platform:test_main",
# ],
# )
# copybara:uncomment_end
84 changes: 84 additions & 0 deletions xla/backends/cpu/runtime/onednn/onednn_interop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_INTEROP_H_
#define XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_INTEROP_H_

#include "oneapi/dnnl/dnnl_graph.hpp"
#include "absl/base/optimization.h"
#include "absl/status/status.h"
#include "xla/tsl/platform/logging.h"
#include "xla/util.h"

namespace xla::cpu {

#define ONEDNN_RETURN_IF_ERROR(expr) \
do { \
absl::Status s = OneDnnStatusToStatus(expr); \
if (!s.ok()) { \
return s; \
} \
} while (0)

#define ONEDNN_LOG_IF_ERROR(expr) \
do { \
absl::Status s = OneDnnStatusToStatus(expr); \
if (!s.ok()) { \
LOG(ERROR) << "DNNL operation failed: " << s; \
} \
} while (0)

// Statically initializes XNNPACK for the current process.
absl::Status InitializeXnnPack();

// Converts oneDNN status to absl::Status.
inline absl::Status OneDnnStatusToStatus(dnnl::graph::status status) {
if (ABSL_PREDICT_TRUE(status == dnnl::graph::status::success)) {
return absl::OkStatus();
}

auto error_message = [](dnnl::graph::status status) {
switch (status) {
case dnnl::graph::status::success:
return "";
case dnnl::graph::status::out_of_memory:
return "out of memory";
case dnnl::graph::status::invalid_arguments:
return "invalid arguments";
case dnnl::graph::status::unimplemented:
return "unimplemented";
case dnnl::graph::status::last_impl_reached:
return "last implementation reached";
case dnnl::graph::status::runtime_error:
return "runtime error";
case dnnl::graph::status::not_required:
return "not required";
case dnnl::graph::status::invalid_graph:
return "invalid graph";
case dnnl::graph::status::invalid_graph_op:
return "invalid graph op";
case dnnl::graph::status::invalid_shape:
return "invalid shape";
case dnnl::graph::status::invalid_data_type:
return "invalid data type";
}
};

return Internal("DNNL operation failed: %s", error_message(status));
}

} // namespace xla::cpu

#endif // XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_INTEROP_H_
60 changes: 60 additions & 0 deletions xla/backends/cpu/runtime/onednn/onednn_threadpool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_THREADPOOL_H_
#define XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_THREADPOOL_H_

#include <cstddef>
#include <cstdint>
#include <functional>

#include "oneapi/dnnl/dnnl_threadpool_iface.hpp"
#include "xla/backends/cpu/runtime/parallel_loop_runner.h"

namespace xla::cpu {

class OneDnnThreadPool final
: public dnnl::threadpool_interop::threadpool_iface {
public:
explicit OneDnnThreadPool(ParallelLoopRunner* runner) : runner_(runner) {}

int get_num_threads() const final;
bool get_in_parallel() const final;
uint64_t get_flags() const final;

void parallel_for(int n, const std::function<void(int, int)>& fn) final;

private:
ParallelLoopRunner* runner_;
};

inline int OneDnnThreadPool::get_num_threads() const {
return runner_->num_threads();
}

inline bool OneDnnThreadPool::get_in_parallel() const {
return runner_->is_in_runner();
}

inline uint64_t OneDnnThreadPool::get_flags() const { return 0; }

inline void OneDnnThreadPool::parallel_for(
int n, const std::function<void(int, int)>& fn) {
runner_->Parallelize(n, [fn, n](size_t task_index) { fn(task_index, n); });
}

} // namespace xla::cpu

#endif // XLA_BACKENDS_CPU_RUNTIME_ONEDNN_ONEDNN_THREADPOOL_H_
118 changes: 118 additions & 0 deletions xla/backends/cpu/runtime/onednn/onednn_threadpool_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/cpu/runtime/onednn/onednn_threadpool.h"

#include <cmath>
#include <cstdint>
#include <vector>

#include "oneapi/dnnl/dnnl.hpp"
#include "oneapi/dnnl/dnnl_common.hpp"
#include "oneapi/dnnl/dnnl_graph.hpp"
#include "oneapi/dnnl/dnnl_threadpool.hpp"
#include <gtest/gtest.h>
#include "absl/status/statusor.h"
#include "xla/backends/cpu/runtime/onednn/onednn_interop.h"
#include "xla/backends/cpu/runtime/parallel_loop_runner.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/threadpool.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

namespace xla::cpu {
namespace {

// Creates a graph with a single Exp operation.
static absl::StatusOr<dnnl::graph::graph> CreateExpGraph(
const dnnl::graph::logical_tensor& src_tensor,
const dnnl::graph::logical_tensor& dst_tensor) {
dnnl::graph::op exp_op(0, dnnl::graph::op::kind::Exp, {src_tensor},
{dst_tensor});

dnnl::graph::graph g(dnnl::engine::kind::cpu);
ONEDNN_RETURN_IF_ERROR(g.add_op(exp_op));
g.finalize();

return g;
}

TEST(OneDnnThreadPoolTest, Binary) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 32);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());

ParallelLoopRunner runner(&device);
OneDnnThreadPool threadpool(&runner);

int64_t d0 = 100;
int64_t d1 = 1000;
int64_t num_elements = d0 * d1;

// We use row-major layout for both source and destination tensors.
dnnl::graph::logical_tensor::dims src_dims = {d0, d1};
dnnl::graph::logical_tensor::dims dst_dims = {d0, d1};

dnnl::graph::logical_tensor::dims src_strides = {d1, 1};
dnnl::graph::logical_tensor::dims dst_strides = {d1, 1};

dnnl::graph::logical_tensor src_tensor(
0, dnnl::graph::logical_tensor::data_type::f32, src_dims, src_strides);
dnnl::graph::logical_tensor dst_tensor(
1, dnnl::graph::logical_tensor::data_type::f32, dst_dims, dst_strides);

// Compile oneDNN graph with a single Exp operation.
TF_ASSERT_OK_AND_ASSIGN(dnnl::graph::graph g,
CreateExpGraph(src_tensor, dst_tensor));
std::vector<dnnl::graph::partition> partitions = g.get_partitions();

// Create oneDNN engine for running the graph on CPU.
dnnl::engine engine(dnnl::engine::kind::cpu, 0);

// Create oneDNN stream backed by parallel loop runner.
dnnl::stream stream =
dnnl::stream(dnnl::threadpool_interop::make_stream(engine, &threadpool));

// Compile graph partitions for given engine.
std::vector<dnnl::graph::compiled_partition> compiled_partitions;
for (const auto& partition : partitions) {
compiled_partitions.push_back(
partition.compile({src_tensor}, {dst_tensor}, engine));
}

// Create tensors for source and destination.
std::vector<float> src_data(num_elements, 1.0f);
std::vector<float> dst_data(num_elements, 0.0f);

dnnl::graph::tensor src(src_tensor, engine, src_data.data());
dnnl::graph::tensor dst(dst_tensor, engine, dst_data.data());

// Execute compiled oneDNN graph on the CPU stream.
compiled_partitions[0].execute(stream, {src}, {dst});

// Wait for the completion of parallel loops scheduled into the runner.
tsl::BlockUntilReady(runner.done_event());

for (int i = 0; i < num_elements; ++i) {
EXPECT_NEAR(dst_data[i], std::exp(1.0f), 1e-5);
}
}

} // namespace
} // namespace xla::cpu
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ size_t ParallelLoopRunner::num_threads() const {
return device_.load()->numThreadsInPool();
}

bool ParallelLoopRunner::is_in_runner() const {
return device_.load()->currentThreadId() > -1;
}

tsl::AsyncValueRef<tsl::Chain> ParallelLoopRunner::TakeDoneEvent(
ParallelLoopRunner&& runner) {
return std::move(runner.done_event_);
Expand Down
4 changes: 4 additions & 0 deletions xla/backends/cpu/runtime/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,12 @@ class ParallelLoopRunner {
const Eigen::ThreadPoolDevice* device() const { return device_; }
void set_device(const Eigen::ThreadPoolDevice* device) { device_ = device; }

// Returns the number of threads in the underlying thread pool.
size_t num_threads() const;

// Returns true if the current thread belongs to the underlying thread pool.
bool is_in_runner() const;

private:
// Forward declarations of the parallel tasks.
struct ParallelTask1D;
Expand Down
Loading