Skip to content

Commit

Permalink
#sdy Add replicated hlo shardings for while/case/if ops with no sdy s…
Browse files Browse the repository at this point in the history
…hardings.

This is needed so when free variables are lifted in StableHLO->HLO conversion, the ops will get a sharding with the free variable shardings.

PiperOrigin-RevId: 724643197
  • Loading branch information
tomnatan30 authored and Google-ML-Automation committed Feb 8, 2025
1 parent 977263c commit a3ede7c
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 22 deletions.
75 changes: 63 additions & 12 deletions xla/service/spmd/shardy/shardy_xla_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ TEST_F(ShardyXLATest, EntryComputationLayoutSingleResult) {
%p1 = f32[3,8,32,4] parameter(1)
%copy.p0 = f32[3,8,32,4] copy(%p0)
%copy.p1 = f32[3,8,32,4] copy(%p1)
%add = f32[3,8,32,4] add(%copy.p0, %copy.p1), sharding={devices=[2,1,1,1]<=[2]}, metadata={op_name="simple_example/add" source_file="source.txt" source_line=42}
%add = f32[3,8,32,4] add(%copy.p0, %copy.p1), sharding={devices=[2,1,1,1]<=[2]}
ROOT %result = f32[3,8,32,4] copy(%add)
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
Expand Down Expand Up @@ -382,7 +382,7 @@ TEST_F(ShardyXLATest, EntryComputationLayoutMissingLayout) {
%p1 = f32[3,8,32,4] parameter(1)
%copy.p0 = f32[3,8,32,4] copy(%p0)
%copy.p1 = f32[3,8,32,4] copy(%p1)
%add = f32[3,8,32,4] add(%copy.p0, %copy.p1), sharding={devices=[2,1,1,1]<=[2]}, metadata={op_name="simple_example/add" source_file="source.txt" source_line=42}
%add = f32[3,8,32,4] add(%copy.p0, %copy.p1), sharding={devices=[2,1,1,1]<=[2]}
ROOT %result = f32[3,8,32,4] copy(%add)
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
Expand Down Expand Up @@ -529,10 +529,10 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) {
%arg_tuple.8 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0)
%get-tuple-element.9 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=0
%get-tuple-element.13 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=4
%add.15 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.9, f32[32,96]{1,0} %get-tuple-element.13), metadata={source_file="-" source_line=25}
%add.15 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.9, f32[32,96]{1,0} %get-tuple-element.13)
%get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=1
%get-tuple-element.12 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=3
%add.14 = s32[] add(s32[] %get-tuple-element.10, s32[] %get-tuple-element.12), metadata={source_file="-" source_line=24}
%add.14 = s32[] add(s32[] %get-tuple-element.10, s32[] %get-tuple-element.12)
%get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=2
ROOT %tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %add.15, s32[] %add.14, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[32,96]{1,0} %get-tuple-element.13)
}
Expand All @@ -544,7 +544,7 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) {
%get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=4
%get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=1
%get-tuple-element.21 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=2
ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.20, s32[] %get-tuple-element.21), direction=LT, metadata={source_file="-" source_line=21}
ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.20, s32[] %get-tuple-element.21), direction=LT
}
ENTRY %main.30 (Arg_0.1: f32[32,96], Arg_1.2: f32[32,96]) -> f32[32,96] {
Expand All @@ -553,10 +553,10 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) {
%constant.5 = s32[] constant(32)
%constant.4 = s32[] constant(1)
%Arg_1.2 = f32[32,96]{1,0} parameter(1), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate}
%tuple.6 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.3, s32[] %constant.5, s32[] %constant.4, f32[32,96]{1,0} %Arg_1.2), metadata={source_file="-" source_line=19}
%while.25 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) while((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %tuple.6), condition=%region_1.17, body=%region_0.7, metadata={source_file="-" source_line=19}
%get-tuple-element.27 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=1, metadata={source_file="-" source_line=19}
%get-tuple-element.26 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=0, metadata={source_file="-" source_line=19}
%tuple.6 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.3, s32[] %constant.5, s32[] %constant.4, f32[32,96]{1,0} %Arg_1.2)
%while.25 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) while((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %tuple.6), condition=%region_1.17, body=%region_0.7
%get-tuple-element.27 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=1
%get-tuple-element.26 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=0
%tuple.28 = (f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %get-tuple-element.26)
ROOT %get-tuple-element.29 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}) %tuple.28), index=0
})";
Expand Down Expand Up @@ -682,15 +682,15 @@ ENTRY %main.0 (Arg_0.0: s64[2]) -> s64[2] {
%custom-call.0 = () custom-call(s64[2] %Arg_0.0), custom_call_target="xla_ffi_python_cpu_callback",
operand_layout_constraints={s64[2]{0}}, custom_call_has_side_effect=true, api_version=API_VERSION_TYPED_FFI,
frontend_attributes={xla.sdy.sharding="#sdy.sharding_per_value<[<@maximal_mesh_0, []>]>"},
sharding={{maximal device=0}}, metadata={op_name="custom-call.2"}, backend_config={descriptor = 126001424235520 : ui64}
sharding={{maximal device=0}}
}
)";
const char* const expected = R"(
// CHECK: ENTRY %main.3 (Arg_0.1: s64[2]) -> s64[2] {
// CHECK-NEXT: ROOT %Arg_0.1 = s64[2] parameter(0), metadata={op_name="Arg_0.0"}
// CHECK-NEXT: ROOT %Arg_0.1 = s64[2] parameter(0)
// CHECK-NEXT{LITERAL}: %custom-call.2 = () custom-call(s64[2] %Arg_0.1), custom_call_target="xla_ffi_python_cpu_callback",
// CHECK-SAME{LITERAL}: operand_layout_constraints={s64[2]{0}}, custom_call_has_side_effect=true, api_version=API_VERSION_TYPED_FFI,
// CHECK-SAME{LITERAL}: sharding={{maximal device=0}}, metadata={op_name="custom-call.2"}, backend_config={descriptor = 126001424235520 : ui64}
// CHECK-SAME{LITERAL}: sharding={{maximal device=0}}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
Expand All @@ -702,5 +702,56 @@ ENTRY %main.0 (Arg_0.0: s64[2]) -> s64[2] {
module->ToString(HloPrintOptions{}.set_include_layout_in_shapes(false)),
expected));
}

TEST_F(ShardyXLATest, WhileShardingOnlyOnFreeVariables) {
const char* const hloString = R"(
HloModule main, entry_computation_layout={(f32[32,96]{1,0}, f32[32,96]{1,0})->f32[32,96]{1,0}}, frontend_attributes={xla.sdy.meshes="{mesh = #sdy.mesh<[\"x\"=4]>}"}
%region_0.6 (arg_tuple.7: (f32[32,96], s32[], f32[32,96])) -> (f32[32,96], s32[], f32[32,96]) {
%arg_tuple.7 = (f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) parameter(0)
%get-tuple-element.8 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %arg_tuple.7), index=0
%sine.11 = f32[32,96]{1,0} sine(f32[32,96]{1,0} %get-tuple-element.8)
%custom-call.12 = f32[32,96]{1,0} custom-call(f32[32,96]{1,0} %sine.11), custom_call_target="Sharding", sharding={replicated}, frontend_attributes={xla.sdy.sharding="#sdy.sharding_per_value<[<@mesh, [{}, {}]>]>"}
%get-tuple-element.10 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %arg_tuple.7), index=2
%add.13 = f32[32,96]{1,0} add(f32[32,96]{1,0} %custom-call.12, f32[32,96]{1,0} %get-tuple-element.10)
%custom-call.14 = f32[32,96]{1,0} custom-call(f32[32,96]{1,0} %add.13), custom_call_target="Sharding", sharding={replicated}, frontend_attributes={xla.sdy.sharding="#sdy.sharding_per_value<[<@mesh, [{}, {}]>]>"}
%get-tuple-element.9 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %arg_tuple.7), index=1
%constant.15 = s32[] constant(1)
%add.16 = s32[] add(s32[] %get-tuple-element.9, s32[] %constant.15)
ROOT %tuple.17 = (f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %custom-call.14, s32[] %add.16, f32[32,96]{1,0} %get-tuple-element.10)
}
%region_1.18 (arg_tuple.19: (f32[32,96], s32[], f32[32,96])) -> pred[] {
%arg_tuple.19 = (f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) parameter(0)
%get-tuple-element.20 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %arg_tuple.19), index=0
%get-tuple-element.22 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %arg_tuple.19), index=2
%get-tuple-element.21 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %arg_tuple.19), index=1
%constant.23 = s32[] constant(32)
ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.21, s32[] %constant.23), direction=LT
}
ENTRY %main.28 (Arg_0.1: f32[32,96], Arg_1.2: f32[32,96]) -> f32[32,96] {
%Arg_0.1 = f32[32,96]{1,0} parameter(0)
%constant.3 = s32[] constant(0)
%Arg_1.2 = f32[32,96]{1,0} parameter(1), sharding={devices=[4,1]<=[4]}, frontend_attributes={xla.sdy.sharding="#sdy.sharding<@mesh, [{\"x\", ?}, {?}]>"}
%custom-call.4 = f32[32,96]{1,0} custom-call(f32[32,96]{1,0} %Arg_1.2), custom_call_target="Sharding", sharding={replicated}, frontend_attributes={xla.sdy.sharding="#sdy.sharding_per_value<[<@mesh, [{?}, {?}]>]>"}
%tuple.5 = (f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.3, f32[32,96]{1,0} %custom-call.4)
%while.25 = (f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) while((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %tuple.5), condition=%region_1.18, body=%region_0.6
ROOT %get-tuple-element.26 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %while.25), index=0
%get-tuple-element.27 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], f32[32,96]{1,0}) %while.25), index=1
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hloString));
runShardyWithSdyImport(module.get());

HloInstruction* whileInst =
FindInstruction(module.get(), xla::HloOpcode::kWhile);
EXPECT_NE(whileInst, nullptr);
// Verify the sharding of the while, and specifically that the sharding of the
// result that corresponds to parameter(1) is further sharded.
EXPECT_THAT(whileInst, op::Sharding("{{replicated}, {replicated}, "
"{devices=[4,1]<=[4]}}"));
}

} // namespace sdy
} // namespace xla
33 changes: 25 additions & 8 deletions xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/xla_data.pb.h"

namespace stablehlo = ::mlir::stablehlo;

namespace xla {
namespace sdy {

Expand All @@ -87,8 +89,6 @@ using ::mlir::success;
using ::mlir::SymbolTable;
using ::mlir::func::FuncOp;

using ::mlir::stablehlo::CustomCallOp;

using ::mlir::sdy::AxisRefAttr;
using ::mlir::sdy::DimensionShardingAttr;
using ::mlir::sdy::kShardingAttr;
Expand Down Expand Up @@ -149,7 +149,8 @@ SmallVector<AxisRefAttr> getOrderedAxisRefs(

// Convert the shardings from kShardingAttr into kXlaShardingAttr.
LogicalResult exportFunc(FuncOp funcOp, const SymbolTable& symbolTable,
OpBuilder& builder) {
OpBuilder& builder,
bool addMissingShardingToControlFlow) {
std::function<StringAttr(const HloSharding&)> getStringAttr =
[&](const HloSharding& hloSharding) {
return builder.getStringAttr(hloSharding.ToString());
Expand Down Expand Up @@ -187,6 +188,13 @@ LogicalResult exportFunc(FuncOp funcOp, const SymbolTable& symbolTable,
kXlaShardingAttr,
convertToHloShardingAttr(op, shardings, getMeshAttr, getStringAttr));
op->removeAttr(kShardingAttr);
} else if (addMissingShardingToControlFlow &&
mlir::isa<stablehlo::WhileOp, stablehlo::CaseOp,
stablehlo::IfOp>(op) &&
!op->hasAttr(kXlaShardingAttr)) {
// We check if the op already has a `kXlaShardingAttr`, since a manual
// sharding might have been added in shard map export pass.
op->setAttr(kXlaShardingAttr, getStringAttr(HloSharding::Replicate()));
}
});

Expand All @@ -199,6 +207,9 @@ class ExportStablehloShardingsPass
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ExportStablehloShardingsPass)

explicit ExportStablehloShardingsPass(bool addMissingShardingToControlFlow)
: addMissingShardingToControlFlow(addMissingShardingToControlFlow) {}

void runOnOperation() final {
ModuleOp moduleOp = getOperation();

Expand All @@ -208,12 +219,13 @@ class ExportStablehloShardingsPass
auto builder = OpBuilder::atBlockBegin(&moduleOp.getBodyRegion().front());

for (auto funcOp : moduleOp.getOps<FuncOp>()) {
if (mlir::failed(exportFunc(funcOp, symbolTable, builder))) {
if (mlir::failed(exportFunc(funcOp, symbolTable, builder,
addMissingShardingToControlFlow))) {
signalPassFailure();
}
}

moduleOp.walk([&](CustomCallOp customCall) {
moduleOp.walk([&](stablehlo::CustomCallOp customCall) {
// StableHLO doesn't have an equivalent of `erf` and `topk` ops.
// If they have a sharding annotation, we need to move it into
// `mhlo.attributes`, which StableHLO->MHLO conversion would lift back up.
Expand Down Expand Up @@ -255,6 +267,9 @@ class ExportStablehloShardingsPass
void getDependentDialects(mlir::DialectRegistry& registry) const final {
registry.insert<SdyDialect>();
}

private:
bool addMissingShardingToControlFlow;
};

} // namespace
Expand Down Expand Up @@ -376,12 +391,14 @@ StringAttr convertToHloShardingAttr(
HloSharding::Tuple(xla::ShapeUtil::MakeTupleShape(shapes), newShardings));
}

std::unique_ptr<Pass> createExportStablehloShardingsPass() {
return std::make_unique<ExportStablehloShardingsPass>();
std::unique_ptr<Pass> createExportStablehloShardingsPass(
bool addMissingShardingToControlFlow) {
return std::make_unique<ExportStablehloShardingsPass>(
addMissingShardingToControlFlow);
}

void registerStablehloExportShardingsPass() {
mlir::registerPass(createExportStablehloShardingsPass);
mlir::registerPass(std::bind(createExportStablehloShardingsPass, false));
}

} // namespace sdy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ mlir::StringAttr convertToHloShardingAttr(
// Creates a pass that converts the shardings from `kShardingAttr` to
// `kXlaShardingAttr` and removes mesh symbols. Fully or partially manual
// shardings are processed in `ShardMapExportPass`.
std::unique_ptr<mlir::Pass> createExportStablehloShardingsPass();
//
// If `addMissingShardingToControlFlow` is true, the pass will add a replicated
// hlo sharding to control flow ops (while, case, if) that have no sdy sharding.
std::unique_ptr<mlir::Pass> createExportStablehloShardingsPass(
bool addMissingShardingToControlFlow = false);

// Register the xla-sdy-stablehlo-export-shardings pass.
void registerStablehloExportShardingsPass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ void addStablehloExportPipeline(mlir::OpPassManager& pm) {
pm.addPass(createExportOpsPass());
pm.addPass(createStablehloRoundTripShardMapExportPass());
pm.addPass(createExportNamedComputationsPass());
pm.addPass(createExportStablehloShardingsPass());
// If we don't add a sharding to a control flow op without one,
// StableHLO -> HLO conversion won't add a sharding for that op even if a
// free variable that has a sharding is lifted as an additional result, and in
// effect the op will have a replicated sharding for all results.
pm.addPass(createExportStablehloShardingsPass(
/*addMissingShardingToControlFlow=*/true));
pm.addPass(createStablehloRoundTripExportCallbackCustomCallsPass());
}

Expand Down
19 changes: 19 additions & 0 deletions xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,25 @@ func.func @sharding_and_op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8
return %0 : tensor<8x2xf64>
}

// CHECK-LABEL: func @while_with_no_sharding
func.func @while_with_no_sharding(
%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32>)
-> tensor<32x96xf32> {
// CHECK: %[[C0:.*]] = stablehlo.constant dense<0>
// CHECK: stablehlo.while(%iterArg = %arg0, %iterArg_1 = %[[C0]])
// CHECK-NOT: mhlo.sharding
%0 = stablehlo.constant dense<0> : tensor<i32>
%1 = stablehlo.constant dense<32> : tensor<i32>
%3:2 = stablehlo.while(%iterArg = %arg0, %iterArg_1 = %0) : tensor<32x96xf32>, tensor<i32>
cond {
%4 = stablehlo.compare LT, %iterArg_1, %1 : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %4 : tensor<i1>
} do {
stablehlo.return %iterArg, %iterArg_1 : tensor<32x96xf32>, tensor<i32>
}
return %3#0 : tensor<32x96xf32>
}

// -----

// CHECK-NOT: xla.sdy.meshes
Expand Down
Loading

0 comments on commit a3ede7c

Please sign in to comment.