diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -232,7 +232,7 @@ Example: - ``` + ```mlir // Rank-reducing extract_slice. %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor<16x4xf32> @@ -372,8 +372,8 @@ "$_self.cast().getNumElements(), " "$_self.cast().getElementType())"> ]> { - string summary = "tensor from elements operation."; - string description = [{ + let summary = "tensor from elements operation."; + let description = [{ Create a N-D tensor from a range of same-type arguments. The number of provided `elements` should equal to the number of the elements in the result type. The `elements` correspond to a flattened tensor. @@ -406,6 +406,144 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +def Tensor_GatherOp : Tensor_Op<"gather", [ + NoSideEffect + ]> { + let summary = "gather a subset of a tensor at specified indices"; + let description = [{ + The `gather` operation extracts a subset of the elements from a `source` + tensor at the given indices. + + In its most general form, the tensor of indices specifies all the coordinates + of every element to extract (i.e. COO format, without the payload). + The indices are expected to be confined to coordinate values that fit the + range of the `source` tensor, otherwise the behavior is undefined. + + The leading dimensions of the index tensor give the result tensor its leading + dimensions. The trailing dimensions of the result tensor are obtained from + the source tensor by omitting the dimensions specified in `gather_dims` + (rank-reducing semantics) or setting them to `1` (rank-preserving semantics) + (see examples). + The trailing dimension of the index tensor contains the coordinates and is + expected to have its size equal to the number of dimensions being gathered. + This convention allows an idiomatic specification and lowering of "gathering + multiple N-D slices from the source tensor". + + Note: in the examples below, we separate out the indexing part of the tensor + type by a whitespace for readability purposes. + + Example: + + ```mlir + // For each 1x2 triple of coordinates in %indices, extract the + // element (i.e. 0-D subset) at the coordinates triple in %source. + // + %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) : + (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32> + + // Note: result type may be further rank-reduced to tensor<1x2x f32>. + ``` + + A slice variant is provided to allow specifying whole slices of the source + tensor. + + Example: + + ```mlir + // For each 5x6 singleton of coordinates in %indices, extract the 2-D + // slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices + // corresponding to the `gather_dims` attribute specified by %indices. + // + %out = tensor.gather %source[%indices] gather_dims([1]) : + (tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32> + + // Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>. + ``` + + The dimensions specified in the gather_dims attribute are ones for which the + result tensor has size `1`. + I.e. if the source type is `axbxcxd` and the coordinates are [1, 3], then + the shape suffix is `ax1xcx1`. + Gather also allows rank-reducing semantics where the shape `ax1xcx1` can be + further simplified to `axc`. + + The elemental type of the indices tensor can be any integer type. + In the absence of target-specific or problem specific information the default + type one should use is `index`. + + This operation does not support unranked tensors. + + An optional `unique` unit attribute may be specified to indicate that the + coordinates in `indices` are statically guaranteed to be unique at runtime. + Incorrectly setting the `unique` attribute when the coordinates are not truly + unique is undefined behavior. + + Only full slices are meant to be supported by this op, if one desires + partial slices (e.g. strided windows) one should compose this op with other + tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of + complexity that would make the op unusable in practice. + + At the tensor-level, the index tensor is specified in an AoS form (i.e. + coordinate tuple is the most minor). It is the responsibility of further + lowerings and bufferiation to implement various concrete layouts. + + Note: As currently specified, the operation must lower to an abstraction that + performs copies to the output tensor. This is because the buffer type system + is currently not rich enough to allow multiple non-contiguous views in the + same type. This is visible more clearly in a notional buffer version of the + op: + + ```mlir + // memref is a contiguous buffer of ?x4x1 elements. + // gather from random source slices must copy to the contiguous output. + %out = memref.gather %source[%indices] gather_dims([1]) : + (memref<4x4xf32>, memref) -> memref + + // Nested buffer support would allow gather to directly index into the + // source buffer (i.e. represent a jagged view into the source). + %out = memref.gather %source[%indices] gather_dims([1]) : + (memref<4x4xf32>, memref) -> memref> + ``` + }]; + + let arguments = (ins AnyRankedTensor:$source, + RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices, + DenseI64ArrayAttr:$gather_dims, + UnitAttr:$unique); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $source `[` $indices `]` + `gather_dims` `(` $gather_dims `)` + (`unique` $unique^)? + attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + // TODO: InferTypeOpInterface once enough confidence is built with + // tensor and its lwoering to memref. + static RankedTensorType inferResultType(RankedTensorType sourceType, + RankedTensorType indicesType, + ArrayRef gatherDims, + bool rankReduced); + RankedTensorType getIndicesType() { + return getIndices().getType().cast(); + } + RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + RankedTensorType getResultType() { + return getResult().getType().cast(); + } + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// @@ -414,8 +552,8 @@ [RecursiveSideEffects, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> { - string summary = "Creates a dynamically sized tensor from elements"; - string description = [{ + let summary = "Creates a dynamically sized tensor from elements"; + let description = [{ This operation creates a dynamically sized tensor with elements of any type. It expects one index operand per dynamic extent of the result tensor. @@ -560,7 +698,7 @@ Example: - ``` + ```mlir // Rank-altering insert_slice. %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] : tensor<16x4xf32> into tensor<8x16x4xf32> @@ -1210,6 +1348,147 @@ let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +def Tensor_ScatterOp : Tensor_Op<"scatter", [ + NoSideEffect + ]> { + let summary = + "scatter a tensor into a destination tensor at specified indices"; + let description = [{ + The `scatter` operation inserts a `source` tensor into a `dest` tensor at + the given indices. + + In its most general form, the tensor of indices specifies all the coordinates + of every element to insert (i.e. COO format, without the payload). + The indices are expected to be confined to coordinate values that fit the + range of the `dest` tensor, otherwise the behavior is undefined. + + The leading dimensions of the index tensor must match that of the dest + tensor. The trailing dimensions of the dest tensor must match those of the + source tensor by omitting the dimensions specified in scatter_dims + (rank-reducing semantics) or setting them to `1` (rank-preserving semantics) + (see examples). + This convention allows an idiomatic specification and lowering of + "scattering multiple N-D slices into the dest tensor". + The result type must match the type of the dest tensor. + + Note: in the examples below, we separate out the indexing part of the tensor + type by a whitespace for readability purposes. + + Example: + + ```mlir + // For each 1x2 triple of coordinates in %indices, insert the + // element (i.e. 0-D subset) at the coordinates triple in %dest. + // + %out = tensor.scatter %source into %dest[%indices] + scatter_dims([0, 1, 2]) unique : + (tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>) + -> tensor<4x4x4xf32> + + // Note: source type may be further rank-reduced to tensor<1x2x f32>. + ``` + + A slice variant is provided to allow specifying insertion of whole tensor + slices into the `dest` tensor. + + Example: + + ```mlir + // For each 3 singleton of coordinates in %indices, insert the 2-D + // slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the + // indices corresponding to the scatter_dims attribute specified by + // %indices. + // + %out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique : + (tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>) + -> tensor<4x5x6xf32> + ``` + + The dimensions specified in the scatter_dims attribute are ones for which the + source tensor has size `1`. + I.e. if the dest type is `axbxcxd` and the coordinates are [1, 3], then + the source type suffix is `ax1xcx1`. + Sactter also allows rank-reducing semantics where the shape `ax1xcx1` can be + further simplified to `axc`. + + The elemental type of the indices tensor can be any integer type. + In the absence of target-specific or problem specific information the default + type one should use is `index`. + + This operation does not support unranked tensors. + + A `unique` unit attribute must be be specified to indicate that the + coordinates are statically guaranteed to be unique at runtime. If coordinates + are not truly unique at runtime, the behavior is undefined. + + Only full slices are meant to be supported by this op, if one desires + partial slices (e.g. strided windows) one should compose this op with other + tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of + complexity that would make the op unusable in practice. + + At the tensor-level, the index tensor is specified in an AoS form (i.e. + coordinate tuple is the most minor). It is the responsibility of further + lowerings and bufferiation to implement various concrete layouts. + + Note: As currently specified, the operation must lower to an abstraction that + performs copies to the output tensor. This is because the buffer type system + is currently not rich enough to allow multiple non-contiguous views in the + same type. This is visible more clearly in a notional buffer version of the + op: + + ```mlir + // memref is a contiguous buffer of ?x4 elements, scatter into + // random dest slices must copy to the contiguous dest. + // + some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32> + memref.scatter %source into %dest[%indices] scatter_dims([1]) unique : + (memref<3x 4xf32>, memref, memref) + + // Nested buffer support in the producing op would allow writing directly + // into the dest buffer. + %v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique : + memref> + some_side_effecting_op_writing_into %v, ...: memref> + ``` + }]; + + let arguments = (ins AnyRankedTensor:$source, + AnyRankedTensor:$dest, + RankedTensorOf<[AnySignlessIntegerOrIndex]>:$indices, + DenseI64ArrayAttr:$scatter_dims, + UnitAttr:$unique); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $source `into` $dest `[` $indices `]` + `scatter_dims` `(` $scatter_dims `)` + (`unique` $unique^)? + attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + RankedTensorType getDestType() { + return getDest().getType().cast(); + } + RankedTensorType getIndicesType() { + return getIndices().getType().cast(); + } + RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + RankedTensorType getResultType() { + return getResult().getType().cast(); + } + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -17,8 +17,11 @@ #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/StringRef.h" +#include using namespace mlir; using namespace mlir::tensor; @@ -543,6 +546,89 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +/// Return the inferred result type for a gatherOp where: +/// - sourceType is the type of the source tensor gathered from +/// - indicesType is the type of the indices used to gather +/// - gatherDims are the dims along which the gather occurs. +/// Return a full rank or ranked-reduced variant of the type depending on +/// the value of rankReduced. +/// +/// The leading dimensions of the index tensor give the result tensor its +/// leading dimensions. +/// The trailing dimensions of the result tensor are obtained from the source +/// tensor by setting the dimensions specified in gather_dims to `1` (if +/// rankedReduced is false), or skipping them (otherwise). +RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType, + RankedTensorType indicesType, + ArrayRef gatherDims, + bool rankReduced) { + SmallVector resultShape(indicesType.getShape().drop_back()); + resultShape.reserve(resultShape.size() + sourceType.getRank()); + for (int64_t idx : llvm::seq(0, sourceType.getRank())) { + if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) { + if (!rankReduced) + resultShape.push_back(1); + continue; + } + resultShape.push_back(sourceType.getDimSize(idx)); + } + return RankedTensorType::Builder(sourceType).setShape(resultShape); +} + +static LogicalResult +verifyGatherOrScatterDims(Operation *op, ArrayRef dims, int64_t rank, + StringRef gatherOrScatter, StringRef sourceOrDest) { + if (dims.empty()) + return op->emitOpError(gatherOrScatter) << "_dims must be non-empty"; + + int64_t numGatherDims = dims.size(); + if (numGatherDims > rank) + return op->emitOpError(gatherOrScatter) + << "_dims overflow " << sourceOrDest << " rank"; + for (int64_t val : dims) { + if (val < 0) + return op->emitOpError(gatherOrScatter) + << "_dims value must be non-negative"; + if (val >= rank) + return op->emitOpError(gatherOrScatter) + << "_dims value must be smaller than " << sourceOrDest << " rank"; + } + for (int64_t i = 1; i < numGatherDims; ++i) { + if (dims[i - 1] >= dims[i]) + return op->emitOpError(gatherOrScatter) + << "_dims values must be strictly increasing"; + } + return success(); +} + +LogicalResult GatherOp::verify() { + int64_t sourceRank = getSourceType().getRank(); + ArrayRef gatherDims = getGatherDims(); + if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank, + "gather", "source"))) + return failure(); + + RankedTensorType expectedResultType = GatherOp::inferResultType( + getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false); + RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType( + getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true); + if (getResultType() != expectedResultType && + getResultType() != expectedRankReducedResultType) { + return emitOpError("result type " + "mismatch: " + "expected ") + << expectedResultType << " or its rank-reduced variant " + << expectedRankReducedResultType << " (got: " << getResultType() + << ")"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// @@ -2306,6 +2392,42 @@ InsertSliceOpSourceCastInserter>(context); } +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +LogicalResult ScatterOp::verify() { + int64_t destRank = getDestType().getRank(); + ArrayRef scatterDims = getScatterDims(); + if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank, + "scatter", "dest"))) + return failure(); + + if (!getUnique()) + return emitOpError("requires 'unique' attribute to be set"); + // TODO: we could also check statically that there are fewer leading index + // tensor dims than the dest dims. If this is not the case, the unique + // attribute cannot be true. + + // Use the GatherOp::inferResultType on the `dest` type and verify the + // expected type matches the source type. + RankedTensorType expectedSourceType = GatherOp::inferResultType( + getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false); + RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType( + getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true); + if (getSourceType() != expectedSourceType && + getSourceType() != expectedRankReducedSourceType) { + return emitOpError("source type " + "mismatch: " + "expected ") + << expectedSourceType << " or its rank-reduced variant " + << expectedRankReducedSourceType << " (got: " << getSourceType() + << ")"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -377,3 +377,140 @@ %w = tensor.splat %v : tensor<8xvector<8xf32>> return } + +// ----- + +func.func @gather_empty_dims( + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{gather_dims must be non-empty}} + %out = tensor.gather %source[%indices] gather_dims([]): + (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32> + return +} + +// ----- + +func.func @gather_coordinate_rank_overflow( + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{gather_dims overflow source rank}} + %out = tensor.gather %source[%indices] gather_dims([0, 1, 2, 3]): + (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32> + return +} + +// ----- + +func.func @gather_coordinate_negative( + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{gather_dims value must be non-negative}} + %out = tensor.gather %source[%indices] gather_dims([-1]): + (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + return +} + +// ----- + +func.func @gather_coordinate_overflow( + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{gather_dims value must be smaller than source rank}} + %out = tensor.gather %source[%indices] gather_dims([42]): + (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + return +} + +// ----- + +func.func @gather_coordinate_overflow( + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{gather_dims values must be strictly increasing}} + %out = tensor.gather %source[%indices] gather_dims([1, 0]): + (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + return +} + +// ----- + +func.func @gather_wrong_result_type( + %source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}} + %out = tensor.gather %source[%indices] gather_dims([0, 2]): + (tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32> + return +} + +// ----- + +func.func @scatter_empty_dims( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{scatter_dims must be non-empty}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([]) unique: + (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32> + return +} + +// ----- + +func.func @scatter_coordinate_rank_overflow( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{scatter_dims overflow dest rank}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2, 3]) unique: + (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2xf32> + return +} + +// ----- + +func.func @scatter_coordinate_negative( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{scatter_dims value must be non-negative}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique: + (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + return +} + +// ----- + +func.func @scatter_coordinate_overflow( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{scatter_dims value must be smaller than dest rank}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique: + (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + return +} + +// ----- + +func.func @scatter_coordinate_overflow( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{scatter_dims values must be strictly increasing}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique: + (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> + return +} + +// ----- + +func.func @scatter_missing_unique( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{requires 'unique' attribute to be set}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]): + (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32> + return +} + +// ----- + +func.func @scatter_wrong_result_type( + %source : tensor, + %dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor')}} + %out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique: + (tensor, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32> + return +} diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -260,3 +260,22 @@ %u = "tensor.splat"(%s) : (f32) -> tensor<4xf32> return } + +// ----- + +// CHECK-LABEL: func @gather_scatter +func.func @gather_scatter( + %dest : tensor<4x5x6xf32>, %indices: tensor<1x3x2xindex>, %indices_i32: tensor<1x3x2xi32>) { + %gathered = tensor.gather %dest[%indices_i32] gather_dims([1, 2]) unique: + (tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4x1x1xf32> + %rank_reduced_gathered = tensor.gather %dest[%indices] gather_dims([1, 2]) unique: + (tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<1x3x4xf32> + + %scattered = tensor.scatter %gathered into %dest[%indices] + scatter_dims([1, 2]) unique: + (tensor<1x3x4x1x1xf32>, tensor<4x5x6xf32>, tensor<1x3x2xindex>) -> tensor<4x5x6xf32> + %rank_reduced_scattered = tensor.scatter %rank_reduced_gathered into %dest[%indices_i32] + scatter_dims([1, 2]) unique: + (tensor<1x3x4xf32>, tensor<4x5x6xf32>, tensor<1x3x2xi32>) -> tensor<4x5x6xf32> + return +}