diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -12,15 +12,10 @@ #include "mlir/Transforms/DialectConversion.h" namespace mlir { -class MLIRContext; class ModuleOp; template class OperationPass; -/// Populate the given list with patterns that convert from Linalg to Standard. -void populateLinalgToStandardConversionPatterns( - OwningRewritePatternList &patterns, MLIRContext *ctx); - /// Create a pass to convert Linalg operations to the Standard dialect. std::unique_ptr> createConvertLinalgToStandardPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -502,8 +502,9 @@ getIteratorTypesAttrName(), getSymbolSourceAttrName() }; } - StringRef getLibraryCallName() { - return library_call().hasValue() ? library_call().getValue() : ""; + std::string getLibraryCallName() { + return library_call().hasValue() ? + library_call()->str() : "op_has_no_registered_library_name"; } llvm::Optional getSymbolSource() { auto ss = symbol_source(); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -594,6 +594,19 @@ llvm::all_of(this->getOperation()->getResults(), isTensorType); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the name registered for this op when lowering to an external + library call. + }], + /*retTy=*/"std::string", + /*methodName=*/"getLibraryCallName", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getLibraryCallName(); + }] + >, //===------------------------------------------------------------------===// // Other static interface methods. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -328,9 +328,7 @@ /// values must not fold away when tiling. Otherwise, use a more robust /// `tileSizeComputationFunction`. LinalgTilingOptions &setTileSizes(SmallVector ts) { - tileSizeComputationFunction = [=](OpBuilder &, Operation *) { - return ts; - }; + tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; return *this; } /// Convenience function to set the `tileSizeComputationFunction` to a @@ -730,6 +728,56 @@ PatternRewriter &rewriter) const override; }; +//===----------------------------------------------------------------------===// +// Patterns to convert a LinalgOp to std.call @external library implementation. +//===----------------------------------------------------------------------===// +// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` +// function. The implementation of the function can be either in the same module +// or in an externally linked library. +// This is a generic entry point for all LinalgOp, except for CopyOp and +// IndexedGenericOp, for which omre specialized patterns are provided. +class LinalgOpToLibraryCallRewrite : public RewritePattern { +public: + LinalgOpToLibraryCallRewrite() + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +/// Rewrite pattern specialization for CopyOp, kicks in when both input and +/// output permutations are left unspecified or are the identity. +class CopyOpToLibraryCallRewrite : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override; +}; + +/// Rewrite CopyOp with permutations into a sequence of TransposeOp and +/// permutation-free CopyOp. This interplays with TransposeOpConversion and +/// LinalgConversion to create a path to the LLVM dialect. +class CopyTransposeRewrite : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override; +}; + +/// Conversion pattern specialization for IndexedGenericOp, has special handling +/// for the extra index operands. +class IndexedGenericOpToLibraryCallRewrite + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const override; +}; + +/// Populate the given list with patterns that convert from Linalg to Standard. +void populateLinalgToStandardConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -11,6 +11,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -21,10 +22,15 @@ /// generated CallOp. MemRefTypes have their layout canonicalized since the /// information is not used in signature generation. /// Note that static size information is not modified. -template static SmallVector extractOperandTypes(Operation *op) { SmallVector result; result.reserve(op->getNumOperands()); + if (auto indexedGenericOp = dyn_cast(op)) { + auto *ctx = op->getContext(); + auto numLoops = indexedGenericOp.getNumLoops(); + result.reserve(op->getNumOperands() + numLoops); + result.assign(numLoops, IndexType::get(ctx)); + } for (auto type : op->getOperandTypes()) { // The underlying descriptor type (e.g. LLVM) does not have layout // information. Canonicalizing the type at the level of std when going into @@ -37,21 +43,8 @@ return result; } -template <> -SmallVector extractOperandTypes(Operation *op) { - auto *ctx = op->getContext(); - auto indexedGenericOp = cast(op); - auto numLoops = indexedGenericOp.getNumLoops(); - - SmallVector result(numLoops, IndexType::get(ctx)); - auto canonicalizedOperands = extractOperandTypes(op); - result.append(canonicalizedOperands.begin(), canonicalizedOperands.end()); - return result; -} - // Get a SymbolRefAttr containing the library function name for the LinalgOp. // If the library function does not exist, insert a declaration. -template static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, PatternRewriter &rewriter) { auto linalgOp = cast(op); @@ -68,7 +61,7 @@ return fnNameAttr; } - SmallVector inputTypes(extractOperandTypes(op)); + SmallVector inputTypes(extractOperandTypes(op)); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); @@ -87,9 +80,7 @@ return fnNameAttr; } -namespace { - -SmallVector +static SmallVector createTypeCanonicalizedMemRefOperands(OpBuilder &b, Location loc, ValueRange operands) { SmallVector res; @@ -107,154 +98,101 @@ return res; } -// LinalgOpConversion creates a new call to the type-canonicalized -// `LinalgOp::getLibraryCallName()` function. -// The implementation of the function can be either in the same module or in an -// externally linked library. -template -class LinalgOpConversion : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override { - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); - - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), - op.getOperands())); - return success(); - } -}; - -/// Conversion pattern specialization for CopyOp. This kicks in when both input -/// and output permutations are left unspecified or are the identity. -template <> -class LinalgOpConversion : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override { - auto inputPerm = op.inputPermutation(); - if (inputPerm.hasValue() && !inputPerm->isIdentity()) - return failure(); - auto outputPerm = op.outputPermutation(); - if (outputPerm.hasValue() && !outputPerm->isIdentity()) - return failure(); - - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); - - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), - op.getOperands())); - return success(); - } -}; - -/// Conversion pattern specialization for IndexedGenericOp. -template <> -class LinalgOpConversion - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IndexedGenericOp op, - PatternRewriter &rewriter) const override { - auto libraryCallName = - getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); - - // TODO: Use induction variables values instead of zeros, when - // IndexedGenericOp is tiled. - auto zero = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - auto indexedGenericOp = cast(op); - auto numLoops = indexedGenericOp.getNumLoops(); - SmallVector operands; - operands.reserve(numLoops + op.getNumOperands()); - for (unsigned i = 0; i < numLoops; ++i) - operands.push_back(zero); - for (auto operand : op.getOperands()) - operands.push_back(operand); - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); - return success(); - } -}; - -/// A non-conversion rewrite pattern kicks in to convert CopyOp with -/// permutations into a sequence of TransposeOp and permutation-free CopyOp. -/// This interplays together with TransposeOpConversion and -/// LinalgConversion to create a path to the LLVM dialect. -class CopyTransposeConversion : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override { - Value in = op.input(), out = op.output(); +LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + // Only LinalgOp for which there is no specialized pattern go through this. + if (!isa(op) || isa(op) || isa(op)) + return failure(); + + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); + + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), + op->getOperands())); + return success(); +} - // If either inputPerm or outputPerm are non-identities, insert transposes. - auto inputPerm = op.inputPermutation(); - if (inputPerm.hasValue() && !inputPerm->isIdentity()) - in = rewriter.create(op.getLoc(), in, - AffineMapAttr::get(*inputPerm)); - auto outputPerm = op.outputPermutation(); - if (outputPerm.hasValue() && !outputPerm->isIdentity()) - out = rewriter.create(op.getLoc(), out, - AffineMapAttr::get(*outputPerm)); +LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite( + CopyOp op, PatternRewriter &rewriter) const { + auto inputPerm = op.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + return failure(); + auto outputPerm = op.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + return failure(); + + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); + + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), + op.getOperands())); + return success(); +} - // If nothing was transposed, fail and let the conversion kick in. - if (in == op.input() && out == op.output()) - return failure(); +LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite( + CopyOp op, PatternRewriter &rewriter) const { + Value in = op.input(), out = op.output(); + + // If either inputPerm or outputPerm are non-identities, insert transposes. + auto inputPerm = op.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + in = rewriter.create(op.getLoc(), in, + AffineMapAttr::get(*inputPerm)); + auto outputPerm = op.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + out = rewriter.create(op.getLoc(), out, + AffineMapAttr::get(*outputPerm)); + + // If nothing was transposed, fail and let the conversion kick in. + if (in == op.input() && out == op.output()) + return failure(); + + rewriter.replaceOpWithNewOp(op, in, out); + return success(); +} - rewriter.replaceOpWithNewOp(op, in, out); - return success(); - } -}; -} // namespace +LogicalResult +mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite( + IndexedGenericOp op, PatternRewriter &rewriter) const { + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); + + // TODO: Use induction variables values instead of zeros, when + // IndexedGenericOp is tiled. + auto zero = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto indexedGenericOp = cast(op); + auto numLoops = indexedGenericOp.getNumLoops(); + SmallVector operands; + operands.reserve(numLoops + op.getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) + operands.push_back(zero); + for (auto operand : op.getOperands()) + operands.push_back(operand); + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); + return success(); +} /// Populate the given list with patterns that convert from Linalg to Standard. -void mlir::populateLinalgToStandardConversionPatterns( +void mlir::linalg::populateLinalgToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { // TODO: ConvOp conversion needs to export a descriptor with relevant // attribute values such as kernel striding and dilation. // clang-format off patterns.insert< - CopyTransposeConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion>(ctx); - // TODO: collect all auto-generated named ops with a tblgen directive. - patterns.insert< - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion>(ctx); + CopyOpToLibraryCallRewrite, + CopyTransposeRewrite, + IndexedGenericOpToLibraryCallRewrite>(ctx); + patterns.insert(); // clang-format on }