diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -40,7 +40,8 @@ including lowering to scalar load/store and other operations or to external library calls and intrinsics. -These ops can have ***either tensor or buffer operands***. +These ops can have ***either tensor or buffer operands***, subject to +[conventions and limitations](#tensors_and_buffers). ### Payload-Carrying Ops Linalg defines two payload carrying operations that implement the [structured ops]( @@ -463,6 +464,76 @@ compilers. As we lay those down and engage more with the community, we expect multiple rounds of discussions and design changes to the original architecture. +### Tensors and Buffers: Conventions and Limitations + +Tensors are immutable SSA values, buffers are mutable regions of memory subject +to side-effects and aliasing. As a consequence, output buffers are passed as +operands whereas output tensors are new SSA values corresponding to op results. +Inputs can be arbitrary tensors or buffers and are always passed as operands. + +The following convention is currently in-flight and is in the process of +replacing other existing conventions. The following convention currently applies +to "named" structured ops which are auto-generated by the linalg-ods tool. + +The convention adopted is as follows: + +1. A first block of `ins` op operands hold read-only inputs of ShapedType. +2. An optional second block of `outs` op operands hold read-write output + buffers of MemRefType. +3. An optional third block of `init` operands hold initialization tensors of + RankedTensorType. Such tensors can appear when the op performs a reduction + and returns a tensor. + +Structured ops with fully parallel semantics, have empty `init`. They may either +write in-place into `outs` buffers or return new tensors. + +Structured ops with reduction semantics and output tensor(s) however have +additional restrictions: + +1. They can only return a single tensor for now. +2. They cannot have any output buffer operand (i.e. `outs` is empty). +3. They have exactly one `init` tensor of the same type as the unique output + tensor. Such an `init` tensor does not have an explicit associate indexing + map. Instead the map of the result tensor is used to signify that the `init` + and the `result` are "tied". + +Points 1. and 2. keep complexity of the representation in check by allowing only +a single result tensor, when reductions are present. + +Point 3. is related to the fact that SSA values cannot represent in-place +updates. Instead, linalg adopts a similar convention that exists in e.g. +`vector.outerproduct`: the value that is reduced into is passed as an explicit +argument and a new result of the same shape is produced. + +It is expected buffer allocation will fold this last input onto the result in a +single output buffer argument, which is why the same indexing map is required: +the last input operand is said to be "tied" to the result. + +Alternative, more complex representations, would allow for: + +1. Multiple results and `init` tensors in arbitrary orders, which could be + captured by an extra ArrayAttr of position pairs. +2. Relaxing the conditions on the indexing map equalities on the each pair and + e.g. allow implicit broadcasts of the input. + +These representations are deemed unnecessarily complex for now and are left for +future discussion. + +As an illustration, the syntax for a `linalg.matmul` writing into a buffer is: + +``` +linalg.matmul ins(%a, %b : memref, tensor) + outs(%c : memref) +``` + +, whereas the syntax for a `linalg.matmul` returning a new tensor is: + +``` +%d = linalg.matmul ins(%a, %b : tensor, memref) + init(%c : tensor) + -> tensor +``` + ### Data Representation: Views The current implementation uses the [Strided MemRef (a.k.a View)]( https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio) @@ -570,10 +641,10 @@ produced: ``` - def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [ - NInputs<2>, - NOutputs<1>, - NamedStructuredOpTraits]> { ... } +def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [ + NInputs<2>, + NOutputs<1>, + NamedStructuredOpTrait]> { ... } ``` When `mlir-linalg-ods-gen -gen-impl=1` is called, the following C++ is produced: diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -699,6 +699,14 @@ - `input` must be either an operand or result [variable](#variables), the `operands` directive, or the `results` directive. +* `type_ref` ( input ) + + - Represents a reference to the type of the given input that must have + already been resolved. + - `input` must be either an operand or result [variable](#variables), the + `operands` directive, or the `results` directive. + - Used to pass previously parsed types to custom directives. + #### Literals A literal is either a keyword or punctuation surrounded by \`\`. @@ -762,6 +770,10 @@ - Single: `Type &` - Optional: `Type &` - Variadic: `SmallVectorImpl &` +* TypeRef Directives + - Single: `Type` + - Optional: `Type` + - Variadic: `const SmallVectorImpl &` When a variable is optional, the value should only be specified if the variable is present. Otherwise, the value should remain `None` or null. @@ -788,6 +800,10 @@ - Single: `Type` - Optional: `Type` - Variadic: `TypeRange` +* TypeRef Directives + - Single: `Type` + - Optional: `Type` + - Variadic: `TypeRange` When a variable is optional, the provided value may be null. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -85,6 +85,11 @@ SmallVector concat(ArrayRef a, ArrayRef b); +/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. +/// Assumes `op` is a LinalgOp. +void getDimsOfType(Operation *op, StringRef iteratorTypeName, + SmallVectorImpl &res); + } // namespace linalg } // namespace mlir @@ -96,5 +101,4 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" - #endif // MLIR_DIALECT_LINALG_LINALGOPS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -32,6 +32,7 @@ NativeOpTrait<"linalg::NOutputs<" # !cast(args_out) # ">::Impl"> {} def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; +def NamedStructuredOpTrait : NativeOpTrait<"linalg::NamedStructuredOpTrait">; // Base Tablegen class for Linalg ops. // Linalg ops that correspond to library calls operate on linalg::View as their @@ -798,24 +799,7 @@ // Named Linalg ops, implemented as a declarative configurations of generic ops. //===----------------------------------------------------------------------===// -class LinalgNamedStructured_Op props> - : LinalgStructuredBase_Op { - string spec = ?; - // We cannot use an assemblyFormat atm because we need to hook in a custom- - // built implicit region from a static OpClass method. - // TODO: Revisit in the future if/when appropriate. - // let assemblyFormat = "`(` operands `)` attr-dict `:` " - // "functional-type(operands, results)"; - - // The parser needs to specialize on the OpType so it has to be auto-generated - // in the linalg-ods tool. - let printer = [{ return ::printNamedStructuredOp(p, *this); }]; - let verifier = [{ return ::verifyNamedStructuredOp(*this); }]; - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -// This file is auto-generated from a tc specification. +// This file is auto-generated from a TC def specification. include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td" #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -25,7 +25,7 @@ //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the number of parallel loops within the current operation. + Return the number of parallel loops. }], /*retTy=*/"unsigned", /*methodName=*/"getNumParallelLoops", @@ -38,7 +38,19 @@ >, InterfaceMethod< /*desc=*/[{ - Return the number of reduction loops within the current operation. + Return the dims that are parallel loops. + }], + /*retTy=*/"void", + /*methodName=*/"getParallelDims", + /*args=*/(ins "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getDimsOfType($_op, getParallelIteratorTypeName(), res); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the number of reduction loops. }], /*retTy=*/"unsigned", /*methodName=*/"getNumReductionLoops", @@ -51,7 +63,19 @@ >, InterfaceMethod< /*desc=*/[{ - Return the number of window loops within the current operation. + Return the dims that are reduction loops. + }], + /*retTy=*/"void", + /*methodName=*/"getReductionDims", + /*args=*/(ins "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getDimsOfType($_op, getReductionIteratorTypeName(), res); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the number of window loops. }], /*retTy=*/"unsigned", /*methodName=*/"getNumWindowLoops", @@ -62,6 +86,18 @@ $_op.iterator_types()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the dims that are window loops. + }], + /*retTy=*/"void", + /*methodName=*/"getWindowDims", + /*args=*/(ins "SmallVectorImpl &":$res), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return getDimsOfType($_op.getOperation(), getWindowIteratorTypeName(), res); + }] + >, InterfaceMethod< /*desc=*/[{ Return the total number of loops within the current operation. @@ -99,14 +135,14 @@ // linalg.indexed_generic ops). InterfaceMethod< /*desc=*/[{ - Return the number of inputs from the current operation. + Return the number of inputs. }], /*retTy=*/"unsigned", /*methodName=*/"getNumInputs" >, InterfaceMethod< /*desc=*/[{ - Return the number of outputs from the current operation. + Return the number of outputs. }], /*retTy=*/"unsigned", /*methodName=*/"getNumOutputs" @@ -160,7 +196,7 @@ >, InterfaceMethod< /*desc=*/[{ - Return the input operands from the current operation. + Return the input operands. }], /*retTy=*/"Operation::operand_range", /*methodName=*/"getInputs", @@ -187,7 +223,6 @@ return res; }] >, - //===------------------------------------------------------------------===// // Output arguments handling. //===------------------------------------------------------------------===// @@ -267,7 +302,7 @@ }]>, InterfaceMethod< /*desc=*/[{ - Return the output buffers (operands) from the current operation. + Return the output buffers (operands). }], /*retTy=*/"Operation::operand_range", /*methodName=*/"getOutputBuffers", @@ -354,7 +389,9 @@ return getInputShapedType(i); if (i < getNumInputsAndOutputBuffers()) return getOutputBufferType(i - $_op.getNumInputs()); - return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()]; + return this->getOperation()->getResult( + i - getNumInputsAndOutputBuffers()). + getType().template cast(); }]>, InterfaceMethod< /*desc=*/[{ @@ -408,11 +445,7 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::to_vector<4>( - llvm::map_range($_op.indexing_maps(), - [](Attribute attr) -> AffineMap { - return attr.cast().getValue(); - })); + return llvm::to_vector<4>($_op.indexing_maps().template getAsValueRange()); }] >, InterfaceMethod< @@ -425,10 +458,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(i < getNumInputsAndOutputs()); - return $_op.indexing_maps() - .getValue()[i] - .template cast() - .getValue(); + return getIndexingMaps()[i]; }] >, InterfaceMethod< @@ -441,10 +471,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(i < $_op.getNumInputs()); - return $_op.indexing_maps() - .getValue()[i] - .template cast() - .getValue(); + return getIndexingMaps()[i]; }] >, InterfaceMethod< @@ -457,10 +484,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(i < $_op.getNumOutputs()); - return $_op.indexing_maps() - .getValue()[i + $_op.getNumInputs()] - .template cast() - .getValue(); + return getIndexingMaps()[i + $_op.getNumInputs()]; }] >, InterfaceMethod< diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -71,6 +71,80 @@ } }; +/// This class provides a verifier for structured ops that are known to operate +/// on buffers or tensors and that support `ins`, `outs` and `init` arguments. +/// This trait must be used in conjunction with an op definition or a trait that +/// provides the methods `getNumInputs` and `getNumOutputs`. +/// +/// Use as a trait as follows: +/// +/// class MatmulOp : public Op { +/// +template +class NamedStructuredOpTrait + : public OpTrait::TraitBase { +public: + unsigned getNumInputs() { + return cast(this->getOperation()).inputs().size(); + } + unsigned getNumOutputs() { + ConcreteType concreteOp = cast(this->getOperation()); + return concreteOp.output_buffers().size() + + concreteOp.output_tensors().size(); + } + static LogicalResult verifyTrait(Operation *op) { + ConcreteType concreteOp = cast(op); + unsigned nInputAndBufferOperands = + concreteOp.getNumInputsAndOutputBuffers(); + if (failed( + OpTrait::impl::verifyAtLeastNOperands(op, nInputAndBufferOperands))) + return failure(); + + SmallVector redDims; + concreteOp.getReductionDims(redDims); + // If no result and no reduction, only check there is no init tensor and we + // are done. + if (redDims.empty() || op->getNumResults() == 0) { + if (!concreteOp.init_tensors().empty()) + return op->emitError("expected empty `init` when op has no " + "results or no reduction dims"); + return success(); + } + + // Only a single tensor result supported atm. + if (op->getNumResults() != 1) + return op->emitError( + "expected single tensor result when reduction present"); + + if (concreteOp.init_tensors().size() != op->getNumResults()) + return op->emitError( + "expected #init tensors to match #results when reduction present"); + + for (unsigned idx = 0, e = op->getNumResults(); idx < e; ++idx) + if (concreteOp.init_tensors()[idx].getType() != op->getResultTypes()[idx]) + return op->emitError("expected init tensor #") + << idx << " of the same type as result #" << idx; + + // Output tensor indexing map may not depend on reduction index. + // TODO: this is not yet tested. Add a test when linalg.generic switches to + // this representation. + for (unsigned idx = 0, e = concreteOp.getNumOutputs(); idx < e; ++idx) { + AffineMap outputMap = concreteOp.getOutputIndexingMap(idx); + for (auto expr : outputMap.getResults()) { + for (auto dim : redDims) { + unsigned pos = dim.cast().getPosition(); + if (expr.isFunctionOfDim(pos)) + return op->emitError( + "unexpected single tensor output indexing map ") + << "is function of reduction dim @" << pos; + } + } + } + + return success(); + } +}; + } // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -15,8 +15,6 @@ include "mlir/IR/OpBase.td" -def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; - //===----------------------------------------------------------------------===// // Shape Inference dialect definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -571,6 +571,10 @@ def AnyVector : VectorOf<[AnyType]>; +// Shaped types. + +def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; + // Tensor types. // Any tensor type whose element type is from the given `allowedTypes` list diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir @@ -30,7 +30,8 @@ } func @conv_1d(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_1d %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_1d ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir @@ -30,7 +30,8 @@ } func @conv_1d_ncw(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_1d_ncw ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir @@ -30,7 +30,8 @@ } func @conv_1d_nwc(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_1d_nwc ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir @@ -30,7 +30,8 @@ } func @conv_2d(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_2d %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_2d ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir @@ -30,7 +30,8 @@ } func @conv_2d_nchw(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_2d_nchw ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir @@ -30,7 +30,8 @@ } func @conv_2d_nhwc(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_2d_nhwc ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir @@ -30,7 +30,8 @@ } func @conv_3d(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_3d %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_3d ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir @@ -30,7 +30,8 @@ } func @conv_3d_ncdhw(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_3d_ncdhw ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir @@ -30,7 +30,8 @@ } func @conv_3d_ndhwc(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_3d_ndhwc ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -35,13 +36,29 @@ /// Forward declarations. template static void buildNamedStructuredOpRegionAndAttributes( - Builder &builder, OperationState &result, TypeRange operandTypes, - TypeRange tensorResultTypes); + OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, + TypeRange outputBufferTypes, TypeRange initTensorTypes, + TypeRange resultTypes); + template -static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputBufferTypes, + TypeRange initTensorTypes, TypeRange resultTypes); +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes); + template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result); + +static void printNamedStructuredOpResults(OpAsmPrinter &p, + TypeRange resultTypes); + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); + template static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); @@ -248,11 +265,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); - if (nInputsAndOutputBuffers != llvm::size(op.views())) - return op.emitOpError("expected exactly ") - << nInputsAndOutputBuffers - << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); if (!llvm::hasSingleElement(region)) @@ -302,8 +314,27 @@ return success(); } -static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } -static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } +static LogicalResult verify(GenericOp op) { + // Temporarily hoisted here to avoid duplicating more code. + // TODO: uniformize with named structured ops. + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers + << " inputs (tensor or buffer) and output buffer operands"; + return verifyGenericOp(op); +} + +static LogicalResult verify(IndexedGenericOp op) { + // Temporarily hoisted here to avoid duplicating more code. + // TODO: uniformize with named structured ops. + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers + << " inputs (tensor or buffer) and output buffer operands"; + return verifyGenericOp(op); +} //===----------------------------------------------------------------------===// // ReshapeOp @@ -1098,12 +1129,28 @@ #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc" +#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" +/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. +/// Assumes `op` is a LinalgOp. +void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, + SmallVectorImpl &res) { + unsigned dim = 0; + MLIRContext *ctx = op->getContext(); + for (auto tn : + cast(op).iterator_types().getAsValueRange()) { + if (tn == iteratorTypeName) + res.push_back(getAffineDimExpr(dim, ctx)); + ++dim; + } +} + AffineMap mlir::linalg::extractOrIdentityMap(Optional maybeMap, unsigned rank, MLIRContext *context) { @@ -1196,8 +1243,8 @@ } // TODO: Consider making all this boilerplate easy to autogenerate -// with Tablegen. This seems a desirable property in the context of OpInterfaces -// where a Linalg "named" op **isa** LinalgOp. +// with Tablegen. This seems a desirable property in the context of +// OpInterfaces where a Linalg "named" op **isa** LinalgOp. OpFoldResult ReshapeOp::fold(ArrayRef operands) { if (succeeded(foldMemRefCast(*this))) return getResult(); @@ -1222,23 +1269,28 @@ //===----------------------------------------------------------------------===// template -void buildNamedStructuredOpRegionAndAttributes(Builder &builder, - OperationState &result, - TypeRange operandTypes, - TypeRange tensorResultTypes) { - Region ®ion = *result.addRegion(); - Block *body = new Block(); +static void buildNamedStructuredOpRegionAndAttributesImpl( + OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, + TypeRange outputBufferTypes, TypeRange initTensorTypes, + TypeRange resultTypes, + std::function errorHandler) { // TODO: atm all operands go through getElementTypeOrSelf, // reconsider when we have evidence we need to. - for (auto t : operandTypes) - body->addArgument(getElementTypeOrSelf(t)); - for (auto t : tensorResultTypes) - body->addArgument(getElementTypeOrSelf(t)); - region.push_back(body); - - OpBuilder opBuilder(builder.getContext()); - opBuilder.setInsertionPointToStart(®ion.front()); - mlir::edsc::ScopedContext scope(opBuilder, builder.getUnknownLoc()); + SmallVector argTypes; + for (auto containers : {inputTypes, outputBufferTypes, resultTypes}) + for (auto t : containers) + argTypes.push_back(getElementTypeOrSelf(t)); + + // RAII. + OpBuilder::InsertionGuard guard(opBuilder); + Block *body = opBuilder.createBlock(®ion, {}, argTypes); + unsigned actual = body->getNumArguments(); + unsigned expected = NamedStructuredOpType::getNumRegionArgs(); + if (expected != actual) + return errorHandler(expected, actual); + + opBuilder.setInsertionPointToStart(body); + mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc()); NamedStructuredOpType::regionBuilder(*body); // indexing_maps is an auto-generated method. @@ -1247,59 +1299,133 @@ } template -static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { - std::array silentAttrNames{getIndexingMapsAttrName(), - getIteratorTypesAttrName()}; - p << op.getOperationName() << ' '; - p.printOptionalAttrDict(op.getAttrs(), silentAttrNames); - p << ' ' << op.getOperands(); - p << " : (" << op.getOperandTypes() << ")"; - auto outputTensorTypes = op.getResultTypes(); - if (!outputTensorTypes.empty()) - p << " -> (" << outputTensorTypes << ")"; +void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, + OperationState &result, + TypeRange inputTypes, + TypeRange outputBufferTypes, + TypeRange initTensorTypes, + TypeRange resultTypes) { + Region ®ion = *result.addRegion(); + buildNamedStructuredOpRegionAndAttributesImpl( + opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes, + resultTypes, [&](unsigned expected, unsigned actual) { + llvm::errs() << "region expects " << expected << " args, got " + << actual; + assert(expected != actual && "incorrect number of arguments"); + }); +} + +template +static ParseResult +parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, + TypeRange inputTypes, TypeRange outputBufferTypes, + TypeRange initTensorTypes, TypeRange resultTypes) { + ParseResult res = success(); + OpBuilder opBuilder(parser.getBuilder().getContext()); + buildNamedStructuredOpRegionAndAttributesImpl( + opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes, + resultTypes, [&](unsigned expected, unsigned actual) { + res = parser.emitError(parser.getCurrentLocation(), + llvm::formatv("region expects {0} args, got {1}", + expected, actual)); + }); + return res; +} + +static ParseResult +parseNamedStructuredOpResults(OpAsmParser &parser, + SmallVectorImpl &resultTypes) { + if (succeeded(parser.parseOptionalArrow())) + if (parser.parseTypeList(resultTypes)) + return failure(); + return success(); } template static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result) { - SmallVector operandsInfo; - result.getContext()->getOrLoadDialect(); + llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc, + initTensorsOperandsLoc; + SmallVector inputsOperands, + outputBuffersOperands, initTensorsOperands; + SmallVector inputsTypes, outputBuffersTypes, initTensorsTypes, + outputTensorsTypes; + std::unique_ptr regionRegion = std::make_unique(); - // Optional attributes may be added. - if (parser.parseOperandList(operandsInfo) || - parser.parseOptionalAttrDict(result.attributes)) + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseKeyword("ins") || parser.parseLParen()) return failure(); - SmallVector operandTypes; - if (parser.parseColon() || parser.parseLParen() || - parser.parseTypeList(operandTypes) || parser.parseRParen()) + inputsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(inputsOperands) || parser.parseColon() || + parser.parseTypeList(inputsTypes) || parser.parseRParen()) return failure(); - // Generic ops may specify that a subset of its outputs are tensors. Such - // outputs are specified in the result type. - SmallVector tensorResultTypes; - if (parser.parseOptionalArrowTypeList(tensorResultTypes)) + if (succeeded(parser.parseOptionalKeyword("outs"))) { + outputBuffersOperandsLoc = parser.getCurrentLocation(); + if (parser.parseLParen() || + parser.parseOperandList(outputBuffersOperands) || parser.parseColon() || + parser.parseTypeList(outputBuffersTypes) || parser.parseRParen()) + return failure(); + } + if (succeeded(parser.parseOptionalKeyword("init"))) { + initTensorsOperandsLoc = parser.getCurrentLocation(); + if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) || + parser.parseColon() || parser.parseTypeList(initTensorsTypes) || + parser.parseRParen()) + return failure(); + } + + if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) return failure(); - if (!tensorResultTypes.empty()) - result.addTypes(tensorResultTypes); + if (parseNamedStructuredOpRegion( + parser, *regionRegion, inputsTypes, outputBuffersTypes, + initTensorsTypes, outputTensorsTypes)) + return failure(); - // The number of parsed arguments must equal - // the number of expected arguments for the current operation. - auto parsedArgs = operandsInfo.size(); - auto expectedArgs = NamedStructuredOpType::getNumInputs() + - NamedStructuredOpType::getNumOutputs(); - if (parsedArgs != expectedArgs) - return parser.emitError(parser.getNameLoc(), - "expects " + std::to_string(expectedArgs) + - " operands, but found " + - std::to_string(parsedArgs)); + if (parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, + result.operands) || + parser.resolveOperands(outputBuffersOperands, outputBuffersTypes, + outputBuffersOperandsLoc, result.operands) || + parser.resolveOperands(initTensorsOperands, initTensorsTypes, + initTensorsOperandsLoc, result.operands)) + return failure(); - buildNamedStructuredOpRegionAndAttributes( - parser.getBuilder(), result, operandTypes, tensorResultTypes); + result.addTypes(outputTensorsTypes); + result.addRegion(std::move(regionRegion)); + result.addAttribute("operand_segment_sizes", + parser.getBuilder().getI32VectorAttr( + {static_cast(inputsOperands.size()), + static_cast(outputBuffersOperands.size()), + static_cast(initTensorsOperands.size())})); + return success(); +} - return parser.resolveOperands(operandsInfo, operandTypes, - parser.getCurrentLocation(), result.operands); +static void printNamedStructuredOpResults(OpAsmPrinter &p, + TypeRange resultTypes) { + if (resultTypes.empty()) + return; + p << "-> " << resultTypes; +} + +template +static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { + p << op.getOperationName(); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{"operand_segment_sizes"}); + p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; + if (!op.output_buffers().empty()) + p << " outs(" << op.output_buffers() << " : " + << op.output_buffers().getTypes() << ")"; + if (!op.init_tensors().empty()) + p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes() + << ")"; + p << " "; + printNamedStructuredOpResults(p, op.output_tensors().getTypes()); + p << " "; + + // Region is elided. } template @@ -1354,8 +1480,6 @@ CANONICALIZERS_AND_FOLDERS(GenericOp) CANONICALIZERS_AND_FOLDERS(IndexedGenericOp) -#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" - // TODO: Determine whether we can generate the folders and verifiers. CANONICALIZERS_AND_FOLDERS(BatchMatmulOp) CANONICALIZERS_AND_FOLDERS(DotOp) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -58,6 +58,8 @@ //===----------------------------------------------------------------------===// void mlir::linalg::LinalgDialect::initialize() { + getContext()->getOrLoadDialect("std"); + addTypes(); addOperations< #define GET_OP_LIST @@ -67,6 +69,7 @@ #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(); + addInterfaces(); } diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir --- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir +++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir @@ -8,7 +8,8 @@ // CHECK-DAG: #[[$map5:.*]] = affine_map<(d0) -> (d0)> func @conv_1d(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_1d %arg0, %arg1, %arg2 : (memref, memref, memref) + linalg.conv_1d ins(%arg0, %arg1 : memref, memref) + outs(%arg2 : memref) return } diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir --- a/mlir/test/Dialect/Linalg/affine.mlir +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -15,7 +15,8 @@ %A = view %arg0[%c0][%M, %K] : memref to memref %B = view %arg0[%c0][%K, %N] : memref to memref %C = view %arg0[%c0][%M, %N] : memref to memref - linalg.matmul %A, %B, %C : (memref, memref, memref) + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) return } @@ -102,7 +103,8 @@ // Named ops to loops. //----------------------------------------------------------------------------// func @named_batch_matmul(%A: memref, %B: memref, %C: memref) { - linalg.batch_matmul %A, %B, %C : (memref, memref, memref) -> () + linalg.batch_matmul ins(%A, %B: memref, memref) + outs(%C : memref) return } // CHECK-LABEL: @named_batch_matmul diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -14,8 +14,9 @@ // CHECK: linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref %4 = linalg.slice %3[%r0, %r0] : memref, !linalg.range, !linalg.range, memref - // CHECK: linalg.matmul{{.*}}: (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) - linalg.matmul %3, %3, %3 : (memref, memref, memref) + // CHECK: linalg.matmul ins({{.*}}memref<16x16xf32>, memref<16x16xf32>) outs({{.*}}memref<16x16xf32>) + linalg.matmul ins(%3, %3: memref, memref) + outs(%3: memref) return %4: memref } diff --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir --- a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir +++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns -//| FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns | FileCheck %s // CHECK-LABEL: scf_for func @scf_for(%A : memref, %step : index) { diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir --- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir +++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir @@ -12,7 +12,8 @@ %0 = dim %C, %c0 : memref %1 = dim %C, %c1 : memref %2 = dim %D, %c1 : memref - linalg.matmul %A, %B, %C : (memref, memref, memref) + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) scf.for %arg5 = %c0 to %0 step %c20 { scf.for %arg6 = %c0 to %2 step %c30 { scf.for %arg7 = %c0 to %1 step %c40 { @@ -28,7 +29,8 @@ %14 = std.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref to memref %16 = std.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref to memref %17 = std.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref to memref - linalg.matmul %14, %16, %17 : (memref, memref, memref) + linalg.matmul ins(%14, %16: memref, memref) + outs(%17: memref) } } } diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -14,10 +14,9 @@ %0 = dim %A, %c0 : memref %1 = dim %A, %c1 : memref %2 = dim %B, %c1 : memref - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { @@ -30,10 +29,9 @@ %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %5, %7, %8 : - (memref, - memref, - memref) + linalg.matmul ins(%5, %7 : memref, + memref) + outs(%8: memref) } } } @@ -61,10 +59,9 @@ %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C: memref) %0 = dim %C, %c0 : memref %1 = dim %C, %c1 : memref %2 = dim %D, %c1 : memref @@ -80,10 +77,9 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %5, %7, %8 : - (memref, - memref, - memref) + linalg.matmul ins(%5, %7 : memref, + memref) + outs(%8 : memref) } } } @@ -113,10 +109,9 @@ %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) %0 = dim %D, %c0 : memref %1 = dim %D, %c1 : memref %2 = dim %C, %c1 : memref @@ -132,10 +127,9 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %5, %7, %8 : - (memref, - memref, - memref) + linalg.matmul ins(%5, %7 : memref, + memref) + outs(%8 : memref) } } } @@ -165,14 +159,12 @@ %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) - linalg.matmul %A, %B, %D : - (memref, - memref, - memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%D : memref) %0 = dim %C, %c0 : memref %1 = dim %C, %c1 : memref %2 = dim %D, %c1 : memref @@ -188,10 +180,9 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %5, %7, %8 : - (memref, - memref, - memref) + linalg.matmul ins(%5, %7 : memref, + memref) + outs(%8 : memref) } } } @@ -227,14 +218,12 @@ %0 = dim %B, %c1 : memref %1 = dim %D, %c0 : memref %2 = dim %D, %c1 : memref - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) - linalg.matmul %C, %B, %D : - (memref, - memref, - memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) + linalg.matmul ins(%C, %B : memref, + memref) + outs(%D : memref) scf.for %arg5 = %c0 to %1 step %c2 { scf.for %arg6 = %c0 to %0 step %c3 { scf.for %arg7 = %c0 to %2 step %c4 { @@ -247,10 +236,9 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %5, %7, %8 : - (memref, - memref, - memref) + linalg.matmul ins(%5, %7 : memref, + memref) + outs(%8 : memref) } } } @@ -275,9 +263,9 @@ // CHECK-DAG: %[[A_I0:.*]] = subview %[[A]][%[[I]], %{{.*}}] // CHECK-DAG: %[[B_00:.*]] = subview %[[B]][%{{.*}}, %{{.*}}] // CHECK-DAG: %[[C_I0_:.*]] = subview %[[C]][%[[I]], %{{.*}}] -// CHECK: linalg.matmul %[[A_I0]], %[[B_00]], %[[C_I0_]] -// CHECK: linalg.matmul %[[C_I0]], %[[B_0K]], %[[D_IK_]] -// CHECK: linalg.matmul %[[D_IK]], %[[B_KJ]], %[[E_IJ]] +// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_]] +// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_]] +// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]] // ----- @@ -297,14 +285,12 @@ %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %C, %c1 : memref - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) - linalg.matmul %A, %C, %E : - (memref, - memref, - memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) + linalg.matmul ins(%A, %C : memref, + memref) + outs(%E : memref) %1 = dim %C, %c0 : memref %2 = dim %D, %c1 : memref scf.for %arg5 = %c0 to %1 step %c2 { @@ -322,10 +308,9 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %5, %7, %8 : - (memref, - memref, - memref) + linalg.matmul ins(%5, %7 : memref, + memref) + outs(%8 : memref) } } } @@ -359,14 +344,12 @@ %2 = dim %C, %c1 : memref %3 = dim %C, %c0 : memref %4 = dim %D, %c1 : memref - linalg.matmul %A, %C, %E : - (memref, - memref, - memref) - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) + linalg.matmul ins(%A, %C : memref, + memref) + outs(%E : memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { scf.for %arg7 = %c0 to %1 step %c4 { @@ -379,10 +362,9 @@ %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %7, %9, %10 : - (memref, - memref, - memref) + linalg.matmul ins(%7, %9 : memref, + memref) + outs(%10 : memref) } } } @@ -398,10 +380,9 @@ %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %7, %9, %10 : - (memref, - memref, - memref) + linalg.matmul ins(%7, %9 : memref, + memref) + outs(%10 : memref) } } } @@ -414,7 +395,7 @@ // CHECK: %[[C_1:.*]] = dim %[[C]], %c1{{_[0-9]*}} : memref // CHECK: %[[C_0:.*]] = dim %[[C]], %c0{{_[0-9]*}} : memref // CHECK: %[[D_1:.*]] = dim %[[D]], %c1{{_[0-9]*}} : memref -// CHECK: linalg.matmul %[[A]], %[[C]], %[[E]] +// CHECK: linalg.matmul ins(%[[A]], %[[C]]{{.*}} outs(%[[E]] // CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: scf.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} { @@ -445,14 +426,12 @@ %c2 = constant 2 : index %0 = dim %A, %c0 : memref %1 = dim %A, %c1 : memref - linalg.matmul %A, %C, %D : - (memref, - memref, - memref) - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) + linalg.matmul ins(%A, %C : memref, + memref) + outs(%D : memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) %2 = dim %D, %c1 : memref scf.for %arg5 = %c0 to %0 step %c2 { scf.for %arg6 = %c0 to %2 step %c3 { @@ -469,10 +448,9 @@ %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %5, %7, %8 : - (memref, - memref, - memref) + linalg.matmul ins(%5, %7 : memref, + memref) + outs(%8 : memref) } } } @@ -742,10 +720,9 @@ %B = alloca(%dim, %dim)[%s0, %s1] : memref %C = alloc(%dim, %dim)[%s0, %s1] : memref - linalg.matmul %A, %B, %C : - (memref, - memref, - memref) + linalg.matmul ins(%A, %B : memref, + memref) + outs(%C : memref) scf.for %i = %c0 to %dim step %c2 { scf.for %j = %c0 to %dim step %c3 { @@ -759,10 +736,9 @@ %2 = std.subview %C[%i, %j][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul %0, %1, %2 : - (memref, - memref, - memref) + linalg.matmul ins(%0, %1 : memref, + memref) + outs(%2 : memref) } } } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -428,13 +428,6 @@ // ----- -func @generic_result_0_element_type(%arg0: memref) { - // expected-error @+1 {{'linalg.dot' expects 3 operands, but found 2}} - linalg.dot %arg0, %arg0 : (memref, memref) -} - -// ----- - func @conv_rank_limit(%arg0: memref, %arg1: memref, %arg2: memref) { // expected-error @+1 {{expects memref ranks to be greater than 2}} linalg.conv(%arg0, %arg1, %arg2) : memref, memref, memref @@ -511,7 +504,8 @@ func @named_ops(%a3: memref, %b3: memref, %c3: memref) { // expected-error @+1 {{op expected indexing_map #1 results to match view rank: 'memref'}} - linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> () + linalg.batch_matmul ins(%a3, %b3: memref, memref) + outs(%c3 : memref) return } @@ -531,3 +525,52 @@ } : tensor -> (tensor, tensor) return } + +// ----- + +func @empty_init_expected(%m: memref, %t: tensor) { + // expected-error @+1 {{expected empty `init` when op has no results or no reduction dims}} + linalg.matmul ins(%m, %m: memref, memref) + outs(%m : memref) + init(%t : tensor) + return +} + +// ----- + +func @incorrect_region_arg_count(%m: memref) { + // expected-error @+3 {{region expects 3 args, got 4}} + %res = linalg.matmul ins(%m, %m : memref, memref) + -> tensor, tensor + return +} + +// ----- + +func @single_tensor_result(%m: memref, %t: tensor) { + // expected-error @+1 {{expected single tensor result when reduction present}} + %res:2 = linalg.matmul ins(%m : memref) + init(%t, %t : tensor, tensor) + -> tensor, tensor + return +} + +// ----- + +func @matching_inits(%m: memref, %t: tensor) { + // expected-error @+1 {{expected #init tensors to match #results when reduction present}} + %res = linalg.matmul ins(%m, %m : memref, memref) + init(%t, %t : tensor, tensor) + -> tensor + return +} + +// ----- + +func @matching_inits(%m: memref, %t: tensor) { + // expected-error @+1 {{expected init tensor #0 of the same type as result #0}} + %res = linalg.matmul ins(%m, %m : memref, memref) + init(%t : tensor) + -> tensor + return +} diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -39,7 +39,8 @@ %A = view %arg0[%c0][%M, %K] : memref to memref %B = view %arg0[%c0][%K, %N] : memref to memref %C = view %arg0[%c0][%M, %N] : memref to memref - linalg.matmul %A, %B, %C : (memref, memref, memref) + linalg.matmul ins(%A, %B: memref, memref) + outs(%C: memref) return } // CHECKLOOP-LABEL: func @matmul(%{{.*}}: memref, @@ -83,7 +84,8 @@ %2 = view %arg0[%c0][%M, %N] : memref to memref %3 = view %arg0[%c0][%M] : memref to memref %4 = view %arg0[%c0][%N] : memref to memref - linalg.matvec %2, %3, %4 : (memref, memref, memref) + linalg.matvec ins(%2, %3: memref, memref) + outs(%4 : memref) return } // CHECKLOOP-LABEL: func @matvec(%{{.*}}: memref, @@ -123,7 +125,8 @@ %1 = view %arg0[%c0][%M] : memref to memref %2 = view %arg0[%c0][%M] : memref to memref %3 = view %arg0[%c0][] : memref to memref - linalg.dot %1, %2, %3 : (memref, memref, memref) + linalg.dot ins(%1, %2 : memref, memref) + outs(%3 : memref) return } // CHECKLOOP-LABEL: func @dot(%{{.*}}: memref, @@ -154,9 +157,9 @@ func @dot_view(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.dot %arg0, %arg1, %arg2 : (memref, - memref, - memref) + linalg.dot ins(%arg0, %arg1 : memref, + memref) + outs(%arg2: memref) return } // CHECKLOOP-LABEL: func @dot_view( @@ -880,7 +883,8 @@ // Named ops to loops. //----------------------------------------------------------------------------// func @named_batch_matmul(%A: memref, %B: memref, %C: memref) { - linalg.batch_matmul %A, %B, %C : (memref, memref, memref) -> () + linalg.batch_matmul ins(%A, %B : memref, memref) + outs(%C : memref) return } // CHECKLOOP-LABEL: @named_batch_matmul @@ -1288,7 +1292,8 @@ // CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref func @conv1d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { - linalg.conv_1d %in, %filter, %out : (memref, memref, memref) + linalg.conv_1d ins(%in, %filter : memref, memref) + outs(%out : memref) return } @@ -1330,7 +1335,8 @@ func @conv2d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { - linalg.conv_2d %in, %filter, %out : (memref, memref, memref) + linalg.conv_2d ins(%in, %filter : memref, memref) + outs(%out: memref) return } // CHECKLOOP-LABEL: @conv2d_no_symbols @@ -1382,7 +1388,8 @@ func @conv3d_no_symbols(%in : memref, %filter : memref, %out : memref) -> () { - linalg.conv_3d %in, %filter, %out : (memref, memref, memref) + linalg.conv_3d ins(%in, %filter : memref, memref) + outs(%out : memref) return } diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -27,10 +27,10 @@ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref - linalg.matmul %11, %14, %17 : - (memref, - memref, - memref) + linalg.matmul + ins(%11, %14: memref, + memref) + outs(%17: memref) } } } @@ -67,10 +67,7 @@ // CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref // CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref // -// CHECK: linalg.matmul %[[partialA]], %[[partialB]], %[[partialC]] : -// CHECK: memref, -// CHECK: memref, -// CHECK: memref +// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]] // // CHECK: linalg.copy(%[[partialC]], %[[vC]]) : // CHECK: memref, @@ -103,10 +100,10 @@ %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref - linalg.matmul %11, %14, %17 : - (memref, - memref, - memref) + linalg.matmul + ins(%11, %14: memref, + memref) + outs(%17: memref) } } } @@ -140,10 +137,7 @@ // CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref // CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref // -// CHECK: linalg.matmul %[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]] : -// CHECK: memref, -// CHECK: memref, -// CHECK: memref +// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]] // // CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : // CHECK: memref, diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -2,8 +2,9 @@ func @gemm(%a : memref, %b : memref, %c : memref) { - linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "START"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "START"} + ins(%a, %b: memref, memref) + outs(%c: memref) return } @@ -26,7 +27,7 @@ // CHECK: linalg.copy(%[[T7]], %[[T19]]) // CHECK: linalg.fill(%[[T21]], %[[C42]]) // CHECK: linalg.copy(%[[T17]], %[[T21]]) -// CHECK: linalg.matmul %[[T19]], %[[T12]], %[[T21]] +// CHECK: linalg.matmul ins(%[[T19]], %[[T12]]{{.*}} outs(%[[T21]] // CHECK-NOT: linalg.fill // CHECK: linalg.copy(%[[T21]], %[[T17]]) // CHECK: dealloc %[[T18]] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -83,30 +83,30 @@ %arg1: memref, %arg2: memref, %arg3: memref) { - linalg.matmul %arg0, %arg0, %arg0 : (memref, - memref, - memref) - linalg.matvec %arg0, %arg1, %arg2 : (memref, - memref, - memref) - linalg.dot %arg1, %arg2, %arg3 : (memref, - memref, - memref) + linalg.matmul ins(%arg0, %arg0 : memref, + memref) + outs(%arg0 : memref) + linalg.matvec ins(%arg0, %arg1: memref, + memref) + outs(%arg2: memref) + linalg.dot ins(%arg1, %arg2: memref, + memref) + outs(%arg3: memref) return } // CHECK-LABEL: func @ops(% -// CHECK-NEXT: linalg.matmul %{{.*}}, %{{.*}}, %{{.*}} : -// CHECK-SAME: (memref, -// CHECK-SAME: memref, -// CHECK-SAME: memref) -// CHECK-NEXT: linalg.matvec %{{.*}}, %{{.*}}, %{{.*}} : -// CHECK-SAME: (memref, -// CHECK-SAME: memref, -// CHECK-SAME: memref) -// CHECK-NEXT: linalg.dot %{{.*}}, %{{.*}}, %{{.*}} : -// CHECK-SAME: (memref, -// CHECK-SAME: memref, -// CHECK-SAME: memref) +// CHECK: linalg.matmul +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : memref, +// CHECK-SAME: memref) +// CHECK-SAME: outs(%{{.*}} : memref) +// CHECK: linalg.matvec +// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref, +// CHECK-SAME: memref) +// CHECK-SAME: outs(%{{.*}}: memref) +// CHECK: linalg.dot +// CHECK-SAME: ins(%{{.*}}, %{{.*}}: memref, +// CHECK-SAME: memref) +// CHECK-SAME: outs(%{{.*}}: memref) // ----- @@ -619,17 +619,27 @@ // CHECK: linalg.reshape {{.*}} [#[[$reshapeD01]], #[[$reshapeD2]]] // CHECK-SAME: memref into memref - -// TODO: Return tensors need a semantics convention update. func @named_ops(%a3: memref, %b3: memref, %c3: memref, - %ta3: tensor, %tb3: tensor, %tc3: tensor) { - linalg.batch_matmul %a3, %b3, %c3 : (memref, memref, memref) -> () - linalg.batch_matmul %ta3, %tb3, %c3 : (tensor, tensor, memref) -> () - return + %ta3: tensor, %tb3: tensor, %tc3: tensor) + -> (tensor, tensor) +{ + linalg.batch_matmul ins(%a3, %b3: memref, memref) + outs(%c3: memref) + linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor) + outs(%c3: memref) + %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor) + init(%tc3: tensor) + -> tensor + %res2 = linalg.batch_matmul ins(%ta3, %b3: tensor, memref) + init(%tc3: tensor) + -> tensor + return %res1, %res2 : tensor, tensor } // CHECK-LABEL: func @named_ops // CHECK: linalg.batch_matmul // CHECK: linalg.batch_matmul +// CHECK: linalg.batch_matmul +// CHECK: linalg.batch_matmul // ----- diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir --- a/mlir/test/Dialect/Linalg/standard.mlir +++ b/mlir/test/Dialect/Linalg/standard.mlir @@ -13,9 +13,9 @@ func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.dot %arg0, %arg1, %arg2 : (memref, - memref, - memref) + linalg.dot ins(%arg0, %arg1: memref, + memref) + outs(%arg2: memref) return } // CHECK-LABEL: func @dot( diff --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir --- a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir @@ -2,8 +2,9 @@ func @gemm1(%a : memref, %b : memref, %c : memref) { - linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute1"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "distribute1"} + ins(%a, %b: memref, memref) + outs(%c: memref) return } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> @@ -21,14 +22,15 @@ // CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]] // CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]] -// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] // ----- func @gemm2(%a : memref, %b : memref, %c : memref) { - linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute2"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "distribute2"} + ins(%a, %b: memref, memref) + outs(%c:memref) return } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> @@ -52,14 +54,15 @@ // CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]] // CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]] -// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] // ----- func @gemm3(%a : memref, %b : memref, %c : memref) { - linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute3"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "distribute3"} + ins(%a, %b: memref, memref) + outs(%c: memref) return } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> @@ -80,14 +83,15 @@ // CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]] // CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]] // CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] -// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] // ----- func @gemm4(%a : memref, %b : memref, %c : memref) { - linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute4"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "distribute4"} + ins(%a, %b: memref, memref) + outs(%c: memref) return } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> @@ -108,14 +112,15 @@ // CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]] // CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]] -// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] // ----- func @gemm5(%a : memref, %b : memref, %c : memref) { - linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute5"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "distribute5"} + ins(%a, %b: memref, memref) + outs(%c: memref) return } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> @@ -138,14 +143,15 @@ // CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]] // CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]] // CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]] -// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] // ----- func @gemm6(%a : memref, %b : memref, %c : memref) { - linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute6"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "distribute6"} + ins(%a, %b: memref, memref) + outs(%c: memref) return } // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> @@ -165,4 +171,4 @@ // CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]] // CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]] // CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]] -// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -31,10 +31,10 @@ func @matmul(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matmul %arg0, %arg1, %arg2 : - (memref, - memref, - memref) + linalg.matmul + ins(%arg0, %arg1: memref, + memref) + outs(%arg2: memref) return } // TILE-2-LABEL: func @matmul( @@ -50,10 +50,7 @@ // TILE-2: %[[szK:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localK]]] // TILE-2: %[[N:.*]] = dim %{{.*}}, %c1 : memref // TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref -// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] : -// TILE-2: (memref, -// TILE-2: memref, -// TILE-2: memref) +// TILE-2: linalg.matmul ins(%[[sAi]]{{.*}} outs(%[[sCi]] // TILE-02-LABEL: func @matmul( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index @@ -68,10 +65,7 @@ // TILE-02: %[[localK:.*]] = dim %{{.*}}, %c1 // TILE-02: %[[szK:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localK]]] // TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [%[[M]], %[[szK]]] [1, 1] : memref to memref -// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] : -// TILE-02: (memref, -// TILE-02: memref, -// TILE-02: memref) +// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]] // TILE-002-LABEL: func @matmul( // TILE-002-DAG: %[[C0:.*]] = constant 0 : index @@ -86,10 +80,7 @@ // TILE-002: %[[szK:.*]] = affine.min #[[$bound_map]](%[[K]])[%[[localK]]] // TILE-002: %[[N:.*]] = dim %{{.*}}, %c1 : memref // TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[szK]], %[[N]]] [1, 1] : memref to memref -// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} : -// TILE-002: (memref, -// TILE-002: memref, -// TILE-002: memref) +// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}} // TILE-234-LABEL: func @matmul( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index @@ -118,10 +109,7 @@ // TILE-234: %[[szN:.*]] = affine.min #[[$bound_map_3]](%[[J]])[%[[localN]]] // TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%[[szM]], %[[szN]]] [1, 1] : memref to memref // -// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] : -// TILE-234: (memref, -// TILE-234: memref, -// TILE-234: memref) +// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]] // When the buffer shapes are known at compile time, it is possible to avoid // the "min" in subview size computation. This test uses buffer sizes divisible @@ -130,10 +118,10 @@ func @matmul_static(%arg0: memref<10x16xf32, offset: ?, strides: [?, 1]>, %arg1: memref<16x12xf32, offset: ?, strides: [?, 1]>, %arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) { - linalg.matmul %arg0, %arg1, %arg2 : - (memref<10x16xf32, offset: ?, strides: [?, 1]>, - memref<16x12xf32, offset: ?, strides: [?, 1]>, - memref<10x12xf32, offset: ?, strides: [?, 1]>) + linalg.matmul + ins(%arg0, %arg1: memref<10x16xf32, offset: ?, strides: [?, 1]>, + memref<16x12xf32, offset: ?, strides: [?, 1]>) + outs(%arg2: memref<10x12xf32, offset: ?, strides: [?, 1]>) return } // TILE-2-LABEL: func @matmul_static( @@ -148,7 +136,7 @@ // TILE-2: %[[sAi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN2]], 16] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref // TILE-2: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[I]]) // TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]], 0] [%[[MIN22]], 12] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref -// TILE-2: linalg.matmul %[[sAi]], %{{.*}}, %[[sCi]] +// TILE-2: linalg.matmul ins(%[[sAi]], %{{.*}}{{.*}} outs(%[[sCi]] // TILE-02-LABEL: func @matmul_static( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index @@ -159,10 +147,7 @@ // TILE-02: %[[sBj:.*]] = subview %{{.*}}[0, %[[J]]] [16, %[[MIN2]]] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref<16x?xf32, #[[$strided2D]]> // TILE-02: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[J]]) // TILE-02: %[[sCj:.*]] = subview %{{.*}}[0, %[[J]]] [10, %[[MIN22]]] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]> -// TILE-02: linalg.matmul %{{.*}}, %[[sBj]], %[[sCj]] : -// TILE-02: (memref<10x16xf32, #[[$strided2D]]>, -// TILE-02: memref<16x?xf32, #[[$strided2D]]>, -// TILE-02: memref<10x?xf32, #[[$strided2D]]>) +// TILE-02: linalg.matmul ins(%{{.*}}, %[[sBj]]{{.*}} outs(%[[sCj]] // TILE-002-LABEL: func @matmul_static( // TILE-002-DAG: %[[C0:.*]] = constant 0 : index @@ -173,10 +158,7 @@ // TILE-002: %[[sAj:.*]] = subview %{{.*}}[0, %[[K]]] [10, %[[MIN2]]] [1, 1] : memref<10x16xf32, #[[$strided2D]]> to memref<10x?xf32, #[[$strided2D]]> // TILE-002: %[[MIN22:.*]] = affine.min #[[$bound_map_static]](%[[K]]) // TILE-002: %[[sBj:.*]] = subview %{{.*}}[%[[K]], 0] [%[[MIN22]], 12] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref -// TILE-002: linalg.matmul %[[sAj]], %[[sBj]], %{{.*}} : -// TILE-002: (memref<10x?xf32, #[[$strided2D]]>, -// TILE-002: memref, -// TILE-002: memref<10x12xf32, #[[$strided2D]]>) +// TILE-002: linalg.matmul ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}} // TILE-234-LABEL: func @matmul_static( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index @@ -193,16 +175,13 @@ // TILE-234: %[[sBkj:.*]] = subview %{{.*}}[%[[K]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<16x12xf32, #[[$strided2D]]> to memref // TILE-234: %[[sCij:.*]] = subview %{{.*}}[%[[I]], %[[J]]] [%{{.*}}, %{{.*}}] [1, 1] : memref<10x12xf32, #[[$strided2D]]> to memref // -// TILE-234: linalg.matmul %[[sAik]], %[[sBkj]], %[[sCij]] : -// TILE-234: (memref, -// TILE-234: memref, -// TILE-234: memref) +// TILE-234: linalg.matmul ins(%[[sAik]], %[[sBkj]]{{.*}} outs(%[[sCij]] func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matvec %arg0, %arg1, %arg2 : ( - memref, - memref, - memref) + linalg.matvec + ins(%arg0, %arg1: memref, + memref) + outs(%arg2: memref) return } // TILE-2-LABEL: func @matvec( @@ -220,7 +199,7 @@ // TILE-2: %[[localN:.*]] = dim %{{.*}}, %c0 // TILE-2: %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localN]]] // TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref to memref -// TILE-2: linalg.matvec %[[sAi]], %{{.*}}, %[[sCi]] : (memref, memref, memref) +// TILE-2: linalg.matvec ins(%[[sAi]], %{{.*}} outs(%[[sCi]] // TILE-02-LABEL: func @matvec( // TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref @@ -237,7 +216,7 @@ // TILE-02: %[[localN:.*]] = dim %{{.*}}, %c0 // TILE-02: %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localN]]] // TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref to memref -// TILE-02: linalg.matvec %[[sAj]], %[[sBj]], %{{.*}} : (memref, memref, memref) +// TILE-02: linalg.matvec ins(%[[sAj]], %[[sBj]]{{.*}} outs(%{{.*}} // TILE-002-LABEL: func @matvec( // TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref @@ -268,12 +247,12 @@ // TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]] // TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref // -// TILE-234: linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref, memref, memref) +// TILE-234: linalg.matvec ins(%[[sAij]], %[[sBj]]{{.*}} outs(%[[sCi]] func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.dot %arg0, %arg1, %arg2 : (memref, - memref, - memref) + linalg.dot + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) return } // TILE-2-LABEL: func @dot( @@ -287,7 +266,7 @@ // TILE-2: %[[localM:.*]] = dim %{{.*}}, %c0 // TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localM]]] // TILE-2: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref -// TILE-2: linalg.dot %[[sAi]], %[[sBi]], {{.*}} : (memref, memref, memref) +// TILE-2: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs( // TILE-02-LABEL: func @dot( // TILE-02-NOT: scf.for @@ -306,7 +285,7 @@ // TILE-234: %[[localM:.*]] = dim %{{.*}}, %c0 // TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]] // TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref -// TILE-234: linalg.dot %[[sAi]], %[[sBi]], %{{.*}} : (memref, memref, memref) +// TILE-234: linalg.dot ins(%[[sAi]], %[[sBi]]{{.*}} outs( func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32 diff --git a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir --- a/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir +++ b/mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir @@ -6,8 +6,8 @@ %arg1 : memref, %arg2 : memref) { - linalg.matmul %arg0, %arg1, %arg2 - : (memref, memref, memref) + linalg.matmul ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) return } // CHECK-LABEL: func @gemm @@ -21,7 +21,7 @@ // CHECK: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]] // CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG5]], %[[ARG4]]] // CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] -// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] // TILE1-LABEL: func @gemm // TILE1-DAG: %[[C2:.*]] = constant 2 : index @@ -30,7 +30,7 @@ // TILE1: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE1: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE1-NOT: subview -// TILE1: linalg.matmul %[[SV1]], %{{.*}}, %[[SV3]] +// TILE1: linalg.matmul ins(%[[SV1]], %{{.*}} outs(%[[SV3]] // TILE2-LABEL: func @gemm // TILE2-DAG: %[[C2:.*]] = constant 2 : index @@ -40,7 +40,7 @@ // TILE2: %[[SV1:.*]] = subview %{{.*}}[%[[ARG3]], 0] // TILE2: %[[SV2:.*]] = subview %{{.*}}[0, %[[ARG4]]] // TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG4]]] -// TILE2: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] +// TILE2: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]] // ----- diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -5,10 +5,10 @@ func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) { - linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "START"} : - (memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, - memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, - memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) + linalg.matmul {__internal_linalg_transform__ = "START"} + ins(%A, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, + memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) + outs(%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) return } @@ -31,7 +31,8 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { // VECTOR-CONTRACTION: vector.contract // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32 - linalg.dot %A, %B, %C : (memref<1584xf32>, memref<1584xf32>, memref) + linalg.dot ins(%A, %B: memref<1584xf32>, memref<1584xf32>) + outs(%C: memref) return } @@ -39,8 +40,8 @@ func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %C: memref<1584xf32>) { // VECTOR-CONTRACTION: vector.contract // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584xf32> into vector<1584xf32> - linalg.matvec %A, %B, %C : - (memref<1584x1584xf32>, memref<1584xf32>, memref<1584xf32>) + linalg.matvec ins(%A, %B: memref<1584x1584xf32>, memref<1584xf32>) + outs(%C: memref<1584xf32>) return } @@ -48,8 +49,8 @@ func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf32>, %C: memref<1584x1584xf32>) { // VECTOR-CONTRACTION: vector.contract // VECTOR-CONTRACTION-SAME: vector<1584x1584xf32>, vector<1584x1584xf32> into vector<1584x1584xf32> - linalg.matmul %A, %B, %C : - (memref<1584x1584xf32>, memref<1584x1584xf32>, memref<1584x1584xf32>) + linalg.matmul ins(%A, %B: memref<1584x1584xf32>, memref<1584x1584xf32>) + outs(%C: memref<1584x1584xf32>) return } @@ -57,7 +58,8 @@ func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // VECTOR-CONTRACTION: vector.contract // VECTOR-CONTRACTION-SAME: vector<1584x1584x1584xf32>, vector<1584x1584x1584xf32> into vector<1584x1584x1584xf32> - linalg.batch_matmul %A, %B, %C : - (memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) + linalg.batch_matmul + ins(%A, %B: memref<1584x1584x1584xf32>, memref<1584x1584x1584xf32>) + outs(%C: memref<1584x1584x1584xf32>) return } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -14,10 +14,11 @@ func @dot(%x: memref, %y: memref, %v: memref) { - linalg.dot %x, %y, %v { __internal_linalg_transform__ = "MEM" } : - (memref, - memref, - memref) + linalg.dot { __internal_linalg_transform__ = "MEM" } + ins(%x, %y: memref, + memref) + outs(%v: memref) + return } // CHECK-LABEL: func @dot @@ -36,10 +37,10 @@ func @matvec(%A: memref, %x: memref, %y: memref) { - linalg.matvec %A, %x, %y : - (memref, - memref, - memref) + linalg.matvec + ins(%A, %x: memref, + memref) + outs(%y: memref) return } // CHECK-LABEL: func @matvec @@ -48,15 +49,17 @@ // CHECK-DAG: %[[c6:.*]] = constant 6 : index // CHECK: scf.parallel {{.*}} step (%[[c5]]) // CHECK: scf.for {{.*}} step %[[c6]] -// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref, memref, memref) +// CHECK: linalg.matvec +// CHECK: ins({{.*}}, {{.*}}: memref, memref) +// CHECK: outs({{.*}}: memref) func @matmul(%A: memref, %B: memref, %C: memref) { - linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "MEM" } : - (memref, - memref, - memref) + linalg.matmul { __internal_linalg_transform__ = "MEM" } + ins(%A, %B: memref, + memref) + outs(%C: memref) return } // CHECK-LABEL: func @matmul @@ -85,10 +88,9 @@ // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c2]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c3]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c4]] { -// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : ( -// CHECK: memref, -// CHECK: memref, -// CHECK: memref) +// CHECK: linalg.matmul +// CHECK: ins({{.*}}, {{.*}}: memref, memref) +// CHECK: outs({{.*}}: memref) #matmul_trait = { args_in = 2, @@ -137,8 +139,9 @@ func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { - linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} : - (memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>) + linalg.matmul { __internal_linalg_transform__ = "VECTORIZE"} + ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) + outs(%C: memref<8x32xf32>) return } // CHECK-LABEL: func @vectorization_test_2 @@ -236,10 +239,10 @@ func @matvec_perm(%A: memref, %x: memref, %y: memref) { - linalg.matvec %A, %x, %y {__internal_linalg_transform__ = "__with_perm__"} : - (memref, - memref, + linalg.matvec {__internal_linalg_transform__ = "__with_perm__"} + ins(%A, %x: memref, memref) + outs(%y: memref) return } // CHECK-LABEL: func @matvec_perm @@ -248,15 +251,17 @@ // CHECK-DAG: %[[c6:.*]] = constant 6 : index // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]] // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] -// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref, memref, memref) +// CHECK: linalg.matvec +// CHECK: ins({{.*}}, {{.*}}: memref, memref) +// CHECK: outs({{.*}}: memref) func @matmul_perm(%A: memref, %B: memref, %C: memref) { - linalg.matmul %A, %B, %C {__internal_linalg_transform__ = "__with_perm__"} : - (memref, - memref, + linalg.matmul {__internal_linalg_transform__ = "__with_perm__"} + ins(%A, %B: memref, memref) + outs(%C : memref) return } // CHECK-LABEL: func @matmul_perm @@ -279,10 +284,9 @@ // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c20]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c30]] { // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c40]] { -// CHECK: linalg.matmul {{.*}}, {{.*}}, {{.*}} : ( -// CHECK: memref, -// CHECK: memref, -// CHECK: memref) +// CHECK: linalg.matmul +// CHECK: ins({{.*}}, {{.*}}: memref, memref) +// CHECK: outs({{.*}}: memref) func @promote_subview_matmul(%arg0: memref, %arg1: memref, @@ -304,10 +308,10 @@ memref to memref %5 = subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : memref to memref - linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_views_"} : - (memref, - memref, - memref) + linalg.matmul {__internal_linalg_transform__ = "_promote_views_"} + ins(%3, %4: memref, + memref) + outs(%5: memref) } } } @@ -336,8 +340,9 @@ // CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref // CHECK: linalg.copy(%[[s1]], %[[l1]]) : memref, memref // CHECK: linalg.copy(%[[s2]], %[[l2]]) : memref, memref -// CHECK: linalg.matmul %[[v0]], %[[v1]], %[[v2]] : -// CHECK: (memref, memref, memref) +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref, memref) +// CHECK-SAME: outs(%[[v2]] : memref) func @promote_first_subview_matmul(%arg0: memref, %arg1: memref, @@ -359,10 +364,10 @@ memref to memref %5 = std.subview %arg2[%arg3, %arg4][%c2000, %c3000][%c1, %c1] : memref to memref - linalg.matmul %3, %4, %5 {__internal_linalg_transform__ = "_promote_first_view_"} : - (memref, - memref, - memref) + linalg.matmul {__internal_linalg_transform__ = "_promote_first_view_"} + ins(%3, %4: memref, + memref) + outs(%5: memref) } } } @@ -391,10 +396,9 @@ // CHECK: linalg.copy(%[[s0]], %[[l0]]) : memref, memref // CHECK-NOT: linalg.copy(%[[s1]], %[[l1]]) : memref, memref // CHECK-NOT: linalg.copy(%[[s2]], %[[l2]]) : memref, memref^ -// CHECK: linalg.matmul %[[v0]], %[[s1]], %[[s2]] : -// CHECK: (memref, -// CHECK: memref, -// CHECK: memref) +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[v0]], %[[s1]] : memref, memref) +// CHECK-SAME: outs(%[[s2]] : memref) func @aligned_promote_fill(%arg0: memref) { %c2000 = constant 2000 : index @@ -421,8 +425,9 @@ func @tile_permute_parallel_loop(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matmul %arg0, %arg1, %arg2 {__internal_linalg_transform__ = "par__with_perm__"} - : (memref, memref, memref) + linalg.matmul {__internal_linalg_transform__ = "par__with_perm__"} + ins(%arg0, %arg1: memref, memref) + outs(%arg2: memref) return } // CHECK-LABEL: func @tile_permute_parallel_loop diff --git a/mlir/test/IR/slice.mlir b/mlir/test/IR/slice.mlir --- a/mlir/test/IR/slice.mlir +++ b/mlir/test/IR/slice.mlir @@ -5,8 +5,10 @@ %b = alloc(%arg2, %arg1) : memref %c = alloc(%arg0, %arg1) : memref %d = alloc(%arg0, %arg1) : memref - linalg.matmul %a, %b, %c : (memref, memref, memref) - linalg.matmul %a, %b, %d : (memref, memref, memref) + linalg.matmul ins(%a, %b : memref, memref) + outs(%c : memref) + linalg.matmul ins(%a, %b : memref, memref) + outs(%d : memref) dealloc %c : memref dealloc %b : memref dealloc %a : memref diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -308,6 +308,25 @@ return failure(); return success(); } +static ParseResult +parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, + Type optOperandType, + const SmallVectorImpl &varOperandTypes) { + if (parser.parseKeyword("type_refs_capture")) + return failure(); + + Type operandType2, optOperandType2; + SmallVector varOperandTypes2; + if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, + varOperandTypes2)) + return failure(); + + if (operandType != operandType2 || optOperandType != optOperandType2 || + varOperandTypes != varOperandTypes2) + return failure(); + + return success(); +} static ParseResult parseCustomDirectiveOperandsAndTypes( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, @@ -365,6 +384,14 @@ printer << ", " << optOperandType; printer << " -> (" << varOperandTypes << ")"; } +static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, + Type operandType, + Type optOperandType, + TypeRange varOperandTypes) { + printer << " type_refs_capture "; + printCustomDirectiveResults(printer, operandType, optOperandType, + varOperandTypes); +} static void printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand, Value optOperand, OperandRange varOperands, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -53,8 +53,6 @@ let results = (outs TensorOf<[ComplexF64]>); } -def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; - def TupleOp : TEST_Op<"tuple_32_bit"> { let results = (outs TupleOf<[I32, F32]>); } @@ -1518,6 +1516,22 @@ }]; } +def FormatCustomDirectiveResultsWithTypeRefs + : TEST_Op<"format_custom_directive_results_with_type_refs", + [AttrSizedResultSegments]> { + let results = (outs AnyType:$result, Optional:$optResult, + Variadic:$varResults); + let assemblyFormat = [{ + custom( + type($result), type($optResult), type($varResults) + ) + custom( + type_ref($result), type_ref($optResult), type_ref($varResults) + ) + attr-dict + }]; +} + def FormatCustomDirectiveSuccessors : TEST_Op<"format_custom_directive_successors", [Terminator]> { let successors = (successor AnySuccessor:$successor, diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -51,7 +51,8 @@ %B = view %bB[%c0][%c16] : memref to memref %C = view %bC[%c0][] : memref to memref - linalg.dot %A, %B, %C : (memref, memref, memref) + linalg.dot ins(%A, %B : memref, memref) + outs(%C : memref) %res = load %C[] : memref dealloc %bC : memref @@ -83,7 +84,8 @@ %B = view %bB[%c0][%c16, %c2] : memref to memref %C = view %bC[%c0][%c2, %c2] : memref to memref - linalg.matmul %A, %B, %C : (memref, memref, memref) + linalg.matmul ins(%A, %B : memref, memref) + outs(%C : memref) %res = load %C[%c0, %c1] : memref dealloc %bC : memref diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -1,9 +1,9 @@ // RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS // RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL -// ODS-LABEL: def Test1Op : LinalgNamedStructured_Op<"test1", [ -// ODS-NEXT: NInputs<2> -// ODS-NEXT: NOutputs<1> +// ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1", [ +// ODS-NEXT: NamedStructuredOpTrait +// ODS-NEXT: AttrSizedOperandSegments // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // // IMPL-LABEL: ArrayAttr Test1Op::iterator_types() { @@ -25,9 +25,9 @@ C(m) = std_addf(std_mulf(A(m, k), B(k))); } -// ODS-LABEL: def Test2Op : LinalgNamedStructured_Op<"test2", [ -// ODS-NEXT: NInputs<2> -// ODS-NEXT: NOutputs<1> +// ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [ +// ODS-NEXT: NamedStructuredOpTrait +// ODS-NEXT: AttrSizedOperandSegments // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // // IMPL-LABEL: ArrayAttr Test2Op::iterator_types() { @@ -49,9 +49,9 @@ C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } -// ODS-LABEL: def Test3Op : LinalgNamedStructured_Op<"test3", [ -// ODS-NEXT: NInputs<2> -// ODS-NEXT: NOutputs<1> +// ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [ +// ODS-NEXT: NamedStructuredOpTrait +// ODS-NEXT: AttrSizedOperandSegments // ODS-NEXT: SingleBlockImplicitTerminator<"YieldOp"> // // IMPL-LABEL: ArrayAttr Test3Op::iterator_types() { diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -230,8 +230,66 @@ def DirectiveTypeZOperandInvalidI : TestFormat_Op<"type_operand_invalid_i", [{ type($result) type($result) }]>, Results<(outs I64:$result)>; + +//===----------------------------------------------------------------------===// +// type_ref + +// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidC : TestFormat_Op<"type_ref_operand_invalid_c", [{ + type_ref($operand) type(operands) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'operands' 'type_ref' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidD : TestFormat_Op<"type_ref_operand_invalid_d", [{ + type_ref(operands) type($operand) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidE : TestFormat_Op<"type_ref_operand_invalid_e", [{ + type_ref($operand) type($operand) +}]>, Arguments<(ins I64:$operand)>; +// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidG : TestFormat_Op<"type_ref_operand_invalid_g", [{ + type_ref($result) type(results) +}]>, Results<(outs I64:$result)>; +// CHECK: error: 'results' 'type_ref' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidH : TestFormat_Op<"type_ref_operand_invalid_h", [{ + type_ref(results) type($result) +}]>, Results<(outs I64:$result)>; +// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive +def DirectiveTypeZZTypeRefOperandInvalidI : TestFormat_Op<"type_ref_operand_invalid_i", [{ + type_ref($result) type($result) +}]>, Results<(outs I64:$result)>; + +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandB : TestFormat_Op<"type_ref_operand_valid_b", [{ + type_ref(operands) attr-dict +}]>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandD : TestFormat_Op<"type_ref_operand_valid_d", [{ + type(operands) type_ref($operand) attr-dict +}]>, Arguments<(ins I64:$operand)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandE : TestFormat_Op<"type_ref_operand_valid_e", [{ + type($operand) type_ref($operand) attr-dict +}]>, Arguments<(ins I64:$operand)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandF : TestFormat_Op<"type_ref_operand_valid_f", [{ + type(results) type_ref(results) attr-dict +}]>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandG : TestFormat_Op<"type_ref_operand_valid_g", [{ + type($result) type_ref(results) attr-dict +}]>, Results<(outs I64:$result)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandH : TestFormat_Op<"type_ref_operand_valid_h", [{ + type(results) type_ref($result) attr-dict +}]>, Results<(outs I64:$result)>; +// CHECK-NOT: error +def DirectiveTypeZZTypeRefOperandI : TestFormat_Op<"type_ref_operand_valid_i", [{ + type($result) type_ref($result) attr-dict +}]>, Results<(outs I64:$result)>; + // CHECK-NOT: error: -def DirectiveTypeZOperandValid : TestFormat_Op<"type_operand_valid", [{ +def DirectiveTypeZZZOperandValid : TestFormat_Op<"type_operand_valid", [{ type(operands) type(results) attr-dict }]>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -237,6 +237,12 @@ // CHECK: test.format_custom_directive_results : i64 -> (i64) test.format_custom_directive_results : i64 -> (i64) +// CHECK: test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64) +test.format_custom_directive_results_with_type_refs : i64, i64 -> (i64) type_refs_capture : i64, i64 -> (i64) + +// CHECK: test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64) +test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64) + func @foo() { // CHECK: test.format_custom_directive_successors ^bb1, ^bb2 test.format_custom_directive_successors ^bb1, ^bb2 diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -980,7 +980,7 @@ /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. void printODS(llvm::raw_ostream &os, StringRef cppOpName, - StringRef linalgOpName); + StringRef linalgOpName, ComprehensionParsingState &state); /// Print the C++ StructuredOpsInterface impl of `iterator_types`. void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, @@ -1419,7 +1419,8 @@ return failure(); } if (genODSDecl) { - printODS(os, cppOpName, tcName); + auto &state = perComprehensionStates.back(); + printODS(os, cppOpName, tcName, state); os << "\n"; } if (genODSImpl) { @@ -1442,31 +1443,72 @@ /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, - StringRef linalgOpName) { - const char *header = R"FMT( def {0} : LinalgNamedStructured_Op<"{1}", [ - NInputs<{2}>, - NOutputs<{3}>, + StringRef linalgOpName, + ComprehensionParsingState &state) { + const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [ + NamedStructuredOpTrait, + AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> { - let arguments = (ins Variadic:$views); + let arguments = (ins Variadic:$inputs, + Variadic:$output_buffers, + Variadic:$init_tensors); let results = (outs Variadic:$output_tensors); - let regions = (region SizedRegion<1>:$region); - let builders = [OpBuilder< - "OpBuilder &b, OperationState &result, TypeRange outputTypes, " - # "ValueRange views", + let regions = (region AnyRegion:$region); + + let builders = [ OpBuilder< + "OpBuilder &b, OperationState &result," + "ValueRange inputs, ValueRange outputBuffers", + [{{ + result.addOperands(inputs); + result.addOperands(outputBuffers); + result.addAttribute( + "operand_segment_sizes", + b.getI32VectorAttr({{static_cast(inputs.size()), + static_cast(outputBuffers.size()), + static_cast(0)})); + buildNamedStructuredOpRegionAndAttributes<{0}>( + b, + result, + TypeRange(inputs), + TypeRange(outputBuffers), + TypeRange(), + TypeRange()); + }]>, OpBuilder< + "OpBuilder &b, OperationState &result, TypeRange resultTensorTypes," + "ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors", [{{ - result.addOperands(views); - result.addTypes(outputTypes); + result.addOperands(inputs); + result.addOperands(outputBuffers); + result.addOperands(initTensors); + result.addTypes(resultTensorTypes); + result.addAttribute( + "operand_segment_sizes", + b.getI32VectorAttr({{static_cast(inputs.size()), + static_cast(outputBuffers.size()), + static_cast(initTensors.size())})); buildNamedStructuredOpRegionAndAttributes<{0}>( - b, result, TypeRange(views), outputTypes); + b, + result, + TypeRange(inputs), + TypeRange(outputBuffers), + TypeRange(initTensors), + resultTensorTypes); }]> ]; - let parser = [{ - return ::parseNamedStructuredOp<{0}>(parser, result); - }]; + let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; + let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }]; + let verifier = [{{ return ::verifyNamedStructuredOp(*this); }]; + let hasFolder = 1; + let hasCanonicalizer = 1; + let extraClassDeclaration = [{{ + // Auto-generated. ArrayAttr iterator_types(); ArrayAttr indexing_maps(); static void regionBuilder(Block &block); + + // Generic methods. + static unsigned getNumRegionArgs() {{ return {4}; } std::string getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } @@ -1481,7 +1523,8 @@ nInputs++; } - os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs); + os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs, + state.orderedTensorArgs.size()); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`. @@ -1680,7 +1723,7 @@ } // Include the proper Linalg header for end-to-end tblgen testing without - // resorting to non-portable shgell manipulations. + // resorting to non-portable shell manipulations. if (testEmitIncludeTdHeader) output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\""; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -53,6 +53,7 @@ ResultsDirective, SuccessorsDirective, TypeDirective, + TypeRefDirective, /// This element is a literal. Literal, @@ -230,7 +231,19 @@ /// The operand that is used to format the directive. std::unique_ptr operand; }; -} // end anonymous namespace + +/// This class represents the `type_ref` directive. +class TypeRefDirective + : public DirectiveElement { +public: + TypeRefDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + Element *getOperand() const { return operand.get(); } + +private: + /// The operand that is used to format the directive. + std::unique_ptr operand; +}; +} // namespace //===----------------------------------------------------------------------===// // LiteralElement @@ -805,6 +818,19 @@ << llvm::formatv( " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n", name); + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind lengthKind; + StringRef name = getTypeListName(dir->getOperand(), lengthKind); + // Refer to the previously encountered TypeDirective for name. + // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration + // to properly track the types that will be parsed and pushed later on. + if (lengthKind != ArgumentLengthKind::Single) + body << " const ::mlir::SmallVector<::mlir::Type, 1> &" << name + << "TypesRef(" << name << "Types);\n"; + else + body << llvm::formatv( + " ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n", + name); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind ignored; body << " ::llvm::ArrayRef<::mlir::Type> " @@ -844,6 +870,15 @@ else body << llvm::formatv("{0}Successor", name); + } else if (auto *dir = dyn_cast(¶m)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv("{0}TypesRef", listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv("{0}TypeRef", listName); + else + body << formatv("{0}RawTypesRef[0]", listName); } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -876,6 +911,16 @@ "{0}Operand;\n", operand->getVar()->name); } + } else if (auto *dir = dyn_cast(¶m)) { + // Reference to an optional which may or may not have been set. + // Retrieve from vector if not empty. + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv( + " ::mlir::Type {0}TypeRef = {0}TypesRef.empty() " + "? Type() : {0}TypesRef[0];\n", + listName); } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -907,6 +952,9 @@ body << llvm::formatv(" if ({0}Operand.hasValue())\n" " {0}Operands.push_back(*{0}Operand);\n", var->name); + } else if (auto *dir = dyn_cast(¶m)) { + // In the `type_ref` case, do not parse a new Type that needs to be added. + // Just do nothing here. } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1101,6 +1149,15 @@ } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full"); + } else if (auto *dir = dyn_cast(element)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(dir->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Variadic) + body << llvm::formatv(variadicTypeParserCode, listName); + else if (lengthKind == ArgumentLengthKind::Optional) + body << llvm::formatv(optionalTypeParserCode, listName); + else + body << formatv(typeParserCode, listName); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1431,6 +1488,17 @@ } else if (auto *successor = dyn_cast(¶m)) { body << successor->getVar()->name << "()"; + } else if (auto *dir = dyn_cast(¶m)) { + auto *typeOperand = dir->getOperand(); + auto *operand = dyn_cast(typeOperand); + auto *var = operand ? operand->getVar() + : cast(typeOperand)->getVar(); + if (var->isVariadic()) + body << var->name << "().getTypes()"; + else if (var->isOptional()) + body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); + else + body << var->name << "().getType()"; } else if (auto *dir = dyn_cast(¶m)) { auto *typeOperand = dir->getOperand(); auto *operand = dyn_cast(typeOperand); @@ -1604,6 +1672,9 @@ } else if (auto *dir = dyn_cast(element)) { body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; + } else if (auto *dir = dyn_cast(element)) { + body << " p << "; + genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; } else if (auto *dir = dyn_cast(element)) { body << " p.printFunctionalType("; genTypeOperandPrinter(dir->getInputs(), body) << ", "; @@ -1670,6 +1741,7 @@ kw_results, kw_successors, kw_type, + kw_type_ref, keyword_end, // String valued tokens. @@ -1874,6 +1946,7 @@ .Case("results", Token::kw_results) .Case("successors", Token::kw_successors) .Case("type", Token::kw_type) + .Case("type_ref", Token::kw_type_ref) .Default(Token::identifier); return Token(kind, str); } @@ -1994,8 +2067,9 @@ LogicalResult parseSuccessorsDirective(std::unique_ptr &element, llvm::SMLoc loc, bool isTopLevel); LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel); - LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element); + bool isTopLevel, bool isTypeRef = false); + LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, + bool isTypeRef = false); //===--------------------------------------------------------------------===// // Lexer Utilities @@ -2440,6 +2514,8 @@ return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); case Token::kw_successors: return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel); + case Token::kw_type_ref: + return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true); case Token::kw_type: return parseTypeDirective(element, dirTok, isTopLevel); @@ -2505,7 +2581,10 @@ return ::mlir::success(); }; for (auto &ele : elements) { - if (auto *typeEle = dyn_cast(ele.get())) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (failed(checkTypeOperand(typeEle->getOperand()))) + return failure(); + } else if (auto *typeEle = dyn_cast(ele.get())) { if (failed(checkTypeOperand(typeEle->getOperand()))) return ::mlir::failure(); } else if (auto *typeEle = dyn_cast(ele.get())) { @@ -2565,7 +2644,7 @@ // Literals, custom directives, and type directives may be used, // but they can't anchor the group. .Case([&](Element *) { + OptionalElement, TypeRefDirective, TypeDirective>([&](Element *) { if (isAnchor) return emitError(childLoc, "only variables can be used to anchor " "an optional group"); @@ -2628,6 +2707,13 @@ // After parsing all of the elements, ensure that all type directives refer // only to variables. for (auto &ele : elements) { + if (auto *typeEle = dyn_cast(ele.get())) { + if (!isa(typeEle->getOperand())) { + return emitError(curLoc, + "type_ref directives within a custom directive " + "may only refer to variables"); + } + } if (auto *typeEle = dyn_cast(ele.get())) { if (!isa(typeEle->getOperand())) { return emitError(curLoc, "type directives within a custom directive " @@ -2649,8 +2735,8 @@ return ::mlir::failure(); // Verify that the element can be placed within a custom directive. - if (!isa(parameters.back().get())) { + if (!isa(parameters.back().get())) { return emitError(childLoc, "only variables and types may be used as " "parameters to a custom directive"); } @@ -2727,22 +2813,26 @@ LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel) { + bool isTopLevel, bool isTypeRef) { llvm::SMLoc loc = tok.getLoc(); if (!isTopLevel) return emitError(loc, "'type' is only valid as a top-level directive"); std::unique_ptr operand; if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || - failed(parseTypeDirectiveOperand(operand)) || + failed(parseTypeDirectiveOperand(operand, isTypeRef)) || failed(parseToken(Token::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); - element = std::make_unique(std::move(operand)); + if (isTypeRef) + element = std::make_unique(std::move(operand)); + else + element = std::make_unique(std::move(operand)); return ::mlir::success(); } LogicalResult -FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element) { +FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, + bool isTypeRef) { llvm::SMLoc loc = curToken.getLoc(); if (failed(parseElement(element, /*isTopLevel=*/false))) return ::mlir::failure(); @@ -2752,23 +2842,36 @@ if (auto *var = dyn_cast(element.get())) { unsigned opIdx = var->getVar() - op.operand_begin(); - if (fmt.allOperandTypes || seenOperandTypes.test(opIdx)) + if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + "' is already bound"); + if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx))) + return emitError(loc, "'type_ref' of '" + var->getVar()->name + + "' is not bound by a prior 'type' directive"); seenOperandTypes.set(opIdx); } else if (auto *var = dyn_cast(element.get())) { unsigned resIdx = var->getVar() - op.result_begin(); - if (fmt.allResultTypes || seenResultTypes.test(resIdx)) + if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + "' is already bound"); + if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx))) + return emitError(loc, "'type_ref' of '" + var->getVar()->name + + "' is not bound by a prior 'type' directive"); seenResultTypes.set(resIdx); } else if (isa(&*element)) { - if (fmt.allOperandTypes || seenOperandTypes.any()) + if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any())) return emitError(loc, "'operands' 'type' is already bound"); + if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all())) + return emitError( + loc, + "'operands' 'type_ref' is not bound by a prior 'type' directive"); fmt.allOperandTypes = true; } else if (isa(&*element)) { - if (fmt.allResultTypes || seenResultTypes.any()) + if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any())) return emitError(loc, "'results' 'type' is already bound"); + if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all())) + return emitError( + loc, "'results' 'type_ref' is not bound by a prior 'type' directive"); fmt.allResultTypes = true; } else { return emitError(loc, "invalid argument to 'type' directive");