diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -232,7 +232,7 @@ Example: - ``` + ```mlir // Rank-reducing extract_slice. %1 = tensor.extract_slice %0[0, 0, 0][1, 16, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor<16x4xf32> @@ -406,6 +406,128 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +def Tensor_GatherOp : Tensor_Op<"gather", [ + NoSideEffect + ]> { + string summary = "gather a subset of a tensor at specified indices"; + string description = [{ + The `gather` operation extracts a subset of the elements from a `source` + tensor at the given indices. + + In its most general form, the tensor of indices specifies all the coordinates + of every element to extract (i.e. COO format, without the payload). + The indices are expected to be confined to coordinate values that fit the + range of the `source` tensor, otherwise the behavior is undefined. + + The leading dimensions of the index tensor give the result tensor its leading + dimensions. The trailing dimensions of the result tensor are obtained from + the source tensor (see examples). This allows an idiomatic specification and + lowering of "gathering multiple N-D slices from the source tensor". + + Note: in the examples below, we separate out the indexing part of the tensor + type by a whitespace for readability purposes. + + Example: + + ```mlir + // For each 1x2 triple of coordinates in %indices, extract the + // element (i.e. 0-D subset) at the coordinates triple in %source. + // This corresponds to an implicit gather_dims(0, 1, 2) attribute. + // + %out = tensor.gather %source[%indices] : + (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32> + + // Note: result type may be further rank-reduced to tensor<1x2x f32>. + ``` + + A slice variant is provided to allow specifying whole slices of the source + tensor. + + Example: + + ```mlir + // For each 5x6 singleton of coordinates in %indices, extract the 2-D + // slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices + // corresponding to the gather_dim attribute specified by %indices. + // + %out = tensor.gather %source[%indices] gather_dims(1) : + (tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32> + + // Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>. + ``` + + The dimensions specified in the gather_dims attribute are ones for which the + result tensor has size `1`. + I.e. if the source type is `axbxcxd` and the coordinates are [1, 3], then + the shape suffix is `ax1xcx1`. + In the future, this op may allow rank-reducing semantics where the shape + `ax1xcx1` could be further simplified to `axc`. + + An optional `unique` unit attribute may be specified to indicate that the + coordinates in `indices` are statically guaranteed to be unique at runtime. + Incorrectly setting the `unique` attribute when the coordinates are not truly + unique is undefined behavior. + + Only full slices are meant to be supported by this op, if one desires + partial slices (e.g. strided windows) one should compose this op with other + tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of + complexity that would make the op unusable in practice. + + At the tensor-level, the index tensor is specified in an AoS form (i.e. + coordinate tuple is the most minor). It is the responsibility of further + lowerings and bufferiation to implement various concrete layouts. + + Note: As currently specified, the operation must lower to an abstraction that + performs copies to the output tensor. This is because the buffer type system + is currently not rich enough to allow multiple non-contiguous views in the + same type. This is visible more clearly in a notional buffer version of the + op: + + ```mlir + // memref is a contiguous buffer of ?x4x1 elements. + // gather from random source slices must copy to the contiguous output. + %out = memref.gather %source[%indices] gather_dims(1) : + (memref<4x4xf32>, memref) -> memref + + // Nested buffer support would allow gather to directly index into the + // source buffer (i.e. represent a jagged view into the source). + %out = memref.gather %source[%indices] gather_dims(1) : + (memref<4x4xf32>, memref) -> memref> + ``` + }]; + + let arguments = (ins AnyRankedTensor:$source, + AnyRankedTensor:$indices, + OptionalAttr:$gather_dims, + UnitAttr:$unique); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $source `[` $indices `]` + (`unique` $unique^)? + (`coordinates` `=` $coordinates^)? + attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + RankedTensorType getIndicesType() { + return getIndices().getType().cast(); + } + RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + RankedTensorType getResultType() { + return getResult().getType().cast(); + } + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// @@ -560,7 +682,7 @@ Example: - ``` + ```mlir // Rank-altering insert_slice. %1 = tensor.insert_slice %t into %0[0, 0, 0][1, 16, 4][1, 1, 1] : tensor<16x4xf32> into tensor<8x16x4xf32> @@ -1210,6 +1332,135 @@ let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +def Tensor_ScatterOp : Tensor_Op<"scatter", [ + NoSideEffect + ]> { + string summary = + "scatter a tensor into a destination tensor at specified indices"; + string description = [{ + The `scatter` operation inserts a `source` tensor into a `dest` tensor at + the given indices. + + In its most general form, the tensor of indices specifies all the coordinates + of every element to extract (i.e. COO format, without the payload). + The indices are expected to be confined to coordinate values that fit the + range of the `dest` tensor, otherwise the behavior is undefined. + + The leading dimensions of the index tensor must match that of the dest + tensor. The trailing dimensions of the dest tensor must match those of the + source tensor (see examples). This allows an idiomatic specification and + lowering of "scattering multiple N-D slices into the dest tensor". + The result type must match the type of the dest tensor. + + Note: in the examples below, we separate out the indexing part of the tensor + type by a whitespace for readability purposes. + + Example: + + ```mlir + // For each 1x2 triple of coordinates in %indices, insert the + // element (i.e. 0-D subset) at the coordinates triple in %dest. + // This corresponds to an implicit scatter_dims(0, 1, 2) attribute. + // + %out = tensor.scatter %source into %dest[%indices] unique : + (tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>) + -> tensor<4x4x4xf32> + ``` + + A slice variant is provided to allow specifying insertion of whole tensor + slices into the `dest` tensor. + + Example: + + ```mlir + // For each 3 singleton of coordinates in %indices, insert the 2-D + // slice into %dest[*, %indices[...]:%indices[...] + 1, *][:, 1, :] + // with the indices corresponding to the scatter_dim attribute specified + // by %indices. + // + %out = tensor.scatter %source into %dest[%indices] scatter_dims(1) unique : + (tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>) + -> tensor<4x5x6xf32> + ``` + + The dimensions specified in the scatter_dims attribute are ones for which the + source tensor has size `1`. + I.e. if the dest type is `axbxcxd` and the coordinates are [1, 3], then + the source type suffix is `ax1xcx1`. + In the future, this op may allow rank-reducing semantics where the shape + `ax1xcx1` could be further simplified to `axc`. + + A `unique` unit attribute must be be specified to indicate that the + coordinates are statically guaranteed to be unique at runtime. If coordinates + are not truly unique at runtime, the behavior is undefined. + + Only full slices are meant to be supported by this op, if one desires + partial slices (e.g. strided windows) one should compose this op with other + tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of + complexity that would make the op unusable in practice. + + At the tensor-level, the index tensor is specified in an AoS form (i.e. + coordinate tuple is the most minor). It is the responsibility of further + lowerings and bufferiation to implement various concrete layouts. + + Note: As currently specified, the operation must lower to an abstraction that + performs copies to the output tensor. This is because the buffer type system + is currently not rich enough to allow multiple non-contiguous views in the + same type. This is visible more clearly in a notional buffer version of the + op: + + ```mlir + // memref is a contiguous buffer of ?x4 elements, scatter into + // random dest slices must copy to the contiguous dest. + // + some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32> + memref.scatter %source into %dest[%indices] scatter_dims(1) unique : + (memref<3x 4xf32>, memref, memref) + + // Nested buffer support in the producing op would allow writing directly + // into the dest buffer. + %v = some_nested_buffer_view_op %dest[%indices] scatter_dims(1) : + memref> + some_side_effecting_op_writing_into %v, ...: memref> + ``` + }]; + + let arguments = (ins AnyRankedTensor:$source, + AnyRankedTensor:$dest, + AnyRankedTensor:$indices, + OptionalAttr:$coordinates, + UnitAttr:$unique); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $source `into` $dest `[` $indices `]` + (`unique` $unique^)? + attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + RankedTensorType getDestType() { + return getDest().getType().cast(); + } + RankedTensorType getIndicesType() { + return getIndices().getType().cast(); + } + RankedTensorType getSourceType() { + return getSource().getType().cast(); + } + RankedTensorType getResultType() { + return getResult().getType().cast(); + } + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include using namespace mlir; using namespace mlir::tensor; @@ -543,6 +544,49 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +LogicalResult GatherOp::verify() { + int64_t inputRank = getInputType().getRank(); + Optional> maybeCoordinates = getCoordinates(); + + ArrayRef coordinates; + if (maybeCoordinates) { + coordinates = *maybeCoordinates; + int64_t numCoordinates = coordinates.size(); + if (numCoordinates > inputRank) + return emitOpError("coordinates overflow input rank"); + for (int64_t val : coordinates) { + if (val < 0) + return emitOpError("coordinate value must be non-negative"); + if (val >= inputRank) + return emitOpError("coordinate value must be smaller than rank"); + } + for (int64_t i = 0, j = 1; j < numCoordinates; ++i, ++j) { + if (coordinates[i] >= coordinates[j]) + return emitOpError("coordinate values must be strictly increasing"); + } + } + // The coordinates along which we gather are projected out of the type. + // I.e. if the input type is `axbxcxd` and the coordinates are [1, 3], then + // the shape suffix is `axc`. + // This is because we gather 1 point along `bxd` and the whole `axc` slice. + ArrayRef inputShape = getInputType().getShape(); + SmallVector shapeSuffix; + shapeSuffix.reserve(inputRank); + for (int64_t i = 0; i < inputRank; ++i) { + if (coordinates.empty() || std::find(coordinates.begin(), coordinates.end(), + i) != coordinates.end()) + continue; + shapeSuffix.push_back(inputShape[i]); + } + return emitOpError("shapeSuffix: ") << shapeSuffix; + + return success(); +} + //===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// @@ -2306,6 +2350,22 @@ InsertSliceOpSourceCastInserter>(context); } +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +LogicalResult ScatterOp::verify() { + // Optional coordinates = getCoordinates(); + // if () + // int64_t numCoordinates = getCoordinates() ? ; + + // // Verify the # indices match if we have a ranked type. + // if (auto destType = getDest().getType().dyn_cast()) + // if (destType.getRank() != static_cast(getIndices().size())) + // return emitOpError("incorrect number of indices"); + return success(); +} + //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -377,3 +377,53 @@ %w = tensor.splat %v : tensor<8xvector<8xf32>> return } + +// ----- + +func.func @gather_coordinate_rank_overflow( + %input : tensor<4x4x4xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{coordinates overflow input rank}} + %out = tensor.gather %input[%indices] along coordinates = [0, 1, 2, 3]: + tensor<4x4x4xf32>[tensor<1x2x3xindex>] -> tensor<1x2xf32> + return +} + +// ----- + +func.func @gather_coordinate_negative( + %input : tensor<4x4x4xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{coordinate value must be non-negative}} + %out = tensor.gather %input[%indices] along coordinates = [-1]: + tensor<4x4x4xf32>[tensor<1x2x3xindex>] -> tensor<1x2xf32> + return +} + +// ----- + +func.func @gather_coordinate_overflow( + %input : tensor<4x4x4xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{coordinate value must be smaller than rank}} + %out = tensor.gather %input[%indices] along coordinates = [42]: + tensor<4x4x4xf32>[tensor<1x2x3xindex>] -> tensor<1x2xf32> + return +} + +// ----- + +func.func @gather_coordinate_overflow( + %input : tensor<4x4x4xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{coordinate values must be strictly increasing}} + %out = tensor.gather %input[%indices] along coordinates = [1, 0]: + tensor<4x4x4xf32>[tensor<1x2x3xindex>] -> tensor<1x2xf32> + return +} + +// ----- + +func.func @gather_coordinate_overflow( + %input : tensor<4x4x4xf32>, %indices: tensor<1x2x3xindex>) { + // expected-error@+1 {{coordinate values must be strictly increasing}} + %out = tensor.gather %input[%indices] along coordinates = [1]: + tensor<4x4x4xf32>[tensor<1x2x3xindex>] -> tensor<1x2xf32> + return +}