diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -9,6 +9,7 @@ #ifndef MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_ #define MLIR_CONVERSION_LINALGTOSTANDARD_LINALGTOSTANDARD_H_ +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -16,6 +17,63 @@ template class OperationPass; +namespace linalg { + +//===----------------------------------------------------------------------===// +// Patterns to convert a LinalgOp to std.call @external library implementation. +//===----------------------------------------------------------------------===// +// These patterns are exposed individually because they are expected to be +// typically used individually. + +// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` +// function. The implementation of the function can be either in the same module +// or in an externally linked library. +// This is a generic entry point for all LinalgOp, except for CopyOp and +// IndexedGenericOp, for which omre specialized patterns are provided. +class LinalgOpToLibraryCallRewrite : public RewritePattern { +public: + LinalgOpToLibraryCallRewrite() + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +/// Rewrite pattern specialization for CopyOp, kicks in when both input and +/// output permutations are left unspecified or are the identity. +class CopyOpToLibraryCallRewrite : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override; +}; + +/// Rewrite CopyOp with permutations into a sequence of TransposeOp and +/// permutation-free CopyOp. This interplays with TransposeOpConversion and +/// LinalgConversion to create a path to the LLVM dialect. +class CopyTransposeRewrite : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override; +}; + +/// Conversion pattern specialization for IndexedGenericOp, has special handling +/// for the extra index operands. +class IndexedGenericOpToLibraryCallRewrite + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const override; +}; + +/// Populate the given list with patterns that convert from Linalg to Standard. +void populateLinalgToStandardConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +} // namespace linalg + /// Create a pass to convert Linalg operations to the Standard dialect. std::unique_ptr> createConvertLinalgToStandardPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -754,98 +754,6 @@ PatternRewriter &rewriter) const override; }; -//===----------------------------------------------------------------------===// -// Patterns to convert a LinalgOp to std.call @external library implementation. -//===----------------------------------------------------------------------===// -// Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` -// function. The implementation of the function can be either in the same module -// or in an externally linked library. -// This is a generic entry point for all LinalgOp, except for CopyOp and -// IndexedGenericOp, for which omre specialized patterns are provided. -class LinalgOpToLibraryCallRewrite : public RewritePattern { -public: - LinalgOpToLibraryCallRewrite() - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; -}; - -/// Rewrite pattern specialization for CopyOp, kicks in when both input and -/// output permutations are left unspecified or are the identity. -class CopyOpToLibraryCallRewrite : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override; -}; - -/// Rewrite CopyOp with permutations into a sequence of TransposeOp and -/// permutation-free CopyOp. This interplays with TransposeOpConversion and -/// LinalgConversion to create a path to the LLVM dialect. -class CopyTransposeRewrite : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override; -}; - -/// Conversion pattern specialization for IndexedGenericOp, has special handling -/// for the extra index operands. -class IndexedGenericOpToLibraryCallRewrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(IndexedGenericOp op, - PatternRewriter &rewriter) const override; -}; - -/// Populate the given list with patterns that convert from Linalg to Standard. -void populateLinalgToStandardConversionPatterns( - OwningRewritePatternList &patterns, MLIRContext *ctx); - -//===----------------------------------------------------------------------===// -// Buffer allocation patterns. -//===----------------------------------------------------------------------===// - -/// Generic BufferizeConversionPattern that matches any Operation* and -/// dispatches internally. This avoids template instantiating one pattern for -/// each LinalgOp op. -class LinalgOpConverter : public BufferizeConversionPattern { -public: - LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter) - : BufferizeConversionPattern(context, converter) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - -/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is -/// stored in memory. A linalg.reshape is introduced to convert to the desired -/// n-D buffer form. -class TensorConstantOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - -/// TensorCastOp converts 1-1 to MemRefCastOp. -class TensorCastOpConverter - : public BufferizeOpConversionPattern { -public: - using BufferizeOpConversionPattern< - TensorCastOp>::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final; -}; - //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -185,105 +186,124 @@ rewriter.replaceOp(linalgOp, outputs); } -LogicalResult mlir::linalg::LinalgOpConverter::matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) - return failure(); - - // We abuse the GenericOpAdaptor here. - // TODO: Manually create an Adaptor that captures inputs, output_buffers and - // init_tensors for all linalg::LinalgOp interface ops. - linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); - - // All inputs need to be turned into buffers first. Until then, bail out. - if (llvm::any_of(adaptor.inputs(), - [](Value in) { return !in.getType().isa(); })) - return failure(); - - // All init_tensors need to be turned into buffers first. Until then, bail - // out. - if (llvm::any_of(adaptor.init_tensors(), - [](Value in) { return !in.getType().isa(); })) - return failure(); - - Location loc = linalgOp.getLoc(); - SmallVector newOutputBuffers(adaptor.output_buffers().begin(), - adaptor.output_buffers().end()); - - if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, newOutputBuffers, - rewriter))) { - linalgOp.emitOpError() << "Failed to allocate buffers for tensor results."; - return failure(); - } +//===----------------------------------------------------------------------===// +// Buffer allocation patterns. +//===----------------------------------------------------------------------===// + +namespace { +/// Generic BufferizeConversionPattern that matches any Operation* and +/// dispatches internally. This avoids template instantiating one pattern for +/// each LinalgOp op. +class LinalgOpConverter : public BufferizeConversionPattern { +public: + LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter) + : BufferizeConversionPattern(context, converter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + + // We abuse the GenericOpAdaptor here. + // TODO: Manually create an Adaptor that captures inputs, output_buffers and + // init_tensors for all linalg::LinalgOp interface ops. + linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); + + // All inputs need to be turned into buffers first. Until then, bail out. + if (llvm::any_of(adaptor.inputs(), + [](Value in) { return !in.getType().isa(); })) + return failure(); + + // All init_tensors need to be turned into buffers first. Until then, bail + // out. + if (llvm::any_of(adaptor.init_tensors(), + [](Value in) { return !in.getType().isa(); })) + return failure(); + + Location loc = linalgOp.getLoc(); + SmallVector newOutputBuffers(adaptor.output_buffers().begin(), + adaptor.output_buffers().end()); + + if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, + newOutputBuffers, rewriter))) { + linalgOp.emitOpError() + << "Failed to allocate buffers for tensor results."; + return failure(); + } + + // Delegate to the linalg generic pattern. + if (auto genericOp = dyn_cast(op)) { + finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), + newOutputBuffers); + return success(); + } - // Delegate to the linalg generic pattern. - if (auto genericOp = dyn_cast(op)) { - finalizeBufferAllocation(rewriter, genericOp, adaptor.inputs(), + finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), newOutputBuffers); return success(); } +}; +} // namespace - finalizeBufferAllocation(rewriter, linalgOp, adaptor.inputs(), - newOutputBuffers); - return success(); -} - -LogicalResult mlir::linalg::TensorConstantOpConverter::matchAndRewrite( - ConstantOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - RankedTensorType rankedTensorType = op.getType().dyn_cast(); - if (!rankedTensorType) - return failure(); - if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { - return s == 0 || ShapedType::isDynamic(s); - })) - return failure(); - - int64_t nElements = 1; - for (int64_t s : rankedTensorType.getShape()) - nElements *= s; - Type elementType = rankedTensorType.getElementType(); - MemRefType memrefType = - converter.convertType(op.getType()).cast(); - VectorType flatVectorType = VectorType::get({nElements}, elementType); - MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); - MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); - - Location loc = op.getLoc(); - auto attr = op.getValue().cast(); - Value alloc = - rewriter.create(loc, memrefOfFlatVectorType, ValueRange{}); - Value cstVec = rewriter.create(loc, flatVectorType, - attr.reshape(flatVectorType)); - rewriter.create(loc, cstVec, alloc); - - Value memref = - rewriter.create(loc, flatMemrefType, alloc); - if (rankedTensorType.getRank() > 1) { - // Introduce a linalg.reshape to flatten the memref. - AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( - /*numDims=*/rankedTensorType.getRank(), op.getContext()); - memref = rewriter.create( - loc, memrefType, memref, - rewriter.getAffineMapArrayAttr(collapseAllDims)); - } - rewriter.replaceOp(op, memref); +namespace { +/// TensorConstantOp conversion inserts a linearized 1-D vector constant that is +/// stored in memory. A linalg.reshape is introduced to convert to the desired +/// n-D buffer form. +class TensorConstantOpConverter + : public BufferizeOpConversionPattern { +public: + using BufferizeOpConversionPattern::BufferizeOpConversionPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + RankedTensorType rankedTensorType = + op.getType().dyn_cast(); + if (!rankedTensorType) + return failure(); + if (llvm::any_of(rankedTensorType.getShape(), [](int64_t s) { + return s == 0 || ShapedType::isDynamic(s); + })) + return failure(); - return success(); -} + int64_t nElements = 1; + for (int64_t s : rankedTensorType.getShape()) + nElements *= s; + Type elementType = rankedTensorType.getElementType(); + MemRefType memrefType = + converter.convertType(op.getType()).cast(); + VectorType flatVectorType = VectorType::get({nElements}, elementType); + MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType); + MemRefType flatMemrefType = MemRefType::get({nElements}, elementType); + + Location loc = op.getLoc(); + auto attr = op.getValue().cast(); + Value alloc = + rewriter.create(loc, memrefOfFlatVectorType, ValueRange{}); + Value cstVec = rewriter.create(loc, flatVectorType, + attr.reshape(flatVectorType)); + rewriter.create(loc, cstVec, alloc); + + Value memref = + rewriter.create(loc, flatMemrefType, alloc); + if (rankedTensorType.getRank() > 1) { + // Introduce a linalg.reshape to flatten the memref. + AffineMap collapseAllDims = AffineMap::getMultiDimIdentityMap( + /*numDims=*/rankedTensorType.getRank(), op.getContext()); + memref = rewriter.create( + loc, memrefType, memref, + rewriter.getAffineMapArrayAttr(collapseAllDims)); + } + rewriter.replaceOp(op, memref); -LogicalResult mlir::linalg::TensorCastOpConverter::matchAndRewrite( - TensorCastOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - if (op.getType().hasRank()) - return failure(); - Type t = UnrankedMemRefType::get(op.getType().getElementType(), - /*memorySpace=*/0); - rewriter.replaceOpWithNewOp(op, t, operands.front()); - return success(); -} + return success(); + } +}; +} // namespace namespace { @@ -347,6 +367,7 @@ OwningRewritePatternList patterns; populateLinalgBufferizePatterns(&context, converter, patterns); + populateStdBufferizePatterns(&context, converter, patterns); populateWithBufferizeOpConversionPatterns( &context, converter, patterns); @@ -366,7 +387,6 @@ patterns.insert< // clang-format off LinalgOpConverter, - TensorCastOpConverter, TensorConstantOpConverter // clang-format on >(context, converter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -31,6 +31,7 @@ MLIRSCFTransforms MLIRPass MLIRStandard + MLIRStandardOpsTransforms MLIRStandardToLLVM MLIRTransforms MLIRTransformUtils