diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -40,7 +40,8 @@
including lowering to scalar load/store and other operations or to external
library calls and intrinsics.
-These ops can have ***either tensor or buffer operands***.
+These ops can have ***either tensor or buffer operands***, subject to
+[conventions and limitations](#tensors_and_buffers).
### Payload-Carrying Ops
Linalg defines two payload carrying operations that implement the [structured ops](
@@ -463,6 +464,76 @@
compilers. As we lay those down and engage more with the community, we expect
multiple rounds of discussions and design changes to the original architecture.
+### Tensors and Buffers: Conventions and Limitations
+
+Tensors are immutable SSA values, buffers are mutable regions of memory subject
+to side-effects and aliasing. As a consequence, output buffers are passed as
+operands whereas output tensors are new SSA values corresponding to op results.
+Inputs can be arbitrary tensors or buffers and are always passed as operands.
+
+The following convention is currently in-flight and is in the process of
+replacing other existing conventions. The following convention currently applies
+to "named" structured ops which are auto-generated by the linalg-ods tool.
+
+The convention adopted is as follows:
+
+1. A first block of `ins` op operands hold read-only inputs of ShapedType.
+2. An optional second block of `outs` op operands hold read-write output
+ buffers of MemRefType.
+3. An optional third block of `init` operands hold initialization tensors of
+ RankedTensorType. Such tensors can appear when the op performs a reduction
+ and returns a tensor.
+
+Structured ops with fully parallel semantics, have empty `init`. They may either
+write in-place into `outs` buffers or return new tensors.
+
+Structured ops with reduction semantics and output tensor(s) however have
+additional restrictions:
+
+1. They can only return a single tensor for now.
+2. They cannot have any output buffer operand (i.e. `outs` is empty).
+3. They have exactly one `init` tensor of the same shape as the unique output
+ tensor. Such an `init` tensor does not have an explicit associate indexing
+ map. Instead the map of the result tensor is used to signify that the `init`
+ and the `result` are "tied".
+
+Points 1. and 2. keep complexity of the representation in check by allowing only
+a single result tensor, when reductions are present.
+
+Point 3. is related to the fact that SSA values cannot represent in-place
+updates. Instead, linalg adopts a similar convention that exists in e.g.
+`vector.outerproduct`: the value that is reduced into is passed as an explicit
+argument and a new result of the same shape is produced.
+
+It is expected buffer allocation will fold this last input onto the result in a
+single output buffer argument, which is why the same indexing map is required:
+the last input operand is said to be "tied" to the result.
+
+Alternative, more complex representations, would allow for:
+
+1. Multiple results and `init` tensors in arbitrary orders, which could be
+ captured by an extra ArrayAttr of position pairs.
+2. Relaxing the conditions on the indexing map equalities on the each pair and
+ e.g. allow implicit broadcasts of the input.
+
+These representations are deemed unnecessarily complex for now and are left for
+future discussion.
+
+As an illustration, the syntax for a `linalg.matmul` writing into a buffer is:
+
+```
+linalg.matmul ins(%a, %b : memref, tensor)
+ outs(%c : memref)
+```
+
+, whereas the syntax for a `linalg.matmul` returning a new tensor is:
+
+```
+%d = linalg.matmul ins(%a, %b : tensor, memref)
+ init(%c : memref)
+ -> tensor
+```
+
### Data Representation: Views
The current implementation uses the [Strided MemRef (a.k.a View)](
https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -32,6 +32,7 @@
NativeOpTrait<"linalg::NOutputs<" # !cast(args_out) # ">::Impl"> {}
def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">;
+def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">;
// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on linalg::View as their
@@ -798,24 +799,7 @@
// Named Linalg ops, implemented as a declarative configurations of generic ops.
//===----------------------------------------------------------------------===//
-class LinalgNamedStructured_Op props>
- : LinalgStructuredBase_Op {
- string spec = ?;
- // We cannot use an assemblyFormat atm because we need to hook in a custom-
- // built implicit region from a static OpClass method.
- // TODO: Revisit in the future if/when appropriate.
- // let assemblyFormat = "`(` operands `)` attr-dict `:` "
- // "functional-type(operands, results)";
-
- // The parser needs to specialize on the OpType so it has to be auto-generated
- // in the linalg-ods tool.
- let printer = [{ return ::printNamedStructuredOp(p, *this); }];
- let verifier = [{ return ::verifyNamedStructuredOp(*this); }];
- let hasFolder = 1;
- let hasCanonicalizer = 1;
-}
-
-// This file is auto-generated from a tc specification.
+// This file is auto-generated from a TC def specification.
include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td"
#endif // LINALG_STRUCTURED_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
@@ -25,7 +25,27 @@
//===------------------------------------------------------------------===//
InterfaceMethod<
/*desc=*/[{
- Return the number of parallel loops within the current operation.
+ Return the dims that are `iteratorTypeName` loops.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getDimsOfType",
+ /*args=*/(ins "StringRef":$iteratorTypeName,
+ "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ unsigned dim = 0;
+ MLIRContext *ctx = this->getOperation()->getContext();
+ for (auto tn : $_op.iterator_types().
+ template getAsValueRange()) {
+ if (tn == iteratorTypeName)
+ res.push_back(getAffineDimExpr(dim, ctx));
+ ++dim;
+ }
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of parallel loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumParallelLoops",
@@ -38,7 +58,19 @@
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of reduction loops within the current operation.
+ Return the dims that are parallel loops.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getParallelDims",
+ /*args=*/(ins "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType(getParallelIteratorTypeName(), res);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of reduction loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumReductionLoops",
@@ -51,7 +83,19 @@
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of window loops within the current operation.
+ Return the dims that are reduction loops.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getReductionDims",
+ /*args=*/(ins "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType(getReductionIteratorTypeName(), res);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the number of window loops.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumWindowLoops",
@@ -62,6 +106,18 @@
$_op.iterator_types());
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the dims that are window loops.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"getWindowDims",
+ /*args=*/(ins "SmallVectorImpl &":$res),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return getDimsOfType(getWindowIteratorTypeName(), res);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Return the total number of loops within the current operation.
@@ -99,14 +155,14 @@
// linalg.indexed_generic ops).
InterfaceMethod<
/*desc=*/[{
- Return the number of inputs from the current operation.
+ Return the number of inputs.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumInputs"
>,
InterfaceMethod<
/*desc=*/[{
- Return the number of outputs from the current operation.
+ Return the number of outputs.
}],
/*retTy=*/"unsigned",
/*methodName=*/"getNumOutputs"
@@ -160,7 +216,7 @@
>,
InterfaceMethod<
/*desc=*/[{
- Return the input operands from the current operation.
+ Return the input operands.
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getInputs",
@@ -187,7 +243,6 @@
return res;
}]
>,
-
//===------------------------------------------------------------------===//
// Output arguments handling.
//===------------------------------------------------------------------===//
@@ -267,7 +322,7 @@
}]>,
InterfaceMethod<
/*desc=*/[{
- Return the output buffers (operands) from the current operation.
+ Return the output buffers (operands).
}],
/*retTy=*/"Operation::operand_range",
/*methodName=*/"getOutputBuffers",
@@ -354,7 +409,9 @@
return getInputShapedType(i);
if (i < getNumInputsAndOutputBuffers())
return getOutputBufferType(i - $_op.getNumInputs());
- return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()];
+ return this->getOperation()->getResult(
+ i - getNumInputsAndOutputBuffers()).
+ getType().template cast();
}]>,
InterfaceMethod<
/*desc=*/[{
@@ -408,11 +465,7 @@
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- return llvm::to_vector<4>(
- llvm::map_range($_op.indexing_maps(),
- [](Attribute attr) -> AffineMap {
- return attr.cast().getValue();
- }));
+ return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange());
}]
>,
InterfaceMethod<
@@ -425,10 +478,7 @@
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < getNumInputsAndOutputs());
- return $_op.indexing_maps()
- .getValue()[i]
- .template cast()
- .getValue();
+ return getIndexingMaps()[i];
}]
>,
InterfaceMethod<
@@ -441,10 +491,7 @@
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumInputs());
- return $_op.indexing_maps()
- .getValue()[i]
- .template cast()
- .getValue();
+ return getIndexingMaps()[i];
}]
>,
InterfaceMethod<
@@ -457,10 +504,7 @@
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(i < $_op.getNumOutputs());
- return $_op.indexing_maps()
- .getValue()[i + $_op.getNumInputs()]
- .template cast()
- .getValue();
+ return getIndexingMaps()[i + $_op.getNumInputs()];
}]
>,
InterfaceMethod<
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
@@ -71,6 +71,80 @@
}
};
+/// This class provides a verifier for structured ops that are known to operate
+/// on buffers or tensors and that support `input`, `into` and `init` arguments.
+/// This trait must be used in conjunction with an op definition or a trait that
+/// provides the methods `getNumInputs` and `getNumOutputs`.
+///
+/// Use as a trait as follows:
+///
+/// class MatmulOp : public Op {
+///
+template
+class NamedStructuredOpTraits
+ : public OpTrait::TraitBase {
+public:
+ unsigned getNumInputs() {
+ return cast(this->getOperation()).inputs().size();
+ }
+ unsigned getNumOutputs() {
+ ConcreteType concreteOp = cast(this->getOperation());
+ return concreteOp.output_buffers().size() +
+ concreteOp.output_tensors().size();
+ }
+ static LogicalResult verifyTrait(Operation *op) {
+ ConcreteType concreteOp = cast(op);
+ unsigned nInputAndBufferOperands =
+ concreteOp.getNumInputsAndOutputBuffers();
+ if (failed(
+ OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands)))
+ return failure();
+
+ SmallVector redDims;
+ concreteOp.getReductionDims(redDims);
+ // If no result and no reduction, only check there is no init tensor and we
+ // are done.
+ if (redDims.empty() || op->getNumResults() == 0) {
+ if (!concreteOp.init_tensors().empty())
+ return op->emitError("expected empty init_tensors when op has no "
+ "results or no reduction dims");
+ return success();
+ }
+
+ // Only a single tensor result supported atm.
+ if (op->getNumResults() != 1)
+ return op->emitError(
+ "expected single tensor result when reduction present");
+
+ if (concreteOp.init_tensors().size() != op->getNumResults())
+ return op->emitError(
+ "expected #init tensors to match #results when reduction present");
+
+ for (unsigned idx = 0, e = op->getNumResults(); idx < e; ++idx)
+ if (concreteOp.init_tensors()[idx].getType() != op->getResultTypes()[idx])
+ return op->emitError("expected init tensor #")
+ << idx << " of the same type as result #" << idx;
+
+ // Output tensor indexing map may not depend on reduction index.
+ // TODO: this is not yet tested. Add a test when linalg.generic switches to
+ // this representation.
+ for (unsigned idx = 0, e = concreteOp.getNumOutputs(); idx < e; ++idx) {
+ AffineMap outputMap = concreteOp.getOutputIndexingMap(idx);
+ for (auto expr : outputMap.getResults()) {
+ for (auto dim : redDims) {
+ unsigned pos = dim.cast().getPosition();
+ if (expr.isFunctionOfDim(pos))
+ return op->emitError(
+ "unexpected single tensor output indexing map ")
+ << "is function of reduction dim @" << pos;
+ }
+ }
+ }
+
+ return success();
+ }
+};
+
} // namespace linalg
} // namespace OpTrait
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td
@@ -15,8 +15,6 @@
include "mlir/IR/OpBase.td"
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
-
//===----------------------------------------------------------------------===//
// Shape Inference dialect definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -571,6 +571,10 @@
def AnyVector : VectorOf<[AnyType]>;
+// Shaped types.
+
+def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
+
// Tensor types.
// Any tensor type whose element type is from the given `allowedTypes` list
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_1d(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_1d %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_1d ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_1d_ncw(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_1d_ncw ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_1d_nwc(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_1d_nwc ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_2d(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_2d %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_2d ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_2d_nchw(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_2d_nchw ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_2d_nhwc(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_2d_nhwc ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_3d(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_3d %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_3d ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_3d_ncdhw(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_3d_ncdhw ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
--- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
+++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir
@@ -34,7 +34,8 @@
}
func @conv_3d_ndhwc(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref, memref, memref)
+ linalg.conv_3d_ndhwc ins (%arg0, %arg1: memref, memref)
+ outs (%arg2: memref)
return
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -34,14 +34,20 @@
/// Forward declarations.
template
-static void buildNamedStructuredOpRegionAndAttributes(
- Builder &builder, OperationState &result, TypeRange operandTypes,
- TypeRange tensorResultTypes);
-template
-static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
+ TypeRange inputTypes, TypeRange outputBufferTypes,
+ TypeRange initTensorTypes, TypeRange resultTypes);
template
-static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
- OperationState &result);
+static void buildNamedStructuredOpRegionAndAttributes(
+ OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes,
+ TypeRange outputBufferTypes, TypeRange initTensorTypes,
+ TypeRange resultTypes);
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+ SmallVectorImpl &resultTypes);
+static void printNamedStructuredOpResults(OpAsmPrinter &p,
+ TypeRange resultTypes);
template
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
@@ -247,11 +253,6 @@
static LogicalResult verifyGenericOp(GenericOpType op) {
auto nInputViews = op.getNumInputs();
auto nLoops = op.getNumLoops();
- auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
- if (nInputsAndOutputBuffers != llvm::size(op.views()))
- return op.emitOpError("expected exactly ")
- << nInputsAndOutputBuffers
- << " inputs (tensor or buffer) and output buffer operands";
auto ®ion = op.region();
if (!llvm::hasSingleElement(region))
@@ -301,8 +302,27 @@
return success();
}
-static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); }
-static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
+static LogicalResult verify(GenericOp op) {
+ // Temporarily hoisted here to avoid duplicating more code.
+ // TODO: uniformize with named structured ops.
+ auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+ if (nInputsAndOutputBuffers != llvm::size(op.views()))
+ return op.emitOpError("expected exactly ")
+ << nInputsAndOutputBuffers
+ << " inputs (tensor or buffer) and output buffer operands";
+ return verifyGenericOp(op);
+}
+
+static LogicalResult verify(IndexedGenericOp op) {
+ // Temporarily hoisted here to avoid duplicating more code.
+ // TODO: uniformize with named structured ops.
+ auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers();
+ if (nInputsAndOutputBuffers != llvm::size(op.views()))
+ return op.emitOpError("expected exactly ")
+ << nInputsAndOutputBuffers
+ << " inputs (tensor or buffer) and output buffer operands";
+ return verifyGenericOp(op);
+}
//===----------------------------------------------------------------------===//
// ReshapeOp
@@ -1097,6 +1117,8 @@
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
@@ -1221,23 +1243,26 @@
//===----------------------------------------------------------------------===//
template
-void buildNamedStructuredOpRegionAndAttributes(Builder &builder,
- OperationState &result,
- TypeRange operandTypes,
- TypeRange tensorResultTypes) {
- Region ®ion = *result.addRegion();
+static void buildNamedStructuredOpRegionAndAttributesImpl(
+ OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes,
+ TypeRange outputBufferTypes, TypeRange initTensorTypes,
+ TypeRange resultTypes,
+ std::function errorHandler) {
Block *body = new Block();
// TODO: atm all operands go through getElementTypeOrSelf,
// reconsider when we have evidence we need to.
- for (auto t : operandTypes)
- body->addArgument(getElementTypeOrSelf(t));
- for (auto t : tensorResultTypes)
- body->addArgument(getElementTypeOrSelf(t));
+ for (auto containers : {inputTypes, outputBufferTypes, resultTypes})
+ for (auto t : containers)
+ body->addArgument(getElementTypeOrSelf(t));
region.push_back(body);
- OpBuilder opBuilder(builder.getContext());
+ unsigned actual = body->getNumArguments();
+ unsigned expected = NamedStructuredOpType::getNumRegionArgs();
+ if (expected != actual)
+ return errorHandler(expected, actual);
+
opBuilder.setInsertionPointToStart(®ion.front());
- mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc());
+ mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
NamedStructuredOpType::regionBuilder(*body);
// indexing_maps is an auto-generated method.
@@ -1246,59 +1271,56 @@
}
template
-static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
- std::array silentAttrNames{getIndexingMapsAttrName(),
- getIteratorTypesAttrName()};
- p << op.getOperationName() << ' ';
- p.printOptionalAttrDict(op.getAttrs(), silentAttrNames);
- p << ' ' << op.getOperands();
- p << " : (" << op.getOperandTypes() << ")";
- auto outputTensorTypes = op.getResultTypes();
- if (!outputTensorTypes.empty())
- p << " -> (" << outputTensorTypes << ")";
+void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+ OperationState &result,
+ TypeRange inputTypes,
+ TypeRange outputBufferTypes,
+ TypeRange initTensorTypes,
+ TypeRange resultTypes) {
+ // TODO: why does programmatic creation fail if not using a local builder?
+ OpBuilder localOpBuilder(opBuilder.getContext());
+ Region ®ion = *result.addRegion();
+ buildNamedStructuredOpRegionAndAttributesImpl(
+ localOpBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
+ resultTypes, [&](unsigned expected, unsigned actual) {
+ llvm::errs() << "region expects " << expected << " args, got "
+ << actual;
+ assert(expected != actual && "incorrect number of arguments");
+ });
}
template
-static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
- OperationState &result) {
- SmallVector operandsInfo;
- result.getContext()->getOrLoadDialect();
-
- // Optional attributes may be added.
- if (parser.parseOperandList(operandsInfo) ||
- parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- SmallVector operandTypes;
- if (parser.parseColon() || parser.parseLParen() ||
- parser.parseTypeList(operandTypes) || parser.parseRParen())
- return failure();
-
- // Generic ops may specify that a subset of its outputs are tensors. Such
- // outputs are specified in the result type.
- SmallVector tensorResultTypes;
- if (parser.parseOptionalArrowTypeList(tensorResultTypes))
- return failure();
-
- if (!tensorResultTypes.empty())
- result.addTypes(tensorResultTypes);
-
- // The number of parsed arguments must equal
- // the number of expected arguments for the current operation.
- auto parsedArgs = operandsInfo.size();
- auto expectedArgs = NamedStructuredOpType::getNumInputs() +
- NamedStructuredOpType::getNumOutputs();
- if (parsedArgs != expectedArgs)
- return parser.emitError(parser.getNameLoc(),
- "expects " + std::to_string(expectedArgs) +
- " operands, but found " +
- std::to_string(parsedArgs));
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion,
+ TypeRange inputTypes, TypeRange outputBufferTypes,
+ TypeRange initTensorTypes, TypeRange resultTypes) {
+ ParseResult res = success();
+ OpBuilder opBuilder(parser.getBuilder().getContext());
+ buildNamedStructuredOpRegionAndAttributesImpl(
+ opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes,
+ resultTypes, [&](unsigned expected, unsigned actual) {
+ std::string msg = std::string("region expects ") +
+ std::to_string(expected) +
+ std::string(" args, got ") + std::to_string(actual);
+ res = parser.emitError(parser.getCurrentLocation(), msg);
+ });
+ return res;
+}
- buildNamedStructuredOpRegionAndAttributes(
- parser.getBuilder(), result, operandTypes, tensorResultTypes);
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+ SmallVectorImpl &resultTypes) {
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(resultTypes))
+ return failure();
+ return success();
+}
- return parser.resolveOperands(operandsInfo, operandTypes,
- parser.getCurrentLocation(), result.operands);
+static void printNamedStructuredOpResults(OpAsmPrinter &p,
+ TypeRange resultTypes) {
+ if (resultTypes.empty())
+ return;
+ p << "-> " << resultTypes;
}
template
@@ -1353,8 +1375,6 @@
CANONICALIZERS_AND_FOLDERS(GenericOp)
CANONICALIZERS_AND_FOLDERS(IndexedGenericOp)
-#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
-
// TODO: Determine whether we can generate the folders and verifiers.
CANONICALIZERS_AND_FOLDERS(BatchMatmulOp)
CANONICALIZERS_AND_FOLDERS(DotOp)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp
@@ -58,6 +58,8 @@
//===----------------------------------------------------------------------===//
void mlir::linalg::LinalgDialect::initialize() {
+ getContext()->getOrLoadDialect("std");
+
addTypes();
addOperations<
#define GET_OP_LIST
@@ -67,6 +69,7 @@
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>();
+
addInterfaces();
}
diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
--- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
+++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir
@@ -13,7 +13,8 @@
// CHECK-DAG: #[[$map10:.*]] = affine_map<(d0, d1, d2, d3) -> ()>
func @conv_1d(%arg0: memref<3xf32>, %arg1: memref<3xf32>, %arg2: memref) {
- linalg.conv_1d %arg0, %arg1, %arg2 : (memref<3xf32>, memref<3xf32>, memref)
+ linalg.conv_1d ins(%arg0, %arg1 : memref<3xf32>, memref<3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -30,7 +31,8 @@
// CHECK: return
func @conv_1d_ncw(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref) {
- linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref)
+ linalg.conv_1d_ncw ins(%arg0, %arg1 : memref<1x3x3xf32>, memref<1x3x3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -48,7 +50,8 @@
func @conv_1d_nwc(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref) {
- linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref)
+ linalg.conv_1d_nwc ins(%arg0, %arg1 : memref<1x3x3xf32>, memref<1x3x3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -65,7 +68,8 @@
// CHECK: return
func @conv_2d(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>, %arg2: memref) {
- linalg.conv_2d %arg0, %arg1, %arg2 : (memref<3x3xf32>, memref<3x3xf32>, memref)
+ linalg.conv_2d ins(%arg0, %arg1 : memref<3x3xf32>, memref<3x3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -82,7 +86,8 @@
// CHECK: return
func @conv_2d_nchw(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref) {
- linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref)
+ linalg.conv_2d_nchw ins(%arg0, %arg1 : memref<1x3x3x3xf32>, memref<1x3x3x3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -99,7 +104,8 @@
// CHECK: return
func @conv_2d_nhwc(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref) {
- linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref)
+ linalg.conv_2d_nhwc ins(%arg0, %arg1 : memref<1x3x3x3xf32>, memref<1x3x3x3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -116,7 +122,8 @@
// CHECK: return
func @conv_3d(%arg0: memref<3x3x3xf32>, %arg1: memref<3x3x3xf32>, %arg2: memref) {
- linalg.conv_3d %arg0, %arg1, %arg2 : (memref<3x3x3xf32>, memref<3x3x3xf32>, memref)
+ linalg.conv_3d ins(%arg0, %arg1 : memref<3x3x3xf32>, memref<3x3x3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -133,7 +140,8 @@
// CHECK: return
func @conv_3d_ncdhw(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref) {
- linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref)
+ linalg.conv_3d_ncdhw ins(%arg0, %arg1 : memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>)
+ outs(%arg2 : memref)
return
}
@@ -150,7 +158,8 @@
// CHECK: return
func @conv_3d_ndhwc(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref) {
- linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref)
+ linalg.conv_3d_ndhwc ins(%arg0, %arg1 : memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>)
+ outs(%arg2 : memref)
return
}
diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir
--- a/mlir/test/Dialect/Linalg/affine.mlir
+++ b/mlir/test/Dialect/Linalg/affine.mlir
@@ -15,7 +15,8 @@
%A = view %arg0[%c0][%M, %K] : memref to memref
%B = view %arg0[%c0][%K, %N] : memref to memref
%C = view %arg0[%c0][%M, %N] : memref to memref
- linalg.matmul %A, %B, %C : (memref, memref, memref)
+ linalg.matmul ins(%A, %B: memref, memref)
+ outs(%C: memref)
return
}
@@ -102,7 +103,8 @@
// Named ops to loops.
//----------------------------------------------------------------------------//
func @named_batch_matmul(%A: memref, %B: memref, %C: memref) {
- linalg.batch_matmul %A, %B, %C : (memref, memref, memref) -> ()
+ linalg.batch_matmul ins(%A, %B: memref, memref)
+ outs(%C : memref)
return
}
// CHECK-LABEL: @named_batch_matmul
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -14,8 +14,9 @@
// CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref
%4 = linalg.slice %3[%r0, %r0] : memref, !linalg.range, !linalg.range, memref
- // CHECK: linalg.matmul{{.*}}: (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>)
- linalg.matmul %3, %3, %3 : (memref, memref, memref)
+ // CHECK: linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>)
+ linalg.matmul ins(%3, %3: memref, memref)
+ outs(%3: memref)
return %4: memref
}
diff --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
--- a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
+++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns
-//| FileCheck %s
+// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns | FileCheck %s
// CHECK-LABEL: scf_for
func @scf_for(%A : memref, %step : index) {
diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
--- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir
@@ -12,7 +12,8 @@
%0 = dim %C, %c0 : memref
%1 = dim %C, %c1 : memref
%2 = dim %D, %c1 : memref
- linalg.matmul %A, %B, %C : (memref, memref, memref)
+ linalg.matmul ins(%A, %B: memref, memref)
+ outs(%C: memref)
scf.for %arg5 = %c0 to %0 step %c20 {
scf.for %arg6 = %c0 to %2 step %c30 {
scf.for %arg7 = %c0 to %1 step %c40 {
@@ -28,7 +29,8 @@
%14 = std.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref to memref
%16 = std.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref to memref
%17 = std.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref to memref
- linalg.matmul %14, %16, %17 : (memref, memref, memref)
+ linalg.matmul ins(%14, %16: memref, memref)
+ outs(%17: memref)
}
}
}
diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir
--- a/mlir/test/Dialect/Linalg/fusion.mlir
+++ b/mlir/test/Dialect/Linalg/fusion.mlir
@@ -14,10 +14,9 @@
%0 = dim %A, %c0 : memref
%1 = dim %A, %c1 : memref
%2 = dim %B, %c1 : memref
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
@@ -30,10 +29,9 @@
%8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %5, %7, %8 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%5, %7 : memref,
+ memref)
+ outs(%8: memref)
}
}
}
@@ -61,10 +59,9 @@
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C: memref)
%0 = dim %C, %c0 : memref
%1 = dim %C, %c1 : memref
%2 = dim %D, %c1 : memref
@@ -80,10 +77,9 @@
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %5, %7, %8 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%5, %7 : memref,
+ memref)
+ outs(%8 : memref)
}
}
}
@@ -113,10 +109,9 @@
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
%0 = dim %D, %c0 : memref
%1 = dim %D, %c1 : memref
%2 = dim %C, %c1 : memref
@@ -132,10 +127,9 @@
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %5, %7, %8 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%5, %7 : memref,
+ memref)
+ outs(%8 : memref)
}
}
}
@@ -165,14 +159,12 @@
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
- linalg.matmul %A, %B, %D :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%D : memref)
%0 = dim %C, %c0 : memref
%1 = dim %C, %c1 : memref
%2 = dim %D, %c1 : memref
@@ -188,10 +180,9 @@
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %5, %7, %8 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%5, %7 : memref,
+ memref)
+ outs(%8 : memref)
}
}
}
@@ -227,14 +218,12 @@
%0 = dim %B, %c1 : memref
%1 = dim %D, %c0 : memref
%2 = dim %D, %c1 : memref
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
- linalg.matmul %C, %B, %D :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
+ linalg.matmul ins(%C, %B : memref,
+ memref)
+ outs(%D : memref)
scf.for %arg5 = %c0 to %1 step %c2 {
scf.for %arg6 = %c0 to %0 step %c3 {
scf.for %arg7 = %c0 to %2 step %c4 {
@@ -247,10 +236,9 @@
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %5, %7, %8 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%5, %7 : memref,
+ memref)
+ outs(%8 : memref)
}
}
}
@@ -275,9 +263,9 @@
// CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}]
// CHECK-DAG: %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}]
// CHECK-DAG: %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}]
-// CHECK: linalg.matmul %[[A_I0]], %[[B_00]], %[[C_I0_]]
-// CHECK: linalg.matmul %[[C_I0]], %[[B_0K]], %[[D_IK_]]
-// CHECK: linalg.matmul %[[D_IK]], %[[B_KJ]], %[[E_IJ]]
+// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]]
+// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]]
+// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
// -----
@@ -297,14 +285,12 @@
%c3 = constant 3 : index
%c2 = constant 2 : index
%0 = dim %C, %c1 : memref
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
- linalg.matmul %A, %C, %E :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
+ linalg.matmul ins(%A, %C : memref,
+ memref)
+ outs(%E : memref)
%1 = dim %C, %c0 : memref
%2 = dim %D, %c1 : memref
scf.for %arg5 = %c0 to %1 step %c2 {
@@ -322,10 +308,9 @@
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %5, %7, %8 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%5, %7 : memref,
+ memref)
+ outs(%8 : memref)
}
}
}
@@ -359,14 +344,12 @@
%2 = dim %C, %c1 : memref
%3 = dim %C, %c0 : memref
%4 = dim %D, %c1 : memref
- linalg.matmul %A, %C, %E :
- (memref,
- memref,
- memref)
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %C : memref,
+ memref)
+ outs(%E : memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
scf.for %arg7 = %c0 to %1 step %c4 {
@@ -379,10 +362,9 @@
%10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %7, %9, %10 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%7, %9 : memref,
+ memref)
+ outs(%10 : memref)
}
}
}
@@ -398,10 +380,9 @@
%10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %7, %9, %10 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%7, %9 : memref,
+ memref)
+ outs(%10 : memref)
}
}
}
@@ -414,7 +395,7 @@
// CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref
// CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref
// CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref
-// CHECK: linalg.matmul %[[A]], %[[C]], %[[E]]
+// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]]
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
// CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
@@ -445,14 +426,12 @@
%c2 = constant 2 : index
%0 = dim %A, %c0 : memref
%1 = dim %A, %c1 : memref
- linalg.matmul %A, %C, %D :
- (memref,
- memref,
- memref)
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %C : memref,
+ memref)
+ outs(%D : memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
%2 = dim %D, %c1 : memref
scf.for %arg5 = %c0 to %0 step %c2 {
scf.for %arg6 = %c0 to %2 step %c3 {
@@ -469,10 +448,9 @@
%8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %5, %7, %8 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%5, %7 : memref,
+ memref)
+ outs(%8 : memref)
}
}
}
@@ -742,10 +720,9 @@
%B = alloca(%dim, %dim)[%s0, %s1] : memref
%C = alloc(%dim, %dim)[%s0, %s1] : memref
- linalg.matmul %A, %B, %C :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%A, %B : memref,
+ memref)
+ outs(%C : memref)
scf.for %i = %c0 to %dim step %c2 {
scf.for %j = %c0 to %dim step %c3 {
@@ -759,10 +736,9 @@
%2 = std.subview %C[%i, %j][%c2, %c3][%c1, %c1] :
memref to
memref
- linalg.matmul %0, %1, %2 :
- (memref,
- memref,
- memref)
+ linalg.matmul ins(%0, %1 : memref,
+ memref)
+ outs(%2 : memref)
}
}
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -428,13 +428,6 @@
// -----
-func @generic_result_0_element_type(%arg0: memref) {
- // expected-error @+1 {{'linalg.dot' expects 3 operands, but found 2}}
- linalg.dot %arg0, %arg0 : (memref, memref)
-}
-
-// -----
-
func @conv_rank_limit(%arg0: memref, %arg1: memref, %arg2: memref) {
// expected-error @+1 {{expects memref ranks to be greater than 2}}
linalg.conv(%arg0, %arg1, %arg2) : memref, memref, memref
@@ -511,7 +504,8 @@
func @named_ops(%a3: memref, %b3: memref, %c3: memref) {
// expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref'}}
- linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> ()
+ linalg.batch_matmul ins(%a3, %b3: memref, memref)
+ outs(%c3 : memref)
return
}
@@ -531,3 +525,52 @@
} : tensor -> (tensor, tensor)
return
}
+
+// -----
+
+func @empty_init_expected(%m: memref, %t: tensor) {
+ // expected-error @+1 {{expected empty init_tensors when op has no results or no reduction dims}}
+ linalg.matmul ins(%m, %m: memref, memref)
+ outs(%m : memref)
+ init(%t : tensor)
+ return
+}
+
+// -----
+
+func @incorrect_region_arg_count(%m: memref) {
+ // expected-error @+3 {{region expects 3 args, got 4}}
+ %res = linalg.matmul ins(%m, %m : memref, memref)
+ -> tensor, tensor
+ return
+}
+
+// -----
+
+func @single_tensor_result(%m: memref, %t: tensor) {
+ // expected-error @+1 {{expected single tensor result when reduction present}}
+ %res:2 = linalg.matmul ins(%m : memref)
+ init(%t, %t : tensor, tensor)
+ -> tensor, tensor
+ return
+}
+
+// -----
+
+func @matching_inits(%m: memref, %t: tensor) {
+ // expected-error @+1 {{expected #init tensors to match #results when reduction present}}
+ %res = linalg.matmul ins(%m, %m : memref, memref)
+ init(%t, %t : tensor, tensor)
+ -> tensor
+ return
+}
+
+// -----
+
+func @matching_inits(%m: memref, %t: tensor) {
+ // expected-error @+1 {{expected init tensor #0 of the same type as result #0}}
+ %res = linalg.matmul ins(%m, %m : memref, memref)
+ init(%t : tensor)
+ -> tensor
+ return
+}
diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir
--- a/mlir/test/Dialect/Linalg/loops.mlir
+++ b/mlir/test/Dialect/Linalg/loops.mlir
@@ -39,7 +39,8 @@
%A = view %arg0[%c0][%M, %K] : memref to memref
%B = view %arg0[%c0][%K, %N] : memref to memref
%C = view %arg0[%c0][%M, %N] : memref to memref
- linalg.matmul %A, %B, %C : (memref, memref, memref)
+ linalg.matmul ins(%A, %B: memref, memref)
+ outs(%C: memref)
return
}
// CHECKLOOP-LABEL: func @matmul(%{{.*}}: memref,
@@ -83,7 +84,8 @@
%2 = view %arg0[%c0][%M, %N] : memref to memref
%3 = view %arg0[%c0][%M] : memref to memref
%4 = view %arg0[%c0][%N] : memref to memref
- linalg.matvec %2, %3, %4 : (memref, memref, memref)
+ linalg.matvec ins(%2, %3: memref, memref)
+ outs(%4 : memref)
return
}
// CHECKLOOP-LABEL: func @matvec(%{{.*}}: memref,
@@ -123,7 +125,8 @@
%1 = view %arg0[%c0][%M] : memref to memref
%2 = view %arg0[%c0][%M] : memref to memref
%3 = view %arg0[%c0][] : memref to memref
- linalg.dot %1, %2, %3 : (memref, memref, memref)
+ linalg.dot ins(%1, %2 : memref, memref)
+ outs(%3 : memref)
return
}
// CHECKLOOP-LABEL: func @dot(%{{.*}}: memref,
@@ -154,9 +157,9 @@
func @dot_view(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.dot %arg0, %arg1, %arg2 : (memref,
- memref,
- memref)
+ linalg.dot ins(%arg0, %arg1 : memref,
+ memref)
+ outs(%arg2: memref)
return
}
// CHECKLOOP-LABEL: func @dot_view(
@@ -880,7 +883,8 @@
// Named ops to loops.
//----------------------------------------------------------------------------//
func @named_batch_matmul(%A: memref, %B: memref, %C: memref) {
- linalg.batch_matmul %A, %B, %C : (memref, memref, memref) -> ()
+ linalg.batch_matmul ins(%A, %B : memref, memref)
+ outs(%C : memref)
return
}
// CHECKLOOP-LABEL: @named_batch_matmul
@@ -1288,7 +1292,8 @@
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref
func @conv1d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () {
- linalg.conv_1d %in, %filter, %out : (memref, memref, memref)
+ linalg.conv_1d ins(%in, %filter : memref, memref)
+ outs(%out : memref)
return
}
@@ -1330,7 +1335,8 @@
func @conv2d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () {
- linalg.conv_2d %in, %filter, %out : (memref, memref, memref)
+ linalg.conv_2d ins(%in, %filter : memref, memref)
+ outs(%out: memref)
return
}
// CHECKLOOP-LABEL: @conv2d_no_symbols
@@ -1382,7 +1388,8 @@
func @conv3d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () {
- linalg.conv_3d %in, %filter, %out : (memref, memref, memref)
+ linalg.conv_3d ins(%in, %filter : memref, memref)
+ outs(%out : memref)
return
}
diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir
--- a/mlir/test/Dialect/Linalg/promote.mlir
+++ b/mlir/test/Dialect/Linalg/promote.mlir
@@ -27,10 +27,10 @@
%11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref
%14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref
%17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref
- linalg.matmul %11, %14, %17 :
- (memref,
- memref,
- memref)
+ linalg.matmul
+ ins(%11, %14: memref,
+ memref)
+ outs(%17: memref)
}
}
}
@@ -67,10 +67,7 @@
// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref
// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref
//
-// CHECK: linalg.matmul %[[partialA]], %[[partialB]], %[[partialC]] :
-// CHECK: memref,
-// CHECK: memref,
-// CHECK: memref
+// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
//
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) :
// CHECK: memref,
@@ -103,10 +100,10 @@
%11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref
%14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref
%17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref
- linalg.matmul %11, %14, %17 :
- (memref,
- memref,
- memref)
+ linalg.matmul
+ ins(%11, %14: memref,
+ memref)
+ outs(%17: memref)
}
}
}
@@ -140,10 +137,7 @@
// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref
// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref
//
-// CHECK: linalg.matmul %[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]] :
-// CHECK: memref,
-// CHECK: memref,
-// CHECK: memref
+// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
//
// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) :
// CHECK: memref,
diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir
--- a/mlir/test/Dialect/Linalg/promotion_options.mlir
+++ b/mlir/test/Dialect/Linalg/promotion_options.mlir
@@ -2,8 +2,9 @@
func @gemm(%a : memref, %b : memref, %c : memref)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "START"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "START"}
+ ins(%a, %b: memref, memref)
+ outs(%c: memref)
return
}
@@ -26,7 +27,7 @@
// CHECK: linalg.copy(%[[T7]], %[[T19]])
// CHECK: linalg.fill(%[[T21]], %[[C42]])
// CHECK: linalg.copy(%[[T17]], %[[T21]])
-// CHECK: linalg.matmul %[[T19]], %[[T12]], %[[T21]]
+// CHECK: linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]]
// CHECK-NOT: linalg.fill
// CHECK: linalg.copy(%[[T21]], %[[T17]])
// CHECK: dealloc %[[T18]]
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -83,30 +83,30 @@
%arg1: memref,
%arg2: memref,
%arg3: memref) {
- linalg.matmul %arg0, %arg0, %arg0 : (memref,
- memref,
- memref)
- linalg.matvec %arg0, %arg1, %arg2 : (memref,
- memref,
- memref)
- linalg.dot %arg1, %arg2, %arg3 : (memref,
- memref,
- memref)
+ linalg.matmul ins(%arg0, %arg0 : memref,
+ memref)
+ outs(%arg0 : memref)
+ linalg.matvec ins(%arg0, %arg1: memref,
+ memref)
+ outs(%arg2: memref)
+ linalg.dot ins(%arg1, %arg2: memref,
+ memref)
+ outs(%arg3: memref)
return
}
// CHECK-LABEL: func @ops(%
-// CHECK-NEXT: linalg.matmul %{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME: (memref,
-// CHECK-SAME: memref,
-// CHECK-SAME: memref)
-// CHECK-NEXT: linalg.matvec %{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME: (memref,
-// CHECK-SAME: memref,
-// CHECK-SAME: memref)
-// CHECK-NEXT: linalg.dot %{{.*}}, %{{.*}}, %{{.*}} :
-// CHECK-SAME: (memref,
-// CHECK-SAME: memref,
-// CHECK-SAME: memref)
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%{{.*}}, %{{.*}} : memref,
+// CHECK-SAME: memref)
+// CHECK-SAME: outs(%{{.*}} : memref)
+// CHECK: linalg.matvec
+// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref,
+// CHECK-SAME: memref)
+// CHECK-SAME: outs(%{{.*}}: memref)
+// CHECK: linalg.dot
+// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref,
+// CHECK-SAME: memref)
+// CHECK-SAME: outs(%{{.*}}: memref)
// -----
@@ -619,17 +619,27 @@
// CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]]
// CHECK-SAME: memref into memref
-
-// TODO: Return tensors need a semantics convention update.
func @named_ops(%a3: memref, %b3: memref, %c3: memref,
- %ta3: tensor, %tb3: tensor, %tc3: tensor) {
- linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> ()
- linalg.batch_matmul %ta3, %tb3, %c3 : (tensor, tensor, memref) -> ()
- return
+ %ta3: tensor, %tb3: tensor, %tc3: tensor)
+ -> (tensor, tensor)
+{
+ linalg.batch_matmul ins(%a3, %b3: memref, memref)
+ outs(%c3: memref)
+ linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor)
+ outs(%c3: memref)
+ %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor)
+ init(%tc3: tensor)
+ -> tensor
+ %res2 = linalg.batch_matmul ins(%ta3, %b3: tensor, memref)
+ init(%tc3: tensor)
+ -> tensor
+ return %res1, %res2 : tensor, tensor
}
// CHECK-LABEL: func @named_ops
// CHECK: linalg.batch_matmul
// CHECK: linalg.batch_matmul
+// CHECK: linalg.batch_matmul
+// CHECK: linalg.batch_matmul
// -----
diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir
--- a/mlir/test/Dialect/Linalg/standard.mlir
+++ b/mlir/test/Dialect/Linalg/standard.mlir
@@ -13,9 +13,9 @@
func @dot(%arg0: memref,
%arg1: memref,
%arg2: memref) {
- linalg.dot %arg0, %arg1, %arg2 : (memref,
- memref,
- memref)
+ linalg.dot ins(%arg0, %arg1: memref,
+ memref)
+ outs(%arg2: memref)
return
}
// CHECK-LABEL: func @dot(
diff --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
--- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir
@@ -2,8 +2,9 @@
func @gemm1(%a : memref, %b : memref, %c : memref)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute1"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "distribute1"}
+ ins(%a, %b: memref, memref)
+ outs(%c: memref)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -21,14 +22,15 @@
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm2(%a : memref, %b : memref, %c : memref)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute2"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "distribute2"}
+ ins(%a, %b: memref, memref)
+ outs(%c:memref)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -52,14 +54,15 @@
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm3(%a : memref, %b : memref, %c : memref)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute3"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "distribute3"}
+ ins(%a, %b: memref, memref)
+ outs(%c: memref)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -80,14 +83,15 @@
// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm4(%a : memref, %b : memref, %c : memref)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute4"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "distribute4"}
+ ins(%a, %b: memref, memref)
+ outs(%c: memref)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -108,14 +112,15 @@
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm5(%a : memref, %b : memref, %c : memref)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute5"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "distribute5"}
+ ins(%a, %b: memref, memref)
+ outs(%c: memref)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -138,14 +143,15 @@
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
func @gemm6(%a : memref, %b : memref, %c : memref)
{
- linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute6"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "distribute6"}
+ ins(%a, %b: memref, memref)
+ outs(%c: memref)
return
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)>
@@ -165,4 +171,4 @@
// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir
--- a/mlir/test/Dialect/Linalg/tile.mlir
+++ b/mlir/test/Dialect/Linalg/tile.mlir
@@ -31,10 +31,10 @@
func @matmul(%arg0: memref,
%arg1: memref,
%arg2: memref) {
- linalg.matmul %arg0, %arg1, %arg2 :
- (memref,
- memref,
- memref)
+ linalg.matmul
+ ins(%arg0, %arg1: memref,
+ memref)
+ outs(%arg2: memref)
return
}
// TILE-2-LABEL: func @matmul(
@@ -50,10 +50,7 @@
// TILE-2: %[[szK:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localK]]]
// TILE-2: %[[N:.*]] = dim %{{.*}}, %c1 : memref
// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref
-// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] :
-// TILE-2: (memref,
-// TILE-2: memref,
-// TILE-2: memref)
+// TILE-2: linalg.matmul ins(%[[sAi]]{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matmul(
// TILE-02-DAG: %[[C0:.*]] = constant 0 : index
@@ -68,10 +65,7 @@
// TILE-02: %[[localK:.*]] = dim %{{.*}}, %c1
// TILE-02: %[[szK:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localK]]]
// TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref to memref
-// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-// TILE-02: (memref,
-// TILE-02: memref,
-// TILE-02: memref)
+// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
// TILE-002-LABEL: func @matmul(
// TILE-002-DAG: %[[C0:.*]] = constant 0 : index
@@ -86,10 +80,7 @@
// TILE-002: %[[szK:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[localK]]]
// TILE-002: %[[N:.*]] = dim %{{.*}}, %c1 : memref
// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref
-// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-// TILE-002: (memref,
-// TILE-002: memref,
-// TILE-002: memref)
+// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-234-LABEL: func @matmul(
// TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -118,10 +109,7 @@
// TILE-234: %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[localN]]]
// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref to memref
//
-// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-// TILE-234: (memref,
-// TILE-234: memref,
-// TILE-234: memref)
+// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
// When the buffer shapes are known at compile time, it is possible to avoid
// the "min" in subview size computation. This test uses buffer sizes divisible
@@ -130,10 +118,10 @@
func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>,
%arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) {
- linalg.matmul %arg0, %arg1, %arg2 :
- (memref<10x16xf32, offset: ?, strides: [?, 1]>,
- memref<16x12xf32, offset: ?, strides: [?, 1]>,
- memref<10x12xf32, offset: ?, strides: [?, 1]>)
+ linalg.matmul
+ ins(%arg0, %arg1: memref<10x16xf32, offset: ?, strides: [?, 1]>,
+ memref<16x12xf32, offset: ?, strides: [?, 1]>)
+ outs(%arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>)
return
}
// TILE-2-LABEL: func @matmul_static(
@@ -148,7 +136,7 @@
// TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref
// TILE-2: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]])
// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref
-// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]]
+// TILE-2: linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matmul_static(
// TILE-02-DAG: %[[C0:.*]] = constant 0 : index
@@ -159,10 +147,7 @@
// TILE-02: %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]>
// TILE-02: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]])
// TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
-// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] :
-// TILE-02: (memref<10x16xf32, #[[$strided2D]]>,
-// TILE-02: memref<16x?xf32, #[[$strided2D]]>,
-// TILE-02: memref<10x?xf32, #[[$strided2D]]>)
+// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]]
// TILE-002-LABEL: func @matmul_static(
// TILE-002-DAG: %[[C0:.*]] = constant 0 : index
@@ -173,10 +158,7 @@
// TILE-002: %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]>
// TILE-002: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]])
// TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref
-// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} :
-// TILE-002: (memref<10x?xf32, #[[$strided2D]]>,
-// TILE-002: memref,
-// TILE-002: memref<10x12xf32, #[[$strided2D]]>)
+// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-234-LABEL: func @matmul_static(
// TILE-234-DAG: %[[C0:.*]] = constant 0 : index
@@ -193,16 +175,13 @@
// TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref
// TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref
//
-// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] :
-// TILE-234: (memref,
-// TILE-234: memref,
-// TILE-234: memref)
+// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]]
func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.matvec %arg0, %arg1, %arg2 : (
- memref,
- memref,
- memref)
+ linalg.matvec
+ ins(%arg0, %arg1: memref,
+ memref)
+ outs(%arg2: memref)
return
}
// TILE-2-LABEL: func @matvec(
@@ -220,7 +199,7 @@
// TILE-2: %[[localN:.*]] = dim %{{.*}}, %c0
// TILE-2: %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localN]]]
// TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref to memref
-// TILE-2: linalg.matvec %[[sAi]], %{{.*}}, %[[sCi]] : (memref, memref, memref)
+// TILE-2: linalg.matvec ins(%[[sAi]], %{{.*}} outs(%[[sCi]]
// TILE-02-LABEL: func @matvec(
// TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -237,7 +216,7 @@
// TILE-02: %[[localN:.*]] = dim %{{.*}}, %c0
// TILE-02: %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localN]]]
// TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref to memref
-// TILE-02: linalg.matvec %[[sAj]], %[[sBj]], %{{.*}} : (memref, memref, memref)
+// TILE-02: linalg.matvec ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}}
// TILE-002-LABEL: func @matvec(
// TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
@@ -268,12 +247,12 @@
// TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
// TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref
//
-// TILE-234: linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref, memref, memref)
+// TILE-234: linalg.matvec ins(%[[sAij]], %[[sBj]]{{.*}} outs(%[[sCi]]
func @dot(%arg0: memref, %arg1: memref, %arg2: memref) {
- linalg.dot %arg0, %arg1, %arg2 : (memref,
- memref,
- memref)
+ linalg.dot
+ ins(%arg0, %arg1: memref, memref)
+ outs(%arg2: memref)
return
}
// TILE-2-LABEL: func @dot(
@@ -287,7 +266,7 @@
// TILE-2: %[[localM:.*]] = dim %{{.*}}, %c0
// TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localM]]]
// TILE-2: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref
-// TILE-2: linalg.dot %[[sAi]], %[[sBi]], {{.*}} : (memref, memref, memref)
+// TILE-2: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs(
// TILE-02-LABEL: func @dot(
// TILE-02-NOT: scf.for
@@ -306,7 +285,7 @@
// TILE-234: %[[localM:.*]] = dim %{{.*}}, %c0
// TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]]
// TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref
-// TILE-234: linalg.dot %[[sAi]], %[[sBi]], %{{.*}} : (memref, memref, memref)
+// TILE-234: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs(
func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32
diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
--- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
+++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
@@ -6,8 +6,8 @@
%arg1 : memref,
%arg2 : memref)
{
- linalg.matmul %arg0, %arg1, %arg2
- : (memref, memref, memref)
+ linalg.matmul ins(%arg0, %arg1: memref, memref)
+ outs(%arg2: memref)
return
}
// CHECK-LABEL: func @gemm
@@ -21,7 +21,7 @@
// CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]]
// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// TILE1-LABEL: func @gemm
// TILE1-DAG: %[[C2:.*]] = constant 2 : index
@@ -30,7 +30,7 @@
// TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1-NOT: subview
-// TILE1: linalg.matmul %[[SV1]], %{{.*}}, %[[SV3]]
+// TILE1: linalg.matmul ins(%[[SV1]], %{{.*}} outs(%[[SV3]]
// TILE2-LABEL: func @gemm
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
@@ -40,7 +40,7 @@
// TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE2: %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]]
// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]]
-// TILE2: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]]
+// TILE2: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
// -----
diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
--- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir
@@ -5,10 +5,10 @@
func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
- linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "START"} :
- (memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
- memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
- memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
+ linalg.matmul {__internal_linalg_transform__ = "START"}
+ ins(%A, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
+ memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
+ outs(%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>)
return
}
@@ -36,7 +36,8 @@
func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32
- linalg.dot %A, %B, %C : (memref<1584xf32>, memref<1584xf32>, memref)
+ linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>)
+ outs(%C: memref)
return
}
@@ -44,8 +45,8 @@
func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32>
- linalg.matvec %A, %B, %C :
- (memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>)
+ linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>)
+ outs(%C: memref<1584xf32>)
return
}
@@ -53,8 +54,8 @@
func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32>
- linalg.matmul %A, %B, %C :
- (memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>)
+ linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>)
+ outs(%C: memref<1584x1584xf32>)
return
}
@@ -62,7 +63,8 @@
func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) {
// VECTOR-CONTRACTION: vector.contract
// VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32>
- linalg.batch_matmul %A, %B, %C :
- (memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+ linalg.batch_matmul
+ ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>)
+ outs(%C: memref<1584x1584x1584xf32>)
return
}
diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir
--- a/mlir/test/Dialect/Linalg/transform-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir
@@ -14,10 +14,11 @@
func @dot(%x: memref,
%y: memref,
%v: memref) {
- linalg.dot %x, %y, %v { __internal_linalg_transform__ = "MEM" } :
- (memref,
- memref,
- memref)
+ linalg.dot { __internal_linalg_transform__ = "MEM" }
+ ins(%x, %y: memref,
+ memref)
+ outs(%v: memref)
+
return
}
// CHECK-LABEL: func @dot
@@ -36,10 +37,10 @@
func @matvec(%A: memref,
%x: memref,
%y: memref) {
- linalg.matvec %A, %x, %y :
- (memref,
- memref,
- memref)
+ linalg.matvec
+ ins(%A, %x: memref,
+ memref)
+ outs(%y: memref)
return
}
// CHECK-LABEL: func @matvec
@@ -48,15 +49,17 @@
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: scf.parallel {{.*}} step (%[[c5]])
// CHECK: scf.for {{.*}} step %[[c6]]
-// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref, memref, memref)
+// CHECK: linalg.matvec
+// CHECK: ins({{.*}}, {{.*}}: memref, memref)
+// CHECK: outs({{.*}}: memref)
func @matmul(%A: memref,
%B: memref,
%C: memref) {
- linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "MEM" } :
- (memref,
- memref,
- memref)
+ linalg.matmul { __internal_linalg_transform__ = "MEM" }
+ ins(%A, %B: memref,
+ memref)
+ outs(%C: memref)
return
}
// CHECK-LABEL: func @matmul
@@ -85,10 +88,9 @@
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] {
-// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK: memref,
-// CHECK: memref,
-// CHECK: memref)
+// CHECK: linalg.matmul
+// CHECK: ins({{.*}}, {{.*}}: memref, memref)
+// CHECK: outs({{.*}}: memref)
#matmul_trait = {
args_in = 2,
@@ -137,8 +139,9 @@
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
- linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :
- (memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>)
+ linalg.matmul { __internal_linalg_transform__ = "VECTORIZE"}
+ ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>)
+ outs(%C: memref<8x32xf32>)
return
}
// CHECK-LABEL: func @vectorization_test_2
@@ -236,10 +239,10 @@
func @matvec_perm(%A: memref,
%x: memref,
%y: memref) {
- linalg.matvec %A, %x, %y {__internal_linalg_transform__ = "__with_perm__"} :
- (memref,
- memref,
+ linalg.matvec {__internal_linalg_transform__ = "__with_perm__"}
+ ins(%A, %x: memref,
memref)
+ outs(%y: memref)
return
}
// CHECK-LABEL: func @matvec_perm
@@ -248,15 +251,17 @@
// CHECK-DAG: %[[c6:.*]] = constant 6 : index
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]]
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]]
-// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref, memref, memref)
+// CHECK: linalg.matvec
+// CHECK: ins({{.*}}, {{.*}}: memref, memref)
+// CHECK: outs({{.*}}: memref)
func @matmul_perm(%A: memref,
%B: memref,
%C: memref) {
- linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "__with_perm__"} :
- (memref,
- memref,
+ linalg.matmul {__internal_linalg_transform__ = "__with_perm__"}
+ ins(%A, %B: memref,
memref)
+ outs(%C : memref)
return
}
// CHECK-LABEL: func @matmul_perm
@@ -279,10 +284,9 @@
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] {
// CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] {
-// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : (
-// CHECK: memref,
-// CHECK: memref,
-// CHECK: memref)
+// CHECK: linalg.matmul
+// CHECK: ins({{.*}}, {{.*}}: memref, memref)
+// CHECK: outs({{.*}}: memref)
func @promote_subview_matmul(%arg0: memref,
%arg1: memref,
@@ -304,10 +308,10 @@
memref to memref
%5 = subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
memref to memref
- linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_views_"} :
- (memref,
- memref,
- memref)
+ linalg.matmul {__internal_linalg_transform__ = "_promote_views_"}
+ ins(%3, %4: memref,
+ memref)
+ outs(%5: memref)
}
}
}
@@ -336,8 +340,9 @@
// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref
// CHECK: linalg.copy(%[[s1]], %[[l1]]) : memref, memref
// CHECK: linalg.copy(%[[s2]], %[[l2]]) : memref, memref
-// CHECK: linalg.matmul %[[v0]], %[[v1]], %[[v2]] :
-// CHECK: (memref, memref, memref)
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref, memref)
+// CHECK-SAME: outs(%[[v2]] : memref)
func @promote_first_subview_matmul(%arg0: memref,
%arg1: memref,
@@ -359,10 +364,10 @@
memref to memref
%5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] :
memref to memref
- linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_first_view_"} :
- (memref,
- memref,
- memref)
+ linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"}
+ ins(%3, %4: memref,
+ memref)
+ outs(%5: memref)
}
}
}
@@ -391,10 +396,9 @@
// CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref
// CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref, memref
// CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref, memref^
-// CHECK: linalg.matmul %[[v0]], %[[s1]], %[[s2]] :
-// CHECK: (memref,
-// CHECK: memref,
-// CHECK: memref)
+// CHECK: linalg.matmul
+// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref, memref)
+// CHECK-SAME: outs(%[[s2]] : memref)
func @aligned_promote_fill(%arg0: memref) {
%c2000 = constant 2000 : index
@@ -421,8 +425,9 @@
func @tile_permute_parallel_loop(%arg0: memref,
%arg1: memref,
%arg2: memref) {
- linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "par__with_perm__"}
- : (memref, memref, memref)
+ linalg.matmul {__internal_linalg_transform__ = "par__with_perm__"}
+ ins(%arg0, %arg1: memref, memref)
+ outs(%arg2: memref)
return
}
// CHECK-LABEL: func @tile_permute_parallel_loop
diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir
--- a/mlir/test/IR/slice.mlir
+++ b/mlir/test/IR/slice.mlir
@@ -5,8 +5,10 @@
%b = alloc(%arg2, %arg1) : memref
%c = alloc(%arg0, %arg1) : memref
%d = alloc(%arg0, %arg1) : memref
- linalg.matmul %a, %b, %c : (memref, memref, memref)
- linalg.matmul %a, %b, %d : (memref, memref, memref)
+ linalg.matmul ins(%a, %b : memref, memref)
+ outs(%c : memref)
+ linalg.matmul ins(%a, %b : memref, memref)
+ outs(%d : memref)
dealloc %c : memref
dealloc %b : memref
dealloc %a : memref
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -53,8 +53,6 @@
let results = (outs TensorOf<[ComplexF64]>);
}
-def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">;
-
def TupleOp : TEST_Op<"tuple_32_bit"> {
let results = (outs TupleOf<[I32, F32]>);
}
diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
--- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
+++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir
@@ -51,7 +51,8 @@
%B = view %bB[%c0][%c16] : memref to memref
%C = view %bC[%c0][] : memref to memref
- linalg.dot %A, %B, %C : (memref, memref, memref)
+ linalg.dot ins(%A, %B : memref, memref)
+ outs(%C : memref)
%res = load %C[] : memref
dealloc %bC : memref
@@ -83,7 +84,8 @@
%B = view %bB[%c0][%c16, %c2] : memref to memref
%C = view %bC[%c0][%c2, %c2] : memref to memref
- linalg.matmul %A, %B, %C : (memref, memref, memref)
+ linalg.matmul ins(%A, %B : memref, memref)
+ outs(%C : memref)
%res = load %C[%c0, %c1] : memref
dealloc %bC : memref
diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
--- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
+++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
@@ -1,9 +1,9 @@
// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS
// RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL
-// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [
-// ODS-NEXT: NInputs<2>
-// ODS-NEXT: NOutputs<1>
+// ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1", [
+// ODS-NEXT: NamedStructuredOpTraits
+// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test1Op::iterator_types() {
@@ -25,9 +25,9 @@
C(m) = std_addf(std_mulf(A(m, k), B(k)));
}
-// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [
-// ODS-NEXT: NInputs<2>
-// ODS-NEXT: NOutputs<1>
+// ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [
+// ODS-NEXT: NamedStructuredOpTraits
+// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test2Op::iterator_types() {
@@ -49,9 +49,9 @@
C(m, n) = std_addf(std_mulf(A(m, k), B(k, n)));
}
-// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [
-// ODS-NEXT: NInputs<2>
-// ODS-NEXT: NOutputs<1>
+// ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [
+// ODS-NEXT: NamedStructuredOpTraits
+// ODS-NEXT: AttrSizedOperandSegments
// ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp">
//
// IMPL-LABEL: ArrayAttr Test3Op::iterator_types() {
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -980,7 +980,10 @@
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
void printODS(llvm::raw_ostream &os, StringRef cppOpName,
- StringRef linalgOpName);
+ StringRef linalgOpName, ComprehensionParsingState &state);
+
+ /// Print the C++ parser and printer for `cppOpName`.
+ void printParserAndPrinter(llvm::raw_ostream &os, StringRef cppOpName);
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
@@ -1419,13 +1422,15 @@
return failure();
}
if (genODSDecl) {
- printODS(os, cppOpName, tcName);
+ auto &state = perComprehensionStates.back();
+ printODS(os, cppOpName, tcName, state);
os << "\n";
}
if (genODSImpl) {
auto &state = perComprehensionStates.back();
std::string extraMethods;
llvm::raw_string_ostream ss(extraMethods);
+ printParserAndPrinter(ss, cppOpName);
printReferenceIterators(ss, cppOpName, state);
printReferenceIndexingMaps(ss, cppOpName, state);
printRegionBuilder(ss, cppOpName, state);
@@ -1442,31 +1447,89 @@
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
- StringRef linalgOpName) {
- const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [
- NInputs<{2}>,
- NOutputs<{3}>,
+ StringRef linalgOpName,
+ ComprehensionParsingState &state) {
+ const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
+ NamedStructuredOpTraits,
+ AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"YieldOp">]> {
- let arguments = (ins Variadic:$views);
+ let arguments = (ins Variadic:$inputs,
+ Variadic:$output_buffers,
+ Variadic:$init_tensors);
let results = (outs Variadic:$output_tensors);
- let regions = (region SizedRegion<1>:$region);
- let builders = [OpBuilder<
- "OpBuilder &b, OperationState &result, TypeRange outputTypes, "
- # "ValueRange views",
+ let regions = (region AnyRegion:$region);
+
+ // Format uses a custom return to parse an optional `:` type-list.
+ // Format uses a custom region to elide the programmatically constructed
+ // region.
+ let assemblyFormat = [{
+ attr-dict
+ `ins` `(` $inputs `:` type($inputs) `)`
+ (`outs` `(` $output_buffers^ `:` type($output_buffers) `)`)?
+ (`init` `(` $init_tensors^ `:` type($init_tensors) `)`)?
+ custom(
+ type($output_tensors))
+ custom<{0}NamedStructuredOpRegion>(
+ $region,
+ type_ref($inputs),
+ type_ref($output_buffers),
+ type_ref($init_tensors),
+ type_ref($output_tensors))
+ }];
+
+ let builders = [ OpBuilder<
+ "OpBuilder &b, OperationState &result,"
+ "ValueRange inputs, ValueRange outputBuffers",
[{{
- result.addOperands(views);
- result.addTypes(outputTypes);
+ result.addOperands(inputs);
+ result.addOperands(outputBuffers);
+ result.addAttribute(
+ "operand_segment_sizes",
+ b.getI32VectorAttr({{static_cast(inputs.size()),
+ static_cast(outputBuffers.size()),
+ static_cast(0)}));
buildNamedStructuredOpRegionAndAttributes<{0}>(
- b, result, TypeRange(views), outputTypes);
+ b,
+ result,
+ TypeRange(inputs),
+ TypeRange(outputBuffers),
+ TypeRange(),
+ TypeRange());
+ }]>, OpBuilder<
+ "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes,"
+ "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors",
+ [{{
+ result.addOperands(inputs);
+ result.addOperands(outputBuffers);
+ result.addOperands(initTensors);
+ result.addTypes(resultTensorTypes);
+ result.addAttribute(
+ "operand_segment_sizes",
+ b.getI32VectorAttr({{static_cast(inputs.size()),
+ static_cast(outputBuffers.size()),
+ static_cast(initTensors.size())}));
+ buildNamedStructuredOpRegionAndAttributes<{0}>(
+ b,
+ result,
+ TypeRange(inputs),
+ TypeRange(outputBuffers),
+ TypeRange(initTensors),
+ resultTensorTypes);
}]>
];
- let parser = [{
- return ::parseNamedStructuredOp<{0}>(parser, result);
- }];
+
+ let verifier = [{{ return ::verifyNamedStructuredOp(*this); }];
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+
let extraClassDeclaration = [{{
+ // Auto-generated.
ArrayAttr iterator_types();
ArrayAttr indexing_maps();
static void regionBuilder(Block &block);
+
+ // Generic methods.
+ static unsigned getNumRegionArgs() {{ return {4}; }
std::string getLibraryCallName() {{
return generateLibraryCallName(getOperation());
}
@@ -1481,7 +1544,42 @@
nInputs++;
}
- os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs);
+ os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
+ state.orderedTensorArgs.size());
+}
+
+/// Print the C++ parser and printer for `cppOpName`.
+void TCParser::printParserAndPrinter(llvm::raw_ostream &os,
+ StringRef cppOpName) {
+ const char *parserAndPrinterFmt =
+ R"FMT(
+ static ParseResult parse{0}NamedStructuredOpRegion(
+ OpAsmParser &parser,
+ Region ®ion,
+ TypeRange inputOperands,
+ TypeRange outputBufferOperands,
+ TypeRange initTensorOperands,
+ TypeRange results) {{
+ return parseNamedStructuredOpRegion<{0}>(
+ parser,
+ region,
+ inputOperands,
+ outputBufferOperands,
+ initTensorOperands,
+ results);
+ }
+ static void print{0}NamedStructuredOpRegion(
+ OpAsmPrinter &printer,
+ Region ®ion,
+ TypeRange inputOperands,
+ TypeRange outputBufferOperands,
+ TypeRange initTensorOperands,
+ TypeRange results) {{
+ // noop
+ }
+ )FMT";
+
+ os << llvm::formatv(parserAndPrinterFmt, cppOpName);
}
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -1680,7 +1778,7 @@
}
// Include the proper Linalg header for end-to-end tblgen testing without
- // resorting to non-portable shgell manipulations.
+ // resorting to non-portable shell manipulations.
if (testEmitIncludeTdHeader)
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";