diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -750,95 +750,11 @@ //===----------------------------------------------------------------------===// // Patterns to convert a LinalgOp to std.call @external library implementation. //===----------------------------------------------------------------------===// -// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` -// function. The implementation of the function can be either in the same module -// or in an externally linked library. -// This is a generic entry point for all LinalgOp, except for CopyOp and -// IndexedGenericOp, for which omre specialized patterns are provided. -class LinalgOpToLibraryCallRewrite : public RewritePattern { -public: - LinalgOpToLibraryCallRewrite() - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; -}; - -/// Rewrite pattern specialization for CopyOp, kicks in when both input and -/// output permutations are left unspecified or are the identity. -class CopyOpToLibraryCallRewrite : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override; -}; - -/// Rewrite CopyOp with permutations into a sequence of TransposeOp and -/// permutation-free CopyOp. This interplays with TransposeOpConversion and -/// LinalgConversion to create a path to the LLVM dialect. -class CopyTransposeRewrite : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override; -}; - -/// Conversion pattern specialization for IndexedGenericOp, has special handling -/// for the extra index operands. -class IndexedGenericOpToLibraryCallRewrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(IndexedGenericOp op, - PatternRewriter &rewriter) const override; -}; /// Populate the given list with patterns that convert from Linalg to Standard. void populateLinalgToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx); -//===----------------------------------------------------------------------===// -// Buffer allocation patterns. -//===----------------------------------------------------------------------===// - -/// Generic BufferizeConversionPattern that matches any Operation* and -/// dispatches internally. This avoids template instantiating one pattern for -/// each LinalgOp op. -class LinalgOpConverter : public BufferizeConversionPattern { -public: - LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter) - : BufferizeConversionPattern(context, converter) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - -/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is -/// stored in memory. A linalg.reshape is introduced to convert to the desired -/// n-D buffer form. -class TensorConstantOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - -/// TensorCastOp converts 1-1 to MemRefCastOp. -class TensorCastOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern< - TensorCastOp>::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -98,89 +98,128 @@ return res; } -LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - // Only LinalgOp for which there is no specialized pattern go through this. - if (!isa(op) || isa(op) || isa(op)) - return failure(); +namespace { +// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` +// function. The implementation of the function can be either in the same module +// or in an externally linked library. +// This is a generic entry point for all LinalgOp, except for CopyOp and +// IndexedGenericOp, for which omre specialized patterns are provided. +class LinalgOpToLibraryCallRewrite : public RewritePattern { +public: + LinalgOpToLibraryCallRewrite() + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), - op->getOperands())); - return success(); -} + // Only LinalgOp for which there is no specialized pattern go through this. + if (!isa(op) || isa(op) || isa(op)) + return failure(); -LogicalResult mlir::linalg::CopyOpToLibraryCallRewrite::matchAndRewrite( - CopyOp op, PatternRewriter &rewriter) const { - auto inputPerm = op.inputPermutation(); - if (inputPerm.hasValue() && !inputPerm->isIdentity()) - return failure(); - auto outputPerm = op.outputPermutation(); - if (outputPerm.hasValue() && !outputPerm->isIdentity()) - return failure(); + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op->getLoc(), + op->getOperands())); + return success(); + } +}; +} // namespace - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), - op.getOperands())); - return success(); -} +namespace { +/// Rewrite pattern specialization for CopyOp, kicks in when both input and +/// output permutations are left unspecified or are the identity. +class CopyOpToLibraryCallRewrite : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const final { + auto inputPerm = op.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + return failure(); + auto outputPerm = op.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + return failure(); -LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite( - CopyOp op, PatternRewriter &rewriter) const { - Value in = op.input(), out = op.output(); + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); - // If either inputPerm or outputPerm are non-identities, insert transposes. - auto inputPerm = op.inputPermutation(); - if (inputPerm.hasValue() && !inputPerm->isIdentity()) - in = rewriter.create(op.getLoc(), in, - AffineMapAttr::get(*inputPerm)); - auto outputPerm = op.outputPermutation(); - if (outputPerm.hasValue() && !outputPerm->isIdentity()) - out = rewriter.create(op.getLoc(), out, - AffineMapAttr::get(*outputPerm)); + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), + op.getOperands())); + return success(); + } +}; +} // namespace - // If nothing was transposed, fail and let the conversion kick in. - if (in == op.input() && out == op.output()) - return failure(); +namespace { +/// Rewrite CopyOp with permutations into a sequence of TransposeOp and +/// permutation-free CopyOp. This interplays with TransposeOpConversion and +/// LinalgConversion to create a path to the LLVM dialect. +class CopyTransposeRewrite : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const final { + Value in = op.input(), out = op.output(); - rewriter.replaceOpWithNewOp(op, in, out); - return success(); -} + // If either inputPerm or outputPerm are non-identities, insert transposes. + auto inputPerm = op.inputPermutation(); + if (inputPerm.hasValue() && !inputPerm->isIdentity()) + in = rewriter.create(op.getLoc(), in, + AffineMapAttr::get(*inputPerm)); + auto outputPerm = op.outputPermutation(); + if (outputPerm.hasValue() && !outputPerm->isIdentity()) + out = rewriter.create(op.getLoc(), out, + AffineMapAttr::get(*outputPerm)); -LogicalResult -mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite( - IndexedGenericOp op, PatternRewriter &rewriter) const { - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); + // If nothing was transposed, fail and let the conversion kick in. + if (in == op.input() && out == op.output()) + return failure(); - // TODO: Use induction variables values instead of zeros, when - // IndexedGenericOp is tiled. - auto zero = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - auto indexedGenericOp = cast(op); - auto numLoops = indexedGenericOp.getNumLoops(); - SmallVector operands; - operands.reserve(numLoops + op.getNumOperands()); - for (unsigned i = 0; i < numLoops; ++i) - operands.push_back(zero); - for (auto operand : op.getOperands()) - operands.push_back(operand); - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); - return success(); -} + rewriter.replaceOpWithNewOp(op, in, out); + return success(); + } +}; +} // namespace + +namespace { +/// Conversion pattern specialization for IndexedGenericOp, has special handling +/// for the extra index operands. +class IndexedGenericOpToLibraryCallRewrite + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const final { + auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); + if (!libraryCallName) + return failure(); + + // TODO: Use induction variables values instead of zeros, when + // IndexedGenericOp is tiled. + auto zero = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto indexedGenericOp = cast(op); + auto numLoops = indexedGenericOp.getNumLoops(); + SmallVector operands; + operands.reserve(numLoops + op.getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) + operands.push_back(zero); + for (auto operand : op.getOperands()) + operands.push_back(operand); + rewriter.replaceOpWithNewOp( + op, libraryCallName.getValue(), TypeRange(), + createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); + return success(); + } +}; +} // namespace /// Populate the given list with patterns that convert from Linalg to Standard. void mlir::linalg::populateLinalgToStandardConversionPatterns( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -185,105 +186,124 @@ rewriter.replaceOp(linalgOp, outputs); } -LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) - return failure(); +//===----------------------------------------------------------------------===// +// Buffer allocation patterns. +//===----------------------------------------------------------------------===// - // We abuse the GenericOpAdaptor here. - // TODO: Manually create an Adaptor that captures inputs, output_buffers and - // init_tensors for all linalg::LinalgOp interface ops. - linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); +namespace { +/// Generic BufferizeConversionPattern that matches any Operation* and +/// dispatches internally. This avoids template instantiating one pattern for +/// each LinalgOp op. +class LinalgOpConverter : public BufferizeConversionPattern { +public: + LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter) + : BufferizeConversionPattern(context, converter) {} - // All inputs need to be turned into buffers first. Until then, bail out. - if (llvm::any_of(adaptor.inputs(), - [](Value in) { return !in.getType().isa(); })) - return failure(); + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { - // All init_tensors need to be turned into buffers first. Until then, bail - // out. - if (llvm::any_of(adaptor.init_tensors(), - [](Value in) { return !in.getType().isa(); })) - return failure(); + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); - Location loc = linalgOp.getLoc(); - SmallVector newOutputBuffers(adaptor.output_buffers().begin(), - adaptor.output_buffers().end()); + // We abuse the GenericOpAdaptor here. + // TODO: Manually create an Adaptor that captures inputs, output_buffers and + // init_tensors for all linalg::LinalgOp interface ops. + linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); - if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers, - rewriter))) { - linalgOp.emitOpError() << "Failed to allocate buffers for tensor results."; - return failure(); - } + // All inputs need to be turned into buffers first. Until then, bail out. + if (llvm::any_of(adaptor.inputs(), + [](Value in) { return !in.getType().isa(); })) + return failure(); - // Delegate to the linalg generic pattern. - if (auto genericOp = dyn_cast(op)) { - finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), + // All init_tensors need to be turned into buffers first. Until then, bail + // out. + if (llvm::any_of(adaptor.init_tensors(), + [](Value in) { return !in.getType().isa(); })) + return failure(); + + Location loc = linalgOp.getLoc(); + SmallVector newOutputBuffers(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); + + if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, + newOutputBuffers, rewriter))) { + linalgOp.emitOpError() + << "Failed to allocate buffers for tensor results."; + return failure(); + } + + // Delegate to the linalg generic pattern. + if (auto genericOp = dyn_cast(op)) { + finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), + newOutputBuffers); + return success(); + } + + finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), newOutputBuffers); return success(); } +}; +} // namespace - finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), - newOutputBuffers); - return success(); -} +namespace { +/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is +/// stored in memory. A linalg.reshape is introduced to convert to the desired +/// n-D buffer form. +class TensorConstantOpConverter + : public BufferizeOpConversionPattern { +public: + using BufferizeOpConversionPattern::BufferizeOpConversionPattern; -LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite( - ConstantOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - RankedTensorType rankedTensorType = op.getType().dyn_cast(); - if (!rankedTensorType) - return failure(); - if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { - return s == 0 || ShapedType::isDynamic(s); - })) - return failure(); + LogicalResult + matchAndRewrite(ConstantOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { - int64_t nElements = 1; - for (int64_t s : rankedTensorType.getShape()) - nElements *= s; - Type elementType = rankedTensorType.getElementType(); - MemRefType memrefType = - converter.convertType(op.getType()).cast(); - VectorType flatVectorType = VectorType::get({nElements}, elementType); - MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); - MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); + RankedTensorType rankedTensorType = + op.getType().dyn_cast(); + if (!rankedTensorType) + return failure(); + if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { + return s == 0 || ShapedType::isDynamic(s); + })) + return failure(); - Location loc = op.getLoc(); - auto attr = op.getValue().cast(); - Value alloc = - rewriter.create(loc, memrefOfFlatVectorType, ValueRange{}); - Value cstVec = rewriter.create(loc, flatVectorType, - attr.reshape(flatVectorType)); - rewriter.create(loc, cstVec, alloc); + int64_t nElements = 1; + for (int64_t s : rankedTensorType.getShape()) + nElements *= s; + Type elementType = rankedTensorType.getElementType(); + MemRefType memrefType = + converter.convertType(op.getType()).cast(); + VectorType flatVectorType = VectorType::get({nElements}, elementType); + MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); + MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); - Value memref = - rewriter.create(loc, flatMemrefType, alloc); - if (rankedTensorType.getRank() > 1) { - // Introduce a linalg.reshape to flatten the memref. - AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( - /*numDims=*/rankedTensorType.getRank(), op.getContext()); - memref = rewriter.create( - loc, memrefType, memref, - rewriter.getAffineMapArrayAttr(collapseAllDims)); + Location loc = op.getLoc(); + auto attr = op.getValue().cast(); + Value alloc = + rewriter.create(loc, memrefOfFlatVectorType, ValueRange{}); + Value cstVec = rewriter.create(loc, flatVectorType, + attr.reshape(flatVectorType)); + rewriter.create(loc, cstVec, alloc); + + Value memref = + rewriter.create(loc, flatMemrefType, alloc); + if (rankedTensorType.getRank() > 1) { + // Introduce a linalg.reshape to flatten the memref. + AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( + /*numDims=*/rankedTensorType.getRank(), op.getContext()); + memref = rewriter.create( + loc, memrefType, memref, + rewriter.getAffineMapArrayAttr(collapseAllDims)); + } + rewriter.replaceOp(op, memref); + + return success(); } - rewriter.replaceOp(op, memref); - - return success(); -} - -LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite( - TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - if (op.getType().hasRank()) - return failure(); - Type t = UnrankedMemRefType::get(op.getType().getElementType(), - /*memorySpace=*/0); - rewriter.replaceOpWithNewOp(op, t, operands.front()); - return success(); -} +}; +} // namespace namespace { @@ -347,6 +367,7 @@ OwningRewritePatternList patterns; populateLinalgBufferizePatterns(&context, converter, patterns); + populateStdBufferizePatterns(&context, converter, patterns); populateWithBufferizeOpConversionPatterns( &context, converter, patterns); @@ -366,7 +387,6 @@ patterns.insert< // clang-format off LinalgOpConverter, - TensorCastOpConverter, TensorConstantOpConverter // clang-format on >(context, converter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -31,6 +31,7 @@ MLIRSCFTransforms MLIRPass MLIRStandard + MLIRStandardOpsTransforms MLIRStandardToLLVM MLIRTransforms MLIRTransformUtils