Skip to content

Commit

Permalink
Add spec for SendOp and RecvOp (#580)
Browse files Browse the repository at this point in the history
fixes #527 

verification is revisit on
#579 and
#667
  • Loading branch information
sdasgup3 authored Dec 2, 2022
1 parent 8f4dfc2 commit 8a2ce22
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 20 deletions.
117 changes: 113 additions & 4 deletions docs/spec_draft.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,19 @@ particular process. (From that perspective, unqualified `name` can be viewed as
a shorthand for `name@(replica_id(), partition_id())`).

The execution order across processes is implementation-defined, except for the
synchronization introduced by collective ops as described below.
synchronization introduced by point-to-point communication and collective ops
as described below.

### Point-to-point communication

StableHLO processes can communicate with each other through
**StableHLO channels**. A channel is represented by a positive id of type
`si64`. Through various ops, it is possible to send values to channels and
receive them from channels.

Further formalization, e.g. where these channel ids are coming from, how
processes programs become aware of them and what kind of synchronization is
introduced by them, is TBD.

### Collective ops

Expand All @@ -197,9 +209,8 @@ and what happens if they don't, is TBD.
If the process group involves cross-partition communication, i.e. there are
processes in the process group whose partition ids are different, then execution
of the collective op needs a **StableHLO channel**, and the collective op must
provide a positive `channel_id` of type `si64`. Further formalization, e.g.
where these channel ids are coming from and how they are synchronized between
programs, is TBD. Cross-replica communication doesn't need channels.
provide a positive `channel_id` of type `si64`. Cross-replica communication
doesn't need channels.

The computations performed by the collective ops are specific to individual ops
and are described in individual op sections below. However, the strategies by
Expand Down Expand Up @@ -385,6 +396,7 @@ syntax.
* [popcnt](#stablehlopopcnt)
* [power](#stablehlopower)
* [real](#stablehloreal)
* [recv](#stablehlorecv)
* [reduce](#stablehloreduce)
* [remainder](#stablehloremainder)
* [replica_id](#stablehloreplica_id)
Expand All @@ -397,6 +409,7 @@ syntax.
* [rsqrt](#stablehlorsqrt)
* [scatter](#stablehloscatter)
* [select](#stablehloselect)
* [send](#stablehlosend)
* [shift_left](#stablehloshift_left)
* [shift_right_arithmetic](#stablehloshift_right_arithmetic)
* [shift_right_logical](#stablehloshift_right_logical)
Expand Down Expand Up @@ -3080,6 +3093,55 @@ More formally, for each element `x`: `real(x) = is_complex(x) ? x.real : x`.

[Back to Ops](#index-of-ops)

## stablehlo.recv

### Semantics

Receives data from a channel with `channel_id` and produces `results`.

If `is_host_transfer` is `true`, then the operation transfers data from the
host. Otherwise, it transfers data from another device. What this means is
implementation-defined.

`results` consist of payload values which come first and a token which comes
last. The operation produces a token to reify its side effects as a value that
other operations can take a data dependency on.

### Inputs

| Name | Type |
|--------------------|-------------------------------------------------|
| `token` | `token` |
| `channel_id` | constant of type `si64` |
| `channel_type` | enum of `DEVICE_TO_DEVICE` and `HOST_TO_DEVICE` |
| `is_host_transfer` | constant of type `i1` |

### Outputs

| Name | Type |
|-----------|-----------------------------------------------------------|
| `results` | variadic number of tensors of any supported type or token |

### Constraints
* (C1) [todo](https://github.com/openxla/stablehlo/issues/579) `channel_type` must be
* `HOST_TO_DEVICE`, if `is_host_transfer` $=$ `true`,
* `DEVICE_TO_DEVICE`, otherwise.
* (C2) size(`results`) $\ge$ 1.
* (C3) type(`results`[-1]) $=$ `token`.

### Examples

```mlir
%results:2 = "stablehlo.recv"(%token) {
// channel_id = 5 : i64,
// channel_type = #stablehlo<channel_type HOST_TO_DEVICE>,
channel_handle = #stablehlo.channel_handle<handle = 5, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<3x4xi32>, !stablehlo.token)
```

[Back to Ops](#index-of-ops)

## stablehlo.reduce

### Semantics
Expand Down Expand Up @@ -3731,6 +3793,53 @@ where `pred_val = rank(pred) == 0 ? pred : pred[i0, ..., iR-1]`.

[Back to Ops](#index-of-ops)

## stablehlo.send

### Semantics

Sends `inputs` to a channel `channel_id`.

The operation takes a token and produces a token to reify its side effects
as a value that other operations can take a data dependency on.

If `is_host_transfer` is `true`, then the operation transfers data to the
host. Otherwise, it transfers data to another device. What this means is
implementation-defined.

### Inputs

| Name | Type |
|--------------------|--------------------------------------------------|
| `inputs` | variadic number of tensors of any supported type |
| `token` | `token` |
| `channel_id` | constant of type `si64` |
| `channel_type` | enum of `DEVICE_TO_DEVICE` and `DEVICE_TO_HOST` |
| `is_host_transfer` | constant of type `i1` |

### Outputs

| Name | Type |
|-----------|---------|
| `result` | `token` |

### Constraints
* (C1) [todo](https://github.com/openxla/stablehlo/issues/579) `channel_type` must be
* `DEVICE_TO_HOST`, if `is_host_transfer` $=$ `true`,
* `DEVICE_TO_DEVICE`, otherwise.

### Examples

```mlir
%result = "stablehlo.send"(%operand, %token) {
// channel_id = 5 : i64,
// channel_type = #stablehlo<channel_type DEVICE_TO_HOST>,
channel_handle = #stablehlo.channel_handle<handle = 5, type = 2>,
is_host_transfer = true
} : (tensor<3x4xi32>, !stablehlo.token) -> !stablehlo.token
```

[Back to Ops](#index-of-ops)

## stablehlo.shift_left

### Semantics
Expand Down
4 changes: 2 additions & 2 deletions docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ one of the following tracking labels.
| power | yes | revisit | yes | yes | no |
| real | yes | yes | yes | yes | no |
| real_dynamic_slice | no | revisit | no | yes | no |
| recv | no | revisit | no | no | no |
| recv | yes | revisit | infeasible | no | no |
| reduce | yes | revisit | yes | revisit | no |
| reduce_precision | no | yes* | yes* | yes | no |
| reduce_scatter | no | revisit | no | no | no |
Expand All @@ -134,7 +134,7 @@ one of the following tracking labels.
| scatter | yes | revisit | no | no | no |
| select | yes | yes | yes | yes | no |
| select_and_scatter | no | revisit | no | no | no |
| send | no | revisit | no | no | no |
| send | yes | revisit | yes | no | no |
| set_dimension_size | no | yes* | yes* | yes | no |
| shift_left | yes | revisit | yes | yes | no |
| shift_right_arithmetic | yes | revisit | yes | yes | no |
Expand Down
11 changes: 11 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2850,6 +2850,17 @@ LogicalResult MapOp::reifyReturnTypeShapes(
&reifiedReturnShapes);
}

//===----------------------------------------------------------------------===//
// Send Op
//===----------------------------------------------------------------------===//

LogicalResult SendOp::inferReturnTypes(
MLIRContext* context, Optional<Location>, ValueRange operands,
DictionaryAttr, RegionRange, SmallVectorImpl<Type>& inferredReturnTypes) {
inferredReturnTypes.push_back(operands[operands.size() - 1].getType());
return success();
}

//===----------------------------------------------------------------------===//
// RecvOp
//===----------------------------------------------------------------------===//
Expand Down
42 changes: 28 additions & 14 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1022,18 +1022,26 @@ def StableHLO_OutfeedOp : StableHLO_Op<"outfeed", []> {
let results = (outs HLO_Token);
}

def StableHLO_SendOp : StableHLO_Op<"send", []> {
def StableHLO_SendOp : StableHLO_Op<"send",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {

let summary = "Send operator";

let description = [{
Sends the given operand data to a Recv instruction in another computation
that shares the same channel handle. Does not return any data. Similar to
the Recv operation, Send operation represents synchronous communication,
and is internally decomposed into 2 HLO instructions (Send and SendDone) to
enable asynchronous data transfers.
Sends `inputs` to a channel `channel_id`.

See https://www.tensorflow.org/xla/operation_semantics#send.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec_draft.md#stablehlosend

Example:
```mlir
%result = "stablehlo.send"(%operand, %token) {
// channel_id = 5 : i64,
// channel_type = #stablehlo<channel_type DEVICE_TO_HOST>,
channel_handle = #stablehlo.channel_handle<handle = 5, type = 2>,
is_host_transfer = true
} : (tensor<3x4xi32>, !stablehlo.token) -> !stablehlo.token
```
}];

let arguments = (ins
Expand All @@ -1051,14 +1059,20 @@ def StableHLO_RecvOp : StableHLO_Op<"recv", []> {
let summary = "Recv operator";

let description = [{
Receives data of the given shape from a Send instruction in another
computation that shares the same channel handle. Returns a tuple containing
value for the received data and a token. Recv operation represents
synchronous communication. However, the instruction is internally decomposed
into 2 HLO instructions (Recv and RecvDone) to enable asynchronous data
transfers.
Receives data from a channel with `channel_id` and produces `results`.

See https://www.tensorflow.org/xla/operation_semantics#recv.
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec_draft.md#stablehlorecv

Example:
```mlir
%results:2 = "stablehlo.recv"(%token) {
// channel_id = 5 : i64,
// channel_type = #stablehlo<channel_type HOST_TO_DEVICE>,
channel_handle = #stablehlo.channel_handle<handle = 5, type = 3>,
is_host_transfer = true
} : (!stablehlo.token) -> (tensor<3x4xi32>, !stablehlo.token)
```
}];

let arguments = (ins
Expand Down

0 comments on commit 8a2ce22

Please sign in to comment.