diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTraits.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -81,7 +81,6 @@ Query the subset of input operands that are of ranked tensor type. }], "SmallVector", "getInputTensorTypes">, - //========================================================================// // Output arguments handling. //========================================================================// @@ -130,6 +129,18 @@ //========================================================================// // Other interface methods. //========================================================================// + InterfaceMethod< + "Query the reference iterator types attribute for this named op. " + "Typically this would be a static method but in order to allow " + "rank-polymorphic ops, this needs to be per object instance.", + "llvm::Optional>", "referenceIteratorTypes" + >, + InterfaceMethod< + "Query the reference indexing maps attribute for this named op. " + "Typically this would be a static method but in order to allow " + "rank-polymorphic ops, this needs to be per object instance.", + "llvm::Optional>", "referenceIndexingMaps" + >, InterfaceMethod< "Query the iterator types attribute within the current operation.", "ArrayAttr", "iterator_types" @@ -138,9 +149,17 @@ "Query the indexing maps attribute within the current operation.", "ArrayAttr", "indexing_maps" >, - InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, + InterfaceMethod<"Query the input or output indexing map at index `i`.", + "AffineMap", "getIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Query the input indexing map at index `i`.", + "AffineMap", "getInputIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Query the output indexing map at index `i`.", + "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Query whether the op has only MemRef input and outputs.", + "bool", "hasBufferSemantics">, //========================================================================// // Other static interface methods. @@ -195,8 +214,23 @@ let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; } +class AffineExpressions map> { + list indexing_maps = map; +} + +class LinalgNamedStructured_Op props> + : Op { + bit hasLibraryImpl = 0; + list iterators = ?; + list iterators_types = ?; + list input_indexing_maps = ?; + list output_indexing_maps = ?; + let assemblyFormat = "`(` operands `)` attr-dict `:` type(operands)"; +} + //////////////////////////////////////////////////////////////////////////////// -// Concrete Linalg ops. +// Named Linalg ops, implemented as special configurations of a generic op. //////////////////////////////////////////////////////////////////////////////// def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { let description = [{ @@ -258,15 +292,18 @@ builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; let extraClassDeclaration = libraryCallName # [{ - ArrayAttr indexing_maps(); - - ArrayAttr iterator_types() { + // Rank-polymorphic. + // filling_value -> O(ivs) with parallel iterators. + llvm::Optional> referenceIteratorTypes() { unsigned nPar = input().getType().cast().getRank(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + return SmallVector(nPar, getParallelIteratorTypeName()); } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for CopyOp"); + } + + ArrayAttr indexing_maps(); }]; let verifier = [{ return ::verify(*this); }]; @@ -274,20 +311,25 @@ } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { - let arguments = (ins AnyStridedMemRef:$input, + let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value); let extraClassDeclaration = libraryCallName # [{ - ArrayAttr indexing_maps(); - - ArrayAttr iterator_types() { - unsigned nPar = input().getType().cast().getRank(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + // Rank-polymorphic. + // filling_value -> O(ivs) with parallel iterators. + llvm::Optional> referenceIteratorTypes() { + unsigned rank = output().getType().cast().getRank(); + return SmallVector(rank, getParallelIteratorTypeName()); + } + llvm::Optional> referenceIndexingMaps() { + auto ctx = getContext(); + unsigned rank = output().getType().cast().getRank(); + if (rank == 0) + return SmallVector{ + AffineMap::get(0, 0, {getAffineConstantExpr(0, ctx)})}; + return SmallVector{ + AffineMap::getMultiDimIdentityMap(rank, ctx)}; } }]; - let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; } @@ -297,33 +339,31 @@ AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<0>); let extraClassDeclaration = libraryCallName # [{ - ArrayAttr indexing_maps(); + llvm::Optional> referenceIteratorTypes() { + return SmallVector{getReductionIteratorTypeName()}; + } - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - return ArrayAttr::get( - StringAttr::get(getReductionIteratorTypeName(), ctx), ctx); + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for DotOp"); } + + ArrayAttr indexing_maps(); }]; let hasFolder = 1; } -def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { +def MatvecOp : LinalgNamedStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { let arguments = (ins AnyStridedMemRefOfRank<2>, AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<1>); - let extraClassDeclaration = libraryCallName # [{ - ArrayAttr indexing_maps(); - - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - Attribute iters[2]{ - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getReductionIteratorTypeName(), ctx)}; - return ArrayAttr::get(iters, ctx); - } - }]; + let hasLibraryImpl = 1; + let iterators = ["i", "r_j"]; + let iterators_types = ["parallel", "reduction"]; + // A(i, r_j) * B(r_j) -> C(i) + let input_indexing_maps = [AffineExpressions<["i", "r_j"]>, + AffineExpressions<["r_j"]>]; + let output_indexing_maps = [AffineExpressions<["i"]>]; let hasFolder = 1; } @@ -335,13 +375,15 @@ let extraClassDeclaration = libraryCallName # [{ ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - Attribute iters[3]{ - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getReductionIteratorTypeName(), ctx)}; - return ArrayAttr::get(iters, ctx); + llvm::Optional> referenceIteratorTypes() { + return SmallVector{ + getParallelIteratorTypeName(), + getParallelIteratorTypeName(), + getReductionIteratorTypeName()}; + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); } }]; @@ -376,14 +418,16 @@ // TODO(ntv) extend to support more than 1 dimensions and potentially // grouping too. unsigned getNumBatchDimensions() { return 1; } + unsigned getNumInputFeatureDimensions() { return 1; } + unsigned getNumOutputFeatureDimensions() { return 1; } ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { + llvm::Optional> referenceIteratorTypes() { // Outer parallel loops are always the number of output dimensions; i.e. - // [ b, xs, q] in the TF notation above. + // [b, xs, q] in the TF notation above. unsigned nPar = getOutputShapedType(0).getRank(); unsigned nRed = getNumInputFeatureDimensions(); // Window loops are a special kind of reduction that is never tiled or @@ -392,13 +436,11 @@ // This may evolve in the future. unsigned nWin = nPar - getNumBatchDimensions() - getNumInputFeatureDimensions(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + SmallVector iters(nPar, getParallelIteratorTypeName()); iters.reserve(nPar + nRed + nWin); - iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx)); - iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + iters.append(nRed, getReductionIteratorTypeName()); + iters.append(nWin, getWindowIteratorTypeName()); + return iters; } int64_t getStride(unsigned i) { @@ -414,6 +456,10 @@ return dilations()->getValue()[i] .cast().getValue().getSExtValue(); } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); + } }]; let verifier = [{ return ::verify(*this); }]; @@ -430,6 +476,9 @@ CPred<"$_self.cast().getRank() == " # rank>] >>; +//////////////////////////////////////////////////////////////////////////////// +// Generic Linalg ops. +//////////////////////////////////////////////////////////////////////////////// class GenericOpBase : LinalgStructuredBase_Op { let arguments = (ins Variadic:$views, I64Attr:$args_in, @@ -449,34 +498,36 @@ getIteratorTypesAttrName() }; } + unsigned getNumInputs() { return args_in().getSExtValue(); } + unsigned getNumOutputs() { return args_out().getSExtValue(); } + FuncOp getFunction() { auto moduleOp = getParentOfType(); return fun().hasValue() ? moduleOp.lookupSymbol(fun().getValue()) : FuncOp(); } + StringRef getLibraryCallName() { return library_call().hasValue() ? library_call().getValue() : ""; } - AffineMap getIndexingMap(unsigned i) { - assert(i < getNumInputsAndOutputs()); - return indexing_maps().getValue()[i].cast().getValue(); - } - AffineMap getInputIndexingMap(unsigned i) { - assert(i < getNumInputs()); - return indexing_maps().getValue()[i].cast().getValue(); - } - AffineMap getOutputIndexingMap(unsigned i) { - assert(i < getNumOutputs()); - return indexing_maps().getValue()[i + getNumInputs()] - .cast().getValue(); - } + + llvm::Optional> referenceIteratorTypes() { + llvm_unreachable( + "No such thing as reference iterator types for a generic op."); + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable( + "No such thing as reference indexing maps for a generic op."); + } }]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; } +/// Index-free GenericOp (i.e. XLA-style, IotaOp is explicit etc) def GenericOp : GenericOpBase<"generic"> { let description = [{ Generic Linalg op form where the key properties of the computation are @@ -601,6 +652,8 @@ let hasFolder = 1; } +/// GenericOp with Indexing (i.e. multi-for style in which the region is passed +/// the enclosing loop induction variables) def IndexedGenericOp : GenericOpBase<"indexed_generic"> { let description = [{ Indexed Generic Linalg op form where the key properties of the computation diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" @@ -208,6 +209,94 @@ //==========================================================================// // Other interface methods. //==========================================================================// + + // Get or build the indexing_maps ArrayAttr. + ArrayAttr iterator_types() { + // Return the attribute if it is present. + if (auto attr = this->getOperation()->getAttr("iterator_types")) + return attr.template cast(); + // If not, form the attribute using the reference iterator types for the + // ConcreteType. + auto maybeReferenceIteratorTypes = + cast(this->getOperation()).referenceIteratorTypes(); + // If there is no reference, this must be a generic op. + // TODO(ntv): Traits are used to define ops. Split into cpp to avoid + // cyclic dependency. + auto name = this->getOperation()->getName().getStringRef(); + if (!maybeReferenceIteratorTypes && name != "generic" && + name != "indexed_generic") { + this->getOperation()->dump(); + llvm_unreachable("Op missing referenceIteratorTypes"); + } + // If we have a reference, build the reference attribute and set it in the + // op before returning. + auto *ctx = this->getOperation()->getContext(); + auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes, + [ctx](StringRef str) -> Attribute { + return StringAttr::get(str, ctx); + }); + auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx); + this->getOperation()->setAttr("iterator_types", attr); + return attr; + } + + // Get or build the indexing_maps ArrayAttr. + ArrayAttr indexing_maps() { + // Return the attribute if it is present. + if (auto attr = this->getOperation()->getAttr("indexing_maps")) + return attr.template cast(); + // If not, form the attribute using the reference indexing map for the + // ConcreteType. + auto maybeReferenceIndexingMaps = + cast(this->getOperation()).referenceIndexingMaps(); + // If there is no reference, this must be a generic op. + auto name = this->getOperation()->getName().getStringRef(); + if (!maybeReferenceIndexingMaps && name != "generic" && + name != "indexed_generic") { + this->getOperation()->dump(); + llvm_unreachable("Op missing referenceIndexingMaps"); + } + // If we have a reference, build the reference attribute and set it in the + // op before returning. + auto *ctx = this->getOperation()->getContext(); + auto attrRange = + llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) { + // 0-D corner case because there is no such thing as a concrete empty + // map type. + if (!map) + map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx)); + return AffineMapAttr::get(map); + }); + SmallVector attrs{attrRange.begin(), attrRange.end()}; + auto attr = ArrayAttr::get(attrs, ctx); + this->getOperation()->setAttr("indexing_maps", attr); + return attr; + } + + AffineMap getIndexingMap(unsigned i) { + assert(i < getNumInputsAndOutputs()); + return indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + } + + AffineMap getInputIndexingMap(unsigned i) { + assert(i < nInputs()); + return indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + } + + AffineMap getOutputIndexingMap(unsigned i) { + assert(i < nOutputs()); + return indexing_maps() + .getValue()[i + nInputs()] + .template cast() + .getValue(); + } + /// Query whether the op has only buffer inputs and no returns. bool hasBufferSemantics() { return this->getOperation()->getNumResults() == 0 && diff --git a/mlir/include/mlir/TableGen/ODSDialectHook.h b/mlir/include/mlir/TableGen/ODSDialectHook.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/ODSDialectHook.h @@ -0,0 +1,42 @@ +//===- ODSDialectHook.h - Dialect customization hooks into ODS --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines ODS customization hooks for dialects to programmatically +// emit dialect specific contents in ODS C++ code emission. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_ODSDIALECTHOOK_H_ +#define MLIR_TABLEGEN_ODSDIALECTHOOK_H_ + +#include + +namespace llvm { +class StringRef; +} + +namespace mlir { +namespace tblgen { +class Operator; +class OpClass; + +// The emission function for dialect specific content. It takes in an Operator +// and updates the OpClass accordingly. +using DialectEmitFunction = + std::function; + +// ODSDialectHookRegistration provides a global initializer that registers a +// dialect specific content emission function. +struct ODSDialectHookRegistration { + ODSDialectHookRegistration(llvm::StringRef dialectName, + DialectEmitFunction emitFn); +}; +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TABLEGEN_ODSDIALECTHOOK_H_ diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -46,6 +46,9 @@ // Returns this op's dialect name. StringRef getDialectName() const; + // Returns the dialect of the op. + const Dialect &getDialect() const { return dialect; } + // Returns the operation name. The name will follow the "." // format if its dialect name is not empty. std::string getOperationName() const; @@ -156,14 +159,8 @@ StringRef getExtraClassDeclaration() const; // Returns the Tablegen definition this operator was constructed from. - // TODO(antiagainst,zinenko): do not expose the TableGen record, this is a - // temporary solution to OpEmitter requiring a Record because Operator does - // not provide enough methods. const llvm::Record &getDef() const; - // Returns the dialect of the op. - const Dialect &getDialect() const { return dialect; } - // Prints the contents in this operator to the given `os`. This is used for // debugging purposes. void print(llvm::raw_ostream &os) const; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -866,6 +866,15 @@ return success(); } +static AffineMap extractOrIdentityMap(Optional maybeMap, + unsigned rank, MLIRContext *context) { + if (maybeMap) + return maybeMap.getValue(); + if (rank == 0) + return AffineMap(); + return AffineMap::getMultiDimIdentityMap(rank, context); +} + namespace mlir { namespace linalg { @@ -880,15 +889,6 @@ } // namespace linalg } // namespace mlir -static AffineMap extractOrIdentityMap(Optional maybeMap, - unsigned rank, MLIRContext *context) { - if (maybeMap) - return maybeMap.getValue(); - if (rank == 0) - return AffineMap(); - return AffineMap::getMultiDimIdentityMap(rank, context); -} - // Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num) // and increments `curIdx` to `curIdx + num`. static SmallVector @@ -938,12 +938,6 @@ extractOrIdentityMap(maybeInputMap, inputRank, context), extractOrIdentityMap(maybeOutputMap, outputRank, context)}; } - if (auto fillOp = dyn_cast(op)) { - // filling_value -> O(ivs) - unsigned rank = fillOp.getNumParallelLoops(); - return SmallVector{ - extractOrIdentityMap(llvm::None, rank, context)}; - } auto i = getAffineDimExpr(0, context); auto j = getAffineDimExpr(1, context); auto k = getAffineDimExpr(2, context); @@ -951,11 +945,6 @@ // A(r_i) * B(r_i) -> C() return SmallVector{AffineMap::get(1, 0, {i}), AffineMap::get(1, 0, {i}), AffineMap()}; - if (isa(op)) - // A(i, r_j) * B(r_j) -> C(i) - return SmallVector{AffineMap::get(2, 0, {i, j}), - AffineMap::get(2, 0, {j}), - AffineMap::get(2, 0, {i})}; if (isa(op)) // A(i, r_k) * B(r_k, j) -> C(i, j) return SmallVector{AffineMap::get(3, 0, {i, k}), @@ -997,23 +986,15 @@ AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), // output[b, x[0], ..., x[N-1], k] AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; - } else if (auto genericOp = dyn_cast(op)) { - SmallVector res; - unsigned nViews = genericOp.getNumInputsAndOutputs(); - res.reserve(nViews); - for (unsigned i = 0, e = nViews; i < e; ++i) { - res.push_back(genericOp.getIndexingMap(i)); - } - return res; - } else if (auto indexedGenericOp = dyn_cast(op)) { - SmallVector res; - unsigned nViews = indexedGenericOp.getNumInputsAndOutputs(); - res.reserve(nViews); - for (unsigned i = 0, e = nViews; i < e; ++i) - res.push_back(indexedGenericOp.getIndexingMap(i)); - return res; } - llvm_unreachable("Missing loopToOperandRangesMaps for op"); + SmallVector res; + auto linalgOp = cast(op); + unsigned nViews = linalgOp.getNumInputsAndOutputs(); + res.reserve(nViews); + for (unsigned i = 0, e = nViews; i < e; ++i) + res.push_back(linalgOp.getIndexingMap(i)); + assert(nViews == linalgOp.indexing_maps().size()); + return res; } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { @@ -1068,15 +1049,9 @@ ArrayAttr mlir::linalg::DotOp::indexing_maps() { return getIndexingMaps(getOperation()); } -ArrayAttr mlir::linalg::FillOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} ArrayAttr mlir::linalg::MatmulOp::indexing_maps() { return getIndexingMaps(getOperation()); } -ArrayAttr mlir::linalg::MatvecOp::indexing_maps() { - return getIndexingMaps(getOperation()); -} // TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of OpInterfaces diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -63,13 +63,14 @@ // CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref, f32 // CHECK: linalg.fill(%[[fullB]], {{.*}}) : memref, f32 // CHECK: linalg.fill(%[[fullC]], {{.*}}) : memref, f32 -// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref, memref -// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref -// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// CHECK: linalg.copy(%[[vA]], %[[partialA]]) {iterator_types = ["parallel", "parallel"]} : memref, memref +// CHECK: linalg.copy(%[[vB]], %[[partialB]]) {iterator_types = ["parallel", "parallel"]} : memref, memref +// CHECK: linalg.copy(%[[vC]], %[[partialC]]) {iterator_types = ["parallel", "parallel"]} : memref, memref // // CHECK: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref, memref, memref // -// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref +// CHECK: linalg.copy(%[[partialC]], %[[vC]]) {iterator_types = ["parallel", "parallel"]} : memref, memref // // CHECK: dealloc %[[tmpA]] : memref<32xi8> // CHECK: dealloc %[[tmpB]] : memref<48xi8> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -1,29 +1,11 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt -split-input-file %s | FileCheck %s +// | mlir-opt | FileCheck %s // TODO(pifon): Re-enable LLVM lowering test after IndexedGenericOp is lowered. // // Test that we can lower all the way to LLVM without crashing, don't check results here. // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 -// CHECK-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> -// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> -// CHECK-DAG: #[[strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)> -// CHECK-DAG: #[[strided6D:.*]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)> - -// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> -// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> - -// CHECK-DAG: #[[reshapeD01:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK-DAG: #[[reshapeD2:.*]] = affine_map<(d0, d1, d2) -> (d2)> -// CHECK-DAG: #[[reshapeD0:.*]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK-DAG: #[[reshapeD12:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[reshapeD012:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[reshape5D01:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> -// CHECK-DAG: #[[reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)> -// CHECK-DAG: #[[reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> - func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range return @@ -31,43 +13,107 @@ // CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-NEXT: linalg.range %{{.*}} : %{{.*}} : %{{.*}} : !linalg.range +// ----- + +// CHECK-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { %c0 = constant 0 : index %0 = muli %arg0, %arg0 : index %1 = alloc (%0) : memref %2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range - %3 = view %1[%c0][%arg0, %arg0] : memref to memref - %4 = linalg.slice %3[%2, %2] : memref, !linalg.range, !linalg.range, memref - %5 = linalg.slice %3[%2, %arg2] : memref, !linalg.range, index, memref - %6 = linalg.slice %3[%arg2, %2] : memref, index, !linalg.range, memref - %7 = linalg.slice %3[%arg2, %arg3] : memref, index, index, memref - %8 = view %1[%c0][%arg0, %arg0] : memref to memref, offset: ?, strides: [?, 1]> + %3 = view %1[%c0][%arg0, %arg0] : + memref to memref + %4 = linalg.slice %3[%2, %2] : + memref, + !linalg.range, + !linalg.range, + memref + %5 = linalg.slice %3[%2, %arg2] : memref, + !linalg.range, + index, + memref + %6 = linalg.slice %3[%arg2, %2] : memref, + index, + !linalg.range, + memref + %7 = linalg.slice %3[%arg2, %arg3] : memref, + index, + index, + memref + %8 = view %1[%c0][%arg0, %arg0] : + memref to memref, offset: ?, strides: [?, 1]> dealloc %1 : memref return } -// CHECK-LABEL: func @views(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { +// CHECK-LABEL: func @views // CHECK: muli %{{.*}}, %{{.*}} : index // CHECK-NEXT: alloc(%{{.*}}) : memref // CHECK-NEXT: range -// CHECK-NEXT: std.view %{{.*}}[%{{.*}}][%{{.*}}] : memref to memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, !linalg.range, index, memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, index, !linalg.range, memref -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, index, index, memref -// CHECK-NEXT: view %{{.*}}[%{{.*}}][%{{.*}}] : memref to memref, #[[strided2D]]> +// CHECK-NEXT: std.view %{{.*}}[%{{.*}}][%{{.*}}] : +// CHECK-SAME: memref to memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : +// CHECK-SAME: memref, +// CHECK-SAME: !linalg.range, +// CHECK-SAME: !linalg.range, +// CHECK-SAME: memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : +// CHECK-SAME: memref, +// CHECK-SAME: !linalg.range, +// CHECK-SAME: index, +// CHECK-SAME: memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : +// CHECK-SAME: memref, +// CHECK-SAME: index, +// CHECK-SAME: !linalg.range, +// CHECK-SAME: memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : +// CHECK-SAME: memref, +// CHECK-SAME: index, +// CHECK-SAME: index, +// CHECK-SAME: memref +// CHECK-NEXT: view %{{.*}}[%{{.*}}][%{{.*}}] : +// CHECK-SAME: memref to memref, #[[strided2D]]> // CHECK-NEXT: dealloc %{{.*}} : memref -func @ops(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) { - linalg.matmul(%arg0, %arg0, %arg0) : memref, memref, memref - linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref - linalg.dot(%arg1, %arg2, %arg3) : memref, memref, memref +// ----- + +// CHECK-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + +func @ops(%arg0: memref, + %arg1: memref, + %arg2: memref, + %arg3: memref) { + linalg.matmul(%arg0, %arg0, %arg0) : memref, + memref, + memref + linalg.matvec(%arg0, %arg1, %arg2) : memref, + memref, + memref + linalg.dot(%arg1, %arg2, %arg3) : memref, + memref, + memref return } // CHECK-LABEL: func @ops(% -// CHECK: {{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { -// CHECK-NEXT: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref +// CHECK-NEXT: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : +// CHECK-SAME: memref, +// CHECK-SAME: memref, +// CHECK-SAME: memref +// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : +// CHECK-SAME: memref, +// CHECK-SAME: memref, +// CHECK-SAME: memref +// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : +// CHECK-SAME: memref, +// CHECK-SAME: memref, +// CHECK-SAME: memref + +// ----- + +// CHECK-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> func @fill_view(%arg0: memref, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref, f32 @@ -77,12 +123,21 @@ // CHECK: %{{.*}}: memref, %{{.*}}: f32) { // CHECK: linalg.fill(%{{.*}}, %{{.*}}) : memref, f32 +// ----- + +// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> + func @transpose(%arg0: memref) { %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref return } // CHECK-LABEL: func @transpose -// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : memref +// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : +// CHECK-SAME: memref + +// ----- + +// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> func @fill_view3(%arg0: memref, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref, f32 @@ -92,15 +147,29 @@ // CHECK: %{{.*}}: memref, %{{.*}}: f32) { // CHECK: linalg.fill(%{{.*}}, %{{.*}}) : memref, f32 -func @copy_view(%arg0: memref, %arg1: memref) { - linalg.copy(%arg0, %arg1) : memref, memref +// ----- + +// CHECK-DAG: #[[strided1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> + +func @copy_view(%arg0: memref, + %arg1: memref) { + linalg.copy(%arg0, %arg1) : memref, + memref return } // CHECK-LABEL: func @copy_view( -// CHECK: %{{.*}}: memref, %{{.*}}: memref) { -// CHECK: linalg.copy(%{{.*}}, %{{.*}}) : memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// CHECK: linalg.copy(%{{.*}}, %{{.*}}) {iterator_types = ["parallel"]} : +// CHECK-SAME: memref, memref -func @copy_view3(%arg0: memref, %arg1: memref) { +// ----- + +// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> +// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1, d0)> + +func @copy_view3(%arg0: memref, + %arg1: memref) { linalg.copy(%arg0, %arg1) {inputPermutation = affine_map<(i, j, k) -> (i, k, j)>, outputPermutation = affine_map<(i, j, k) -> (k, j, i)>} : memref, memref @@ -108,28 +177,67 @@ } // CHECK-LABEL: func @copy_view3( // CHECK: %{{.*}}: memref, %{{.*}}: memref) { -// CHECK: linalg.copy(%{{.*}}, %{{.*}}) {inputPermutation = #[[map0]], outputPermutation = #[[map1]]} : memref, memref +// CHECK: linalg.copy(%{{.*}}, %{{.*}}) { +// CHECK-SAME: inputPermutation = #[[map0]], +// CHECK-SAME: outputPermutation = #[[map1]]} : +// CHECK-SAME: memref, +// CHECK-SAME: memref + +// ----- -func @conv_view3(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv(%arg0, %arg1, %arg2) : memref, memref, memref +// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> + +func @conv_view3(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) : memref, + memref, + memref return } // CHECK-LABEL: func @conv_view3( -// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { -// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref +// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) : +// CHECK-SAME: memref, +// CHECK-SAME: memref, +// CHECK-SAME: memref + +// ----- + +// CHECK-DAG: #[[strided6D:.*]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5)> -func @conv_view6(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} : memref, memref, memref +func @conv_view6(%arg0: memref, + %arg1: memref, + %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} : + memref, + memref, + memref return } // CHECK-LABEL: func @conv_view6( -// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { -// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) { +// CHECK-SAME: dilations = [4, 4, 5, 5] +// CHECK-SAME: strides = [2, 2, 3, 3] +// CHECK-SAME: memref, +// CHECK-SAME: memref, +// CHECK-SAME: memref + +// ----- + +// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> + +func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 { + %f0 = constant 0.0 : f32 + return %f0 : f32 +} #accesses = [ affine_map<(i, j, k) -> (j, i)>, affine_map<(i, j, k) -> (i, k, i + j)> ] + #trait = { args_in = 1, args_out = 1, @@ -138,24 +246,47 @@ fun = @foo, library_call = "some_external_function_name_1" } -func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 { - %f0 = constant 0.0 : f32 - return %f0 : f32 -} -func @generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.generic #trait %arg0, %arg1 {foo = 1} : memref, offset: ?, strides: [?, 1]>, memref + +func @generic(%arg0: memref, offset: ?, strides: [?, 1]>, + %arg1: memref) { + linalg.generic #trait %arg0, %arg1 {foo = 1} : + memref, offset: ?, strides: [?, 1]>, + memref return } // CHECK-LABEL: func @foo // CHECK-LABEL: func @generic -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, memref +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, +// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], +// CHECK-SAME: library_call = "some_external_function_name_1" +// CHECK-SAME: {foo = 1 : i64}: +// CHECK-SAME: memref, #[[strided2D]]>, memref -func @generic_with_tensor_input(%arg0: tensor>, %arg1: memref) { - linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor>, memref +func @generic_with_tensor_input(%arg0: tensor>, + %arg1: memref) { + linalg.generic #trait %arg0, %arg1 {foo = 1} : + tensor>, + memref return } // CHECK-LABEL: func @generic_with_tensor_input -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, memref +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, +// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], +// CHECK-SAME: library_call = "some_external_function_name_1"} +// CHECK-SAME: {foo = 1 : i64}: +// CHECK-SAME: tensor>, memref + +// ----- + +func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 { + %f0 = constant 0.0 : f32 + return %f0 : f32 +} + +#accesses = [ + affine_map<(i, j, k) -> (j, i)>, + affine_map<(i, j, k) -> (i, k, i + j)> +] #trait2 = { args_in = 2, @@ -165,14 +296,31 @@ fun = @foo, library_call = "some_external_function_name_1" } -func @generic_with_tensor_input_and_output(%arg0: tensor>, %arg1: tensor) -> (tensor) { - %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor + +func @generic_with_tensor_input_and_output( + %arg0: tensor>, %arg1: tensor) + -> (tensor) { + %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} : + tensor>, tensor -> tensor return %0 : tensor } // CHECK-LABEL: func @generic_with_tensor_input_and_output -// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, tensor -> tensor +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, +// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], +// CHECK-SAME: library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: +// CHECK-SAME: tensor>, tensor -> tensor // CHECK: return {{.*}} : tensor +// ----- + +// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> + +#accesses = [ + affine_map<(i, j, k) -> (j, i)>, + affine_map<(i, j, k) -> (i, k, i + j)> +] + #trait3 = { args_in = 1, args_out = 1, @@ -180,31 +328,53 @@ iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2" } -func @generic_region(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { + +func @generic_region(%arg0: memref, offset: ?, strides: [?, 1]>, + %arg1: memref) { linalg.generic #trait3 %arg0, %arg1 { ^bb(%a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 - } {foo = 1}: memref, offset: ?, strides: [?, 1]>, memref + } {foo = 1}: memref, offset: ?, strides: [?, 1]>, + memref return } // CHECK-LABEL: func @generic_region -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} { -// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors +// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], +// CHECK-SAME: library_call = "some_external_function_name_2" +// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 -// CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref +// CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, +// CHECK-SAME: memref + func @indexed_generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { linalg.indexed_generic #trait3 %arg0, %arg1 { ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 - } {foo = 1}: memref, offset: ?, strides: [?, 1]>, memref + } {foo = 1}: memref, offset: ?, strides: [?, 1]>, + memref return } // CHECK-LABEL: func @indexed_generic -// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_2"} %{{.*}}, %{{.*}} { +// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], +// CHECK-SAME: library_call = "some_external_function_name_2" // CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // CHECK: linalg.yield %{{.*}} : f32 -// CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref +// CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, +// CHECK-SAME: memref + +// ----- + +// CHECK-DAG: #[[reshapeD01:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[reshapeD2:.*]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-DAG: #[[reshapeD0:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[reshapeD12:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[reshapeD012:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[reshape5D01:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-DAG: #[[reshape5D2:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)> +// CHECK-DAG: #[[reshape5D34:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)> func @reshape_static(%arg0: memref<3x4x5xf32>) { // Reshapes that collapse and expand back a contiguous tensor. @@ -253,6 +423,13 @@ // CHECK: linalg.reshape {{.*}} [#[[reshape5D01]], #[[reshape5D2]], #[[reshape5D34]]] // CHECK-SAME: memref<1x3x4x1x5xf32> into memref<3x4x5xf32> +// ----- + +// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[strided2DOFF0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> +// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)> +// CHECK-DAG: #[[strided3DOFF0:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s0 + d1 * s1 + d2)> + func @reshape_dynamic(%arg0: memref, %arg1: memref, %arg2: memref) { diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -41,7 +41,8 @@ // TILE-2: %[[sAi:.*]] = std.subview %{{.*}}[%[[I]], %[[C0]]][%[[C2]], %[[K]]][%[[C1]], %[[C1]]] : memref to memref // TILE-2: %[[N:.*]] = dim %{{.*}}, 1 : memref // TILE-2: %[[sCi:.*]] = std.subview %{{.*}}[%[[I]], %[[C0]]][%[[C2]], %[[N]]][%[[C1]], %[[C1]]] : memref to memref -// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) {iterator_types = ["parallel", "parallel", "reduction"]} : memref, memref, memref // TILE-02-LABEL: func @matmul( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index @@ -53,7 +54,8 @@ // TILE-02: %[[sBj:.*]] = std.subview %{{.*}}[%[[C0]], %[[J]]][%[[K]], %[[C2]]][%[[C1]], %[[C1]]] : memref to memref // TILE-02: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[sCj:.*]] = std.subview %{{.*}}[%[[C0]], %[[J]]][%[[M]], %[[C2]]][%[[C1]], %[[C1]]] : memref to memref -// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) {iterator_types = ["parallel", "parallel", "reduction"]} : memref, memref, memref // TILE-002-LABEL: func @matmul( // TILE-002-DAG: %[[C0:.*]] = constant 0 : index @@ -65,7 +67,8 @@ // TILE-002: %[[sAj:.*]] = std.subview %{{.*}}[%[[C0]], %[[K]]][%[[M]], %[[C2]]][%[[C1]], %[[C1]]] : memref to memref // TILE-002: %[[N:.*]] = dim %{{.*}}, 1 : memref // TILE-002: %[[sBj:.*]] = std.subview %{{.*}}[%[[K]], %[[C0]]][%[[C2]], %[[N]]][%[[C1]], %[[C1]]] : memref to memref -// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) {iterator_types = ["parallel", "parallel", "reduction"]} : memref, memref, memref // TILE-234-LABEL: func @matmul( // TILE-234-DAG: %[[C0:.*]] = constant 0 : index @@ -83,7 +86,8 @@ // TILE-234: %[[sBkj:.*]] = std.subview %{{.*}}[%[[K]], %[[J]]][%[[C4]], %[[C3]]][%[[C1]], %[[C1]]] : memref to memref // TILE-234: %[[sCij:.*]] = std.subview %{{.*}}[%[[I]], %[[J]]][%[[C2]], %[[C3]]][%[[C1]], %[[C1]]] : memref to memref // -// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) {iterator_types = ["parallel", "parallel", "reduction"]} : memref, memref, memref func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref @@ -98,7 +102,8 @@ // TILE-2: %[[N:.*]] = dim %{{.*}}, 1 : memref // TILE-2: %[[sAi:.*]] = std.subview %{{.*}}[%[[I]], %[[C0]]][%[[C2]], %[[N]]][%[[C1]], %[[C1]]] : memref to memref // TILE-2: %[[sCi:.*]] = std.subview %{{.*}}[%[[I]]][%[[C2]]][%[[C1]]] : memref to memref -// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) {{{.*}}iterator_types = ["parallel", "reduction"]} : memref, memref, memref // TILE-02-LABEL: func @matvec( // TILE-02-DAG: %[[C0:.*]] = constant 0 : index @@ -109,7 +114,8 @@ // TILE-02: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[sAj:.*]] = std.subview %{{.*}}[%[[C0]], %[[J]]][%[[M]], %[[C2]]][%[[C1]], %[[C1]]] : memref to memref // TILE-02: %[[sBj:.*]] = std.subview %{{.*}}[%[[J]]][%[[C2]]][%[[C1]]] : memref to memref -// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) {{{.*}}iterator_types = ["parallel", "reduction"]} : memref, memref, memref // TILE-002-LABEL: func @matvec( // TILE-002-NOT: loop.for @@ -127,7 +133,8 @@ // TILE-234: %[[sBj:.*]] = std.subview %{{.*}}[%[[J]]][%[[C3]]][%[[C1]]] : memref to memref // TILE-234: %[[sCi:.*]] = std.subview %{{.*}}[%[[I]]][%[[C2]]][%[[C1]]] : memref to memref // -// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) {{{.*}}iterator_types = ["parallel", "reduction"]} : memref, memref, memref func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref @@ -141,7 +148,8 @@ // TILE-2: loop.for %[[I]] = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[sAi:.*]] = std.subview %{{.*}}[%[[I]]][%[[C2]]][%[[C1]]] : memref to memref // TILE-2: %[[sBi:.*]] = std.subview %{{.*}}[%[[I]]][%[[C2]]][%[[C1]]] : memref to memref -// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) {iterator_types = ["reduction"]} : memref, memref, memref // TILE-02-LABEL: func @dot( // TILE-02-NOT: loop.for @@ -157,7 +165,8 @@ // TILE-234: loop.for %[[I:.*]] = %{{.*}} to %[[ubK]] step %{{.*}} { // TILE-234: %[[sAi:.*]] = std.subview %{{.*}}[%[[I]]][%[[C2]]][%[[C1]]] : memref to memref // TILE-234: %[[sBi:.*]] = std.subview %{{.*}}[%[[I]]][%[[C2]]][%[[C1]]] : memref to memref -// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) {iterator_types = ["reduction"]} : memref, memref, memref func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32 diff --git a/mlir/test/Dialect/Linalg/tile_conv.mlir b/mlir/test/Dialect/Linalg/tile_conv.mlir --- a/mlir/test/Dialect/Linalg/tile_conv.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv.mlir @@ -37,4 +37,5 @@ // TILE-23004: %[[X1:.*]] = dim %{{.*}}, 3 : memref // TILE-23004: %[[OutputView:.*]] = std.subview %{{.*}}[%[[ivI]], %[[ivJ]], %[[C0]], %[[C0]]][%[[C2]], %[[C3]], %[[X0]], %[[X1]]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref to memref // -// TILE-23004: linalg.conv(%[[FilterView]], %[[InputView]], %[[OutputView]]) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref +// TODO(ntv, rriddle): Pretty-printing/parsing behavior of "named" ops wrt attribute dictionary +// TILE-23004: linalg.conv(%[[FilterView]], %[[InputView]], %[[OutputView]]) {dilations = [10, 20], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "window", "window"], strides = [30, 40]} : memref, memref, memref diff --git a/mlir/tools/mlir-tblgen/LinalgNamedOpsGen.cpp b/mlir/tools/mlir-tblgen/LinalgNamedOpsGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/LinalgNamedOpsGen.cpp @@ -0,0 +1,105 @@ +//===- LinalgNamedOpsGen.cpp - MLIR Linalg op generator -------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// LinalgNamedOpsGen uses the description of structured operations in the Linalg +// dialect to generate op definitions, parsers, pretty-printers and matchers +// (e.g. linalg.matmul). +// +//===----------------------------------------------------------------------===// + +#include "DocGenUtilities.h" +#include "mlir/Support/STLExtras.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/ODSDialectHook.h" +#include "mlir/TableGen/OpClass.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::tblgen; + +using mlir::tblgen::Operator; + +// Fills a function with signature "void iterator_types()". +void iteratorTypesBody(const Operator &op, OpMethodBody &body) { + const auto &def = op.getDef(); + auto iteratorTypes = def.getValueAsListOfStrings("iterators_types"); + body << " return SmallVector{"; + interleaveComma(iteratorTypes, body, + [&body](StringRef str) { body << "\"" << str << "\""; }); + body << "};\n"; +} + +// Fils a function with signature "ArrayAttr iterator_maps()". +void indexingMapsBody(const Operator &op, OpMethodBody &body) { + const auto &def = op.getDef(); + auto iterators = def.getValueAsListOfStrings("iterators"); + unsigned numDims = iterators.size(); + /* return AffineMap::get({...}, ctx); */ + auto ins = def.getValueAsListOfDefs("input_indexing_maps"); + auto outs = def.getValueAsListOfDefs("output_indexing_maps"); + SmallVector maps(ins.begin(), ins.end()); + maps.append(outs.begin(), outs.end()); + + /* AffineExpr ...; */ + body << " MLIRContext *ctx = getContext();\n"; + interleaveComma(iterators, body << " AffineExpr "); + body << ";\n"; + /* bindDims(ctx, ...); */ + interleaveComma(iterators, body << " bindDims(ctx, "); + body << ");\n"; + body << " return SmallVector{"; + interleaveComma(maps, body, [&body, numDims](const Record *def) { + body << "\n AffineMap::get(" << numDims << ", 0, {"; + interleaveComma(def->getValueAsListOfStrings("indexing_maps"), body, + [&body, numDims](const StringRef &str) { + body << "\n simplifyAffineExpr(" << str << ", " + << numDims << ", 0)"; + }); + body << "})"; + }); + body << "};\n"; +} + +void emitLinalgFunctions(const Operator &op, OpClass &emitClass) { + const auto &def = op.getDef(); + if (def.isSubClassOf("LinalgNamedStructured_Op")) { + { + auto &method = + emitClass.newMethod("llvm::Optional>", + "referenceIteratorTypes"); + auto &body = method.body(); + iteratorTypesBody(op, body); + } + { + auto &method = emitClass.newMethod( + "llvm::Optional>", "referenceIndexingMaps"); + auto &body = method.body(); + indexingMapsBody(op, body); + } + if (def.getValueAsBit("hasLibraryImpl")) { + auto &method = emitClass.newMethod("std::string", "getLibraryCallName"); + auto &body = method.body(); + body << "return generateLibraryCallName(getOperation());"; + } + } +} + +struct LinalgODSHookRegistration : public ODSDialectHookRegistration { + LinalgODSHookRegistration() + : ODSDialectHookRegistration("linalg", emitLinalgFunctions) {} +}; + +static LinalgODSHookRegistration linalgODSHookRegistration; diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -16,11 +16,13 @@ #include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/ODSDialectHook.h" #include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/OpInterfaces.h" #include "mlir/TableGen/OpTrait.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -32,6 +34,32 @@ using namespace mlir; using namespace mlir::tblgen; +using llvm::CodeInit; +using llvm::DefInit; +using llvm::formatv; +using llvm::Init; +using llvm::ListInit; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::StringInit; + +//===----------------------------------------------------------------------===// +// Dialect hook registration +//===----------------------------------------------------------------------===// + +static llvm::ManagedStatic> dialectHooks; + +ODSDialectHookRegistration::ODSDialectHookRegistration( + StringRef dialectName, DialectEmitFunction emitFn) { + bool inserted = dialectHooks->try_emplace(dialectName, emitFn).second; + assert(inserted && "Multiple ODS hooks for the same dialect!"); + (void)inserted; +} + +//===----------------------------------------------------------------------===// +// Static string definitions +//===----------------------------------------------------------------------===// + static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "odsArg"; static const char *const builderOpState = "odsState"; @@ -292,6 +320,7 @@ verifyCtx.withOp("(*this->getOperation())"); genTraits(); + // Generate C++ code for various op methods. The order here determines the // methods in the generated file. genOpAsmInterface(); @@ -308,6 +337,13 @@ genFolderDecls(); genOpInterfaceMethods(); generateOpFormat(op, opClass); + + // If a dialect hook is registered for this op's dialect, emit dialect + // specific content. + auto dialectHookIt = dialectHooks->find(op.getDialectName()); + if (dialectHookIt != dialectHooks->end()) { + dialectHookIt->second(op, opClass); + } } void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {