diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -13,7 +13,7 @@ #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" @@ -38,12 +38,12 @@ } /// A StructuredIndexed represents a captured value that can be indexed and -/// passed to the `makeLinalgGenericOp`. It allows writing intuitive index +/// passed to the `makeGenericLinalgOp`. It allows writing intuitive index /// expressions such as: /// /// ``` /// StructuredIndexed A(vA), B(vB), C(vC); -/// makeLinalgGenericOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); +/// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); /// ``` struct StructuredIndexed { StructuredIndexed(Value v) : value(v) {} @@ -68,7 +68,7 @@ inline void defaultRegionBuilder(ArrayRef args) {} -Operation *makeLinalgGenericOp( +Operation *makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, function_ref)> regionBuilder = diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -19,9 +19,10 @@ let name = "linalg"; let description = [{ The `linalg` dialect groups together a set of types, operations and - transformations that are useful to implement a structured abstraction where - ops can lower to scalar load/store and operations or to more general library - calls. + transformations that are useful to implement a structured abstraction on + buffers and tensors. These abstractions are useful for transformations and + can lower to scalar load/store and other operations or to more general + library calls. The `linalg` dialect manipulates the following types and operations: @@ -67,12 +68,13 @@ A set of payload carrying operations that implement the [structured ops]( https://docs.google.com/presentation/d/1P-j1GrH6Q5gLBjao0afQ-GfvcAeF-QU4GXXeSy0eJ9I/edit#slide=id.p ) - abstraction on buffers. `linalg` has `2` generic operations `linalg.generic` - and `linalg.indexed_generic` for expressing custom operations. This is - subject to further evolution as transformations and analyses continue to be - developed. + abstraction on tensors and buffers. `linalg` has `2` generic operations + `linalg.generic` and `linalg.indexed_generic` for expressing custom + operations. + This is subject to further evolution as transformations and analyses + continue to be developed. - Additionally, `linalg` provides some common named operations: + Additionally, `linalg` provides some commonly named operations: * `linalg.copy`, * `linalg.fill`, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -59,7 +59,8 @@ } def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, - Arguments<(ins AnyStridedMemRef:$view, Variadic>:$indexings)>, + Arguments<(ins AnyStridedMemRef:$view, + Variadic>:$indexings)>, Results<(outs AnyStridedMemRef)> { let summary = "Produce a rank-reduced `subview` of a base `view`."; let description = [{ @@ -108,11 +109,11 @@ let extraClassDeclaration = [{ enum { FirstIndexingOperand = 1 }; - unsigned getRank() { return getViewType().getRank(); } - Type getElementType() { return getViewType().getElementType(); } - MemRefType getViewType() { return getType().cast(); } + unsigned getRank() { return getShapedType().getRank(); } + Type getElementType() { return getShapedType().getElementType(); } + ShapedType getShapedType() { return getType().cast(); } unsigned getBaseViewRank() { return getBaseViewType().getRank(); } - MemRefType getBaseViewType() { return view()->getType().cast(); } + ShapedType getBaseViewType() { return view()->getType().cast();} // Get the underlying indexing at a given rank. Value indexing(unsigned rank) { return *(indexings().begin() + rank); } @@ -131,7 +132,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>, Results<(outs AnyStridedMemRef)> { - let summary = "transpose operation produces a new strided memref (metadata-only)"; + let summary = "`transpose` produces a new strided memref (metadata-only)"; let description = [{ The `linalg.transpose` op produces a strided memref whose sizes and strides are a permutation of the original `view`. This is a pure metadata @@ -151,14 +152,14 @@ let verifier = [{ if (!permutation().isPermutation()) return emitOpError("expected a permutation map"); - if (permutation().getNumDims() != getViewType().getRank()) + if (permutation().getNumDims() != getShapedType().getRank()) return emitOpError("expected a permutation map of same rank as the view"); return success(); }]; let extraClassDeclaration = [{ static StringRef getPermutationAttrName() { return "permutation"; } - MemRefType getViewType() { return view()->getType().cast(); } + ShapedType getShapedType() { return view()->getType().cast(); } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -29,7 +29,7 @@ class NOutputs : NativeOpTrait<"linalg::NOutputs<" # !cast(args_out) # ">::Impl"> {} -def ViewTraits : NativeOpTrait<"linalg::ViewTraits">; +def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; // The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp' // interface. @@ -89,23 +89,32 @@ "Value ", "getOutput", (ins "unsigned":$i) >, InterfaceMethod<[{ - Query the index of the given input value, or `None` if the value is not - an input. + Return the index of the given input value `v`, or `None` if the value is + not an input. }], - "llvm::Optional", "getIndexOfInput", (ins "Value ":$view) + "llvm::Optional", "getIndexOfInput", (ins "Value ":$v) >, InterfaceMethod<[{ Query the index of the given view value, or `None` if the value is not - an view. + a view. }], "llvm::Optional", "getIndexOfOutput", (ins "Value ":$view) >, InterfaceMethod<[{ - Query the type of the input view at the given index. - }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>, + Query the type of the input shape at the given index. + }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Query the type of the output view at the given index. - }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>, + }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + InterfaceMethod<[{ + Query the subset of input operands that are of ranked tensor type. + }], "SmallVector", "getInputTensorTypes">, + InterfaceMethod<[{ + Query the subset of output operands that are of ranked tensor type. + }], "SmallVector", "getOutputTensorTypes">, StaticInterfaceMethod<[{ Create an operation of the current type with the given location, @@ -147,7 +156,7 @@ // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op { + !listconcat(props, [StructuredOpTraits, LinalgStructuredInterface])> { let parser = [{ return parseLinalgStructuredOp(parser, result); }]; let printer = [{ printLinalgStructuredOp(p, *this); }]; } @@ -340,7 +349,7 @@ ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions; i.e. // [ b, xs, q] in the TF notation above. - unsigned nPar = getOutputViewType(0).getRank(); + unsigned nPar = getOutputShapedType(0).getRank(); unsigned nRed = getNumInputFeatureDimensions(); // Window loops are a special kind of reduction that is never tiled or // parallelized across; i.e. [zs] in the TF notation above whose number @@ -374,8 +383,17 @@ let verifier = [{ return ::verify(*this); }]; } +def LinalgOperand: Type< + Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>; + +class LinalgOperandOfRank: Type< + And<[ + LinalgOperand.predicate, + CPred<"$_self.cast().getRank() == " # rank>] + >>; + class GenericOpBase : LinalgStructuredBase_Op { - let arguments = (ins Variadic:$views, + let arguments = (ins Variadic:$views, I64Attr:$args_in, I64Attr:$args_out, AffineMapArrayAttr:$indexing_maps, @@ -383,6 +401,7 @@ OptionalAttr:$doc, OptionalAttr:$fun, OptionalAttr:$library_call); + let results = (outs Variadic:$output_tensors); let regions = (region AnyRegion:$region); let extraClassDeclaration = [{ SmallVector linalgTraitAttrNames() { @@ -511,6 +530,28 @@ } } ``` + + To allow progressive lowering from the value world (a.k.a tensor values) to + the buffer world (a.k.a memref values), a `linalg.generic` op accepts + mixing input and output ranked tensor values with input and output memrefs. + + ```mlir + %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + tensor, + memref, + tensor + -> (tensor) + ``` + + In this case, the number of return values must match the number of output + tensor arguments. The semantics is that the `linalg.generic` op + produces (i.e. allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most + transformations can be applied. In particular, transformations that create + control flow around linalg.generic operations are not expected to mix with + tensors because SSA values do not escape naturally. Still, transformations + and rewrites that take advantage of tensor SSA values are expected to be + useful and will be added in the near future. }]; let verifier = [{ return ::verify(*this); }]; } @@ -555,9 +596,11 @@ Example: Defining a #matmul_trait attribute in MLIR can be done as follows: ```mlir - func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) + func @fma(%offset_m: index, %offset_n: index, %offset_k: index, + %a: f32, %b: f32, %c: f32) -> f32 { + "some_optional_condition"(%offset_m, %offset_n, %offset_k) %d = mulf %a, %b: f32 %e = addf %c, %d: f32 return %e: f32 @@ -587,7 +630,7 @@ This may lower to either: ```mlir - call @linalg_matmul(%A, %B, %C) : + call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) : (memref, memref, memref) @@ -609,6 +652,29 @@ } } ``` + + To allow progressive lowering from the value world (a.k.a tensor values) to + the buffer world (a.k.a memref values), a `linalg.indexed_generic` op + accepts mixing input and output ranked tensor values with input and output + memrefs. + + ```mlir + %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} + : tensor, + memref, + tensor + -> (tensor) + ``` + + In this case, the number of return values must match the number of output + tensor arguments. The semantics is that the `linalg.indexed_generic` op + produces (i.e. allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most + transformations can be applied. In particular, transformations that create + control flow around linalg.generic operations are not expected to mix with + tensors because SSA values do not escape naturally. Still, transformations + and rewrites that take advantage of tensor SSA values are expected to be + useful and will be added in the near future. }]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -20,7 +20,7 @@ namespace linalg { /// This class provides the API for ops that are known to have a specified -/// number of inputs, all passed as operands. This is used as a trait like this: +/// number of inputs, all passed as operands. Use as a trait as follows: /// /// class DotOp : public Op::Impl> { /// @@ -34,7 +34,7 @@ }; /// This class provides the API for ops that are known to have a specified -/// number of inputs, all passed as operands. This is used as a trait like this: +/// number of outputs, all passed as operands. Use as a trait as follows: /// /// class DotOp : public Op::Impl> { /// @@ -47,78 +47,101 @@ }; }; -/// This class provides the API for ops that are known to operate on views. This -/// trait must be used in conjunction with an op definition or a trait that -/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a -/// trait like this: +/// This class provides the API for structured ops that are known to operate on +/// buffers or tensors. This trait must be used in conjunction with an op +/// definition or a trait that provides the methods `getNumInputs` and +/// `getNumOutputs`. Use as a trait as follows: /// -/// class DotOp : public Op { +/// class DotOp : public Op { /// template -class ViewTraits : public OpTrait::TraitBase { +class StructuredOpTraits + : public OpTrait::TraitBase { private: - /// Return the number of input views. For internal use only. + /// Return the number of inputs. For internal use only. unsigned nInputs() { return cast(this->getOperation()).getNumInputs(); } - /// Return the number of input views. For internal use only. + /// Return the number of outputs. For internal use only. unsigned nOutputs() { return cast(this->getOperation()).getNumOutputs(); } public: - /// Return the `i`-th input view. + /// Return the `i`-th input value. Value getInput(unsigned i) { assert(i < nInputs()); return this->getOperation()->getOperand(i); } - /// Return the index of `view` in the list of input views if found, llvm::None + /// Return the index of `value` in the list of inputs if found, llvm::None /// otherwise. - Optional getIndexOfInput(Value view) { - auto it = llvm::find(getInputs(), view); + Optional getIndexOfInput(Value value) { + auto it = llvm::find(getInputs(), value); if (it != getInputs().end()) return it - getInputs().begin(); return llvm::None; } - /// Return the `i`-th input view type. - MemRefType getInputViewType(unsigned i) { - return getInput(i)->getType().template cast(); + /// Return the `i`-th input buffer type. + ShapedType getInputShapedType(unsigned i) { + return getInput(i)->getType().template cast(); } - /// Return the range over input views. + /// Return the range over inputs. Operation::operand_range getInputs() { auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + nInputs()}; } - /// Return the `i`-th output view. + /// Return the `i`-th output. Value getOutput(unsigned i) { return this->getOperation()->getOperand(nInputs() + i); } - /// Return the index of `view` in the list of output views if found, + /// Return the index of `value` in the list of output values if found, /// llvm::None otherwise. - Optional getIndexOfOutput(Value view) { - auto it = llvm::find(getOutputs(), view); + Optional getIndexOfOutput(Value value) { + auto it = llvm::find(getOutputs(), value); if (it != getOutputs().end()) return it - getOutputs().begin(); return llvm::None; } - /// Return the `i`-th output view type. - MemRefType getOutputViewType(unsigned i) { - return getOutput(i)->getType().template cast(); - } - /// Return the range over output views. + /// Return the `i`-th output buffer type. + ShapedType getOutputShapedType(unsigned i) { + return getOutput(i)->getType().template cast(); + } + /// Query whether the op has only MemRef input and outputs. + bool hasBufferSemantics() { + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputsAndOutputs(), + [](Value v) { return v.getType().isa(); }); + } + /// Query the subset of input operands that are of ranked tensor type. + SmallVector getInputTensorTypes() { + SmallVector res; + for (auto type : getInputs().getTypes()) + if (auto t = type.template dyn_cast()) + res.push_back(t); + return res; + } + /// Query the subset of output operands that are of ranked tensor type. + SmallVector getOutputTensorTypes() { + SmallVector res; + for (auto type : getOutputs().getTypes()) + if (auto t = type.template dyn_cast()) + res.push_back(t); + return res; + } + /// Return the range over outputs. Operation::operand_range getOutputs() { auto range = this->getOperation()->getOperands(); return {range.begin() + nInputs(), range.begin() + getNumInputsAndOutputs()}; } - /// Return the number of input and output views. + /// Return the number of inputs and outputs. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the `i`-th view type. - MemRefType getViewType(unsigned i) { - return (i < nInputs()) ? getInputViewType(i) - : getOutputViewType(i - nInputs()); + /// Return the `i`-th buffer type. + ShapedType getShapedType(unsigned i) { + return (i < nInputs()) ? getInputShapedType(i) + : getOutputShapedType(i - nInputs()); } - /// Return the range over input and output views. + /// Return the range over inputs and outputs. Operation::operand_range getInputsAndOutputs() { auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + getNumInputsAndOutputs()}; @@ -143,8 +166,8 @@ cast(this->getOperation()).iterator_types()); } static LogicalResult verifyTrait(Operation *op) { - auto nViews = cast(op).getNumInputsAndOutputs(); - if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews))) + auto nOperands = cast(op).getNumInputsAndOutputs(); + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); return success(); } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -18,24 +18,24 @@ include "mlir/Dialect/AffineOps/AffineOps.td" def HasNoLinalgTransformMarker : CPred<[{ - !$0.getAttrOfType(LinalgTransforms::kLinalgTransformMarker) + !op.getAttrOfType(LinalgTransforms::kLinalgTransformMarker) }]>; class HasLinalgTransformMarker : CPred<[{ - $0.getAttrOfType( + op.getAttrOfType( LinalgTransforms::kLinalgTransformMarker) && - $0.getAttrOfType( + op.getAttrOfType( LinalgTransforms::kLinalgTransformMarker).getValue() == "}] # str # [{"}]>; class IsProducedByOpOfType : - CPred<"isProducedByOpOfType<" # str # ">($0, $1)">; + CPred<"isProducedByOpOfType<" # str # ">(op, $0)">; class AffineMapDomainHasDim : CPred<[{ - $0.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. + op.getAttrOfType(getIndexingMapsAttrName()).getValue()[0]. cast().getValue().getNumDims() ==}] # n # [{}]>; class HasOperandsOfType: CPred<[{ - llvm::any_of($0.getOperands(), + llvm::any_of(op.getOperands(), [](Value v) { return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp()); }) @@ -50,7 +50,7 @@ // patterns. class TileAndFuseLinalgOp< list sizes, list operandIndices, string value> : NativeCodeCall< - "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, $0, {" # + "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt.result # "}, {" # StrJoinInt.result # "}," # " \"" # value # "\")))" # " return matchFailure();">; @@ -67,7 +67,7 @@ // of elements as `sizes`. class TileLinalgOp sizes, string value, list permutation=[]> : NativeCodeCall< - "if (failed(tileLinalgOpAndSetMarker($_builder, $0, {" # + "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt.result # "}, \"" # value # "\", {" # StrJoinInt.result # "})))" # " return matchFailure();">; @@ -76,33 +76,37 @@ // Linalg to loop patterns. //===----------------------------------------------------------------------===// class LinalgOpToLoops : NativeCodeCall< - "if (failed(linalgOpToLoops<" # OpType # ">($_builder, $0))) " # + "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; class LinalgOpToAffineLoops : NativeCodeCall< - "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, $0))) " # + "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " # " return matchFailure();">; //===----------------------------------------------------------------------===// -// Linalg to vector contraction patterns. +// Linalg to vector patterns precondition and DRR. //===----------------------------------------------------------------------===// -class LinalgOpToVectorContraction : NativeCodeCall< - "if (failed(vectorizeGenericOp($_builder, $0))) " # - " return matchFailure();">; +def PreconditionVectorizeGenericLinalgOp : CPred< + "succeeded(vectorizeGenericLinalgOpPrecondition(op))">; +def VectorizeGenericLinalgOp : NativeCodeCall< + "vectorizeGenericLinalgOp($_builder, op)">; //===----------------------------------------------------------------------===// -// Linalg generic permutation patterns. +// Linalg generic permutation patterns precondition and DRR. //===----------------------------------------------------------------------===// +class PreconditionPermuteGenericLinalgOp permutation> : CPred< + "succeeded(permuteGenericLinalgOpPrecondition(op, {" # + StrJoinInt.result # "}))">; class PermuteGenericLinalgOp permutation, string value> : NativeCodeCall< - "if (failed(permuteGenericLinalgOp($_builder, $0, {" # - StrJoinInt.result # "}, \"" # value # "\"))) " # - " return matchFailure();">; + "permuteGenericLinalgOp($_builder, op, {" # StrJoinInt.result # + "}, \"" # value # "\")">; //===----------------------------------------------------------------------===// -// Linalg promote subview operands. +// Linalg promote subview operands precondition and DRR. //===----------------------------------------------------------------------===// -class LinalgOpPromoteSubviews : NativeCodeCall< - "if (failed(linalgOpPromoteSubviews($_builder, $0))) " # - " return matchFailure();">; +def PreconditionPromoteSubviewsLinalgOp : CPred< + "succeeded(promoteSubviewsLinalgOpPrecondition(op))">; +def PromoteSubviewsLinalgOp : NativeCodeCall< + "promoteSubviewsLinalgOp($_builder, op)">; #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -79,16 +79,24 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); /// Rewrite a linalg.generic into a suitable vector.contraction op. -LogicalResult vectorizeGenericOp(PatternRewriter &rewriter, Operation *op); +LogicalResult vectorizeGenericLinalgOpPrecondition(Operation *op); +SmallVector vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op); /// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` /// and `iterator_types` permutated according to `permutation`. -LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, - ArrayRef permutation, - StringRef linalgMarker); - -/// Promote std.subviews feeding linalg operations -LogicalResult linalgOpPromoteSubviews(PatternRewriter &rewriter, Operation *op); +LogicalResult +permuteGenericLinalgOpPrecondition(Operation *op, + ArrayRef permutation); +SmallVector permuteGenericLinalgOp(PatternRewriter &rewriter, + Operation *op, + ArrayRef permutation, + StringRef linalgMarker); + +/// Promote std.subviews feeding linalg operations. +LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op); +SmallVector promoteSubviewsLinalgOp(PatternRewriter &rewriter, + Operation *op); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h @@ -64,8 +64,9 @@ // // This will be extended in the future to support more advanced use cases than // simple pointwise ops. -Value unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, - ArrayRef targetShape); +SmallVector +unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, + ArrayRef targetShape); } // namespace vector } // namespace mlir diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -186,7 +186,8 @@ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) .cast(); - BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); + BaseViewConversionHelper desc( + lowering.convertType(sliceOp.getShapedType())); // TODO(ntv): extract sizes and emit asserts. SmallVector strides(memRefType.getRank()); @@ -215,7 +216,7 @@ desc.setOffset(baseOffset); // Corner case, no sizes or strides: early return the descriptor. - if (sliceOp.getViewType().getRank() == 0) + if (sliceOp.getShapedType().getRank() == 0) return rewriter.replaceOp(op, {desc}), matchSuccess(); Value zero = @@ -279,7 +280,7 @@ return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); BaseViewConversionHelper desc( - lowering.convertType(transposeOp.getViewType())); + lowering.convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -32,7 +32,7 @@ } } -Operation *mlir::edsc::makeLinalgGenericOp( +Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, ArrayRef outputs, function_ref)> regionBuilder, @@ -68,6 +68,7 @@ edsc::ScopedContext::getBuilder() .create( edsc::ScopedContext::getLocation(), + ArrayRef{}, // TODO(ntv): support tensors values, IntegerAttr::get(IntegerType::get(64, ctx), nInputs), IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), @@ -118,7 +119,7 @@ ValueHandle a(args[0]); linalg_yield(unaryOp(a)); }; - return makeLinalgGenericOp(iterTypes, {I}, {O}, fun); + return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); } Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, @@ -141,7 +142,7 @@ ValueHandle a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; - return makeLinalgGenericOp(iterTypes, {I1, I2}, {O}, fun); + return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); } Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, @@ -170,7 +171,7 @@ AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC); - return makeLinalgGenericOp( + return makeGenericLinalgOp( {IterType::Parallel, IterType::Parallel, IterType::Reduction}, {A({m, k}), B({k, n})}, {C({m, n})}, @@ -198,7 +199,7 @@ unsigned numDims = c.cast().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); // clang-format off - return makeLinalgGenericOp( + return makeGenericLinalgOp( {par, par, par, par, red, red, red}, { I({b, // Roundtrip to flattened form to serve as canonicalization and ensure @@ -231,7 +232,7 @@ bindDims(ctx, b, dm, c, h, w, kh, kw); unsigned numDims = kw.cast().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); - return makeLinalgGenericOp( + return makeGenericLinalgOp( {par, par, par, par, par, red, red}, { I({b, // Roundtrip to flattened form to serve as canonicalization and ensure diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -61,6 +61,10 @@ p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs(), attrNames); p << ": " << op.getOperandTypes(); + + auto outputTensorTypes = op.getResultTypes(); + if (!outputTensorTypes.empty()) + p << " -> " << outputTensorTypes; } static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); } @@ -92,6 +96,13 @@ if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(operandTypes)) return failure(); + // Generic ops may specify that a subset of its outputs are tensors. Such + // outputs are specified in the result type. + SmallVector tensorResultTypes; + if (parser.parseOptionalArrowTypeList(tensorResultTypes)) + return failure(); + if (!tensorResultTypes.empty()) + result.addTypes(tensorResultTypes); return parser.resolveOperands(operandsInfo, operandTypes, parser.getCurrentLocation(), result.operands); } @@ -103,13 +114,13 @@ auto nViews = op.getNumInputsAndOutputs(); auto nInputViews = op.getNumInputs(); if (block.getNumArguments() != nViews) - return op.emitError( + return op.emitOpError( "op expected number of block arguments to match number of views"); for (unsigned i = 0; i < nViews; ++i) { - auto viewType = op.getViewType(i); + auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i)->getType()) - return op.emitError("op expected block argument ") + return op.emitOpError("expected block argument ") << i << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") << "view: " << viewType; @@ -122,22 +133,22 @@ auto nLoops = op.getNumLoops(); auto nViews = op.getNumInputsAndOutputs(); if (block.getNumArguments() != nViews + nLoops) - return op.emitError( + return op.emitOpError( "op expected number of block arguments to match number of views + " "number of loops"); for (unsigned i = 0; i < nLoops; ++i) { if (!block.getArgument(i)->getType().isIndex()) - return op.emitError("op expected block argument ") + return op.emitOpError("expected block argument ") << i << " to be of IndexType"; } for (unsigned i = 0; i < nViews; ++i) { unsigned memrefArgIndex = i + nLoops; - auto viewType = op.getViewType(i); + auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex)->getType()) - return op.emitError("op expected block argument ") + return op.emitOpError("expected block argument ") << memrefArgIndex << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") << "view: " << viewType; @@ -152,24 +163,24 @@ auto nViews = op.getNumInputsAndOutputs(); auto nInputViews = op.getNumInputs(); if (funType.getNumInputs() != nViews) - return op.emitError("op expected fun arguments to match number of views"); + return op.emitOpError("expected fun arguments to match number of views"); if (funType.getNumResults() != op.getNumOutputs()) - return op.emitError( + return op.emitOpError( "op expected fun results to match number of output views"); for (auto en : llvm::enumerate(op.indexing_maps())) { auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputViewType(idx) - : op.getOutputViewType(idx - nInputViews); + auto view = (idx < nInputViews) ? op.getInputShapedType(idx) + : op.getOutputShapedType(idx - nInputViews); if (funType.getInput(idx) != view.getElementType()) - return op.emitError("op expected fun argument ") + return op.emitOpError("expected fun argument ") << idx << " of the same type as elemental type " << view.getElementType() << " of view " << idx; if (idx >= nInputViews) { auto resultIdx = idx - nInputViews; if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitError("op expected fun result ") + return op.emitOpError("expected fun result ") << resultIdx << " of the same type as elemental type " << view.getElementType() << " of view " << idx; } @@ -184,30 +195,30 @@ auto nOutputs = op.getNumOutputs(); auto nViews = op.getNumInputsAndOutputs(); if (funType.getNumInputs() != nViews + nLoops) - return op.emitError( + return op.emitOpError( "op expected fun arguments to match number of views + number of loops"); if (funType.getNumResults() != nOutputs) - return op.emitError( + return op.emitOpError( "op expected fun results to match number of output views"); for (unsigned i = 0; i < nLoops; ++i) { if (!funType.getInput(i).isIndex()) - return op.emitError("op expected fun argument ") + return op.emitOpError("expected fun argument ") << i << " to be of IndexType"; } for (auto en : llvm::enumerate(op.indexing_maps())) { auto idx = en.index(); auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputViewType(idx) - : op.getOutputViewType(idx - nInputViews); + auto view = (idx < nInputViews) ? op.getInputShapedType(idx) + : op.getOutputShapedType(idx - nInputViews); if (funType.getInput(funIdx) != view.getElementType()) - return op.emitError("op expected fun argument ") + return op.emitOpError("expected fun argument ") << funIdx << " of the same type as elemental type " << view.getElementType() << " of view " << idx; if (idx >= nInputViews) { auto resultIdx = idx - nInputViews; if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitError("op expected fun result ") + return op.emitOpError("expected fun result ") << resultIdx << " of the same type as elemental type " << view.getElementType() << " of view " << idx; } @@ -221,19 +232,19 @@ auto nLoops = op.getNumLoops(); auto nViews = op.getNumInputsAndOutputs(); if (nViews != llvm::size(op.views())) - return op.emitError("op expected exactly ") << nViews << " view operands"; + return op.emitOpError("expected exactly ") << nViews << " view operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); auto funType = funOp ? funOp.getType() : FunctionType(); if (!region.empty()) { if (region.getBlocks().size() != 1) - return op.emitError("op expected region with 1 block"); + return op.emitOpError("expected region with 1 block"); if (failed(verifyBlockArgs(op, region.getBlocks().front()))) return failure(); } else { if (!funOp || !funOp.getType()) - return op.emitError( + return op.emitOpError( "op expected fun attribute to refer to a defined symbol"); if (failed(verifyFuncArgs(op, funType))) return failure(); @@ -245,35 +256,52 @@ auto idx = en.index(); auto m = en.value().template cast().getValue(); indexingMaps.push_back(m); // Save reference to map for further checks. - auto view = (idx < nInputViews) ? op.getInputViewType(idx) - : op.getOutputViewType(idx - nInputViews); + auto view = (idx < nInputViews) ? op.getInputShapedType(idx) + : op.getOutputShapedType(idx - nInputViews); if (m.getNumSymbols() != 0) - return op.emitError("op expected indexing_map #") + return op.emitOpError("expected indexing_map #") << idx << " to have no symbols"; if (m.getNumDims() != nLoops) - return op.emitError("op expected indexing_map #") + return op.emitOpError("expected indexing_map #") << idx << " to have " << nLoops << " dim(s) to match the number of loops"; if (m.getNumResults() == 1 && view.getRank() == 0) { auto cst = m.getResult(0).template dyn_cast(); if (!cst || cst.getValue() != 0) - return op.emitError("op expected indexing_map #") + return op.emitOpError("expected indexing_map #") << idx << " to be 0 to match 0-D view: " << view; } if (m.getNumResults() != view.getRank()) - return op.emitError("op expected indexing_map #") + return op.emitOpError("expected indexing_map #") << idx << " results to match view rank: " << view; } auto concatMap = concatAffineMaps(indexingMaps); auto aggregateMap = inversePermutation(concatMap); if (!aggregateMap) - return op.emitError("op expected the concatenation of maps in indexing_map " - "to be invertible"); + return op.emitOpError( + "op expected the concatenation of maps in indexing_map " + "to be invertible"); + + auto outputTensorTypes = op.getOutputTensorTypes(); + if (outputTensorTypes.size() != op.getNumResults()) + return op.emitOpError("expected #output tensor operands (") + << outputTensorTypes.size() << ") to match #results (" + << op.getNumResults() << ")"; + + unsigned index = 0; + for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { + auto resTy = std::get<0>(it); + auto outOpTy = std::get<1>(it); + if (resTy != outOpTy) + return op.emitOpError("result #") + << index << " must be " << outOpTy << ", but got " << resTy; + ++index; + } return success(); } @@ -465,13 +493,13 @@ // The operand number and types must match the view element types. auto nOutputViews = genericOp.getNumOutputs(); if (op.getNumOperands() != nOutputViews) - return op.emitOpError("op expected ") + return op.emitOpError("expected ") << nOutputViews << " operand to match enclosing linalg.generic op"; for (unsigned i = 0; i != nOutputViews; ++i) { - auto elementType = genericOp.getOutputViewType(i).getElementType(); + auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i)->getType() != elementType) - return op.emitError("type of return operand ") + return op.emitOpError("type of return operand ") << i << " (" << op.getOperand(i)->getType() << ") doesn't match view element type (" << elementType << ")"; } @@ -481,7 +509,7 @@ static LogicalResult verify(YieldOp op) { auto *parentOp = op.getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) - return op.emitOpError("op expected single non-empty parent region"); + return op.emitOpError("expected single non-empty parent region"); auto genericOp = dyn_cast(parentOp); if (genericOp) @@ -536,7 +564,7 @@ } static LogicalResult verify(FillOp op) { - auto viewType = op.getOutputViewType(0); + auto viewType = op.getOutputShapedType(0); auto fillType = op.value()->getType(); if (viewType.getElementType() != fillType) return op.emitOpError("expects fill type to match view elemental type"); @@ -544,8 +572,8 @@ } static LogicalResult verify(CopyOp op) { - auto outputViewType = op.getOutputViewType(0); - auto inputViewType = op.getInputViewType(0); + auto outputViewType = op.getOutputShapedType(0); + auto inputViewType = op.getInputShapedType(0); if (inputViewType.getElementType() != outputViewType.getElementType()) return op.emitOpError("expects views of the same type"); if (inputViewType.getRank() != outputViewType.getRank()) @@ -675,8 +703,8 @@ // I(input_perm(ivs)) -> O(output_perm(ivs)) auto maybeInputMap = copyOp.inputPermutation(); auto maybeOutputMap = copyOp.outputPermutation(); - unsigned inputRank = copyOp.getInputViewType(0).getRank(); - unsigned outputRank = copyOp.getOutputViewType(0).getRank(); + unsigned inputRank = copyOp.getInputShapedType(0).getRank(); + unsigned outputRank = copyOp.getOutputShapedType(0).getRank(); return SmallVector{ extractOrIdentityMap(maybeInputMap, inputRank, context), extractOrIdentityMap(maybeOutputMap, outputRank, context)}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -114,6 +114,9 @@ return false; } +//============================================================================// +// Precondition and transformation for vectorization of Linalg generic ops. +//============================================================================// static bool hasMultiplyAddBody(linalg::GenericOp op) { auto &r = op.region(); if (r.empty()) @@ -153,12 +156,8 @@ genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } -LogicalResult mlir::linalg::vectorizeGenericOp(PatternRewriter &rewriter, - Operation *op) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); - +LogicalResult +mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { // TODO(ntv): This is in fact much more general than just vectorization for // matmul ops. auto genericOp = dyn_cast(op); @@ -175,7 +174,20 @@ if (!llvm::all_of(genericOp.getInputsAndOutputs(), isStaticMemRefWithIdentityLayout)) return failure(); + return success(); +} + +SmallVector +mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + + assert(succeeded(vectorizeGenericLinalgOpPrecondition(op)) && + "DRR failure case must be a precondition"); + auto genericOp = cast(op); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; @@ -188,16 +200,35 @@ auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); std_store(vRes, vectorMemRefC); + return {}; +} + +//============================================================================// +// Precondition and transformation for permutation of Linalg generic ops. +//============================================================================// +LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition( + Operation *op, ArrayRef permutation) { + if (permutation.empty()) + return failure(); + // Transformation applies to generic ops only. + if (!isa(op) && !isa(op)) + return failure(); + LinalgOp linOp = cast(op); + // Transformation applies to buffers only. + if (!linOp.hasBufferSemantics()) + return failure(); return success(); } -LogicalResult +SmallVector mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, ArrayRef permutation, StringRef linalgMarker) { - // If permutation is empty, there is nothing to be done. - if (permutation.empty()) - return failure(); + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op + << ":\n"); + + assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) && + "DRR failure case must be a precondition"); auto linOp = cast(op); auto permutationMap = inversePermutation( @@ -220,19 +251,41 @@ op->setAttr(LinalgTransforms::kLinalgTransformMarker, rewriter.getStringAttr(linalgMarker)); linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); - return success(); + return {}; } -LogicalResult mlir::linalg::linalgOpPromoteSubviews(PatternRewriter &rewriter, - Operation *op) { +//============================================================================// +// Precondition and transformation for Linalg subview promotion. +//============================================================================// +LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { LinalgOp linOp = dyn_cast(op); + // Transformation applies to buffers only. + if (!linOp || !linOp.hasBufferSemantics()) + return failure(); + if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) { + return isa_and_nonnull(v.getDefiningOp()); + })) + return failure(); + return success(); +} + +SmallVector +mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, + Operation *op) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " + << *op << ":\n"); + + assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && + "DRR failure case must be a precondition"); + + LinalgOp linOp = cast(op); SetVector subViews; for (auto it : linOp.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { - auto resOp = promoteSubViewOperands(rewriter, linOp, subViews); - return success(resOp); + promoteSubViewOperands(rewriter, linOp, subViews); + return {}; } - return failure(); + llvm_unreachable("DRR failure case must be a precondition"); } diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -477,7 +477,7 @@ } // Entry point for unrolling declarative pattern rewrites. -Value mlir::vector::unrollSingleResultOpMatchingType( +SmallVector mlir::vector::unrollSingleResultOpMatchingType( PatternRewriter &builder, Operation *op, ArrayRef targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); @@ -497,8 +497,8 @@ } // Unroll 'op' with 'iterationBounds' to 'targetShape'. - return unrollSingleResultStructuredOp(op, iterationBounds, vectors, - resultIndex, targetShape, builder); + return SmallVector{unrollSingleResultStructuredOp( + op, iterationBounds, vectors, resultIndex, targetShape, builder)}; } // Generates slices of 'vectorType' according to 'sizes' and 'strides, and diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -423,6 +423,51 @@ // ----- +func @generic_result_tensor_type(%arg0: memref(off + i)>) { + // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}} + %0 = linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ (i) -> (i) ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: f32): + linalg.yield %i: f32 + }: memref(off + i)> -> f32 +} + +// ----- + +func @generic_result_tensor_count(%arg0: memref(off + i)>) { + // expected-error @+1 {{op expected #output tensor operands (0) to match #results (1)}} + %0 = linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ (i) -> (i) ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: f32): + linalg.yield %i: f32 + }: memref(off + i)> -> tensor +} + +// ----- + +func @generic_result_tensor_type(%arg0: tensor) { + // expected-error @+1 {{op result #0 must be 'tensor', but got 'tensor'}} + %0 = linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ (i) -> (i) ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: f32): + linalg.yield %i: f32 + }: tensor -> tensor +} + +// ----- + func @generic_fun_result_0_element_type(%arg0: memref) { // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}} linalg.dot(%arg0, %arg0): memref, memref diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -139,6 +139,29 @@ // CHECK-LABEL: func @generic // CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, memref +func @generic_with_tensor_input(%arg0: tensor>, %arg1: memref) { + linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor>, memref + return +} +// CHECK-LABEL: func @generic_with_tensor_input +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, memref + +func @generic_with_tensor_output(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: tensor) -> (tensor) { + %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : memref, offset: ?, strides: [?, 1]>, tensor -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @generic_with_tensor_output +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, tensor -> tensor +// CHECK: return {{.*}} : tensor + +func @generic_with_tensor_input_and_output(%arg0: tensor>, %arg1: tensor) -> (tensor) { + %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @generic_with_tensor_input_and_output +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, tensor -> tensor +// CHECK: return {{.*}} : tensor + #trait2 = { args_in = 1, args_out = 1, diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -19,11 +19,11 @@ //===----------------------------------------------------------------------===// // Test Linalg fusion patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$consumer $A, $B, $C), - (TileAndFuseLinalgOp<[100, 150], [0], "L1"> $consumer), +def : Pat<(MatmulOp:$op $A, $_, $_), + (TileAndFuseLinalgOp<[100, 150], [0], "L1">), [ - (Constraint $consumer), - (Constraint> $consumer, $A), + (Constraint), + (Constraint> $A), ], // In the buffer world there is no use-def chains or dags so benefits // cannot be computed automatically from the length of the matched @@ -36,91 +36,103 @@ //===----------------------------------------------------------------------===// // Linalg tiling patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2000, 3000, 4000], "L3"> $op), +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2000, 3000, 4000], "L3">), [(Constraint]>> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[200, 300, 400], "L2"> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[20, 30, 40], "L1"> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2, 3, 4], "REG"> $op), - [(Constraint> $op)]>; + HasLinalgTransformMarker<"MEM">]>>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[200, 300, 400], "L2">), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[20, 30, 40], "L1">), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2, 3, 4], "REG">), + [(Constraint>)]>; -def : Pattern<(MatvecOp:$op $A, $b, $c), - [(TileLinalgOp<[5, 6], "L1"> $op)], - [(Constraint $op)]>; +def : Pattern<(MatvecOp:$op $_, $_, $_), + [(TileLinalgOp<[5, 6], "L1">)], + [(Constraint)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8000], "L1"> $op)], +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8000], "L1">)], [(Constraint, HasLinalgTransformMarker<"L3">, - HasLinalgTransformMarker<"L2">]>> $op)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8], "REG"> $op)], - [(Constraint> $op)]>; + HasLinalgTransformMarker<"L2">]>>)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8], "REG">)], + [(Constraint>)]>; //===----------------------------------------------------------------------===// // Linalg tiling and permutation patterns. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]> $op), - [(Constraint> $op)]>; -def : Pat<(MatmulOp:$op $A, $B, $C), - (TileLinalgOp<[20, 30, 40], "REG__with_perm__"> $op), - [(Constraint> $op)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[2000, 3000, 4000], "L2__with_perm__", [1,2,0]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[200, 300, 400], "L1__with_perm__", [1,0,2]>), + [(Constraint>)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (TileLinalgOp<[20, 30, 40], "REG__with_perm__">), + [(Constraint>)]>; -def : Pattern<(MatvecOp:$op $A, $b, $c), - [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]> $op)], - [(Constraint> $op)]>; +def : Pattern<(MatvecOp:$op $_, $_, $_), + [(TileLinalgOp<[5, 6], "L1__with_perm__", [1,0]>)], + [(Constraint>)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8000], "L1__with_perm__"> $op)], - [(Constraint> $op)]>; -def : Pattern<(DotOp:$op $a, $b, $c), - [(TileLinalgOp<[8], "REG__with_perm__"> $op)], - [(Constraint> $op)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8000], "L1__with_perm__">)], + [(Constraint>)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(TileLinalgOp<[8], "REG__with_perm__">)], + [(Constraint>)]>; //===----------------------------------------------------------------------===// // Linalg to loops patterns. //===----------------------------------------------------------------------===// -def : Pattern<(DotOp:$op $a, $b, $c), - [(LinalgOpToLoops<"DotOp"> $op)], - [(Constraint> $op)]>; +def : Pattern<(DotOp:$op $_, $_, $_), + [(LinalgOpToLoops<"DotOp">)], + [(Constraint>)]>; //===----------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===----------------------------------------------------------------------===// -def : Pattern<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - [(LinalgOpToVectorContraction<"GenericOp"> $op)], - [(Constraint> $op)]>; +def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + [(VectorizeGenericLinalgOp)], + [(Constraint, + PreconditionVectorizeGenericLinalgOp + ]>>)]>; //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// -def : Pat<(GenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), - [(Constraint]>> $op)]>; +def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), + [(Constraint, + PreconditionPermuteGenericLinalgOp<[1, 2, 0]> + ]>>)]>; -def : Pat<(IndexedGenericOp:$op $_1, $_2, $_3, $_4, $_5, $_6, $_7, $_8), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED"> $op), - [(Constraint]>> $op)]>; +def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), + (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), + [(Constraint, + PreconditionPermuteGenericLinalgOp<[1, 2, 0]> + ]>>)]>; //===----------------------------------------------------------------------===// // Linalg subview operands promotion. //===----------------------------------------------------------------------===// -def : Pat<(MatmulOp:$op $A, $B, $C), - (LinalgOpPromoteSubviews<"MatmulOp"> $op), - [(Constraint> $op), - (Constraint> $op)]>; +def : Pat<(MatmulOp:$op $_, $_, $_), + (PromoteSubviewsLinalgOp), + [(Constraint, + HasLinalgTransformMarker<"_promote_views_">]>> + )]>; #endif // TEST_LINALG_TRANSFORMS_PATTERNS diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -573,8 +573,14 @@ auto val = handleResultPattern(resultTree, offsets[i], 0); os.indent(4) << "\n"; // Resolve each symbol for all range use so that we can loop over them. + // We need an explicit cast to `SmallVector` to capture the cases where + // `{0}` resolves to an `Operation::result_range` as well as cases that + // are not iterable (e.g. vector that gets wrapped in additional braces by + // RewriterGen). os << symbolInfoMap.getAllRangeUse( - val, " for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }", + val, + " for (auto v : SmallVector{ {0} }) {{ " + "tblgen_repl_values.push_back(v); }", "\n"); } os.indent(4) << "\n";