diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTraits.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -39,19 +39,19 @@ // Loop types handling. //========================================================================// InterfaceMethod< - "Query the number of parallel loops within the current operation.", + "Return the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" >, InterfaceMethod< - "Query the number of reduction loops within the current operation.", + "Return the number of reduction loops within the current operation.", "unsigned", "getNumReductionLoops" >, InterfaceMethod< - "Query the number of window loops within the current operation.", + "Return the number of window loops within the current operation.", "unsigned", "getNumWindowLoops" >, InterfaceMethod< - "Query the number of loops within the current operation.", + "Return the number of loops within the current operation.", "unsigned", "getNumLoops">, InterfaceMethod< @@ -63,10 +63,10 @@ // Input arguments handling. //========================================================================// InterfaceMethod< - "Query the number of inputs from the current operation.", + "Return the number of inputs from the current operation.", "unsigned", "getNumInputs" >, - InterfaceMethod<"Query the input view at the given index.", + InterfaceMethod<"Return the input view at the given index.", "Value ", "getInput", (ins "unsigned":$i) >, InterfaceMethod<[{ @@ -76,41 +76,40 @@ "llvm::Optional", "getIndexOfInput", (ins "Value ":$v) >, InterfaceMethod< - "Query the input operands from the current operation.", + "Return the input operands from the current operation.", "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ - Query the type of the input shape at the given index. + Return the type of the input shape at the given index. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the subset of input operands that are of ranked tensor type. + Return the subset of input operands that are of ranked tensor type. }], "SmallVector", "getInputTensorTypes">, - //========================================================================// // Output arguments handling. //========================================================================// InterfaceMethod< - "Query the number of outputs from the current operation.", + "Return the number of outputs from the current operation.", "unsigned", "getNumOutputs" >, - InterfaceMethod<"Query the output buffer at the given index.", + InterfaceMethod<"Return the output buffer at the given index.", "Value ", "getOutputBuffer", (ins "unsigned":$i) >, InterfaceMethod<[{ - Query the index of the given buffer value, or `None` if the value is not - part of the output buffers. + Return the index of the given buffer value, or `None` if the value is + not part of the output buffers. }], "llvm::Optional", "getIndexOfOutputBuffer", (ins "Value ":$view) >, InterfaceMethod<[{ - Query the type of the output buffer at the given index. + Return the type of the output buffer at the given index. }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the results that are of ranked tensor type. + Return the results that are of ranked tensor type. }], "SmallVector", "getOutputTensorTypes">, InterfaceMethod< - "Query the output buffers (operands) from the current operation.", + "Return the output buffers (operands) from the current operation.", "Operation::operand_range", "getOutputBuffers" >, @@ -136,18 +135,44 @@ // Other interface methods. //========================================================================// InterfaceMethod< - "Query the iterator types attribute within the current operation.", + "Return the reference iterators for this named op (if any are specied). " + "These reference iterators are used to specify the default behavior of " + "the op. Typically this would be a static method but in order to allow " + "rank-polymorphic ops, this needs to be per object instance. Named ops " + "must define referenceIterators, even if empty for the 0-D case. " + "Generic ops on the other hand have a None `referenceIterators`", + "llvm::Optional>", "referenceIterators" + >, + InterfaceMethod< + "Return the reference indexing maps for this named op (if any are " + "specified). Typically this would be a static method but in order to " + "allow rank-polymorphic ops, this needs to be per object instance. Named " + "ops must define referenceIterators, even if empty for the 0-D case. " + "Generic ops on the other hand have a None `referenceIndexingMaps`", + "llvm::Optional>", "referenceIndexingMaps" + >, + InterfaceMethod< + "Return the iterator types attribute within the current operation.", "ArrayAttr", "iterator_types" >, InterfaceMethod< - "Query the indexing maps attribute within the current operation.", + "Return the indexing maps attribute within the current operation.", "ArrayAttr", "indexing_maps" >, + InterfaceMethod<"Return the input or output indexing map at index `i`.", + "AffineMap", "getIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Return the input indexing map at index `i`.", + "AffineMap", "getInputIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Return the output indexing map at index `i`.", + "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i) + >, InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. + Return whether the op has only MemRef input and outputs. }], "bool", "hasBufferSemantics">, - InterfaceMethod<[{ - Query whether the op has only RankedTensor input and outputs. + Return<[{ + Return whether the op has only RankedTensor input and outputs. }], "bool", "hasTensorSemantics">, //========================================================================// @@ -204,7 +229,7 @@ } //////////////////////////////////////////////////////////////////////////////// -// Concrete Linalg ops. +// Named Linalg ops, implemented as special configurations of a generic op. //////////////////////////////////////////////////////////////////////////////// def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { let description = [{ @@ -266,14 +291,19 @@ builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { + // Rank-polymorphic. + // filling_value -> O(ivs) with parallel iterators. + llvm::Optional> referenceIterators() { unsigned nPar = input().getType().cast().getRank(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + return SmallVector(nPar, getParallelIteratorTypeName()); + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for CopyOp"); } }]; let verifier = [{ return ::verify(*this); }]; @@ -282,21 +312,24 @@ } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { - let arguments = (ins AnyStridedMemRef:$input, + let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - unsigned nPar = input().getType().cast().getRank(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + // Rank-polymorphic. + // filling_value -> O(ivs) with parallel iterators. + llvm::Optional> referenceIterators() { + unsigned nPar = output().getType().cast().getRank(); + return SmallVector(nPar, getParallelIteratorTypeName()); } - }]; - let verifier = [{ return ::verify(*this); }]; + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for CopyOp"); + } + }]; let hasFolder = 1; } @@ -305,12 +338,16 @@ AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<0>); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - return ArrayAttr::get( - StringAttr::get(getReductionIteratorTypeName(), ctx), ctx); + llvm::Optional> referenceIterators() { + return SmallVector{getReductionIteratorTypeName()}; + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for DotOp"); } }]; @@ -322,14 +359,18 @@ AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<1>); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - Attribute iters[2]{ - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getReductionIteratorTypeName(), ctx)}; - return ArrayAttr::get(iters, ctx); + llvm::Optional> referenceIterators() { + return SmallVector{ + getParallelIteratorTypeName(), + getReductionIteratorTypeName()}; + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatvecOp"); } }]; @@ -341,15 +382,19 @@ AnyStridedMemRefOfRank<2>, AnyStridedMemRefOfRank<2>); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - Attribute iters[3]{ - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getReductionIteratorTypeName(), ctx)}; - return ArrayAttr::get(iters, ctx); + llvm::Optional> referenceIterators() { + return SmallVector{ + getParallelIteratorTypeName(), + getParallelIteratorTypeName(), + getReductionIteratorTypeName()}; + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); } }]; @@ -387,11 +432,13 @@ unsigned getNumInputFeatureDimensions() { return 1; } unsigned getNumOutputFeatureDimensions() { return 1; } + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { + llvm::Optional> referenceIterators() { // Outer parallel loops are always the number of output dimensions; i.e. - // [ b, xs, q] in the TF notation above. + // [b, xs, q] in the TF notation above. unsigned nPar = getOutputShapedType(0).getRank(); unsigned nRed = getNumInputFeatureDimensions(); // Window loops are a special kind of reduction that is never tiled or @@ -400,13 +447,11 @@ // This may evolve in the future. unsigned nWin = nPar - getNumBatchDimensions() - getNumInputFeatureDimensions(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + SmallVector iters(nPar, getParallelIteratorTypeName()); iters.reserve(nPar + nRed + nWin); - iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx)); - iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + iters.append(nRed, getReductionIteratorTypeName()); + iters.append(nWin, getWindowIteratorTypeName()); + return iters; } int64_t getStride(unsigned i) { @@ -422,6 +467,10 @@ return dilations()->getValue()[i] .cast().getValue().getSExtValue(); } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); + } }]; let verifier = [{ return ::verify(*this); }]; @@ -438,6 +487,9 @@ CPred<"$_self.cast().getRank() == " # rank>] >>; +//////////////////////////////////////////////////////////////////////////////// +// Generic Linalg ops. +//////////////////////////////////////////////////////////////////////////////// class GenericOpBase : LinalgStructuredBase_Op { let arguments = (ins Variadic:$views, I64Attr:$args_in, @@ -457,34 +509,36 @@ getIteratorTypesAttrName() }; } + unsigned getNumInputs() { return args_in().getSExtValue(); } + unsigned getNumOutputs() { return args_out().getSExtValue(); } + FuncOp getFunction() { auto moduleOp = getParentOfType(); return fun().hasValue() ? moduleOp.lookupSymbol(fun().getValue()) : FuncOp(); } + StringRef getLibraryCallName() { return library_call().hasValue() ? library_call().getValue() : ""; } - AffineMap getIndexingMap(unsigned i) { - assert(i < getNumInputsAndOutputs()); - return indexing_maps().getValue()[i].cast().getValue(); - } - AffineMap getInputIndexingMap(unsigned i) { - assert(i < getNumInputs()); - return indexing_maps().getValue()[i].cast().getValue(); - } - AffineMap getOutputIndexingMap(unsigned i) { - assert(i < getNumOutputs()); - return indexing_maps().getValue()[i + getNumInputs()] - .cast().getValue(); - } + + llvm::Optional> referenceIterators() { + llvm_unreachable( + "No such thing as reference iterator types for a generic op."); + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable( + "No such thing as reference indexing maps for a generic op."); + } }]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; } +/// Index-free GenericOp. def GenericOp : GenericOpBase<"generic"> { let description = [{ Generic Linalg op form where the key properties of the computation are @@ -609,6 +663,8 @@ let hasFolder = 1; } +/// GenericOp with Indexing (i.e. multi-for style in which the region is passed +/// the enclosing loop induction variables) def IndexedGenericOp : GenericOpBase<"indexed_generic"> { let description = [{ Indexed Generic Linalg op form where the key properties of the computation diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" @@ -214,6 +215,103 @@ //==========================================================================// // Other interface methods. //==========================================================================// + + // Get or build the indexing_maps ArrayAttr. + ArrayAttr iterator_types() { + // Return the attribute if it is present. + if (auto attr = this->getOperation()->getAttr("iterator_types")) + return attr.template cast(); + + // If not, form the attribute using the reference iterator types for the + // ConcreteType. + auto maybeReferenceIteratorTypes = + cast(this->getOperation()).referenceIterators(); + + // If there is no reference, this must be a generic op. + // TODO(ntv): Traits are used to define ops. Split into cpp to avoid + // cyclic dependency. + auto name = this->getOperation()->getName().getStringRef(); + if (!maybeReferenceIteratorTypes && name != "generic" && + name != "indexed_generic") { + this->getOperation()->dump(); + llvm_unreachable("Op missing "); + } + + // If we have a reference, build the reference attribute. + auto *ctx = this->getOperation()->getContext(); + auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes, + [ctx](StringRef str) -> Attribute { + return StringAttr::get(str, ctx); + }); + auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx); + // TODO(ntv): Need to memoize this. Can't just store as an attribute atm as + // it will impact parser, printer and tests. + // this->getOperation()->setAttr("iterator_types", attr); + return attr; + } + + // Get or build the indexing_maps ArrayAttr. + ArrayAttr indexing_maps() { + // Return the attribute if it is present. + if (auto attr = this->getOperation()->getAttr("indexing_maps")) + return attr.template cast(); + + // If not, form the attribute using the reference indexing map for the + // ConcreteType. + auto maybeReferenceIndexingMaps = + cast(this->getOperation()).referenceIndexingMaps(); + + // If there is no reference, this must be a generic op. + auto name = this->getOperation()->getName().getStringRef(); + if (!maybeReferenceIndexingMaps && name != "generic" && + name != "indexed_generic") { + this->getOperation()->dump(); + llvm_unreachable("Op missing referenceIndexingMaps"); + } + + // If we have a reference, build the reference attribute and set it in the + // op before returning. + auto *ctx = this->getOperation()->getContext(); + auto attrRange = + llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) { + // 0-D corner case because there is no such thing as a concrete empty + // map type. + if (!map) + map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx)); + return AffineMapAttr::get(map); + }); + SmallVector attrs{attrRange.begin(), attrRange.end()}; + auto attr = ArrayAttr::get(attrs, ctx); + // TODO(ntv): Need to memoize this. Can't just store as an attribute atm as + // it will impact parser, printer and tests. + // this->getOperation()->setAttr("indexing_maps", attr); + return attr; + } + + AffineMap getIndexingMap(unsigned i) { + assert(i < getNumInputsAndOutputs()); + return indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + } + + AffineMap getInputIndexingMap(unsigned i) { + assert(i < nInputs()); + return indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + } + + AffineMap getOutputIndexingMap(unsigned i) { + assert(i < nOutputs()); + return indexing_maps() + .getValue()[i + nInputs()] + .template cast() + .getValue(); + } + /// Query whether the op has only buffer inputs and no returns. bool hasBufferSemantics() { return this->getOperation()->getNumResults() == 0 && diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -866,6 +866,15 @@ return success(); } +static AffineMap extractOrIdentityMap(Optional maybeMap, + unsigned rank, MLIRContext *context) { + if (maybeMap) + return maybeMap.getValue(); + if (rank == 0) + return AffineMap(); + return AffineMap::getMultiDimIdentityMap(rank, context); +} + namespace mlir { namespace linalg { @@ -880,15 +889,6 @@ } // namespace linalg } // namespace mlir -static AffineMap extractOrIdentityMap(Optional maybeMap, - unsigned rank, MLIRContext *context) { - if (maybeMap) - return maybeMap.getValue(); - if (rank == 0) - return AffineMap(); - return AffineMap::getMultiDimIdentityMap(rank, context); -} - // Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num) // and increments `curIdx` to `curIdx + num`. static SmallVector @@ -997,23 +997,15 @@ AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), // output[b, x[0], ..., x[N-1], k] AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; - } else if (auto genericOp = dyn_cast(op)) { - SmallVector res; - unsigned nViews = genericOp.getNumInputsAndOutputs(); - res.reserve(nViews); - for (unsigned i = 0, e = nViews; i < e; ++i) { - res.push_back(genericOp.getIndexingMap(i)); - } - return res; - } else if (auto indexedGenericOp = dyn_cast(op)) { - SmallVector res; - unsigned nViews = indexedGenericOp.getNumInputsAndOutputs(); - res.reserve(nViews); - for (unsigned i = 0, e = nViews; i < e; ++i) - res.push_back(indexedGenericOp.getIndexingMap(i)); - return res; } - llvm_unreachable("Missing loopToOperandRangesMaps for op"); + SmallVector res; + auto linalgOp = cast(op); + unsigned nViews = linalgOp.getNumInputsAndOutputs(); + res.reserve(nViews); + for (unsigned i = 0, e = nViews; i < e; ++i) + res.push_back(linalgOp.getIndexingMap(i)); + assert(nViews == linalgOp.indexing_maps().size()); + return res; } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {