diff --git a/xla/service/layout_assignment.cc b/xla/service/layout_assignment.cc index 18ea43a2284ef..e415e5bc949d1 100644 --- a/xla/service/layout_assignment.cc +++ b/xla/service/layout_assignment.cc @@ -2427,9 +2427,8 @@ absl::Status LayoutAssignment::RunOnComputation( // parameter layout constraints). TF_RETURN_IF_ERROR(AddMandatoryConstraints(channel_constraints, constraints)); - // Add any backend-specific constraints. - TF_RETURN_IF_ERROR(AddBackendConstraints(constraints)); - + // Custom call constraints should be propagated with more priority and + // carefully than mandatory constraints but not more that backend constraints. for (HloInstruction* instruction : constraints->computation()->MakeInstructionPostOrder()) { if (!IsLayoutConstrainedCustomCall(instruction)) { @@ -2457,6 +2456,9 @@ absl::Status LayoutAssignment::RunOnComputation( } } + // Add any backend-specific constraints. + TF_RETURN_IF_ERROR(AddBackendConstraints(constraints)); + // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(constraints));