diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -40,7 +40,8 @@ including lowering to scalar load/store and other operations or to external library calls and intrinsics. -These ops can have ***either tensor or buffer operands***. +These ops can have ***either tensor or buffer operands***, subject to +[conventions and limitations](#tensors_and_buffers). ### Payload-Carrying Ops Linalg defines two payload carrying operations that implement the [structured ops]( @@ -463,6 +464,55 @@ compilers. As we lay those down and engage more with the community, we expect multiple rounds of discussions and design changes to the original architecture. +### Tensors and Buffers: Conventions and Limitations + +Tensors are immutable SSA values, buffers are mutable regions of memory subject +to side-effects and aliasing. As a consequence, output buffers are passed as +operands whereas output tensors are new SSA values corresponding to op results. +Inputs can be arbitrary tensors or buffers and are always passed as operands. +The convention adopted is as follows: + +1. The first `[0 .. args_in)` op operands are read-only input ShapedType. +2. The `n` results are write-only output RankedTensorType. Note that `n <= + args_out`. +3. The operands `[args_in .. args_in + args_out - n)` are MemRefType buffers. +4. Other non-ShapedType operands may appear as operands `[args_in + args_out - + n .. getNumOperands())` + +In the case of structured ops with fully parallel semantics, inputs and outputs +can be tensors or buffers without requiring additional constraints. + +Structured ops with reduction semantics and output tensor(s) however have +additional restrictions: + +1. they can only return a single tensor +2. they cannot have any output buffer operand +3. as a consequence of points 1. + 2., they must have exactly one output +4. their last input argument must be a tensor of the same shape and with the + same indexing map as their unique output tensor. + +Points 1. - 3. keep complexity of the representation in check by allowing only 1 +result tensor, when reductions are present. + +Point 4 is related to the fact that SSA values cannot represent in-place +updates. Instead, linalg adopts a similar convention that exists in e.g. +`vector.outerproduct`: the value that is reduced into is passed as an explicit +argument and a new result of the same shape is produced. + +It is expected buffer allocation will fold this last input onto the result in a +single output buffer argument, which is why the same indexing map is required: +the last input operand is said to be "tied" to the result. + +Alternative, more complex representations, would allow for: + +1. Multiple tensor results and tied inputs in arbitrary orders, that could be + captured by an ArrayAttr of position pairs. +2. Relaxing the conditions on the indexing map equalities on the each pair and + e.g. allow implicit broadcasts of the input. + +These representations are deemed unnecessarily complex for now and are left for +future discussion. + ### Data Representation: Views The current implementation uses the [Strided MemRef (a.k.a View)]( https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -22,6 +22,26 @@ //===------------------------------------------------------------------===// // Loop types handling. //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Return the dims that are reduction loops within the current operation. + }], + /*retTy=*/"void", + /*methodName=*/"getDimsOfType", + /*args=*/(ins "StringRef":$iteratorTypeName, + "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + unsigned dim = 0; + MLIRContext *ctx = this->getOperation()->getContext(); + for (auto tn : $_op.iterator_types(). + template getAsValueRange()) { + if (tn == iteratorTypeName) + res.push_back(getAffineDimExpr(dim, ctx)); + ++dim; + } + }] + >, InterfaceMethod< /*desc=*/[{ Return the number of parallel loops within the current operation. @@ -35,6 +55,18 @@ $_op.iterator_types()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the dims that are parallel loops within the current operation. + }], + /*retTy=*/"void", + /*methodName=*/"getParallelDims", + /*args=*/(ins "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getDimsOfType(getParallelIteratorTypeName(), res); + }] + >, InterfaceMethod< /*desc=*/[{ Return the number of reduction loops within the current operation. @@ -48,6 +80,18 @@ $_op.iterator_types()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the dims that are reduction loops within the current operation. + }], + /*retTy=*/"void", + /*methodName=*/"getReductionDims", + /*args=*/(ins "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getDimsOfType(getReductionIteratorTypeName(), res); + }] + >, InterfaceMethod< /*desc=*/[{ Return the number of window loops within the current operation. @@ -61,6 +105,18 @@ $_op.iterator_types()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the dims that are window loops within the current operation. + }], + /*retTy=*/"void", + /*methodName=*/"getWindowDims", + /*args=*/(ins "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getDimsOfType(getWindowIteratorTypeName(), res); + }] + >, InterfaceMethod< /*desc=*/[{ Return the total number of loops within the current operation. @@ -186,6 +242,43 @@ return res; }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if there exists a tied input ShapedType / output + RankedTensorType pair. This is the case when the op has return values ( + which are RankedTensorTypes by construction) and a reduction. + }], + /*retTy=*/"bool", + /*methodName=*/"hasTiedResultTensor", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (this->getOperation()->getNumResults() == 0) + return false; + SmallVector redDims; + $_op.getReductionDims(redDims); + return !redDims.empty(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the index for the input operand that is tied to the result index + `resultIdx`. + }], + /*retTy=*/"bool", + /*methodName=*/"getTiedInputOperandIndex", + /*args=*/(ins "unsigned":$resultIdx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(resultIdx < this->getOperation()->getNumResults() && + "result index overflow"); + assert(resultIdx == 0 && "only single result tensor supported for now"); + unsigned nInputs = $_op.getNumInputs(); + unsigned nInAndOutBuffers = $_op.getNumInputsAndOutputBuffers(); + assert(nInputs == nInAndOutBuffers && "cannot have output buffer"); + return nInputs - 1; + }] + >, //===------------------------------------------------------------------===// // Output arguments handling. @@ -353,7 +446,9 @@ return getInputShapedType(i); if (i < getNumInputsAndOutputBuffers()) return getOutputBufferType(i - $_op.getNumInputs()); - return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]; + return this->getOperation()->getResult( + i - getNumInputsAndOutputBuffers()). + getType().template cast(); }]>, InterfaceMethod< /*desc=*/[{ @@ -407,11 +502,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::to_vector<4>( - llvm::map_range($_op.indexing_maps(), - [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); + return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange()); }] >, InterfaceMethod< @@ -424,10 +515,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(i < getNumInputsAndOutputs()); - return $_op.indexing_maps() - .getValue()[i] - .template cast() - .getValue(); + return getIndexingMaps()[i]; }] >, InterfaceMethod< @@ -440,10 +528,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(i < $_op.getNumInputs()); - return $_op.indexing_maps() - .getValue()[i] - .template cast() - .getValue(); + return getIndexingMaps()[i]; }] >, InterfaceMethod< @@ -456,10 +541,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(i < $_op.getNumOutputs()); - return $_op.indexing_maps() - .getValue()[i + $_op.getNumInputs()] - .template cast() - .getValue(); + return getIndexingMaps()[i + $_op.getNumInputs()]; }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -17,6 +17,8 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" + namespace mlir { namespace OpTrait { namespace linalg { @@ -62,11 +64,56 @@ public: static LogicalResult verifyTrait(Operation *op) { ConcreteType concreteOp = cast(op); - auto nOperands = cast(op).getNumInputsAndOutputBuffers(); - if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) + unsigned nInputAndBufferOperands = + concreteOp.getNumInputsAndOutputBuffers(); + if (failed( + OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands))) return failure(); + if (op->getNumResults() > concreteOp.getNumOutputs()) return op->emitError("unexpected #results > #outputs"); + + if (!concreteOp.hasTiedResultTensor()) + return success(); + + // Only a single tensor result supported atm. + if (op->getNumResults() != 1) + return op->emitError( + "expected single tensor result when reduction present"); + + if (concreteOp.getNumOutputs() != 1) + return op->emitError( + "expected single tensor output when result and reduction present"); + + // Result-returning op with at least a reduction. + SmallVector redDims; + concreteOp.getReductionDims(redDims); + + // Output tensor indexing map may not depend on reduction index. + AffineMap outputMap = concreteOp.getOutputIndexingMap(0); + for (auto expr : outputMap.getResults()) { + for (auto dim : redDims) { + unsigned pos = dim.cast().getPosition(); + if (expr.isFunctionOfDim(pos)) + return op->emitError("unexpected single tensor output indexing map ") + << "is function of reduction dim @" << pos; + } + } + + unsigned nInputs = concreteOp.getNumInputs(); + if (nInputs < op->getNumResults() + 1) + return op->emitError("expected at least one more input than results to " + "accomodate reduction and tied results"); + + // There must be a matching last input buffer or tensor operand for the + // tensor result. + // Tied input + AffineMap lastInputMap = + concreteOp.getInputIndexingMap(concreteOp.getTiedInputOperandIndex(0)); + if (outputMap != lastInputMap) + return op->emitError("expected last input operand with indexing map " + "matching the tensor result's map"); + return success(); } }; diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -524,3 +524,88 @@ } : tensor -> (tensor, tensor) return } + +// ----- + +func @single_tensor_result(%arg0: memref) { + // expected-error @+1 {{expected single tensor result when reduction present}} + linalg.generic { + args_in = 1, + args_out = 2, + indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ], + iterator_types = [ "reduction" ] + } %arg0 { + ^bb(%0: f32): + %f0 = constant 0.0 : f32 + linalg.yield %f0, %f0: f32, f32 + } : memref -> (tensor, tensor) + return +} + +// ----- + +func @single_tensor_result(%arg0: memref) { + // expected-error @+1 {{expected single tensor output when result and reduction present}} + linalg.generic { + args_in = 1, + args_out = 2, + indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ], + iterator_types = [ "reduction" ] + } %arg0, %arg0 { + ^bb(%0: f32): + %f0 = constant 0.0 : f32 + linalg.yield %f0: f32 + } : memref, memref -> tensor + return +} + +// ----- + +func @single_tensor_result_not_function_of_reduction(%arg0: memref) { + // expected-error @+1 {{unexpected single tensor output indexing map is function of reduction dim @0}} + linalg.generic { + args_in = 1, + args_out = 1, + indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ], + iterator_types = [ "reduction" ] + } %arg0 { + ^bb(%0: f32): + %f0 = constant 0.0 : f32 + linalg.yield %f0: f32 + } : memref -> tensor + return +} + +// ----- + +func @single_tensor_result_last_input_not_matching(%arg0: memref) { + // expected-error @+1 {{expected at least one more input than results to accomodate reduction and tied results}} + linalg.generic { + args_in = 1, + args_out = 1, + indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (0)> ], + iterator_types = [ "reduction" ] + } %arg0 { + ^bb(%0: f32): + %f0 = constant 0.0 : f32 + linalg.yield %f0: f32 + } : memref -> tensor + return +} + +// ----- + +func @single_tensor_result_last_input_not_matching(%arg0: memref) { + // expected-error @+1 {{expected last input operand with indexing map matching the tensor result's map}} + linalg.generic { + args_in = 2, + args_out = 1, + indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (0)> ], + iterator_types = [ "reduction" ] + } %arg0, %arg0 { + ^bb(%0: f32, %1: f32): + %f0 = constant 0.0 : f32 + linalg.yield %f0: f32 + } : memref, memref -> tensor + return +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -88,7 +88,7 @@ memref) linalg.matvec %arg0, %arg1, %arg2 : (memref, memref, - memref) + memref) linalg.dot %arg1, %arg2, %arg3 : (memref, memref, memref) @@ -653,3 +653,38 @@ // CHECK-LABEL: func @memref_reshape_zero_dim // CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref // CHECK: linalg.reshape %{{.*}} [] : memref into memref<1x1xf32> + + +// ----- + +#accesses = [ + affine_map<(i, j, k) -> (j, i, k)>, + affine_map<(i, j, k) -> (i, j)>, + affine_map<(i, j, k) -> (i, j)> +] + +#trait = { + args_in = 2, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "reduction"], + library_call = "some_external_function_name_1" +} + +func @generic_with_tied_result_tensor( + %arg0: tensor>, %arg1: tensor) + -> (tensor) { + %0 = linalg.indexed_generic #trait %arg0, %arg1 { + ^bb(%i: index, %j: index, %k: index, %v0: vector<3x4xi4>, %v1: f32) : + %f0 = constant 0.0 : f32 + linalg.yield %f0 : f32 + } : tensor>, tensor -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @generic_with_tied_result_tensor +// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], +// CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: %{{.*}}, %{{.*}} +// CHECK: tensor>, tensor -> tensor +// CHECK: return {{.*}} : tensor