Shardy Doc Diff
49 removals
864 lines
50 additions
846 lines
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
# 'sdy' Dialect
# 'sdy' Dialect
The Shardy (SDY) dialect defines an axis-based tensor sharding
_The Shardy (SDY) dialect defines an axis-based tensor sharding
representation and additional API components to attach shardings to tensors.
representation and additional API components to attach shardings to tensors._
[TOC]
[TOC]
## Operations
## Operations
### `sdy.all_gather` (sdy::AllGatherOp)
### `sdy.all_gather` (sdy::AllGatherOp)
_Gathers chunks of a tensor along axes_
_Gathers chunks of a tensor along axes_
Syntax:
Syntax:
```
```
operation ::= `sdy.all_gather` $gatheringAxes $tensor `out_sharding````=```$outSharding attr-dict `:` type($result)
operation ::= `sdy.all_gather` $gatheringAxes $tensor `out_sharding````=```$outSharding attr-dict `:` type($result)
```
```
Gathers chunks of a tensor along axes specified in `gatheringAxes`.
Gathers chunks of a tensor along axes specified in `gatheringAxes`.
The `gatheringAxes` is a list of lists of axes. Each inner list specifies
The `gatheringAxes` is a list of lists of axes. Each inner list specifies
the axes along which a separate gather should be performed. The outer list
the axes along which a separate gather should be performed. The outer list
is over the dimensions of the tensor. It will be applied to the sharding of
is over the dimensions of the tensor. It will be applied to the sharding of
the operand (`tensor`) to obtain the sharding of the result (`outSharding`).
the operand (`tensor`) to obtain the sharding of the result (`outSharding`).
Note that `outSharding` is not used to determine the sharding of the result.
Note that `outSharding` is not used to determine the sharding of the result.
Instead, the sharding of the result is determined by the sharding of the
Instead, the sharding of the result is determined by the sharding of the
operand and the `gatheringAxes`, and `outSharding` must match this inferred
operand and the `gatheringAxes`, and `outSharding` must match this inferred
sharding.
sharding.
Example:
Example:
```mlir
```mlir
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8xf32>
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", "b", "c"}, {}, {"d"}\]>]>} : tensor<8x8xf32>
%2 = sdy.all_gather gathering_axes=[{"b", "c"}, {}, {"d"}\] %1 to_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8xf32>
%2 = sdy.all_gather gathering_axes=[{"b", "c"}, {}, {"d"}\] %1 to_sharding=<@mesh, [{"a"}, {}, {}\]> : tensor<8x8xf32>
```
```
Traits: `SameOperandsAndResultType`
Traits: `SameOperandsAndResultType`
Interfaces: `InferTypeOpInterface`
Interfaces: `InferTypeOpInterface`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>gatheringAxes</code></td><td>::mlir::sdy::ListOfAxisRefListsAttr</td><td></td></tr>
<tr><td><code>gatheringAxes</code></td><td>::mlir::sdy::ListOfAxisRefListsAttr</td><td></td></tr>
<tr><td><code>outSharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
<tr><td><code>outSharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `tensor` | tensor of any type values
| `tensor` | tensor of any type values |
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
| `result` | tensor of any type values
| `result` | tensor of any type values |
### `sdy.constant` (sdy::ConstantOp)
### `sdy.constant` (sdy::ConstantOp)
_Constant operation_
_Constant operation_
Produces an `output` tensor from a constant `value`.
Produces an `output` tensor from a constant `value`.
See:
See:
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#constant
NOTE: SDY defines its own constant op that isn't ConstantLike and doesn't
NOTE: SDY defines its own constant op that isn't ConstantLike and doesn't
have a folder, so that we'll be able to duplicate constants without any
have a folder, so that we'll be able to duplicate constants without any
greedy pattern rewriter folding them back into a single constant. In this
greedy pattern rewriter folding them back into a single constant. In this
way, constants can be sharded differently for every use, and no propagation
way, constants can be sharded differently for every use, and no propagation
is done between constants (or constant expressions).
is done between constants (or constant expressions).
Example:
Example:
```mlir
```mlir
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
%output = sdy.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
```
```
Traits: `AlwaysSpeculatableImplTrait`
Traits: `AlwaysSpeculatableImplTrait`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
Effects: `MemoryEffects::Effect{}`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>value</code></td><td>::mlir::ElementsAttr</td><td>constant vector/tensor attribute</td></tr>
<tr><td><code>value</code></td><td>::mlir::ElementsAttr</td><td>constant vector/tensor attribute</td></tr>
</table>
</table>
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
| `output` | tensor of any type values
| `output` | tensor of any type values |
### `sdy.data_flow_edge` (sdy::DataFlowEdgeOp)
### `sdy.data_flow_edge` (sdy::DataFlowEdgeOp)
_Data flow edge op._
_Data flow edge op._
Syntax:
Syntax:
```
```
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
operation ::= `sdy.data_flow_edge` $input (`sharding````=``` $sharding^)? attr-dict `:` type($result)
```
```
A data flow edge of some op X defines a bridge between a set of sources
A data flow edge of some op X defines a bridge between a set of sources
(each is either an operand of X or an operand of X's block terminator) and
(each is either an operand of X or an operand of X's block terminator) and
a set of targets (each is either a result of X or a block argument of X),
a set of targets (each is either a result of X or a block argument of X),
such that all sources and targets should be sharded in the same way.
such that all sources and targets should be sharded in the same way.
An op can have multiple data flow edges that are orthogonal to one another.
An op can have multiple data flow edges that are orthogonal to one another.
For example:
For example:
```mlir
```mlir
y_0, ..., y_n = while (x_0, ..., x_n)
y_0, ..., y_n = while (x_0, ..., x_n)
((pred_arg_0,... , pred_arg_n) { ... })
((pred_arg_0,... , pred_arg_n) { ... })
((body_arg_0,..., body_arg_n) {
((body_arg_0,..., body_arg_n) {
...
...
return return_value_0, ..., return_value_n
return return_value_0, ..., return_value_n
})
})
```
```
This while op has n data flow edges, the i-th data flow edges is between
This while op has n data flow edges, the i-th data flow edges is between
sources `x_i`, `return_value_i` and targets `y_i`, `pred_arg_i`,
sources `x_i`, `return_value_i` and targets `y_i`, `pred_arg_i`,
`body_arg_i`.
`body_arg_i`.
An `sdy.data_flow_edge` takes as input the root target of an edge (can be
An `sdy.data_flow_edge` takes as input the root target of an edge (can be
any of the targets, but preferably an op result rather than a block
any of the targets, but preferably an op result rather than a block
argument), which shouldn't have any other uses. This op isn't pure because
argument), which shouldn't have any other uses. This op isn't pure because
it can take an input that originally didn't have any uses.
it can take an input that originally didn't have any uses.
The `sdy.data_flow_edge` also holds an optional sharding for all targets of
The `sdy.data_flow_edge` also holds an optional sharding for all targets of
the edge, and that sharding should be updated instead of the targets'
the edge, and that sharding should be updated instead of the targets'
sharding (if can be attached) during propagation. This is useful when an op
sharding (if can be attached) during propagation. This is useful when an op
has many edges, as it's much more efficient to:
has many edges, as it's much more efficient to:
- propagate through each edge separately.
- propagate through each edge separately.
- update the sharding of each edge separately instead of all targets at once
- update the sharding of each edge separately instead of all targets at once
(e.g. an op has a single immutable `TensorShardingPerValueAttr` for result
(e.g. an op has a single immutable `TensorShardingPerValueAttr` for result
shardings).
shardings).
- add each edge to the worklist separately when the sharding of a source has
- add each edge to the worklist separately when the sharding of a source has
changed.
changed.
Propagation will propagate shardings between all sources and targets of a
Propagation will propagate shardings between all sources and targets of a
`sdy.data_flow_edge` as if it was a regular op with the sources as operands
`sdy.data_flow_edge` as if it was a regular op with the sources as operands
and targets as results, and an identity `sdy.op_sharding_rule`. That means
and targets as results, and an identity `sdy.op_sharding_rule`. That means
that forward propagation is from sources to targets and backwards
that forward propagation is from sources to targets and backwards
propagation is from targets to sources.
propagation is from targets to sources.
We don't allow the input of a `sdy.data_flow_edge` to be defined by an
We don't allow the input of a `sdy.data_flow_edge` to be defined by an
`SdyDialect` op, so we can assume that it's defined by an op that has
`SdyDialect` op, so we can assume that it's defined by an op that has
unregistered `sdy.sharding` attribute.
unregistered `sdy.sharding` attribute.
NOTE: it's NOT the responsibility of the `sdy.data_flow_edge` to link
NOTE: it's NOT the responsibility of the `sdy.data_flow_edge` to link
between sources and targets, it's simply attached to the root target of the
between sources and targets, it's simply attached to the root target of the
edge. The op that this edge is bound to (while in the example above) is
edge. The op that this edge is bound to (while in the example above) is
responsible for providing this information.
responsible for providing this information.
Traits: `SameOperandsAndResultType`
Traits: `SameOperandsAndResultType`
Interfaces: `InferTypeOpInterface`
Interfaces: `InferTypeOpInterface`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>sharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
<tr><td><code>sharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `input` | shaped of any type values
| `input` | shaped of any type values |
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
| `result` | shaped of any type values
| `result` | shaped of any type values |
### `sdy.manual_computation` (sdy::ManualComputationOp)
### `sdy.manual_computation` (sdy::ManualComputationOp)
_Multi-device parallelism operation with manual collectives_
_Multi-device parallelism operation with manual collectives_
Syntax:
Syntax:
```
```
operation ::= `sdy.manual_computation` `(`operands`)`
operation ::= `sdy.manual_computation` `(`operands`)`
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)
`manual_axes````=```$manual_axes
`manual_axes````=```$manual_axes
custom<SingleBlockRegionNoBlockId>($body)
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
attr-dict
`:`
`:`
functional-type(operands, results)
functional-type(operands, results)
```
```
Jump into a region written in terms of per-device local code with explicit
Jump into a region written in terms of per-device local code with explicit
collectives, where logical shapes match local per-device physical buffer
collectives, where logical shapes match local per-device physical buffer
shapes and collectives correspond exactly to physical cross-device
shapes and collectives correspond exactly to physical cross-device
communication.
communication.
The body is local wrt the manual_axes. Propagation will occur through
The body is local wrt the manual_axes. Propagation will occur through
the body on any free axes - those not in the manual_axes list.
the body on any free axes - those not in the manual_axes list.
Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`
Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`
Interfaces: `ShardableDataFlowOpInterface`
Interfaces: `ShardableDataFlowOpInterface`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>in_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>in_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>out_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>out_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>manual_axes</code></td><td>::mlir::sdy::ManualAxesAttr</td><td></td></tr>
<tr><td><code>manual_axes</code></td><td>::mlir::sdy::ManualAxesAttr</td><td></td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `tensors` | variadic of ranked tensor of any type values
| `tensors` | variadic of ranked tensor of any type values |
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
| `results` | variadic of ranked tensor of any type values
| `results` | variadic of ranked tensor of any type values |
### `sdy.mesh` (sdy::MeshOp)
### `sdy.mesh` (sdy::MeshOp)
_Named mesh_
_Named mesh_
Syntax:
Syntax:
```
```
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
operation ::= `sdy.mesh` $sym_name `=` $mesh attr-dict
```
```
Defines a new named mesh. All meshes in a module must have the same number
Defines a new named mesh. All meshes in a module must have the same number
of devices (except for meshes with a single device_id).
of devices (except for meshes with a single device_id).
The mesh is a `Symbol` operation that appears in the module's
The mesh is a `Symbol` operation that appears in the module's
`SymbolTable` and can be referenced by its `name`.
`SymbolTable` and can be referenced by its `name`.
Traits: `HasParent<ModuleOp>`
Traits: `HasParent<ModuleOp>`
Interfaces: `Symbol`
Interfaces: `Symbol`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>sym_name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>sym_name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>mesh</code></td><td>::mlir::sdy::MeshAttr</td><td>Mesh of axes and a list of devices</td></tr>
<tr><td><code>mesh</code></td><td>::mlir::sdy::MeshAttr</td><td>Mesh of axes and a list of devices</td></tr>
</table>
</table>
### `sdy.named_computation` (sdy::NamedComputationOp)
### `sdy.named_computation` (sdy::NamedComputationOp)
_Named computation operation_
_Named computation operation_
Syntax:
Syntax:
```
```
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`in_shardings````=```custom<StrippedTensorShardingPerValueAttr>($in_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
(`out_shardings````=```custom<StrippedTensorShardingPerValueAttr>($out_shardings)^)?
custom<SingleBlockRegionNoBlockId>($body)
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
attr-dict
`:` functional-type($operands, results)
`:` functional-type($operands, results)
```
```
Groups a computation, i.e. a block of operations, and gives it a name.
Groups a computation, i.e. a block of operations, and gives it a name.
Propagation will flow in/out of the region as if everything was inlined.
Propagation will flow in/out of the region as if everything was inlined.
This can be used to handle propagating through call instructions to other
This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
functions. Any users of Shardy should write an import/export pass that
converts their call ops to `sdy.named_computation` ops, duplicating/copying
converts their call ops to `sdy.named_computation` ops, duplicating/copying
the body of the called function into the body of the `named_computation`.
the body of the called function into the body of the `named_computation`.
The type of each block arguments and returned values in the region must be
The type of each block arguments and returned values in the region must be
the same as the type of the operands and results type of the op.
the same as the type of the operands and results type of the op.
Example:
Example:
```mlir
```mlir
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
```
```
Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `RecursivelySpeculatableImplTrait`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`
Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `RecursivelySpeculatableImplTrait`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `ShardableDataFlowOpInterface`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `ShardableDataFlowOpInterface`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>in_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>in_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>out_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
<tr><td><code>out_shardings</code></td><td>::mlir::sdy::TensorShardingPerValueAttr</td><td>Tensor sharding per operand/result of an op</td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `operands` | variadic of any type
| `operands` | variadic of any type |
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
«unnamed» | variadic of any type
| «unnamed» | variadic of any type |
### `sdy.propagation_barrier` (sdy::PropagationBarrierOp)
### `sdy.propagation_barrier` (sdy::PropagationBarrierOp)
_Propagation barrier operation_
_Propagation barrier operation_
Syntax:
Syntax:
```
```
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
operation ::= `sdy.propagation_barrier` $input `allowed_direction````=```$allowed_direction attr-dict `:` type($input)
```
```
This op operates like an identity op, outputting the same value it took as
This op operates like an identity op, outputting the same value it took as
input. But in terms of propagation, this will only allow propagation to flow
input. But in terms of propagation, this will only allow propagation to flow
through it in a certain direction.
through it in a certain direction.
This prevents shardings from being propagated between the uses of the result
This prevents shardings from being propagated between the uses of the result
of the barrier op and its operand.
of the barrier op and its operand.
- `FORWARD` means shardings can only flow from the operand to the result.
- `FORWARD` means shardings can only flow from the operand to the result.
- `BACKWARD` means shardings can only flow from the result to the operand.
- `BACKWARD` means shardings can only flow from the result to the operand.
- `NONE` means no sharding can propagate through this op.
- `NONE` means no sharding can propagate through this op.
- Cannot specify `BOTH`, as this op would be redundant.
- Cannot specify `BOTH`, as this op would be redundant.
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultType`
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultType`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
Effects: `MemoryEffects::Effect{}`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>allowed_direction</code></td><td>::mlir::sdy::PropagationDirectionAttr</td><td>propagation direction enum</td></tr>
<tr><td><code>allowed_direction</code></td><td>::mlir::sdy::PropagationDirectionAttr</td><td>propagation direction enum</td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `input` | ranked tensor of any type values
| `input` | ranked tensor of any type values |
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
| `result` | ranked tensor of any type values
| `result` | ranked tensor of any type values |
### `sdy.reshard` (sdy::ReshardOp)
### `sdy.reshard` (sdy::ReshardOp)
_Reshards a tensor to a different sharding_
_Reshards a tensor to a different sharding_
Syntax:
Syntax:
```
```
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
operation ::= `sdy.reshard` $input $sharding attr-dict `:` type($result)
```
```
Reshards the input tensor with the specified sharding, which is different
Reshards the input tensor with the specified sharding, which is different
from the input tensor's existing sharding.
from the input tensor's existing sharding.
Both ShardingConstraintOp and ReshardOp attach a sharding to a tensor. Their
Both ShardingConstraintOp and ReshardOp attach a sharding to a tensor. Their
lifespan is:
lifespan is:
1. Before sharding propagation, ShardingConstraintOp is added by users.
1. Before sharding propagation, ShardingConstraintOp is added by users.
2. Sharding propagation consumes ShardingConstraintOp. There is no
2. Sharding propagation consumes ShardingConstraintOp. There is no
ShardingConstraintOp in the results of sharding propagation. Instead,
ShardingConstraintOp in the results of sharding propagation. Instead,
ReshardOp may be added if needed.
ReshardOp may be added if needed.
3. A partitioner converts a ReshardOp into a collective op (or an identity
3. A partitioner converts a ReshardOp into a collective op (or an identity
op). There should be no ReshardOp in the results of the partitioner.
op). There should be no ReshardOp in the results of the partitioner.
// TODO(b/331680067). Add a canonicalization pattern to remove redundant
// TODO(b/331680067). Add a canonicalization pattern to remove redundant
// reshard ops.
// reshard ops.
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultType`
Traits: `AlwaysSpeculatableImplTrait`, `Elementwise`, `SameOperandsAndResultType`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `InferTypeOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
Effects: `MemoryEffects::Effect{}`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>sharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
<tr><td><code>sharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `input` | tensor of any type values
| `input` | tensor of any type values |
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
| `result` | tensor of any type values
| `result` | tensor of any type values |
### `sdy.return` (sdy::ReturnOp)
### `sdy.return` (sdy::ReturnOp)
_The `sdy.return` operation terminates the regions attached to
_The `sdy.return` operation terminates the regions attached to
`sdy` region-based ops and any other Shardy region-based ops. It is
`sdy` region-based ops and any other Shardy region-based ops. It is
variadic: it takes as arguments a list of values whose types can be any (but
variadic: it takes as arguments a list of values whose types can be any (but
of the same kind, e.g. `AnyTensor`) and therefore can be reused at various
of the same kind, e.g. `AnyTensor`) and therefore can be reused at various
levels of the Shardy IR stack._
levels of the Shardy IR stack._
Syntax:
Syntax:
```
```
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
operation ::= `sdy.return` attr-dict ($results^ `:` type($results))?
```
```
Traits: `AlwaysSpeculatableImplTrait`, `Terminator`
Traits: `AlwaysSpeculatableImplTrait`, `Terminator`
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`
Effects: `MemoryEffects::Effect{}`
Effects: `MemoryEffects::Effect{}`
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `results` | variadic of any type
| `results` | variadic of any type |
### `sdy.sharding_constraint` (sdy::ShardingConstraintOp)
### `sdy.sharding_constraint` (sdy::ShardingConstraintOp)
_Constrains a tensor to the specified sharding_
_Constrains a tensor to the specified sharding_
Syntax:
Syntax:
```
```
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
operation ::= `sdy.sharding_constraint` $input $sharding attr-dict `:` type($result)
```
```
Attaches a sharding to an intermediate tensor (e.g. the result of a matmul)
Attaches a sharding to an intermediate tensor (e.g. the result of a matmul)
to indicate that this is how that tensor, or a subset of its uses, should be
to indicate that this is how that tensor, or a subset of its uses, should be
sharded.
sharded.
If the sharding has open dimensions and unconstraint axes, it means the
If the sharding has open dimensions and unconstraint axes, it means the
tensor can be further sharded along the open dimensions.
tensor can be further sharded along the open dimensions.
This op can either:
This op can either:
- Have no uses (dangling) - which means the attached sharding is how the
- Have no uses (dangling) - which means the attached sharding is how the
input tensor itself should be sharded.
input tensor itself should be sharded.
- Have uses - which means the attached sharding is how the uses of the
- Have uses - which means the attached sharding is how the uses of the
sharding constraint op should be sharded, while other uses of the input
sharding constraint op should be sharded, while other uses of the input
tensor might have a different sharding (if the input tensor has no other
tensor might have a different sharding (if the input tensor has no other
uses then the behavior is the same as the no uses case).
uses then the behavior is the same as the no uses case).
Traits: `Elementwise`, `SameOperandsAndResultType`
Traits: `Elementwise`, `SameOperandsAndResultType`
Interfaces: `InferTypeOpInterface`
Interfaces: `InferTypeOpInterface`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>sharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
<tr><td><code>sharding</code></td><td>::mlir::sdy::TensorShardingAttr</td><td>Tensor sharding</td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `input` | tensor of any type values
| `input` | tensor of any type values |
#### Results:
#### Results:
| Result | Description |
| Result | Description |
| :----: | ----------- |
| :----: | ----------- |
| `result` | tensor of any type values
| `result` | tensor of any type values |
### `sdy.sharding_group` (sdy::ShardingGroupOp)
### `sdy.sharding_group` (sdy::ShardingGroupOp)
_Constrains tensors in the group to have the same sharding._
_Constrains tensors in the group to have the same sharding._
Syntax:
Syntax:
```
```
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
operation ::= `sdy.sharding_group` $input `group_id````=```$group_id attr-dict `:` type($input)
```
```
This op provides an interface to assign tensors to sharding groups (
This op provides an interface to assign tensors to sharding groups (
groups of tensors that will be enforced to have identical shardings).
groups of tensors that will be enforced to have identical shardings).
During propagation, as soon as one group element is sharded, all other
During propagation, as soon as one group element is sharded, all other
members will be sharded in exactly the same way. This operation takes the
members will be sharded in exactly the same way. This operation takes the
argument group ID and returns no result, but instead modifies the internal
argument group ID and returns no result, but instead modifies the internal
sharding group representation to add the input tensor to the group with the
sharding group representation to add the input tensor to the group with the
given ID.
given ID.
Interfaces: `InferTypeOpInterface`
Interfaces: `InferTypeOpInterface`
#### Attributes:
#### Attributes:
<table>
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>group_id</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
<tr><td><code>group_id</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
</table>
</table>
#### Operands:
#### Operands:
| Operand | Description |
| Operand | Description |
| :-----: | ----------- |
| :-----: | ----------- |
| `input` | ranked tensor of any type values
| `input` | ranked tensor of any type values |
## Attributes
## Attributes
### AxisRefAttr
### AxisRefAttr
Reference to either a full axis or a split sub-axis
_Reference to either a full axis or a split sub-axis_
Syntax:
Syntax:
```
```
#sdy.axis_ref<
#sdy.axis_ref<
::llvm::StringRef, # name
::llvm::StringRef, # name
SubAxisInfoAttr # sub_axis_info
SubAxisInfoAttr # sub_axis_info
>
>
```
```
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| name | `::llvm::StringRef` | name |
| name | `::llvm::StringRef` | name |
| sub_axis_info | `SubAxisInfoAttr` | |
| sub_axis_info | `SubAxisInfoAttr` | |
### AxisRefListAttr
### AxisRefListAttr
Syntax:
Syntax:
```
```
#sdy.axis_ref_list<
#sdy.axis_ref_list<
::llvm::ArrayRef<AxisRefAttr> # value
::llvm::ArrayRef<AxisRefAttr> # value
>
>
```
```
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| value | `::llvm::ArrayRef<AxisRefAttr>` | |
| value | `::llvm::ArrayRef<AxisRefAttr>` | |
### DimMappingAttr
### DimMappingAttr
List of factor indices for a dimension
_List of factor indices for a dimension_
All factor indices must be in the range [0, num_factors) and an empty list
All factor indices must be in the range [0, num_factors) and an empty list
indicates that this is a null mapping (this is parsed/printed with `*`),
indicates that this is a null mapping (this is parsed/printed with `*`),
i.e. the dimension isn't mapped to any factors.
i.e. the dimension isn't mapped to any factors.
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| factor_indices | `::llvm::ArrayRef<int64_t>` | |
| factor_indices | `::llvm::ArrayRef<int64_t>` | |
### DimensionShardingAttr
### DimensionShardingAttr
Dimension sharding
_Dimension sharding_
List of axis names to shard a tensor dimension on from major to minor, a
List of axis names to shard a tensor dimension on from major to minor, a
boolean indicating whether the dimension can be further sharded, and an
boolean indicating whether the dimension can be further sharded, and an
optional integer denoting the priority of this dimension sharding, which
optional integer denoting the priority of this dimension sharding, which
will respected during sharding propagation. Priorities originate from user
will respected during sharding propagation. Priorities originate from user
sharding annotations and a lower value denotes a higher priority. The
sharding annotations and a lower value denotes a higher priority. The
highest priority is assumed when the priority is missing in the annotation.
highest priority is assumed when the priority is missing in the annotation.
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| axes | `::llvm::ArrayRef<AxisRefAttr>` | list of axis refs |
| axes | `::llvm::ArrayRef<AxisRefAttr>` | list of axis refs |
| is_closed | `bool` | |
| is_closed | `bool` | |
| priority | `std::optional<int64_t>` | |
| priority | `std::optional<int64_t>` | |
### ListOfAxisRefListsAttr
### ListOfAxisRefListsAttr
Syntax:
Syntax:
```
```
#sdy.list_of_axis_ref_lists<
#sdy.list_of_axis_ref_lists<
::llvm::ArrayRef<AxisRefListAttr> # value
::llvm::ArrayRef<AxisRefListAttr> # value
>
>
```
```
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| value | `::llvm::ArrayRef<AxisRefListAttr>` | |
| value | `::llvm::ArrayRef<AxisRefListAttr>` | |
### ManualAxesAttr
### ManualAxesAttr
Syntax:
Syntax:
```
```
#sdy.manual_axes<
#sdy.manual_axes<
::llvm::ArrayRef<StringAttr> # value
::llvm::ArrayRef<StringAttr> # value
>
>
```
```
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| value | `::llvm::ArrayRef<StringAttr>` | |
| value | `::llvm::ArrayRef<StringAttr>` | |
### MeshAttr
### MeshAttr
Mesh of axes and a list of devices
_Mesh of axes and a list of devices_
Syntax:
Syntax:
```
```
#sdy.mesh<
#sdy.mesh<
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<MeshAxisAttr>, # axes
::llvm::ArrayRef<int64_t> # device_ids
::llvm::ArrayRef<int64_t> # device_ids
>
>
```
```
A mesh is a list of axes and an optional list of device IDs specifying the
A mesh is a list of axes and an optional list of device IDs specifying the
device ordering.
device ordering.
If the list of axes is empty, the mesh has an implicit unnamed axis of
If the list of axes is empty, the mesh has an implicit unnamed axis of
size 1. In this case, if a device ID list is not provided, the implicit
size 1. In this case, if a device ID list is not provided, the implicit
device ID list is [0]; if a device ID list is provided, it must
device ID list is [0]; if a device ID list is provided, it must
contains a single integer of any non-negative value. We call this
contains a single integer of any non-negative value. We call this
maximal-sharding case.
maximal-sharding case.
For all non-maximal-sharding cases, if a device ID list is specified, the
For all non-maximal-sharding cases, if a device ID list is specified, the
product of the axis sizes should match the number of devices. If a device ID
product of the axis sizes should match the number of devices. If a device ID
list is not specified, the implicit device ID list is iota(product(axes)).
list is not specified, the implicit device ID list is iota(product(axes)).
For simplicity, we also disallow specifying a device ID list that is the
For simplicity, we also disallow specifying a device ID list that is the
same as iota(product(axes)); in this case, a device ID list shouldn't be
same as iota(product(axes)); in this case, a device ID list shouldn't be
specified.
specified.
Here are some examples of meshes:
Here are some examples of meshes:
- An empty mesh represents a placeholder mesh that can be replaced during
- An empty mesh represents a placeholder mesh that can be replaced during
propagation: <[]>
propagation: <[]>
- A mesh with an unnamed axis and an explicit device ID, which is typically
- A mesh with an unnamed axis and an explicit device ID, which is typically
used to represent maximal sharding: <[], device_ids=[3]>
used to represent maximal sharding: <[], device_ids=[3]>
- A mesh with two axes and implicit device IDs iota(6): <["a"=2, "b"=3]>
- A mesh with two axes and implicit device IDs iota(6): <["a"=2, "b"=3]>
- A mesh with two axes and explicit device IDs specifying the device
- A mesh with two axes and explicit device IDs specifying the device
ordering: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
ordering: <["a"=3, "b"=2], device_ids=[0, 2, 4, 1, 3, 5]>
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| axes | `::llvm::ArrayRef<MeshAxisAttr>` | |
| axes | `::llvm::ArrayRef<MeshAxisAttr>` | |
| device_ids | `::llvm::ArrayRef<int64_t>` | |
| device_ids | `::llvm::ArrayRef<int64_t>` | |
### MeshAxisAttr
### MeshAxisAttr
Named axis in a mesh
_Named axis in a mesh_
Syntax:
Syntax:
```
```
#sdy.mesh_axis<
#sdy.mesh_axis<
::llvm::StringRef, # name
::llvm::StringRef, # name
int64_t # size
int64_t # size
>
>
```
```
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| name | `::llvm::StringRef` | name |
| name | `::llvm::StringRef` | name |
| size | `int64_t` | |
| size | `int64_t` | |
### OpShardingRuleAttr
### OpShardingRuleAttr
Specifies how an operation can be partitioned.
_Specifies how an operation can be partitioned._
Syntax:
Syntax:
```
```
#sdy.op_sharding_rule<
#sdy.op_sharding_rule<
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<int64_t>, # factor_sizes
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # operand_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
::llvm::ArrayRef<TensorMappingAttr>, # result_mappings
bool # is_custom_rule
bool # is_custom_rule
>
>
```
```
A sharding rule specifies how an operation can be partitioned according to
A sharding rule specifies how an operation can be partitioned according to
various properties on the op - any attributes, the shape of operands,
various properties on the op - any attributes, the shape of operands,
the shape of the results, etc. For example:
the shape of the results, etc. For example:
```
```
%0 = stablehlo.add %arg0, %arg1 {
%0 = stablehlo.add %arg0, %arg1 {
sdy.sharding_rule = #sdy.op_sharding_rule<
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, j],[i, j])->([i, j])
([i, j],[i, j])->([i, j])
{i=8, j=8}>
{i=8, j=8}>
} : tensor<8x8xf32>
} : tensor<8x8xf32>
```
```
```
```
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
%1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] {
sdy.sharding_rule = #sdy.op_sharding_rule<
sdy.sharding_rule = #sdy.op_sharding_rule<
([i, k],[k, j])->([i, j])
([i, k],[k, j])->([i, j])
{i=8, j=16, k=8}>
{i=8, j=16, k=8}>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
}: (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
```
```
Note that we allow factors with size 1 even though they cannot be sharded,
Note that we allow factors with size 1 even though they cannot be sharded,
this is mainly for completeness as many ops such as pointwise ops have size
this is mainly for completeness as many ops such as pointwise ops have size
one dimensions that correspond across operands and results.
one dimensions that correspond across operands and results.
`is_custom_rule` describes whether this is a rule defined by a user for a
`is_custom_rule` describes whether this is a rule defined by a user for a
`stablehlo.custom_call` op. The partitioner doesn't know how to partition
`stablehlo.custom_call` op. The partitioner doesn't know how to partition
these ops, so a user must tell it how. When it is a custom rule, then the
these ops, so a user must tell it how. When it is a custom rule, then the
rule is always preserved/never removed. `is_custom_rule` can only be true
rule is always preserved/never removed. `is_custom_rule` can only be true
for `stablehlo.custom_call` ops.
for `stablehlo.custom_call` ops.
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| factor_sizes | `::llvm::ArrayRef<int64_t>` | |
| factor_sizes | `::llvm::ArrayRef<int64_t>` | |
| operand_mappings | `::llvm::ArrayRef<TensorMappingAttr>` | |
| operand_mappings | `::llvm::ArrayRef<TensorMappingAttr>` | |
| result_mappings | `::llvm::ArrayRef<TensorMappingAttr>` | |
| result_mappings | `::llvm::ArrayRef<TensorMappingAttr>` | |
| is_custom_rule | `bool` | |
| is_custom_rule | `bool` | |
### SubAxisInfoAttr
### SubAxisInfoAttr
Info about how this sub-axis is derived from the full axis
_Info about how this sub-axis is derived from the full axis_
Syntax:
Syntax:
```
```
#sdy.sub_axis_info<
#sdy.sub_axis_info<
int64_t, # pre_size
int64_t, # pre_size
int64_t # size
int64_t # size
>
>
```
```
When splitting a full axis into n sub-axes, the axis is reshaped into
When splitting a full axis into n sub-axes, the axis is reshaped into
[k_1,...,k_n], and the ith sub-axis can be expressed by the product of all
[k_1,...,k_n], and the ith sub-axis can be expressed by the product of all
axis sizes to its left `m=prod(k_1,...,k_(i-1))` (aka pre-size) and size
axis sizes to its left `m=prod(k_1,...,k_(i-1))` (aka pre-size) and size
k_i. Therefore, the sub-axis-info attribute holds those two numbers and is
k_i. Therefore, the sub-axis-info attribute holds those two numbers and is
denoted as follows: `(m)k` for pre-size m and size k.
denoted as follows: `(m)k` for pre-size m and size k.
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| pre_size | `int64_t` | |
| pre_size | `int64_t` | |
| size | `int64_t` | |
| size | `int64_t` | |
### TensorMappingAttr
### TensorMappingAttr
Factor mappings for each dimension of a tensor.
_Factor mappings for each dimension of a tensor._
Syntax:
Syntax:
```
```
#sdy.tensor_mapping<
#sdy.tensor_mapping<
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
::llvm::ArrayRef<DimMappingAttr> # dim_mappings
>
>
```
```
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| dim_mappings | `::llvm::ArrayRef<DimMappingAttr>` | |
| dim_mappings | `::llvm::ArrayRef<DimMappingAttr>` | |
### TensorShardingAttr
### TensorShardingAttr
Tensor sharding
_Tensor sharding_
Syntax:
Syntax:
```
```
#sdy.sharding<
#sdy.sharding<
::mlir::Attribute, # mesh_or_ref
::mlir::Attribute, # mesh_or_ref
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<DimensionShardingAttr>, # dim_shardings
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
::llvm::ArrayRef<AxisRefAttr> # replicated_axes
>
>
```
```
A tensor sharding is bound to a specific mesh, and can only reference axis
A tensor sharding is bound to a specific mesh, and can only reference axis
names from that mesh. The dimension shardings tell us for each dimension of
names from that mesh. The dimension shardings tell us for each dimension of
the tensor, along which axes (or sub-axes) it is sharded from major to
the tensor, along which axes (or sub-axes) it is sharded from major to
minor. All other axes that don’t shard a dimension are either implicitly or
minor. All other axes that don’t shard a dimension are either implicitly or
explicitly (if they appear in the list of replicated axes) replicated.
explicitly (if they appear in the list of replicated axes) replicated.
The mesh this sharding is bound to can either be specified by a symbol
The mesh this sharding is bound to can either be specified by a symbol
name, referencing a corresponding `MeshOp` symbol, or an inlined `MeshAttr`.
name, referencing a corresponding `MeshOp` symbol, or an inlined `MeshAttr`.
#### Parameters:
#### Parameters:
| Parameter | C++ type | Description |
| Parameter | C++ type | Description |
| :-------: | :-------: | ----------- |
| :-------: | :-------: | ----------- |
| mesh_or_ref | `::mlir::Attribute` | mesh attr or flat mesh symbol reference attr |
| mesh_or_ref | `::mlir::Attribute` | mesh attr or flat m
| dim_shardings