diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -40,7 +40,8 @@
including lowering to scalar load/store and other operations or to external
library calls and intrinsics.
-These ops can have ***either tensor or buffer operands***.
+These ops can have ***either tensor or buffer operands***, subject to
+[conventions and limitations](#tensors_and_buffers).
### Payload-Carrying Ops
Linalg defines two payload carrying operations that implement the [structured ops](
@@ -463,6 +464,55 @@
compilers. As we lay those down and engage more with the community, we expect
multiple rounds of discussions and design changes to the original architecture.
+### Tensors and Buffers: Conventions and Limitations
+
+Tensors are immutable SSA values, buffers are mutable regions of memory subject
+to side-effects and aliasing. As a consequence, output buffers are passed as
+operands whereas output tensors are new SSA values corresponding to op results.
+Inputs can be arbitrary tensors or buffers and are always passed as operands.
+The convention adopted is as follows:
+
+1. The first `[0 .. args_in)` op operands are read-only input ShapedType.
+2. The `n` results are write-only output RankedTensorType. Note that `n <=
+ args_out`.
+3. The operands `[args_in .. args_in + args_out - n)` are MemRefType buffers.
+4. Other non-ShapedType operands may appear as operands `[args_in + args_out -
+ n .. getNumOperands())`
+
+In the case of structured ops with fully parallel semantics, inputs and outputs
+can be tensors or buffers without requiring additional constraints.
+
+Structured ops with reduction semantics and output tensor(s) however have
+additional restrictions:
+
+1. they can only return a single tensor
+2. they cannot have any output buffer operand
+3. as a consequence of points 1. + 2., they must have exactly one output
+4. their last input argument must be a tensor of the same shape and with the
+ same indexing map as their unique output tensor.
+
+Points 1. - 3. keep complexity of the representation in check by allowing only 1
+result tensor, when reductions are present.
+
+Point 4 is related to the fact that SSA values cannot represent in-place
+updates. Instead, linalg adopts a similar convention that exists in e.g.
+`vector.outerproduct`: the value that is reduced into is passed as an explicit
+argument and a new result of the same shape is produced.
+
+It is expected buffer allocation will fold this last input onto the result in a
+single output buffer argument, which is why the same indexing map is required:
+the last input operand is said to be "tied" to the result.
+
+Alternative, more complex representations, would allow for:
+
+1. Multiple tensor results and tied inputs in arbitrary orders, that could be
+ captured by an ArrayAttr of position pairs.
+2. Relaxing the conditions on the indexing map equalities on the each pair and
+ e.g. allow implicit broadcasts of the input.
+
+These representations are deemed unnecessarily complex for now and are left for
+future discussion.
+
### Data Representation: Views
The current implementation uses the [Strided MemRef (a.k.a View)](
https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -22,6 +22,26 @@
//===------------------------------------------------------------------===//
// Loop types handling.
//===------------------------------------------------------------------===//
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the dims that are reduction loops within the current operation.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getDimsOfType",
+ /*args=*/(ins "StringRef":$iteratorTypeName,
+ "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ unsigned dim = 0;
+ MLIRContext *ctx = this->getOperation()->getContext();
+ for (auto tn : $_op.iterator_types().
+ template getAsValueRange()) {
+ if (tn == iteratorTypeName)
+ res.push_back(getAffineDimExpr(dim, ctx));
+ ++dim;
+ }
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the number of parallel loops within the current operation.
@@ -35,6 +55,18 @@
$_op.iterator_types());
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the dims that are parallel loops within the current operation.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getParallelDims",
+ /*args=*/(ins "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType(getParallelIteratorTypeName(), res);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the number of reduction loops within the current operation.
@@ -48,6 +80,18 @@
$_op.iterator_types());
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the dims that are reduction loops within the current operation.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getReductionDims",
+ /*args=*/(ins "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType(getReductionIteratorTypeName(), res);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the number of window loops within the current operation.
@@ -61,6 +105,18 @@
$_op.iterator_types());
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the dims that are window loops within the current operation.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getWindowDims",
+ /*args=*/(ins "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType(getWindowIteratorTypeName(), res);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the total number of loops within the current operation.
@@ -186,6 +242,43 @@
return res;
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return `true` if there exists a tied input ShapedType / output
+ RankedTensorType pair. This is the case when the op has return values (
+ which are RankedTensorTypes by construction) and a reduction.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasTiedResultTensor",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ if (this->getOperation()->getNumResults() == 0)
+ return false;
+ SmallVector redDims;
+ $_op.getReductionDims(redDims);
+ return !redDims.empty();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the index for the input operand that is tied to the result index
+ `resultIdx`.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"getTiedInputOperandIndex",
+ /*args=*/(ins "unsigned":$resultIdx),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(resultIdx < this->getOperation()->getNumResults() &&
+ "result index overflow");
+ assert(resultIdx == 0 && "only single result tensor supported for now");
+ unsigned nInputs = $_op.getNumInputs();
+ unsigned nInAndOutBuffers = $_op.getNumInputsAndOutputBuffers();
+ assert(nInputs == nInAndOutBuffers && "cannot have output buffer");
+ return nInputs - 1;
+ }]
+ >,
//===------------------------------------------------------------------===//
// Output arguments handling.
@@ -353,7 +446,9 @@
return getInputShapedType(i);
if (i < getNumInputsAndOutputBuffers())
return getOutputBufferType(i - $_op.getNumInputs());
- return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()];
+ return this->getOperation()->getResult(
+ i - getNumInputsAndOutputBuffers()).
+ getType().template cast();
}]>,
InterfaceMethod<
/*desc=*/[{
@@ -407,11 +502,7 @@
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return llvm::to_vector<4>(
- llvm::map_range($_op.indexing_maps(),
- [](Attribute attr) -> AffineMap {
- return attr.cast().getValue();
- }));
+ return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange());
}]
>,
InterfaceMethod<
@@ -424,10 +515,7 @@
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < getNumInputsAndOutputs());
- return $_op.indexing_maps()
- .getValue()[i]
- .template cast()
- .getValue();
+ return getIndexingMaps()[i];
}]
>,
InterfaceMethod<
@@ -440,10 +528,7 @@
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumInputs());
- return $_op.indexing_maps()
- .getValue()[i]
- .template cast()
- .getValue();
+ return getIndexingMaps()[i];
}]
>,
InterfaceMethod<
@@ -456,10 +541,7 @@
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumOutputs());
- return $_op.indexing_maps()
- .getValue()[i + $_op.getNumInputs()]
- .template cast()
- .getValue();
+ return getIndexingMaps()[i + $_op.getNumInputs()];
}]
>,
InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -17,6 +17,8 @@
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+
namespace mlir {
namespace OpTrait {
namespace linalg {
@@ -62,11 +64,56 @@
public:
static LogicalResult verifyTrait(Operation *op) {
ConcreteType concreteOp = cast(op);
- auto nOperands = cast(op).getNumInputsAndOutputBuffers();
- if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands)))
+ unsigned nInputAndBufferOperands =
+ concreteOp.getNumInputsAndOutputBuffers();
+ if (failed(
+ OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands)))
return failure();
+
if (op->getNumResults() > concreteOp.getNumOutputs())
return op->emitError("unexpected #results > #outputs");
+
+ if (!concreteOp.hasTiedResultTensor())
+ return success();
+
+ // Only a single tensor result supported atm.
+ if (op->getNumResults() != 1)
+ return op->emitError(
+ "expected single tensor result when reduction present");
+
+ if (concreteOp.getNumOutputs() != 1)
+ return op->emitError(
+ "expected single tensor output when result and reduction present");
+
+ // Result-returning op with at least a reduction.
+ SmallVector redDims;
+ concreteOp.getReductionDims(redDims);
+
+ // Output tensor indexing map may not depend on reduction index.
+ AffineMap outputMap = concreteOp.getOutputIndexingMap(0);
+ for (auto expr : outputMap.getResults()) {
+ for (auto dim : redDims) {
+ unsigned pos = dim.cast().getPosition();
+ if (expr.isFunctionOfDim(pos))
+ return op->emitError("unexpected single tensor output indexing map ")
+ << "is function of reduction dim @" << pos;
+ }
+ }
+
+ unsigned nInputs = concreteOp.getNumInputs();
+ if (nInputs < op->getNumResults() + 1)
+ return op->emitError("expected at least one more input than results to "
+ "accomodate reduction and tied results");
+
+ // There must be a matching last input buffer or tensor operand for the
+ // tensor result.
+ // Tied input
+ AffineMap lastInputMap =
+ concreteOp.getInputIndexingMap(concreteOp.getTiedInputOperandIndex(0));
+ if (outputMap != lastInputMap)
+ return op->emitError("expected last input operand with indexing map "
+ "matching the tensor result's map");
+
return success();
}
};
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -524,3 +524,88 @@
} : tensor -> (tensor, tensor)
return
}
+
+// -----
+
+func @single_tensor_result(%arg0: memref) {
+ // expected-error @+1 {{expected single tensor result when reduction present}}
+ linalg.generic {
+ args_in = 1,
+ args_out = 2,
+ indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ],
+ iterator_types = [ "reduction" ]
+ } %arg0 {
+ ^bb(%0: f32):
+ %f0 = constant 0.0 : f32
+ linalg.yield %f0, %f0: f32, f32
+ } : memref -> (tensor, tensor)
+ return
+}
+
+// -----
+
+func @single_tensor_result(%arg0: memref) {
+ // expected-error @+1 {{expected single tensor output when result and reduction present}}
+ linalg.generic {
+ args_in = 1,
+ args_out = 2,
+ indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ],
+ iterator_types = [ "reduction" ]
+ } %arg0, %arg0 {
+ ^bb(%0: f32):
+ %f0 = constant 0.0 : f32
+ linalg.yield %f0: f32
+ } : memref, memref -> tensor
+ return
+}
+
+// -----
+
+func @single_tensor_result_not_function_of_reduction(%arg0: memref) {
+ // expected-error @+1 {{unexpected single tensor output indexing map is function of reduction dim @0}}
+ linalg.generic {
+ args_in = 1,
+ args_out = 1,
+ indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ],
+ iterator_types = [ "reduction" ]
+ } %arg0 {
+ ^bb(%0: f32):
+ %f0 = constant 0.0 : f32
+ linalg.yield %f0: f32
+ } : memref -> tensor
+ return
+}
+
+// -----
+
+func @single_tensor_result_last_input_not_matching(%arg0: memref) {
+ // expected-error @+1 {{expected at least one more input than results to accomodate reduction and tied results}}
+ linalg.generic {
+ args_in = 1,
+ args_out = 1,
+ indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (0)> ],
+ iterator_types = [ "reduction" ]
+ } %arg0 {
+ ^bb(%0: f32):
+ %f0 = constant 0.0 : f32
+ linalg.yield %f0: f32
+ } : memref -> tensor
+ return
+}
+
+// -----
+
+func @single_tensor_result_last_input_not_matching(%arg0: memref) {
+ // expected-error @+1 {{expected last input operand with indexing map matching the tensor result's map}}
+ linalg.generic {
+ args_in = 2,
+ args_out = 1,
+ indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)>, affine_map<(i) -> (0)> ],
+ iterator_types = [ "reduction" ]
+ } %arg0, %arg0 {
+ ^bb(%0: f32, %1: f32):
+ %f0 = constant 0.0 : f32
+ linalg.yield %f0: f32
+ } : memref, memref -> tensor
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -88,7 +88,7 @@
memref)
linalg.matvec %arg0, %arg1, %arg2 : (memref,
memref,
- memref)
+ memref)
linalg.dot %arg1, %arg2, %arg3 : (memref,
memref,
memref)
@@ -653,3 +653,38 @@
// CHECK-LABEL: func @memref_reshape_zero_dim
// CHECK: linalg.reshape %{{.*}} [] : memref<1x1xf32> into memref
// CHECK: linalg.reshape %{{.*}} [] : memref into memref<1x1xf32>
+
+
+// -----
+
+#accesses = [
+ affine_map<(i, j, k) -> (j, i, k)>,
+ affine_map<(i, j, k) -> (i, j)>,
+ affine_map<(i, j, k) -> (i, j)>
+]
+
+#trait = {
+ args_in = 2,
+ args_out = 1,
+ indexing_maps = #accesses,
+ iterator_types = ["parallel", "parallel", "reduction"],
+ library_call = "some_external_function_name_1"
+}
+
+func @generic_with_tied_result_tensor(
+ %arg0: tensor>, %arg1: tensor)
+ -> (tensor) {
+ %0 = linalg.indexed_generic #trait %arg0, %arg1 {
+ ^bb(%i: index, %j: index, %k: index, %v0: vector<3x4xi4>, %v1: f32) :
+ %f0 = constant 0.0 : f32
+ linalg.yield %f0 : f32
+ } : tensor>, tensor -> tensor
+ return %0 : tensor
+}
+// CHECK-LABEL: func @generic_with_tied_result_tensor
+// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64,
+// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"],
+// CHECK-SAME: library_call = "some_external_function_name_1"}
+// CHECK-SAME: %{{.*}}, %{{.*}}
+// CHECK: tensor>, tensor -> tensor
+// CHECK: return {{.*}} : tensor