diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -19,9 +19,10 @@ let name = "linalg"; let description = [{ The `linalg` dialect groups together a set of types, operations and - transformations that are useful to implement a structured abstraction where - ops can lower to scalar load/store and operations or to more general library - calls. + transformations that are useful to implement a structured abstraction on + buffers and tensors. These abstractions are useful for transformations and + can lower to scalar load/store and other operations or to more general + library calls. The `linalg` dialect manipulates the following types and operations: @@ -67,12 +68,13 @@ A set of payload carrying operations that implement the [structured ops]( https://docs.google.com/presentation/d/1P-j1GrH6Q5gLBjao0afQ-GfvcAeF-QU4GXXeSy0eJ9I/edit#slide=id.p ) - abstraction on buffers. `linalg` has `2` generic operations `linalg.generic` - and `linalg.indexed_generic` for expressing custom operations. This is - subject to further evolution as transformations and analyses continue to be - developed. + abstraction on tensors and buffers. `linalg` has `2` generic operations + `linalg.generic` and `linalg.indexed_generic` for expressing custom + operations. + This is subject to further evolution as transformations and analyses + continue to be developed. - Additionally, `linalg` provides some common named operations: + Additionally, `linalg` provides some commonly named operations: * `linalg.copy`, * `linalg.fill`, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -59,7 +59,8 @@ } def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>, - Arguments<(ins AnyStridedMemRef:$view, Variadic>:$indexings)>, + Arguments<(ins AnyStridedMemRef:$view, + Variadic>:$indexings)>, Results<(outs AnyStridedMemRef)> { let summary = "Produce a rank-reduced `subview` of a base `view`."; let description = [{ @@ -108,11 +109,11 @@ let extraClassDeclaration = [{ enum { FirstIndexingOperand = 1 }; - unsigned getRank() { return getViewType().getRank(); } - Type getElementType() { return getViewType().getElementType(); } - MemRefType getViewType() { return getType().cast(); } + unsigned getRank() { return getShapedType().getRank(); } + Type getElementType() { return getShapedType().getElementType(); } + ShapedType getShapedType() { return getType().cast(); } unsigned getBaseViewRank() { return getBaseViewType().getRank(); } - MemRefType getBaseViewType() { return view()->getType().cast(); } + ShapedType getBaseViewType() { return view()->getType().cast();} // Get the underlying indexing at a given rank. Value indexing(unsigned rank) { return *(indexings().begin() + rank); } @@ -131,7 +132,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>, Results<(outs AnyStridedMemRef)> { - let summary = "transpose operation produces a new strided memref (metadata-only)"; + let summary = "`transpose` produces a new strided memref (metadata-only)"; let description = [{ The `linalg.transpose` op produces a strided memref whose sizes and strides are a permutation of the original `view`. This is a pure metadata @@ -151,14 +152,14 @@ let verifier = [{ if (!permutation().isPermutation()) return emitOpError("expected a permutation map"); - if (permutation().getNumDims() != getViewType().getRank()) + if (permutation().getNumDims() != getShapedType().getRank()) return emitOpError("expected a permutation map of same rank as the view"); return success(); }]; let extraClassDeclaration = [{ static StringRef getPermutationAttrName() { return "permutation"; } - MemRefType getViewType() { return view()->getType().cast(); } + ShapedType getShapedType() { return view()->getType().cast(); } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -89,23 +89,32 @@ "Value ", "getOutput", (ins "unsigned":$i) >, InterfaceMethod<[{ - Query the index of the given input value, or `None` if the value is not - an input. + Return the index of the given input value `v`, or `None` if the value is + not an input. }], - "llvm::Optional", "getIndexOfInput", (ins "Value ":$view) + "llvm::Optional", "getIndexOfInput", (ins "Value ":$v) >, InterfaceMethod<[{ Query the index of the given view value, or `None` if the value is not - an view. + a view. }], "llvm::Optional", "getIndexOfOutput", (ins "Value ":$view) >, InterfaceMethod<[{ - Query the type of the input view at the given index. - }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>, + Query the type of the input shape at the given index. + }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ Query the type of the output view at the given index. - }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>, + }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + InterfaceMethod<[{ + Query the subset of input operands that are of ranked tensor type. + }], "SmallVector", "getInputTensorTypes">, + InterfaceMethod<[{ + Query the subset of output operands that are of ranked tensor type. + }], "SmallVector", "getOutputTensorTypes">, StaticInterfaceMethod<[{ Create an operation of the current type with the given location, @@ -340,7 +349,7 @@ ArrayAttr iterator_types() { // Outer parallel loops are always the number of output dimensions; i.e. // [ b, xs, q] in the TF notation above. - unsigned nPar = getOutputViewType(0).getRank(); + unsigned nPar = getOutputShapedType(0).getRank(); unsigned nRed = getNumInputFeatureDimensions(); // Window loops are a special kind of reduction that is never tiled or // parallelized across; i.e. [zs] in the TF notation above whose number @@ -374,8 +383,17 @@ let verifier = [{ return ::verify(*this); }]; } +def LinalgOperand: Type< + Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>; + +class LinalgOperandOfRank: Type< + And<[ + LinalgOperand.predicate, + CPred<"$_self.cast().getRank() == " # rank>] + >>; + class GenericOpBase : LinalgStructuredBase_Op { - let arguments = (ins Variadic:$views, + let arguments = (ins Variadic:$views, I64Attr:$args_in, I64Attr:$args_out, AffineMapArrayAttr:$indexing_maps, @@ -383,6 +401,7 @@ OptionalAttr:$doc, OptionalAttr:$fun, OptionalAttr:$library_call); + let results = (outs Variadic:$output_tensors); let regions = (region AnyRegion:$region); let extraClassDeclaration = [{ SmallVector linalgTraitAttrNames() { @@ -511,6 +530,28 @@ } } ``` + + To allow progressive lowering from the value world (a.k.a tensor values) to + the buffer world (a.k.a memref values), a `linalg.generic` op accepts + mixing input and output ranked tensor values with input and output memrefs. + + ```mlir + %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + tensor, + memref, + tensor + -> (tensor) + ``` + + In this case, the number of return values must match the number of output + tensor arguments. The semantics is that the `linalg.generic` op + produces (i.e. allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most + transformations can be applied. In particular, transformations that create + control flow around linalg.generic operations are not expected to mix with + tensors because SSA values do not escape naturally. Still, transformations + and rewrites that take advantage of tensor SSA values are expected to be + useful and will be added in the near future. }]; let verifier = [{ return ::verify(*this); }]; } @@ -555,9 +596,11 @@ Example: Defining a #matmul_trait attribute in MLIR can be done as follows: ```mlir - func @fma(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32) + func @fma(%offset_m: index, %offset_n: index, %offset_k: index, + %a: f32, %b: f32, %c: f32) -> f32 { + "some_optional_condition"(%offset_m, %offset_n, %offset_k) %d = mulf %a, %b: f32 %e = addf %c, %d: f32 return %e: f32 @@ -587,7 +630,7 @@ This may lower to either: ```mlir - call @linalg_matmul(%A, %B, %C) : + call @linalg_matmul(%offset_m, %offset_n, %offset_k, %A, %B, %C) : (memref, memref, memref) @@ -609,6 +652,29 @@ } } ``` + + To allow progressive lowering from the value world (a.k.a tensor values) to + the buffer world (a.k.a memref values), a `linalg.indexed_generic` op + accepts mixing input and output ranked tensor values with input and output + memrefs. + + ```mlir + %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} + : tensor, + memref, + tensor + -> (tensor) + ``` + + In this case, the number of return values must match the number of output + tensor arguments. The semantics is that the `linalg.indexed_generic` op + produces (i.e. allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most + transformations can be applied. In particular, transformations that create + control flow around linalg.generic operations are not expected to mix with + tensors because SSA values do not escape naturally. Still, transformations + and rewrites that take advantage of tensor SSA values are expected to be + useful and will be added in the near future. }]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -20,7 +20,7 @@ namespace linalg { /// This class provides the API for ops that are known to have a specified -/// number of inputs, all passed as operands. This is used as a trait like this: +/// number of inputs, all passed as operands. Use as a trait as follows: /// /// class DotOp : public Op::Impl> { /// @@ -34,7 +34,7 @@ }; /// This class provides the API for ops that are known to have a specified -/// number of inputs, all passed as operands. This is used as a trait like this: +/// number of outputs, all passed as operands. Use as a trait as follows: /// /// class DotOp : public Op::Impl> { /// @@ -47,79 +47,101 @@ }; }; -/// This class provides the API for ops that are known to operate on views. This -/// trait must be used in conjunction with an op definition or a trait that -/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a -/// trait like this: +/// This class provides the API for structured ops that are known to operate on +/// buffers or tensors. This trait must be used in conjunction with an op +/// definition or a trait that provides the methods `getNumInputs` and +/// `getNumOutputs`. Use as a trait as follows: /// -/// class DotOp : public Op { +/// class DotOp : public Op { /// template class StructuredOpTraits : public OpTrait::TraitBase { private: - /// Return the number of input views. For internal use only. + /// Return the number of inputs. For internal use only. unsigned nInputs() { return cast(this->getOperation()).getNumInputs(); } - /// Return the number of input views. For internal use only. + /// Return the number of outputs. For internal use only. unsigned nOutputs() { return cast(this->getOperation()).getNumOutputs(); } public: - /// Return the `i`-th input view. + /// Return the `i`-th input value. Value getInput(unsigned i) { assert(i < nInputs()); return this->getOperation()->getOperand(i); } - /// Return the index of `view` in the list of input views if found, llvm::None + /// Return the index of `value` in the list of inputs if found, llvm::None /// otherwise. - Optional getIndexOfInput(Value view) { - auto it = llvm::find(getInputs(), view); + Optional getIndexOfInput(Value value) { + auto it = llvm::find(getInputs(), value); if (it != getInputs().end()) return it - getInputs().begin(); return llvm::None; } - /// Return the `i`-th input view type. - MemRefType getInputViewType(unsigned i) { - return getInput(i)->getType().template cast(); + /// Return the `i`-th input buffer type. + ShapedType getInputShapedType(unsigned i) { + return getInput(i)->getType().template cast(); } - /// Return the range over input views. + /// Return the range over inputs. Operation::operand_range getInputs() { auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + nInputs()}; } - /// Return the `i`-th output view. + /// Return the `i`-th output. Value getOutput(unsigned i) { return this->getOperation()->getOperand(nInputs() + i); } - /// Return the index of `view` in the list of output views if found, + /// Return the index of `value` in the list of output values if found, /// llvm::None otherwise. - Optional getIndexOfOutput(Value view) { - auto it = llvm::find(getOutputs(), view); + Optional getIndexOfOutput(Value value) { + auto it = llvm::find(getOutputs(), value); if (it != getOutputs().end()) return it - getOutputs().begin(); return llvm::None; } - /// Return the `i`-th output view type. - MemRefType getOutputViewType(unsigned i) { - return getOutput(i)->getType().template cast(); - } - /// Return the range over output views. + /// Return the `i`-th output buffer type. + ShapedType getOutputShapedType(unsigned i) { + return getOutput(i)->getType().template cast(); + } + /// Query whether the op has only MemRef input and outputs. + bool hasBufferSemantics() { + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputsAndOutputs(), + [](Value v) { return v.getType().isa(); }); + } + /// Query the subset of input operands that are of ranked tensor type. + SmallVector getInputTensorTypes() { + SmallVector res; + for (auto type : getInputs().getTypes()) + if (auto t = type.template dyn_cast()) + res.push_back(t); + return res; + } + /// Query the subset of output operands that are of ranked tensor type. + SmallVector getOutputTensorTypes() { + SmallVector res; + for (auto type : getOutputs().getTypes()) + if (auto t = type.template dyn_cast()) + res.push_back(t); + return res; + } + /// Return the range over outputs. Operation::operand_range getOutputs() { auto range = this->getOperation()->getOperands(); return {range.begin() + nInputs(), range.begin() + getNumInputsAndOutputs()}; } - /// Return the number of input and output views. + /// Return the number of inputs and outputs. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the `i`-th view type. - MemRefType getViewType(unsigned i) { - return (i < nInputs()) ? getInputViewType(i) - : getOutputViewType(i - nInputs()); + /// Return the `i`-th buffer type. + ShapedType getShapedType(unsigned i) { + return (i < nInputs()) ? getInputShapedType(i) + : getOutputShapedType(i - nInputs()); } - /// Return the range over input and output views. + /// Return the range over inputs and outputs. Operation::operand_range getInputsAndOutputs() { auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + getNumInputsAndOutputs()}; @@ -144,8 +166,8 @@ cast(this->getOperation()).iterator_types()); } static LogicalResult verifyTrait(Operation *op) { - auto nViews = cast(op).getNumInputsAndOutputs(); - if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews))) + auto nOperands = cast(op).getNumInputsAndOutputs(); + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); return success(); } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -84,25 +84,29 @@ " return matchFailure();">; //===----------------------------------------------------------------------===// -// Linalg to vector contraction patterns. +// Linalg to vector patterns precondition and DRR. //===----------------------------------------------------------------------===// -class VectorizeGenericLinalgOp : NativeCodeCall< - "if (failed(vectorizeGenericLinalgOp($_builder, op))) " # - " return matchFailure();">; +def PreconditionVectorizeGenericLinalgOp : CPred< + "succeeded(vectorizeGenericLinalgOpPrecondition(op))">; +def VectorizeGenericLinalgOp : NativeCodeCall< + "vectorizeGenericLinalgOp($_builder, op)">; //===----------------------------------------------------------------------===// -// Linalg generic permutation patterns. +// Linalg generic permutation patterns precondition and DRR. //===----------------------------------------------------------------------===// +class PreconditionPermuteGenericLinalgOp permutation> : CPred< + "succeeded(permuteGenericLinalgOpPrecondition(op, {" # + StrJoinInt.result # "}))">; class PermuteGenericLinalgOp permutation, string value> : NativeCodeCall< - "if (failed(permuteGenericLinalgOp($_builder, op, {" # - StrJoinInt.result # "}, \"" # value # "\"))) " # - " return matchFailure();">; + "permuteGenericLinalgOp($_builder, op, {" # StrJoinInt.result # + "}, \"" # value # "\")">; //===----------------------------------------------------------------------===// -// Linalg promote subview operands. +// Linalg promote subview operands precondition and DRR. //===----------------------------------------------------------------------===// -class PromoteSubviewsLinalgOp : NativeCodeCall< - "if (failed(promoteSubviewsLinalgOp($_builder, op))) " # - " return matchFailure();">; +def PreconditionPromoteSubviewsLinalgOp : CPred< + "succeeded(promoteSubviewsLinalgOpPrecondition(op))">; +def PromoteSubviewsLinalgOp : NativeCodeCall< + "promoteSubviewsLinalgOp($_builder, op)">; #endif // LINALG_TRANSFORMS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -79,17 +79,24 @@ LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op); /// Rewrite a linalg.generic into a suitable vector.contraction op. -LogicalResult vectorizeGenericLinalgOp(PatternRewriter &rewriter, - Operation *op); +LogicalResult vectorizeGenericLinalgOpPrecondition(Operation *op); +SmallVector vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op); /// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` /// and `iterator_types` permutated according to `permutation`. -LogicalResult permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, - ArrayRef permutation, - StringRef linalgMarker); - -/// Promote std.subviews feeding linalg operations -LogicalResult promoteSubviewsLinalgOp(PatternRewriter &rewriter, Operation *op); +LogicalResult +permuteGenericLinalgOpPrecondition(Operation *op, + ArrayRef permutation); +SmallVector permuteGenericLinalgOp(PatternRewriter &rewriter, + Operation *op, + ArrayRef permutation, + StringRef linalgMarker); + +/// Promote std.subviews feeding linalg operations. +LogicalResult promoteSubviewsLinalgOpPrecondition(Operation *op); +SmallVector promoteSubviewsLinalgOp(PatternRewriter &rewriter, + Operation *op); } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h --- a/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorTransforms.h @@ -64,8 +64,9 @@ // // This will be extended in the future to support more advanced use cases than // simple pointwise ops. -Value unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, - ArrayRef targetShape); +SmallVector +unrollSingleResultOpMatchingType(PatternRewriter &builder, Operation *op, + ArrayRef targetShape); } // namespace vector } // namespace mlir diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -186,7 +186,8 @@ auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) .cast(); - BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); + BaseViewConversionHelper desc( + lowering.convertType(sliceOp.getShapedType())); // TODO(ntv): extract sizes and emit asserts. SmallVector strides(memRefType.getRank()); @@ -215,7 +216,7 @@ desc.setOffset(baseOffset); // Corner case, no sizes or strides: early return the descriptor. - if (sliceOp.getViewType().getRank() == 0) + if (sliceOp.getShapedType().getRank() == 0) return rewriter.replaceOp(op, {desc}), matchSuccess(); Value zero = @@ -279,7 +280,7 @@ return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); BaseViewConversionHelper desc( - lowering.convertType(transposeOp.getViewType())); + lowering.convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -68,6 +68,7 @@ edsc::ScopedContext::getBuilder() .create( edsc::ScopedContext::getLocation(), + ArrayRef{}, // TODO(ntv): support tensors values, IntegerAttr::get(IntegerType::get(64, ctx), nInputs), IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -61,6 +61,10 @@ p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs(), attrNames); p << ": " << op.getOperandTypes(); + + auto outputTensorTypes = op.getResultTypes(); + if (!outputTensorTypes.empty()) + p << " -> " << outputTensorTypes; } static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); } @@ -92,6 +96,13 @@ if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonTypeList(operandTypes)) return failure(); + // Generic ops may specify that a subset of its outputs are tensors. Such + // outputs are specified in the result type. + SmallVector tensorResultTypes; + if (parser.parseOptionalArrowTypeList(tensorResultTypes)) + return failure(); + if (!tensorResultTypes.empty()) + result.addTypes(tensorResultTypes); return parser.resolveOperands(operandsInfo, operandTypes, parser.getCurrentLocation(), result.operands); } @@ -107,7 +118,7 @@ "expected number of block arguments to match number of views"); for (unsigned i = 0; i < nViews; ++i) { - auto viewType = op.getViewType(i); + auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i)->getType()) return op.emitOpError("expected block argument ") << i << " of the same type as elemental type of " @@ -134,7 +145,7 @@ for (unsigned i = 0; i < nViews; ++i) { unsigned memrefArgIndex = i + nLoops; - auto viewType = op.getViewType(i); + auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex)->getType()) return op.emitOpError("expected block argument ") @@ -159,8 +170,8 @@ for (auto en : llvm::enumerate(op.indexing_maps())) { auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputViewType(idx) - : op.getOutputViewType(idx - nInputViews); + auto view = (idx < nInputViews) ? op.getInputShapedType(idx) + : op.getOutputShapedType(idx - nInputViews); if (funType.getInput(idx) != view.getElementType()) return op.emitOpError("expected fun argument ") << idx << " of the same type as elemental type " @@ -197,8 +208,8 @@ for (auto en : llvm::enumerate(op.indexing_maps())) { auto idx = en.index(); auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputViewType(idx) - : op.getOutputViewType(idx - nInputViews); + auto view = (idx < nInputViews) ? op.getInputShapedType(idx) + : op.getOutputShapedType(idx - nInputViews); if (funType.getInput(funIdx) != view.getElementType()) return op.emitOpError("expected fun argument ") << funIdx << " of the same type as elemental type " @@ -245,8 +256,8 @@ auto idx = en.index(); auto m = en.value().template cast().getValue(); indexingMaps.push_back(m); // Save reference to map for further checks. - auto view = (idx < nInputViews) ? op.getInputViewType(idx) - : op.getOutputViewType(idx - nInputViews); + auto view = (idx < nInputViews) ? op.getInputShapedType(idx) + : op.getOutputShapedType(idx - nInputViews); if (m.getNumSymbols() != 0) return op.emitOpError("expected indexing_map #") @@ -275,6 +286,22 @@ return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); + auto outputTensorTypes = op.getOutputTensorTypes(); + if (outputTensorTypes.size() != op.getNumResults()) + return op.emitOpError("expected #output tensor operands (") + << outputTensorTypes.size() << ") to match #results (" + << op.getNumResults() << ")"; + + unsigned index = 0; + for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { + auto resTy = std::get<0>(it); + auto outOpTy = std::get<1>(it); + if (resTy != outOpTy) + return op.emitOpError("result #") + << index << " must be " << outOpTy << ", but got " << resTy; + ++index; + } + return success(); } @@ -465,11 +492,11 @@ // The operand number and types must match the view element types. auto nOutputViews = genericOp.getNumOutputs(); if (op.getNumOperands() != nOutputViews) - return op.emitOpError("op expected ") + return op.emitOpError("expected ") << nOutputViews << " operand to match enclosing linalg.generic op"; for (unsigned i = 0; i != nOutputViews; ++i) { - auto elementType = genericOp.getOutputViewType(i).getElementType(); + auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i)->getType() != elementType) return op.emitOpError("type of return operand ") << i << " (" << op.getOperand(i)->getType() @@ -481,7 +508,7 @@ static LogicalResult verify(YieldOp op) { auto *parentOp = op.getParentOp(); if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) - return op.emitOpError("op expected single non-empty parent region"); + return op.emitOpError("expected single non-empty parent region"); auto genericOp = dyn_cast(parentOp); if (genericOp) @@ -536,7 +563,7 @@ } static LogicalResult verify(FillOp op) { - auto viewType = op.getOutputViewType(0); + auto viewType = op.getOutputShapedType(0); auto fillType = op.value()->getType(); if (viewType.getElementType() != fillType) return op.emitOpError("expects fill type to match view elemental type"); @@ -544,8 +571,8 @@ } static LogicalResult verify(CopyOp op) { - auto outputViewType = op.getOutputViewType(0); - auto inputViewType = op.getInputViewType(0); + auto outputViewType = op.getOutputShapedType(0); + auto inputViewType = op.getInputShapedType(0); if (inputViewType.getElementType() != outputViewType.getElementType()) return op.emitOpError("expects views of the same type"); if (inputViewType.getRank() != outputViewType.getRank()) @@ -675,8 +702,8 @@ // I(input_perm(ivs)) -> O(output_perm(ivs)) auto maybeInputMap = copyOp.inputPermutation(); auto maybeOutputMap = copyOp.outputPermutation(); - unsigned inputRank = copyOp.getInputViewType(0).getRank(); - unsigned outputRank = copyOp.getOutputViewType(0).getRank(); + unsigned inputRank = copyOp.getInputShapedType(0).getRank(); + unsigned outputRank = copyOp.getOutputShapedType(0).getRank(); return SmallVector{ extractOrIdentityMap(maybeInputMap, inputRank, context), extractOrIdentityMap(maybeOutputMap, outputRank, context)}; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -114,6 +114,9 @@ return false; } +//============================================================================// +// Precondition and transformation for vectorization of Linalg generic ops. +//============================================================================// static bool hasMultiplyAddBody(linalg::GenericOp op) { auto &r = op.region(); if (r.empty()) @@ -153,12 +156,8 @@ genericOp.indexing_maps() == maps && hasMultiplyAddBody(genericOp); } -LogicalResult mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, - Operation *op) { - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); - +LogicalResult +mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { // TODO(ntv): This is in fact much more general than just vectorization for // matmul ops. auto genericOp = dyn_cast(op); @@ -175,7 +174,20 @@ if (!llvm::all_of(genericOp.getInputsAndOutputs(), isStaticMemRefWithIdentityLayout)) return failure(); + return success(); +} + +SmallVector +mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, + Operation *op) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg op as vector.contract: " + << *op << ":\n"); + + assert(succeeded(vectorizeGenericLinalgOpPrecondition(op)) && + "DRR failure case must be a precondition"); + auto genericOp = cast(op); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; @@ -188,16 +200,35 @@ auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); std_store(vRes, vectorMemRefC); + return {}; +} + +//============================================================================// +// Precondition and transformation for permutation of Linalg generic ops. +//============================================================================// +LogicalResult mlir::linalg::permuteGenericLinalgOpPrecondition( + Operation *op, ArrayRef permutation) { + if (permutation.empty()) + return failure(); + // Transformation applies to generic ops only. + if (!isa(op) && !isa(op)) + return failure(); + LinalgOp linOp = cast(op); + // Transformation applies to buffers only. + if (!linOp.hasBufferSemantics()) + return failure(); return success(); } -LogicalResult +SmallVector mlir::linalg::permuteGenericLinalgOp(PatternRewriter &rewriter, Operation *op, ArrayRef permutation, StringRef linalgMarker) { - // If permutation is empty, there is nothing to be done. - if (permutation.empty()) - return failure(); + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Permute dims for linalg op: " << *op + << ":\n"); + + assert(succeeded(permuteGenericLinalgOpPrecondition(op, permutation)) && + "DRR failure case must be a precondition"); auto linOp = cast(op); auto permutationMap = inversePermutation( @@ -220,19 +251,41 @@ op->setAttr(LinalgTransforms::kLinalgTransformMarker, rewriter.getStringAttr(linalgMarker)); linOp.clone(rewriter, linOp.getLoc(), op->getOperands()); - return success(); + return {}; } -LogicalResult mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, - Operation *op) { +//============================================================================// +// Precondition and transformation for Linalg subview promotion. +//============================================================================// +LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { LinalgOp linOp = dyn_cast(op); + // Transformation applies to buffers only. + if (!linOp || !linOp.hasBufferSemantics()) + return failure(); + if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) { + return isa_and_nonnull(v.getDefiningOp()); + })) + return failure(); + return success(); +} + +SmallVector +mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, + Operation *op) { + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: Promote subviews for linalg op: " + << *op << ":\n"); + + assert(succeeded(promoteSubviewsLinalgOpPrecondition(op)) && + "DRR failure case must be a precondition"); + + LinalgOp linOp = cast(op); SetVector subViews; for (auto it : linOp.getInputsAndOutputs()) if (auto sv = dyn_cast_or_null(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { - auto resOp = promoteSubViewOperands(rewriter, linOp, subViews); - return success(resOp); + promoteSubViewOperands(rewriter, linOp, subViews); + return {}; } - return failure(); + llvm_unreachable("DRR failure case must be a precondition"); } diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -462,7 +462,7 @@ } // Entry point for unrolling declarative pattern rewrites. -Value mlir::vector::unrollSingleResultOpMatchingType( +SmallVector mlir::vector::unrollSingleResultOpMatchingType( PatternRewriter &builder, Operation *op, ArrayRef targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); @@ -482,8 +482,8 @@ } // Unroll 'op' with 'iterationBounds' to 'targetShape'. - return unrollSingleResultStructuredOp(op, iterationBounds, vectors, - resultIndex, targetShape, builder); + return SmallVector{unrollSingleResultStructuredOp( + op, iterationBounds, vectors, resultIndex, targetShape, builder)}; } // Generates slices of 'vectorType' according to 'sizes' and 'strides, and diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -423,6 +423,51 @@ // ----- +func @generic_result_tensor_type(%arg0: memref(off + i)>) { + // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'f32'}} + %0 = linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ (i) -> (i) ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: f32): + linalg.yield %i: f32 + }: memref(off + i)> -> f32 +} + +// ----- + +func @generic_result_tensor_count(%arg0: memref(off + i)>) { + // expected-error @+1 {{op expected #output tensor operands (0) to match #results (1)}} + %0 = linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ (i) -> (i) ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: f32): + linalg.yield %i: f32 + }: memref(off + i)> -> tensor +} + +// ----- + +func @generic_result_tensor_type(%arg0: tensor) { + // expected-error @+1 {{op result #0 must be 'tensor', but got 'tensor'}} + %0 = linalg.generic { + args_in = 0, + args_out = 1, + indexing_maps = [ (i) -> (i) ], + iterator_types = ["parallel"] + } %arg0 { + ^bb(%i: f32): + linalg.yield %i: f32 + }: tensor -> tensor +} + +// ----- + func @generic_fun_result_0_element_type(%arg0: memref) { // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}} linalg.dot(%arg0, %arg0): memref, memref diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -139,6 +139,29 @@ // CHECK-LABEL: func @generic // CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, memref +func @generic_with_tensor_input(%arg0: tensor>, %arg1: memref) { + linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor>, memref + return +} +// CHECK-LABEL: func @generic_with_tensor_input +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, memref + +func @generic_with_tensor_output(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: tensor) -> (tensor) { + %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : memref, offset: ?, strides: [?, 1]>, tensor -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @generic_with_tensor_output +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, tensor -> tensor +// CHECK: return {{.*}} : tensor + +func @generic_with_tensor_input_and_output(%arg0: tensor>, %arg1: tensor) -> (tensor) { + %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @generic_with_tensor_input_and_output +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, tensor -> tensor +// CHECK: return {{.*}} : tensor + #trait2 = { args_in = 1, args_out = 1, diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -100,27 +100,39 @@ // Linalg to vector contraction patterns. //===----------------------------------------------------------------------===// def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), - [(VectorizeGenericLinalgOp<"GenericOp">)], - [(Constraint>)]>; + [(VectorizeGenericLinalgOp)], + [(Constraint, + PreconditionVectorizeGenericLinalgOp + ]>>)]>; //===----------------------------------------------------------------------===// // Linalg generic permutation patterns. //===----------------------------------------------------------------------===// def : Pat<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">), - [(Constraint]>>)]>; + (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), + [(Constraint, + PreconditionPermuteGenericLinalgOp<[1, 2, 0]> + ]>>)]>; def : Pat<(IndexedGenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), - (PermuteGenericLinalgOp<[1,2,0],"PERMUTED">), - [(Constraint]>>)]>; + (PermuteGenericLinalgOp<[1, 2, 0], "PERMUTE"> $op), + [(Constraint, + PreconditionPermuteGenericLinalgOp<[1, 2, 0]> + ]>>)]>; //===----------------------------------------------------------------------===// // Linalg subview operands promotion. //===----------------------------------------------------------------------===// def : Pat<(MatmulOp:$op $_, $_, $_), - (PromoteSubviewsLinalgOp<"MatmulOp">), - [(Constraint>), - (Constraint>)]>; + (PromoteSubviewsLinalgOp), + [(Constraint, + HasLinalgTransformMarker<"_promote_views_">]>> + )]>; #endif // TEST_LINALG_TRANSFORMS_PATTERNS diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -329,8 +329,9 @@ os.indent(indent) << "{\n"; indent += 2; os.indent(indent) << formatv( - "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");\n", depth, - attr.getStorageType(), namedAttr->name); + "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");" + "(void)tblgen_attr;\n", + depth, attr.getStorageType(), namedAttr->name); // TODO(antiagainst): This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { @@ -573,8 +574,14 @@ auto val = handleResultPattern(resultTree, offsets[i], 0); os.indent(4) << "\n"; // Resolve each symbol for all range use so that we can loop over them. + // We need an explicit cast to `SmallVector` to capture the cases where + // `{0}` resolves to an `Operation::result_range` as well as cases that + // are not iterable (e.g. vector that gets wrapped in additional braces by + // RewriterGen). os << symbolInfoMap.getAllRangeUse( - val, " for (auto v : {0}) {{ tblgen_repl_values.push_back(v); }", + val, + " for (auto v : SmallVector{ {0} }) {{ " + "tblgen_repl_values.push_back(v); }", "\n"); } os.indent(4) << "\n";