diff --git a/docs/spec_draft.md b/docs/spec_draft.md index 2bbfc3e768..8cc8c406ec 100644 --- a/docs/spec_draft.md +++ b/docs/spec_draft.md @@ -179,7 +179,19 @@ particular process. (From that perspective, unqualified `name` can be viewed as a shorthand for `name@(replica_id(), partition_id())`). The execution order across processes is implementation-defined, except for the -synchronization introduced by collective ops as described below. +synchronization introduced by point-to-point communication and collective ops +as described below. + +### Point-to-point communication + +StableHLO processes can communicate with each other through +**StableHLO channels**. A channel is represented by a positive id of type +`si64`. Through various ops, it is possible to send values to channels and +receive them from channels. + +Further formalization, e.g. where these channel ids are coming from, how +processes programs become aware of them and what kind of synchronization is +introduced by them, is TBD. ### Collective ops @@ -197,9 +209,8 @@ and what happens if they don't, is TBD. If the process group involves cross-partition communication, i.e. there are processes in the process group whose partition ids are different, then execution of the collective op needs a **StableHLO channel**, and the collective op must -provide a positive `channel_id` of type `si64`. Further formalization, e.g. -where these channel ids are coming from and how they are synchronized between -programs, is TBD. Cross-replica communication doesn't need channels. +provide a positive `channel_id` of type `si64`. Cross-replica communication +doesn't need channels. The computations performed by the collective ops are specific to individual ops and are described in individual op sections below. However, the strategies by @@ -385,6 +396,7 @@ syntax. * [popcnt](#stablehlopopcnt) * [power](#stablehlopower) * [real](#stablehloreal) + * [recv](#stablehlorecv) * [reduce](#stablehloreduce) * [remainder](#stablehloremainder) * [replica_id](#stablehloreplica_id) @@ -397,6 +409,7 @@ syntax. * [rsqrt](#stablehlorsqrt) * [scatter](#stablehloscatter) * [select](#stablehloselect) + * [send](#stablehlosend) * [shift_left](#stablehloshift_left) * [shift_right_arithmetic](#stablehloshift_right_arithmetic) * [shift_right_logical](#stablehloshift_right_logical) @@ -3080,6 +3093,55 @@ More formally, for each element `x`: `real(x) = is_complex(x) ? x.real : x`. [Back to Ops](#index-of-ops) +## stablehlo.recv + +### Semantics + +Receives data from a channel with `channel_id` and produces `results`. + +If `is_host_transfer` is `true`, then the operation transfers data from the +host. Otherwise, it transfers data from another device. What this means is +implementation-defined. + +`results` consist of payload values which come first and a token which comes +last. The operation produces a token to reify its side effects as a value that +other operations can take a data dependency on. + +### Inputs + +| Name | Type | +|--------------------|-------------------------------------------------| +| `token` | `token` | +| `channel_id` | constant of type `si64` | +| `channel_type` | enum of `DEVICE_TO_DEVICE` and `HOST_TO_DEVICE` | +| `is_host_transfer` | constant of type `i1` | + +### Outputs + +| Name | Type | +|-----------|-----------------------------------------------------------| +| `results` | variadic number of tensors of any supported type or token | + +### Constraints + * (C1) [todo](https://github.com/openxla/stablehlo/issues/579) `channel_type` must be + * `HOST_TO_DEVICE`, if `is_host_transfer` $=$ `true`, + * `DEVICE_TO_DEVICE`, otherwise. + * (C2) size(`results`) $\ge$ 1. + * (C3) type(`results`[-1]) $=$ `token`. + +### Examples + +```mlir +%results:2 = "stablehlo.recv"(%token) { + // channel_id = 5 : i64, + // channel_type = #stablehlo, + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true +} : (!stablehlo.token) -> (tensor<3x4xi32>, !stablehlo.token) +``` + +[Back to Ops](#index-of-ops) + ## stablehlo.reduce ### Semantics @@ -3731,6 +3793,53 @@ where `pred_val = rank(pred) == 0 ? pred : pred[i0, ..., iR-1]`. [Back to Ops](#index-of-ops) +## stablehlo.send + +### Semantics + +Sends `inputs` to a channel `channel_id`. + +The operation takes a token and produces a token to reify its side effects +as a value that other operations can take a data dependency on. + +If `is_host_transfer` is `true`, then the operation transfers data to the +host. Otherwise, it transfers data to another device. What this means is +implementation-defined. + +### Inputs + +| Name | Type | +|--------------------|--------------------------------------------------| +| `inputs` | variadic number of tensors of any supported type | +| `token` | `token` | +| `channel_id` | constant of type `si64` | +| `channel_type` | enum of `DEVICE_TO_DEVICE` and `DEVICE_TO_HOST` | +| `is_host_transfer` | constant of type `i1` | + +### Outputs + +| Name | Type | +|-----------|---------| +| `result` | `token` | + +### Constraints + * (C1) [todo](https://github.com/openxla/stablehlo/issues/579) `channel_type` must be + * `DEVICE_TO_HOST`, if `is_host_transfer` $=$ `true`, + * `DEVICE_TO_DEVICE`, otherwise. + +### Examples + +```mlir +%result = "stablehlo.send"(%operand, %token) { + // channel_id = 5 : i64, + // channel_type = #stablehlo, + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true +} : (tensor<3x4xi32>, !stablehlo.token) -> !stablehlo.token +``` + +[Back to Ops](#index-of-ops) + ## stablehlo.shift_left ### Semantics diff --git a/docs/status.md b/docs/status.md index 06a734ebfb..2c6da2bc82 100644 --- a/docs/status.md +++ b/docs/status.md @@ -116,7 +116,7 @@ one of the following tracking labels. | power | yes | revisit | yes | yes | no | | real | yes | yes | yes | yes | no | | real_dynamic_slice | no | revisit | no | yes | no | -| recv | no | revisit | no | no | no | +| recv | yes | revisit | infeasible | no | no | | reduce | yes | revisit | yes | revisit | no | | reduce_precision | no | yes* | yes* | yes | no | | reduce_scatter | no | revisit | no | no | no | @@ -134,7 +134,7 @@ one of the following tracking labels. | scatter | yes | revisit | no | no | no | | select | yes | yes | yes | yes | no | | select_and_scatter | no | revisit | no | no | no | -| send | no | revisit | no | no | no | +| send | yes | revisit | yes | no | no | | set_dimension_size | no | yes* | yes* | yes | no | | shift_left | yes | revisit | yes | yes | no | | shift_right_arithmetic | yes | revisit | yes | yes | no | diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 90db065bc2..74d4f567fc 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -2850,6 +2850,17 @@ LogicalResult MapOp::reifyReturnTypeShapes( &reifiedReturnShapes); } +//===----------------------------------------------------------------------===// +// Send Op +//===----------------------------------------------------------------------===// + +LogicalResult SendOp::inferReturnTypes( + MLIRContext* context, Optional, ValueRange operands, + DictionaryAttr, RegionRange, SmallVectorImpl& inferredReturnTypes) { + inferredReturnTypes.push_back(operands[operands.size() - 1].getType()); + return success(); +} + //===----------------------------------------------------------------------===// // RecvOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 17a2acd23f..bb13abd9f8 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1022,18 +1022,26 @@ def StableHLO_OutfeedOp : StableHLO_Op<"outfeed", []> { let results = (outs HLO_Token); } -def StableHLO_SendOp : StableHLO_Op<"send", []> { +def StableHLO_SendOp : StableHLO_Op<"send", + [DeclareOpInterfaceMethods]> { let summary = "Send operator"; let description = [{ - Sends the given operand data to a Recv instruction in another computation - that shares the same channel handle. Does not return any data. Similar to - the Recv operation, Send operation represents synchronous communication, - and is internally decomposed into 2 HLO instructions (Send and SendDone) to - enable asynchronous data transfers. + Sends `inputs` to a channel `channel_id`. - See https://www.tensorflow.org/xla/operation_semantics#send. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec_draft.md#stablehlosend + + Example: + ```mlir + %result = "stablehlo.send"(%operand, %token) { + // channel_id = 5 : i64, + // channel_type = #stablehlo, + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor<3x4xi32>, !stablehlo.token) -> !stablehlo.token + ``` }]; let arguments = (ins @@ -1051,14 +1059,20 @@ def StableHLO_RecvOp : StableHLO_Op<"recv", []> { let summary = "Recv operator"; let description = [{ - Receives data of the given shape from a Send instruction in another - computation that shares the same channel handle. Returns a tuple containing - value for the received data and a token. Recv operation represents - synchronous communication. However, the instruction is internally decomposed - into 2 HLO instructions (Recv and RecvDone) to enable asynchronous data - transfers. + Receives data from a channel with `channel_id` and produces `results`. - See https://www.tensorflow.org/xla/operation_semantics#recv. + See: + https://github.com/openxla/stablehlo/blob/main/docs/spec_draft.md#stablehlorecv + + Example: + ```mlir + %results:2 = "stablehlo.recv"(%token) { + // channel_id = 5 : i64, + // channel_type = #stablehlo, + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor<3x4xi32>, !stablehlo.token) + ``` }]; let arguments = (ins