Skip to content

Commit

Permalink
[Pallas] Reductions with replicated axes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717827641
  • Loading branch information
Google-ML-Automation committed Feb 14, 2025
1 parent 60dcded commit b5179df
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 115 deletions.
1 change: 1 addition & 0 deletions jaxlib/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ cc_library(
"@xla//xla:array",
"@xla//xla:shape_util",
"@xla//xla:util",
"@xla//xla/tsl/platform:errors",
] + pallas_extension_deps,
)

Expand Down
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ std::unique_ptr<VRegDataBounds> VectorLayout::tileDataBounds(

if (!hasNaturalTopology(target_shape)) {
if (!offsets_[0].has_value() || !offsets_[1].has_value()) {
emitError(UnknownLoc::get(mlir_ctx), "Not implemented");
emitError(UnknownLoc::get(mlir_ctx),
"Not implemented: non-natural topology with replication");
return nullptr;
}
const int64_t so = *offsets_[0];
Expand Down
301 changes: 187 additions & 114 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "jaxlib/mosaic/dialect/tpu/vreg_util.h"
#include "xla/array.h"
#include "xla/layout.h"
#include "xla/tsl/platform/errors.h"
#include "xla/util.h"

// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more
Expand Down Expand Up @@ -1997,7 +1998,7 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
if (in_tiling != out_tiling) {
return op.emitOpError(
"Expected tilings are the same after multiplying the "
"second-minor dimension by the ratio of bitwidths.");
"second-minor dimension by the ratio of bitwidths.");
}
auto in_offsets = in_layout.offsets();
auto out_offsets = out_layout.offsets();
Expand All @@ -2012,7 +2013,7 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op,
in_offsets[1] != out_offsets[1]) {
return op.emitOpError(
"Expected offsets are the same after multiplying the "
"second-minor dimension by the ratio of bitwidths.");
"second-minor dimension by the ratio of bitwidths.");
}
if (in_layout.implicit_dim() != out_layout.implicit_dim()) {
return op.emitOpError(
Expand Down Expand Up @@ -3805,7 +3806,7 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op,
extract_op.replaceAllUsesWith(
builder.create<vector::ExtractOp>(
op.getLoc(), rotated_vreg,
ArrayRef<int64_t>{0, 0})
ArrayRef<int64_t>{0, 0})
.getResult());
}
extract_op.erase();
Expand Down Expand Up @@ -3956,7 +3957,6 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
false};
break;
}
const std::array<bool, 2> allow_replicated = {!reduces[0], !reduces[1]};

if ((reduces[0] || reduces[1]) &&
!src_layout.hasNativeTiling(ctx.target_shape)) {
Expand All @@ -3968,9 +3968,10 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
return multi_reduction_op.emitOpError("Not implemented: Tiling change");
}
for (int i = 0; i < 2; ++i) {
if (reduces[i] && src_layout.offsets()[i] == std::nullopt) {
if (reduces[i] && src_layout.offsets()[i] == std::nullopt &&
element_type.getIntOrFloatBitWidth() != 32) {
return multi_reduction_op.emitOpError(
"Not implemented: Reductions over replicated axes");
"Not implemented: Non-32-bit reductions over replicated axes");
}
// Offsets have to be equal, unless we're reducing over that dimension.
if (src_layout.offsets()[i] != dst_layout.offsets()[i] && !reduces[i]) {
Expand Down Expand Up @@ -4034,130 +4035,202 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<int64_t> src_shape = src_ty.getShape();
auto all_results_ok = dst_vregs.EachStatus(
[&](const absl::Span<const int64_t> idx, Value *const dst_vreg) {
// Extract a subset of source vregs that reduce into this result vreg.
SmallVector<int64_t> src_slice_start;
src_slice_start.reserve(src_rank);
SmallVector<int64_t> src_slice_end;
src_slice_end.reserve(src_rank);
for (int64_t i : idx) {
src_slice_start.push_back(i);
src_slice_end.push_back(i + 1);
}
for (int64_t d : dims) {
src_slice_start.insert(src_slice_start.begin() + d, 0);
src_slice_end.insert(src_slice_end.begin() + d, src_vregs.dim(d));
}
xla::Array<Value> reduced_vregs =
src_vregs.Slice(src_slice_start, src_slice_end);
std::optional<Value> acc_vreg;
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
Value result;
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
result =
// Extract a subset of source vregs that reduce into this result vreg.
SmallVector<int64_t> src_slice_start;
src_slice_start.reserve(src_rank);
SmallVector<int64_t> src_slice_end;
src_slice_end.reserve(src_rank);
for (int64_t i : idx) {
src_slice_start.push_back(i);
src_slice_end.push_back(i + 1);
}
for (int64_t d : dims) {
int64_t d_size = src_vregs.dim(d);
src_slice_start.insert(src_slice_start.begin() + d, 0);
if (!src_layout.offsets()[0].has_value() && d == src_rank - 2) {
d_size = 1;
}
if (!src_layout.offsets()[1].has_value() && d == src_rank - 1) {
d_size = 1;
}
src_slice_end.insert(src_slice_end.begin() + d, d_size);
}
xla::Array<Value> reduced_vregs =
src_vregs.Slice(src_slice_start, src_slice_end);
std::optional<Value> acc_vreg;
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
Value result;
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
result =
is_int
? builder.create<arith::AddIOp>(loc, lhs, rhs).getResult()
: builder.create<arith::AddFOp>(loc, lhs, rhs)
.getResult();
break;
case tpu::ReductionKind::MAX:
break;
case tpu::ReductionKind::MAX:
result = is_int ? builder.create<arith::MaxSIOp>(loc, lhs, rhs)
.getResult()
: builder.create<arith::MaximumFOp>(loc, lhs, rhs)
.getResult();
break;
case tpu::ReductionKind::MIN:
: builder.create<arith::MaximumFOp>(loc, lhs, rhs)
.getResult();
break;
case tpu::ReductionKind::MIN:
result = is_int ? builder.create<arith::MinSIOp>(loc, lhs, rhs)
.getResult()
: builder.create<arith::MinimumFOp>(loc, lhs, rhs)
.getResult();
break;
: builder.create<arith::MinimumFOp>(loc, lhs, rhs)
.getResult();
break;
}
return result;
};
auto reduction_status = reduced_vregs.EachStatus(
[&](const absl::Span<const int64_t> red_idx, Value *const src_vreg) {
SmallVector<int64_t> src_idx(red_idx.begin(), red_idx.end());
for (int i = 0; i < src_idx.size(); ++i) {
src_idx[i] += src_slice_start[i];
}
return result;
};
auto reduction_status = reduced_vregs.EachStatus(
[&](const absl::Span<const int64_t> red_idx,
Value *const src_vreg) {
SmallVector<int64_t> src_idx(red_idx.begin(), red_idx.end());
for (int i = 0; i < src_idx.size(); ++i) {
src_idx[i] += src_slice_start[i];
}
const std::unique_ptr<VRegDataBounds> data_bounds =
src_layout.tileDataBounds(builder.getContext(), src_shape,
src_idx, ctx.target_shape,
allow_replicated);
if (data_bounds == nullptr) {
// Op error has already been emitted inside tileDataBounds().
return absl::UnknownError("Unable to obtain data bounds");
}
// TODO(tlongeri): Maybe assemble/disassemble should take
// TypedValue<VectorType> and we could save casts here and
// elsewhere
FailureOr<Value> failure_or_vreg =
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
*data_bounds, neutral);
if (failed(failure_or_vreg)) {
op.emitOpError("Failed to mask vreg");
return absl::UnknownError("");
}
Value vreg = failure_or_vreg.value();
if (!acc_vreg.has_value()) {
acc_vreg = vreg;
} else {
acc_vreg = reduce_elementwise(*acc_vreg, vreg);
}
return absl::OkStatus();
});
if (!reduction_status.ok()) {
return reduction_status;
}
TPU_ASSERT_OP(acc_vreg.has_value());
if (reduces[1]) {
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 1, tpu_kind);
const std::unique_ptr<VRegDataBounds> data_bounds =
src_layout.tileDataBounds(builder.getContext(), src_shape,
src_idx, ctx.target_shape,
{true, true});
if (data_bounds == nullptr) {
// Op error has already been emitted inside tileDataBounds().
return absl::UnknownError("Unable to obtain data bounds");
}
Value vreg = *src_vreg;
// If replicated, we don't need to mask.
if (src_layout.offsets()[0].has_value() ||
src_layout.offsets()[1].has_value()) {
// TODO(tlongeri): Maybe assemble/disassemble should take
// TypedValue<VectorType> and we could save casts here and
// elsewhere
FailureOr<Value> failure_or_vreg =
maskOOB(ctx, builder, cast<TypedValue<VectorType>>(*src_vreg),
*data_bounds, neutral);
if (failed(failure_or_vreg)) {
op.emitOpError("Failed to mask vreg");
return absl::UnknownError("");
}
vreg = failure_or_vreg.value();
}
if (!acc_vreg.has_value()) {
acc_vreg = vreg;
} else {
acc_vreg = reduce_elementwise(*acc_vreg, vreg);
}
return absl::OkStatus();
});
TF_RETURN_IF_ERROR(reduction_status);
TPU_ASSERT_OP(acc_vreg.has_value());
const bool is_double_replicated_double_reduced =
reduces[0] && reduces[1] && !src_layout.offsets()[0].has_value() &&
!src_layout.offsets()[1].has_value();
if (reduces[1]) {
if (src_layout.offsets()[1].has_value()) {
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, /* dim= */ 1, tpu_kind);
} else {
int64_t size_dim1 = src_layout.getImplicitTiledDims(src_shape, 1)[1];
if (is_double_replicated_double_reduced) {
size_dim1 *= src_layout.getImplicitTiledDims(src_shape, 1)[0];
}
if (reduces[0]) {
// Packed types are compressed along rows, so we need to reduce them
// within each 32-bit word. There's no performance penalty for doing
// this in 32-bit precision, so we take advantage of it.
Type acc_vreg_ty = acc_vreg->getType();
if (acc_layout.packing() > 1) {
Type vreg_ty_32 = nullptr;
if (acc.getType().getElementType().isBF16()) {
vreg_ty_32 =
getNativeVregType(builder.getF32Type(), ctx.target_shape);
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
if (is_int) {
IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim1);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getI32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulIOp>(loc, *acc_vreg, source_value);
} else {
multi_reduction_op.emitOpError(
"Not implemented: Unsupported reduction dtype");
return absl::UnknownError("");
FloatAttr size_attr = builder.getF32FloatAttr(size_dim1);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getF32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulFOp>(loc, *acc_vreg, source_value);
}
Value acc_vreg_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved);
for (int i = 1; i < acc_layout.packing(); ++i) {
Value acc_vreg_part_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved);
acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32);
break;
// We don't need to do anything for other reduction kinds.
case tpu::ReductionKind::MAX:
case tpu::ReductionKind::MIN:
break;
}
}
}
if (reduces[0]) {
// Packed types are compressed along rows, so we need to reduce them
// within each 32-bit word. There's no performance penalty for doing
// this in 32-bit precision, so we take advantage of it.
Type acc_vreg_ty = acc_vreg->getType();
if (acc_layout.packing() > 1) {
Type vreg_ty_32 = nullptr;
if (acc.getType().getElementType().isBF16()) {
vreg_ty_32 =
getNativeVregType(builder.getF32Type(), ctx.target_shape);
} else {
multi_reduction_op.emitOpError(
"Not implemented: Unsupported reduction dtype");
return absl::UnknownError("");
}
Value acc_vreg_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved);
for (int i = 1; i < acc_layout.packing(); ++i) {
Value acc_vreg_part_32 = builder.create<tpu::UnpackSubelementsOp>(
loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved);
acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32);
}
acc_vreg = acc_vreg_32;
}
// At this point acc_vreg is always 32-bit.
if (src_layout.offsets()[0].has_value()) {
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
} else if (!is_double_replicated_double_reduced) {
int64_t size_dim0 = src_layout.getImplicitTiledDims(src_shape, 1)[0];
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
if (is_int) {
IntegerAttr size_attr = builder.getI32IntegerAttr(size_dim0);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getI32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulIOp>(loc, *acc_vreg, source_value);
} else {
FloatAttr size_attr = builder.getF32FloatAttr(size_dim0);
TypedValue<VectorType> source_value = getFullVector(
builder,
getNativeVregType(builder.getF32Type(), ctx.target_shape),
size_attr);
acc_vreg =
builder.create<arith::MulFOp>(loc, *acc_vreg, source_value);
}
acc_vreg = acc_vreg_32;
}
// At this point acc_vreg is always 32-bit.
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
// We pack the final result back into the original type.
if (acc_layout.packing() > 1) {
SmallVector<int32_t> positions(acc_layout.packing());
break;
case tpu::ReductionKind::MAX:
case tpu::ReductionKind::MIN:
break;
}
}
// We pack the final result back into the original type.
if (acc_layout.packing() > 1) {
SmallVector<int32_t> positions(acc_layout.packing());
std::iota(positions.begin(), positions.end(),
static_cast<int32_t>(0));
SmallVector<Value> parts(acc_layout.packing(), *acc_vreg);
acc_vreg = builder.create<tpu::PackSubelementsOp>(
SmallVector<Value> parts(acc_layout.packing(), *acc_vreg);
acc_vreg = builder.create<tpu::PackSubelementsOp>(
loc, acc_vreg_ty, parts,
builder.getDenseI32ArrayAttr(positions),
tpu::PackFormat::kInterleaved);
}
}
*dst_vreg = *acc_vreg;
return absl::OkStatus();
});
tpu::PackFormat::kInterleaved);
}
}
*dst_vreg = *acc_vreg;
return absl::OkStatus();
});
if (!all_results_ok.ok()) {
return failure();
}
Expand Down Expand Up @@ -4702,7 +4775,7 @@ LogicalResult tpu_prng_random_bits_rule(RewriteContext &ctx, Operation &op,
const VectorLayout &layout_out = *layouts_out.front();
tpu::PRNGRandomBitsOp rng_op = cast<tpu::PRNGRandomBitsOp>(op);
if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone)) {
VectorLayout::ImplicitDim::kNone)) {
return op.emitOpError(
"Unsupported output layout for ") << rng_op->getName();
}
Expand Down
Loading

0 comments on commit b5179df

Please sign in to comment.