Skip to content

Commit

Permalink
[XLA:GPU] Model bytes accessed.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724344346
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Feb 7, 2025
1 parent c2d22ce commit 5eec5b3
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 106 deletions.
3 changes: 2 additions & 1 deletion xla/service/gpu/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:hlo_cost_analysis",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
Expand Down Expand Up @@ -225,6 +226,7 @@ cc_library(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/stream_executor:device_description",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand All @@ -245,7 +247,6 @@ xla_cc_test(
":hlo_op_profiles",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:test_helpers",
"//xla/service:hlo_cost_analysis",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
Expand Down
161 changes: 87 additions & 74 deletions xla/service/gpu/model/gpu_hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,15 @@ limitations under the License.
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace gpu {

namespace {
// Use the "reserved" keys for these properties so lookups are fast.
static constexpr absl::string_view kIRSizeKey = HloCostAnalysis::kReserved0Key;

Expand All @@ -61,6 +63,47 @@ static constexpr absl::string_view kCollNumDevicesKey =
static constexpr absl::string_view kCollBytesTransferred =
"Number of bytes transferred.";

template <typename T>
absl::StatusOr<int64_t> NumRanks(const T& instr) {
const HloModuleConfig& config = instr.GetModule()->config();
TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(instr.channel_id().has_value(),
instr.use_global_device_ids()));

// Get number of ranks for this instruction based on replica groups and mode.
int64_t num_devices = config.num_partitions();
int64_t num_replicas = config.replica_count();
TF_ASSIGN_OR_RETURN(
std::vector<int64_t> participant_counts,
GetPariticipantCountsForReplicaGroups(
num_replicas, num_devices, instr.replica_groups(), group_mode));
int64_t num_ranks = 1;

for (auto count : participant_counts) {
num_ranks = std::max(num_ranks, count);
}
return num_ranks;
}

int64_t ShapeSize(const Shape& shape,
const GpuHloCostAnalysis::ShapeSizeFunction& get_shape,
int64_t index_to_skip = -1) {
int64_t shape_size = 0;
ShapeUtil::ForEachLeafShape(
shape, [&](const Shape& subshape, const ShapeIndex& index) {
if (!index.empty() && index.front() == index_to_skip) {
return;
}

if (subshape.IsArray()) {
shape_size += get_shape(subshape);
}
});
return shape_size;
}

} // namespace

// We use static tables to look up system bandwidths for different
// type of hardware below.
// TODO TJ this needs to be hosted somewhere more centralized.
Expand Down Expand Up @@ -333,25 +376,8 @@ int64_t GpuHloCostAnalysis::GetFlopsForElementwiseOp(

absl::Status GpuHloCostAnalysis::HandleAllReduce(
const HloInstruction* allreduce) {
const HloModuleConfig& config = allreduce->GetModule()->config();
TF_ASSIGN_OR_RETURN(
CollectiveOpGroupMode group_mode,
GetCollectiveOpGroupMode(
allreduce->channel_id().has_value(),
Cast<HloAllReduceInstruction>(allreduce)->use_global_device_ids()));

// Get number of ranks for this instruction based on replica groups and mode.
int64_t num_devices = config.num_partitions();
int64_t num_replicas = config.replica_count();
TF_ASSIGN_OR_RETURN(
std::vector<int64_t> participant_counts,
GetPariticipantCountsForReplicaGroups(
num_replicas, num_devices, allreduce->replica_groups(), group_mode));
int64_t num_ranks = 1;

for (auto count : participant_counts) {
num_ranks = std::max(num_ranks, count);
}
TF_ASSIGN_OR_RETURN(int64_t num_ranks,
NumRanks(*Cast<HloAllReduceInstruction>(allreduce)));

VLOG(5) << "Computing cost for " << num_ranks << " ranks in "
<< allreduce->ToString();
Expand Down Expand Up @@ -463,47 +489,44 @@ absl::Status GpuHloCostAnalysis::HandleReduce(const HloInstruction* hlo) {

absl::Status GpuHloCostAnalysis::HandleAllReduceStart(
const HloInstruction* hlo) {
int64_t output_bytes_accessed = 0;
ShapeUtil::ForEachLeafShape(
hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.IsArray()) {
output_bytes_accessed += GetShapeSize(subshape);
}
});
current_properties_.set_output_bytes_accessed(output_bytes_accessed);
current_properties_[kCollBytesTransferred] = output_bytes_accessed;
int64_t bytes_transferred = ShapeSize(hlo->shape(), options_.shape_size);

current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(
hlo->to_apply()->root_instruction()->opcode(), hlo->shape());
current_properties_[kBytesAccessedKey] = bytes_transferred;
current_properties_[kCollBytesTransferred] = bytes_transferred;
return absl::OkStatus();
}

absl::Status GpuHloCostAnalysis::HandleAllGather(const HloInstruction* hlo) {
int64_t output_bytes_accessed = 0;
ShapeUtil::ForEachLeafShape(
hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.IsArray()) {
output_bytes_accessed += GetShapeSize(subshape);
}
});
current_properties_.set_output_bytes_accessed(output_bytes_accessed);
current_properties_[kCollBytesTransferred] = output_bytes_accessed;
TF_ASSIGN_OR_RETURN(int64_t num_ranks,
NumRanks(*Cast<HloAllGatherInstruction>(hlo)));

int64_t bytes_transferred = ShapeSize(hlo->shape(), options_.shape_size);
int64_t rank_size_bytes = bytes_transferred / num_ranks;
int64_t write_bytes = rank_size_bytes * (2 * num_ranks - 1);
int64_t read_bytes = rank_size_bytes * num_ranks;

current_properties_[kBytesAccessedKey] = write_bytes + read_bytes;
current_properties_[kCollBytesTransferred] = bytes_transferred;

return absl::OkStatus();
}

absl::Status GpuHloCostAnalysis::HandleAllGatherStart(
const HloInstruction* hlo) {
int64_t output_bytes_accessed = 0;
ShapeUtil::ForEachLeafShape(
hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
// Skip first element of a tuple as it expresses the input of the
// collective operation.
if (index.empty() || index.front() == 0) {
return;
}
if (subshape.IsArray()) {
output_bytes_accessed += GetShapeSize(subshape);
}
});
current_properties_.set_output_bytes_accessed(output_bytes_accessed);
current_properties_[kCollBytesTransferred] = output_bytes_accessed;
TF_ASSIGN_OR_RETURN(int64_t num_ranks,
NumRanks(*Cast<HloAllGatherInstruction>(hlo)));

int64_t bytes_transferred =
ShapeSize(hlo->shape(), options_.shape_size, /*index_to_skip=*/0);
int64_t rank_size_bytes = bytes_transferred / num_ranks;
int64_t write_bytes = rank_size_bytes * (2 * num_ranks - 1);
int64_t read_bytes = rank_size_bytes * num_ranks;

current_properties_[kBytesAccessedKey] = write_bytes + read_bytes;
current_properties_[kCollBytesTransferred] = bytes_transferred;

return absl::OkStatus();
}

Expand All @@ -513,37 +536,27 @@ absl::Status GpuHloCostAnalysis::HandleAsyncStart(const HloInstruction* hlo) {
VLOG(2) << "Only Reduce Scatter is supported.";
return absl::OkStatus();
}
int index_to_skip = 1;
int64_t bytes_transferred = 0;
ShapeUtil::ForEachLeafShape(
hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
// Skip second element of a tuple as it is an output but it is not
// actual bytes transferred.
if (index.empty() || index.front() == index_to_skip) {
return;
}
if (subshape.IsArray()) {
bytes_transferred += GetShapeSize(subshape);
}
});

current_properties_[kCollBytesTransferred] = bytes_transferred;
return absl::OkStatus();
return HandleReduceScatter(async_start->async_wrapped_instruction());
}

absl::Status GpuHloCostAnalysis::HandleReduceScatter(
const HloInstruction* hlo) {
int64_t bytes_transferred = 0;
TF_ASSIGN_OR_RETURN(int64_t num_ranks,
NumRanks(*Cast<HloReduceScatterInstruction>(hlo)));

for (auto* operand : hlo->operands()) {
ShapeUtil::ForEachLeafShape(
operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.IsArray()) {
bytes_transferred += GetShapeSize(subshape);
}
});
int64_t bytes_transferred = 0;
for (HloInstruction* operand : hlo->operands()) {
bytes_transferred += ShapeSize(operand->shape(), options_.shape_size);
}
int64_t rank_size_bytes = bytes_transferred / num_ranks;
int64_t write_bytes = rank_size_bytes * num_ranks;
int64_t read_bytes = rank_size_bytes * (2 * num_ranks - 1);

current_properties_[kBytesAccessedKey] = write_bytes + read_bytes;
current_properties_[kCollBytesTransferred] = bytes_transferred;
current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(
hlo->to_apply()->root_instruction()->opcode(), hlo->shape());

return absl::OkStatus();
}
Expand Down
37 changes: 27 additions & 10 deletions xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,15 +665,17 @@ ENTRY entry_computation {
const HloInstruction* all_reduce =
module->entry_computation()->root_instruction()->operand(0);
EXPECT_EQ(analysis_.BytesTransferred(*all_reduce), 4096 * 4);
EXPECT_EQ(analysis_.bytes_accessed(*all_reduce), 4096 * 4);
}

TEST_F(GpuHloCostAnalysisTest, AllGather) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m, num_partitions=4
ENTRY entry_computation {
p = f32[1024] parameter(0)
ROOT _ = f32[4096] all-gather(p), dimensions={0}
ROOT _ = f32[4096] all-gather(p), dimensions={0}, use_global_device_ids=true,
replica_groups={{0,1,2,3}}, channel_id=1
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
Expand All @@ -684,17 +686,21 @@ ENTRY entry_computation {
const HloInstruction* all_gather =
module->entry_computation()->root_instruction();
EXPECT_EQ(analysis_.BytesTransferred(*all_gather), 4096 * 4);
// write = (3 + 4) * 1024 * 4 bytes
// read = 4 * 1024 * 4 bytes
EXPECT_EQ(analysis_.bytes_accessed(*all_gather), 45056);
}

TEST_F(GpuHloCostAnalysisTest, AsyncAllGather) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m, num_partitions=4
ENTRY entry_computation {
p.0 = f32[1024] parameter(0)
p.1 = f32[512] parameter(1)
ag-start = ((f32[1024],f32[512]), (f32[4096],f32[2048])) all-gather-start(p.0,p.1),
dimensions={0}
dimensions={0}, use_global_device_ids=true, replica_groups={{0,1,2,3}},
channel_id=1
ROOT _ = (f32[4096],f32[2048]) all-gather-done(ag-start)
}
)";
Expand All @@ -707,11 +713,14 @@ ENTRY entry_computation {
module->entry_computation()->root_instruction()->operand(0);
// Output is (f32[4096], f32[2048]).
EXPECT_EQ(analysis_.BytesTransferred(*all_gather), 4096 * 4 + 2048 * 4);
// write = (3 + 4) * (1024 + 512) * 4 bytes
// read = 4 * (1024 + 512) * 4 bytes
EXPECT_EQ(analysis_.bytes_accessed(*all_gather), 67584);
}

TEST_F(GpuHloCostAnalysisTest, ReduceScatter) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m, num_partitions=4
add {
param_0 = f32[] parameter(0)
Expand All @@ -721,7 +730,8 @@ add {
ENTRY entry_computation {
p = f32[4096] parameter(0)
ROOT _ = f32[1024] reduce-scatter(p), dimensions={0}, to_apply=add
ROOT _ = f32[1024] reduce-scatter(p), dimensions={0}, to_apply=add,
use_global_device_ids=true, replica_groups={{0,1,2,3}}, channel_id=1
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
Expand All @@ -732,11 +742,14 @@ ENTRY entry_computation {
const HloInstruction* reduce_scatter =
module->entry_computation()->root_instruction();
EXPECT_EQ(analysis_.BytesTransferred(*reduce_scatter), 4096 * 4);
// read = (3 + 4) * 1024 * 4 bytes
// write = 4 * 1024 * 4 bytes
EXPECT_EQ(analysis_.bytes_accessed(*reduce_scatter), 45056);
}

TEST_F(GpuHloCostAnalysisTest, AsyncReduceScatter) {
absl::string_view hlo_string = R"(
HloModule m
HloModule m, num_partitions=4
add {
param_0 = f32[] parameter(0)
Expand All @@ -748,14 +761,15 @@ async_computation {
param_3 = f32[4096] parameter(0)
param_4 = f32[2048] parameter(1)
ROOT r = (f32[1024],f32[512]) reduce-scatter(param_3,param_4),
dimensions={0},
to_apply=add
dimensions={0}, to_apply=add, use_global_device_ids=true,
replica_groups={{0,1,2,3}}, channel_id=1
}
ENTRY entry_computation {
p.0 = f32[4096] parameter(0)
p.1 = f32[2048] parameter(1)
rs-start = ((f32[4096],f32[2048]),(f32[1024],f32[512])) async-start(p.0,p.1), calls=async_computation
rs-start = ((f32[4096],f32[2048]),(f32[1024],f32[512])) async-start(p.0,p.1),
calls=async_computation
ROOT _ = (f32[1024],f32[512]) async-done(rs-start)
}
)";
Expand All @@ -768,6 +782,9 @@ ENTRY entry_computation {
module->entry_computation()->root_instruction()->operand(0);
// Output is (f32[1024],f32[512]).
EXPECT_EQ(analysis_.BytesTransferred(*reduce_scatter), 4096 * 4 + 2048 * 4);
// read = (3 + 4) * (1024 + 512) * 4 bytes
// write = 4 * (1024 + 512) * 4 bytes
EXPECT_EQ(analysis_.bytes_accessed(*reduce_scatter), 67584);
}

TEST_F(GpuHloCostAnalysisTest, CustomOpProfileIsUsed) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ TEST_F(SolGpuCostModelStatsCollectionTest,
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
CHECK: ar-start
CHECK-SAME: collective_backend_config
CHECK-SAME: "exec_time_us":1495
CHECK-SAME: "exec_time_us":1
)"));
}

Expand Down
Loading

0 comments on commit 5eec5b3

Please sign in to comment.