diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -302,7 +302,8 @@ // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -301,7 +301,8 @@ // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -200,7 +200,7 @@ // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -302,7 +302,8 @@ // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -200,7 +200,7 @@ // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. auto module = getOperation(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -17,8 +17,8 @@ #include "llvm/ADT/SmallBitVector.h" namespace mlir { - class BufferizeTypeConverter; +class FrozenRewritePatternList; namespace linalg { @@ -844,8 +844,8 @@ //===----------------------------------------------------------------------===// /// Helper function to allow applying rewrite patterns, interleaved with more /// global transformations, in a staged fashion: -/// 1. the first stage consists of a list of OwningRewritePatternList. Each -/// OwningRewritePatternList in this list is applied once, in order. +/// 1. the first stage consists of a list of FrozenRewritePatternList. Each +/// FrozenRewritePatternList in this list is applied once, in order. /// 2. the second stage consists of a single OwningRewritePattern that is /// applied greedily until convergence. /// 3. the third stage consists of applying a lambda, generally used for @@ -853,8 +853,8 @@ /// transformations where patterns can be ordered and applied at a finer /// granularity than a sequence of traditional compiler passes. LogicalResult applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const OwningRewritePatternList &stage2Patterns, + Operation *op, ArrayRef stage1Patterns, + const FrozenRewritePatternList &stage2Patterns, function_ref stage3Lambda = nullptr); } // namespace linalg diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -440,6 +440,11 @@ PatternListT::size_type size() const { return patterns.size(); } void clear() { patterns.clear(); } + /// Take ownership of the patterns held by this list. + std::vector> takePatterns() { + return std::move(patterns); + } + //===--------------------------------------------------------------------===// // Pattern Insertion //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h @@ -0,0 +1,38 @@ +//===- FrozenRewritePatternList.h - FrozenRewritePatternList ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H +#define MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +/// This class represents a frozen set of patterns that can be processed by a +/// pattern applicator. This class is designed to enable caching pattern lists +/// such that they need not be continuously recomputed. +class FrozenRewritePatternList { + using PatternListT = std::vector>; + +public: + /// Freeze the patterns held in `patterns`, and take ownership. + FrozenRewritePatternList(OwningRewritePatternList &&patterns); + + /// Return the patterns held by this list. + iterator_range> + getPatterns() const { + return llvm::make_pointee_range(patterns); + } + +private: + /// The patterns held by this list. + std::vector> patterns; +}; + +} // end namespace mlir + +#endif // MLIR_REWRITE_FROZENREWRITEPATTERNLIST_H diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -14,7 +14,7 @@ #ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H #define MLIR_REWRITE_PATTERNAPPLICATOR_H -#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternList.h" namespace mlir { class PatternRewriter; @@ -29,8 +29,8 @@ /// `impossibleToMatch`. using CostModel = function_ref; - explicit PatternApplicator(const OwningRewritePatternList &owningPatternList) - : owningPatternList(owningPatternList) {} + explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList) + : frozenPatternList(frozenPatternList) {} /// Attempt to match and rewrite the given op with any pattern, allowing a /// predicate to decide if a pattern can be applied or not, and hooks for if @@ -71,13 +71,12 @@ function_ref onSuccess); /// The list that owns the patterns used within this applicator. - const OwningRewritePatternList &owningPatternList; - + const FrozenRewritePatternList &frozenPatternList; /// The set of patterns to match for each operation, stable sorted by benefit. - DenseMap> patterns; + DenseMap> patterns; /// The set of patterns that may match against any operation type, stable /// sorted by benefit. - SmallVector anyOpPatterns; + SmallVector anyOpPatterns; }; } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -13,9 +13,7 @@ #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_ #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Rewrite/FrozenRewritePatternList.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringMap.h" @@ -805,11 +803,11 @@ /// the `unconvertedOps` set will not necessarily be complete.) LLVM_NODISCARD LogicalResult applyPartialConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps = nullptr); LLVM_NODISCARD LogicalResult applyPartialConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps = nullptr); /// Apply a complete conversion on the given operations, and all nested @@ -818,10 +816,10 @@ /// within 'ops'. LLVM_NODISCARD LogicalResult applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); LLVM_NODISCARD LogicalResult applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully @@ -833,11 +831,11 @@ /// the regions nested within 'ops'. LLVM_NODISCARD LogicalResult applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps); LLVM_NODISCARD LogicalResult applyAnalysisConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps); } // end namespace mlir diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -14,7 +14,7 @@ #ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ #define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ -#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternList.h" namespace mlir { @@ -32,11 +32,11 @@ /// LogicalResult applyPatternsAndFoldGreedily(Operation *op, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Rewrite the given regions, which must be isolated from above. LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Applies the specified patterns on `op` alone while also trying to fold it, /// by selecting the highest benefits patterns in a greedy manner. Returns @@ -44,7 +44,7 @@ /// was folded away or erased as a result of becoming dead. Note: This does not /// apply any patterns recursively to the regions of `op`. LogicalResult applyOpPatternsAndFold(Operation *op, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, bool *erased = nullptr); } // end namespace mlir diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -179,7 +179,8 @@ target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) { + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { signalPassFailure(); } } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -679,7 +679,8 @@ ConversionTarget target(getContext()); target .addLegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -723,7 +723,7 @@ target.addDynamicallyLegalOp( [&](CallOp op) { return converter.isLegal(op.getResultTypes()); }); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } } // namespace diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -240,7 +240,8 @@ populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -124,17 +124,16 @@ return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); }); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns, llvmPatterns; // Apply in-dialect lowering first. In-dialect lowering will replace ops // which need to be lowered further, which is not supported by a single // conversion pass. populateGpuRewritePatterns(m.getContext(), patterns); - applyPatternsAndFoldGreedily(m, patterns); - patterns.clear(); + applyPatternsAndFoldGreedily(m, std::move(patterns)); - populateStdToLLVMConversionPatterns(converter, patterns); - populateGpuToNVVMConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, llvmPatterns); + populateGpuToNVVMConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); target.addIllegalDialect(); target.addIllegalOp(); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns))) + if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -59,16 +59,15 @@ /*useAlignedAlloc =*/false}; LLVMTypeConverter converter(m.getContext(), options); - OwningRewritePatternList patterns; + OwningRewritePatternList patterns, llvmPatterns; populateGpuRewritePatterns(m.getContext(), patterns); - applyPatternsAndFoldGreedily(m, patterns); - patterns.clear(); + applyPatternsAndFoldGreedily(m, std::move(patterns)); - populateVectorToLLVMConversionPatterns(converter, patterns); - populateVectorToROCDLConversionPatterns(converter, patterns); - populateStdToLLVMConversionPatterns(converter, patterns); - populateGpuToROCDLConversionPatterns(converter, patterns); + populateVectorToLLVMConversionPatterns(converter, llvmPatterns); + populateVectorToROCDLConversionPatterns(converter, llvmPatterns); + populateStdToLLVMConversionPatterns(converter, llvmPatterns); + populateGpuToROCDLConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); target.addIllegalDialect(); target.addIllegalOp(); // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); - if (failed(applyPartialConversion(m, target, patterns))) + if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -64,7 +64,7 @@ populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); - if (failed(applyFullConversion(kernelModules, *target, patterns))) + if (failed(applyFullConversion(kernelModules, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -335,7 +335,7 @@ LLVMConversionTarget target(getContext()); target.addLegalOp(); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -41,7 +41,7 @@ typeConverter.isLegal(&op.getBody()); }); - if (failed(applyFullConversion(module, *target, patterns))) + if (failed(applyFullConversion(module, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -211,7 +211,7 @@ target.addLegalOp(); OwningRewritePatternList patterns; populateLinalgToStandardConversionPatterns(patterns, &getContext()); - if (failed(applyFullConversion(module, target, patterns))) + if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -67,7 +67,7 @@ [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); }); target.addLegalOp(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -54,7 +54,8 @@ target.addLegalDialect(); target.addLegalDialect(); target.addIllegalOp(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -412,7 +412,8 @@ ConversionTarget target(getContext()); target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -290,7 +290,7 @@ ConversionTarget target(*context); target.addLegalDialect(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); // Finally, modify the kernel function in SPIR-V modules to avoid symbolic diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -52,7 +52,7 @@ // conversion. target.addLegalOp(); target.addLegalOp(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -508,7 +508,7 @@ // Apply conversion. auto module = getOperation(); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3788,7 +3788,7 @@ populateStdToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(m, target, patterns))) + if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), StringAttr::get(this->dataLayout, m.getContext())); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -40,9 +40,8 @@ populateStandardToSPIRVPatterns(context, typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); - if (failed(applyPartialConversion(module, *target, patterns))) { + if (failed(applyPartialConversion(module, *target, std::move(patterns)))) return signalPassFailure(); - } } std::unique_ptr> diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -203,7 +203,8 @@ OwningRewritePatternList patterns; auto *context = &getContext(); populateStdLegalizationPatternsForSPIRVLowering(context, patterns); - applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); + applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)); } std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1619,7 +1619,7 @@ populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } // Convert to the LLVM IR dialect. @@ -1632,7 +1632,8 @@ populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -173,7 +173,8 @@ LLVMConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -710,7 +710,7 @@ auto *context = getFunction().getContext(); populateVectorToSCFConversionPatterns( patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll)); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -109,7 +109,7 @@ target->addLegalOp(); target->addLegalOp(); - if (failed(applyFullConversion(module, *target, patterns))) + if (failed(applyFullConversion(module, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -228,6 +228,7 @@ OwningRewritePatternList patterns; AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); + FrozenRewritePatternList frozenPatterns(std::move(patterns)); for (Operation *op : copyOps) - applyOpPatternsAndFold(op, std::move(patterns)); + applyOpPatternsAndFold(op, frozenPatterns); } diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -83,6 +83,7 @@ AffineForOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext()); AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext()); + FrozenRewritePatternList frozenPatterns(std::move(patterns)); func.walk([&](Operation *op) { for (auto attr : op->getAttrs()) { if (auto mapAttr = attr.second.dyn_cast()) @@ -94,6 +95,6 @@ // The simplification of the attribute will likely simplify the op. Try to // fold / apply canonicalization patterns when we have affine dialect ops. if (isa(op)) - applyOpPatternsAndFold(op, patterns); + applyOpPatternsAndFold(op, frozenPatterns); }); } diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -159,7 +159,8 @@ OwningRewritePatternList patterns; AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext()); bool erased; - applyOpPatternsAndFold(ifOp, patterns, &erased); + FrozenRewritePatternList frozenPatterns(std::move(patterns)); + applyOpPatternsAndFold(ifOp, frozenPatterns, &erased); if (erased) { if (folded) *folded = true; @@ -189,7 +190,7 @@ // a sequence of affine.fors that are all perfectly nested). applyPatternsAndFoldGreedily( hoistedIfOp.getParentWithTrait(), - std::move(patterns)); + frozenPatterns); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -350,7 +350,8 @@ populateWithBufferizeOpConversionPatterns( &context, converter, patterns); - if (failed(applyFullConversion(this->getOperation(), target, patterns))) + if (failed(applyFullConversion(this->getOperation(), target, + std::move(patterns)))) this->signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -31,7 +31,7 @@ // Emplace patterns one at a time while also maintaining a simple chained // state transition. unsigned stepCount = 0; - SmallVector stage1Patterns; + SmallVector stage1Patterns; auto zeroState = Identifier::get(std::to_string(stepCount), context); auto currentState = zeroState; for (const std::unique_ptr &t : transformationSequence) { @@ -60,7 +60,7 @@ hoistRedundantVectorTransfers(cast(op)); return success(); }; - linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns, + linalg::applyStagedPatterns(func, stage1Patterns, std::move(stage2Patterns), stage3Transforms); //===--------------------------------------------------------------------===// @@ -73,7 +73,7 @@ OwningRewritePatternList patterns; patterns.insert( context, vectorTransformsOptions); - applyPatternsAndFoldGreedily(module, patterns); + applyPatternsAndFoldGreedily(module, std::move(patterns)); // Programmatic controlled lowering of vector.contract only. OwningRewritePatternList vectorContractLoweringPatterns; @@ -81,13 +81,14 @@ .insert( vectorTransformsOptions, context); - applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns); + applyPatternsAndFoldGreedily(module, + std::move(vectorContractLoweringPatterns)); // Programmatic controlled lowering of vector.transfer only. OwningRewritePatternList vectorToLoopsPatterns; populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, vectorToSCFOptions); - applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns); + applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns)); // Ensure we drop the marker in the end. module.walk([](LinalgOp op) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -518,7 +518,7 @@ FoldUnitDimLoops>(context); else populateLinalgFoldUnitExtentDimsPatterns(context, patterns); - applyPatternsAndFoldGreedily(funcOp.getBody(), patterns); + applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -913,7 +913,7 @@ OwningRewritePatternList patterns; Operation *op = getOperation(); populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); - applyPatternsAndFoldGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; @@ -926,7 +926,7 @@ OwningRewritePatternList patterns; Operation *op = getOperation(); populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns); - applyPatternsAndFoldGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -593,7 +593,7 @@ AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); // Just apply the patterns greedily. - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -595,7 +595,7 @@ MLIRContext *ctx = funcOp.getContext(); OwningRewritePatternList patterns; insertTilingPatterns(patterns, options, ctx); - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); applyPatternsAndFoldGreedily(funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); // Drop the marker. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -257,8 +257,8 @@ } LogicalResult mlir::linalg::applyStagedPatterns( - Operation *op, ArrayRef stage1Patterns, - const OwningRewritePatternList &stage2Patterns, + Operation *op, ArrayRef stage1Patterns, + const FrozenRewritePatternList &stage2Patterns, function_ref stage3Lambda) { unsigned iteration = 0; (void)iteration; diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -96,7 +96,7 @@ auto func = getFunction(); auto *context = &getContext(); patterns.insert(context); - applyPatternsAndFoldGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, std::move(patterns)); } std::unique_ptr> mlir::quant::createConvertConstPass() { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -129,7 +129,7 @@ auto ctx = func.getContext(); patterns.insert( ctx, &hadFailure); - applyPatternsAndFoldGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -30,7 +30,7 @@ populateBufferizeMaterializationLegality(target); populateSCFStructuralTypeConversionsAndLegality(context, typeConverter, patterns, target); - if (failed(applyPartialConversion(func, target, patterns))) + if (failed(applyPartialConversion(func, target, std::move(patterns)))) return signalPassFailure(); }; }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -107,12 +107,10 @@ // TODO: Change the type for the indirect users such as spv.Load, spv.Store, // spv.FunctionCall and so on. - - for (auto spirvModule : module.getOps()) { - if (failed(applyFullConversion(spirvModule, target, patterns))) { + FrozenRewritePatternList frozenPatterns(std::move(patterns)); + for (auto spirvModule : module.getOps()) + if (failed(applyFullConversion(spirvModule, target, frozenPatterns))) signalPassFailure(); - } - } } std::unique_ptr> diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -262,7 +262,7 @@ return op->getDialect()->getNamespace() == spirv::SPIRVDialect::getDialectNamespace(); }); - if (failed(applyPartialConversion(module, target, patterns))) + if (failed(applyPartialConversion(module, target, std::move(patterns)))) return signalPassFailure(); // Walks over all the FuncOps in spirv::ModuleOp to lower the entry point diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp @@ -26,7 +26,8 @@ populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter, patterns, target); - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp --- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp @@ -49,7 +49,7 @@ OwningRewritePatternList patterns; populateRemoveShapeConstraintsPatterns(patterns, &ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -67,7 +67,8 @@ ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + if (failed(mlir::applyPartialConversion(getFunction(), target, + std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -148,8 +148,8 @@ populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp(); - - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp @@ -81,7 +81,8 @@ return op.kind() != AtomicRMWKind::maxf && op.kind() != AtomicRMWKind::minf; }); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + if (failed(mlir::applyPartialConversion(getFunction(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt --- a/mlir/lib/Rewrite/CMakeLists.txt +++ b/mlir/lib/Rewrite/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRRewrite + FrozenRewritePatternList.cpp PatternApplicator.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -0,0 +1,19 @@ +//===- FrozenRewritePatternList.cpp - Frozen Pattern List -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Rewrite/FrozenRewritePatternList.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// FrozenRewritePatternList +//===----------------------------------------------------------------------===// + +FrozenRewritePatternList::FrozenRewritePatternList( + OwningRewritePatternList &&patterns) + : patterns(patterns.takePatterns()) {} diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -22,28 +22,28 @@ // Separate patterns by root kind to simplify lookup later on. patterns.clear(); anyOpPatterns.clear(); - for (const auto &pat : owningPatternList) { + for (const auto &pat : frozenPatternList.getPatterns()) { // If the pattern is always impossible to match, just ignore it. - if (pat->getBenefit().isImpossibleToMatch()) { + if (pat.getBenefit().isImpossibleToMatch()) { LLVM_DEBUG({ llvm::dbgs() - << "Ignoring pattern '" << pat->getRootKind() + << "Ignoring pattern '" << pat.getRootKind() << "' because it is impossible to match (by pattern benefit)\n"; }); continue; } - if (Optional opName = pat->getRootKind()) - patterns[*opName].push_back(pat.get()); + if (Optional opName = pat.getRootKind()) + patterns[*opName].push_back(&pat); else - anyOpPatterns.push_back(pat.get()); + anyOpPatterns.push_back(&pat); } // Sort the patterns using the provided cost model. - llvm::SmallDenseMap benefits; - auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) { + llvm::SmallDenseMap benefits; + auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { return benefits[lhs] > benefits[rhs]; }; - auto processPatternList = [&](SmallVectorImpl &list) { + auto processPatternList = [&](SmallVectorImpl &list) { // Special case for one pattern in the list, which is the most common case. if (list.size() == 1) { if (model(*list.front()).isImpossibleToMatch()) { @@ -59,7 +59,7 @@ // Collect the dynamic benefits for the current pattern list. benefits.clear(); - for (RewritePattern *pat : list) + for (const Pattern *pat : list) benefits.try_emplace(pat, model(*pat)); // Sort patterns with highest benefit first, and remove those that are @@ -81,8 +81,8 @@ void PatternApplicator::walkAllPatterns( function_ref walk) { - for (auto &it : owningPatternList) - walk(*it); + for (auto &it : frozenPatternList.getPatterns()) + walk(it); } LogicalResult PatternApplicator::matchAndRewrite( @@ -91,7 +91,7 @@ function_ref onFailure, function_ref onSuccess) { // Check to see if there are patterns matching this specific operation type. - MutableArrayRef opPatterns; + MutableArrayRef opPatterns; auto patternIt = patterns.find(op->getName()); if (patternIt != patterns.end()) opPatterns = patternIt->second; @@ -104,7 +104,7 @@ auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); while (opIt != opE && anyIt != anyE) { // Try to match the pattern providing the most benefit. - RewritePattern *pattern; + const RewritePattern *pattern; if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) pattern = *(opIt++); else @@ -118,7 +118,7 @@ // If we break from the loop, then only one of the ranges can still have // elements. Loop over both without checking given that we don't need to // interleave anymore. - for (RewritePattern *pattern : llvm::concat( + for (const RewritePattern *pattern : llvm::concat( llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, onSuccess))) diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -32,7 +32,7 @@ op->getCanonicalizationPatterns(patterns, context); Operation *op = getOperation(); - applyPatternsAndFoldGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -503,7 +503,7 @@ /// canonicalization patterns. static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context, - const OwningRewritePatternList &canonPatterns) { + const FrozenRewritePatternList &canonPatterns) { // Collect the sets of nodes to canonicalize. SmallVector nodesToCanonicalize; for (auto *node : currentSCC) { @@ -574,7 +574,7 @@ /// the inlining of newly devirtualized calls. void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context, - const OwningRewritePatternList &canonPatterns); + const FrozenRewritePatternList &canonPatterns); }; } // end anonymous namespace @@ -596,13 +596,14 @@ OwningRewritePatternList canonPatterns; for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(canonPatterns, context); + FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns)); // Run the inline transform in post-order over the SCCs in the callgraph. SymbolTableCollection symbolTable; Inliner inliner(context, cg, symbolTable); CGUseList useList(getOperation(), cg, symbolTable); runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { - inlineSCC(inliner, useList, scc, context, canonPatterns); + inlineSCC(inliner, useList, scc, context, frozenCanonPatterns); }); // After inlining, make sure to erase any callables proven to be dead. @@ -611,7 +612,7 @@ void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context, - const OwningRewritePatternList &canonPatterns) { + const FrozenRewritePatternList &canonPatterns) { // If we successfully inlined any calls, run some simplifications on the // nodes of the scc. Continue attempting to inline until we reach a fixed // point, or a maximum iteration count. We canonicalize here as it may diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1459,7 +1459,7 @@ using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(ConversionTarget &targetInfo, - const OwningRewritePatternList &patterns); + const FrozenRewritePatternList &patterns); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; @@ -1555,7 +1555,7 @@ } // namespace OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, - const OwningRewritePatternList &patterns) + const FrozenRewritePatternList &patterns) : target(targetInfo), applicator(patterns) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. @@ -2078,7 +2078,7 @@ // conversion mode. struct OperationConverter { explicit OperationConverter(ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, OpConversionMode mode, DenseSet *trackedOps = nullptr) : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} @@ -2672,7 +2672,7 @@ LogicalResult mlir::applyPartialConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Partial, unconvertedOps); @@ -2680,7 +2680,7 @@ } LogicalResult mlir::applyPartialConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet *unconvertedOps) { return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, unconvertedOps); @@ -2691,13 +2691,13 @@ /// operation fails. LogicalResult mlir::applyFullConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); return opConverter.convertOperations(ops); } LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { return applyFullConversion(llvm::makeArrayRef(op), target, patterns); } @@ -2710,7 +2710,7 @@ LogicalResult mlir::applyAnalysisConversion(ArrayRef ops, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); @@ -2718,7 +2718,7 @@ } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, - const OwningRewritePatternList &patterns, + const FrozenRewritePatternList &patterns, DenseSet &convertedOps) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, convertedOps); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -37,7 +37,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, - const OwningRewritePatternList &patterns) + const FrozenRewritePatternList &patterns) : PatternRewriter(ctx), matcher(patterns), folder(ctx) { worklist.reserve(64); @@ -219,13 +219,13 @@ /// LogicalResult mlir::applyPatternsAndFoldGreedily(Operation *op, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { return applyPatternsAndFoldGreedily(op->getRegions(), patterns); } /// Rewrite the given regions, which must be isolated from above. LogicalResult mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns) { + const FrozenRewritePatternList &patterns) { if (regions.empty()) return success(); @@ -259,7 +259,7 @@ class OpPatternRewriteDriver : public PatternRewriter { public: explicit OpPatternRewriteDriver(MLIRContext *ctx, - const OwningRewritePatternList &patterns) + const FrozenRewritePatternList &patterns) : PatternRewriter(ctx), matcher(patterns), folder(ctx) { // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); @@ -343,7 +343,7 @@ /// folding. `erased` is set to true if the op is erased as a result of being /// folded, replaced, or dead. LogicalResult mlir::applyOpPatternsAndFold( - Operation *op, const OwningRewritePatternList &patterns, bool *erased) { + Operation *op, const FrozenRewritePatternList &patterns, bool *erased) { // Start the pattern driver. OpPatternRewriteDriver driver(op->getContext(), patterns); bool opErased; diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -144,7 +144,7 @@ ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot>(context); - if (failed(applyPartialConversion(fn, *target, patterns))) + if (failed(applyPartialConversion(fn, *target, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -83,7 +83,7 @@ // Verify named pattern is generated with expected name. patterns.insert(&getContext()); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; } // end anonymous namespace @@ -601,7 +601,7 @@ // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, patterns, + (void)applyPartialConversion(getOperation(), target, std::move(patterns), &unlegalizedOps); // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) @@ -616,7 +616,7 @@ return (bool)op->getAttrOfType("test.dynamically_legal"); }); - (void)applyFullConversion(getOperation(), target, patterns); + (void)applyFullConversion(getOperation(), target, std::move(patterns)); return; } @@ -625,8 +625,8 @@ // Analyze the convertible operations. DenseSet legalizedOps; - if (failed(applyAnalysisConversion(getOperation(), target, patterns, - legalizedOps))) + if (failed(applyAnalysisConversion(getOperation(), target, + std::move(patterns), legalizedOps))) return signalPassFailure(); // Emit remarks for each legalizable operation. @@ -704,7 +704,8 @@ return std::distance(op->operand_begin(), op->operand_end()) > 1; }); - if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { + if (failed(mlir::applyFullConversion(getFunction(), target, + std::move(patterns)))) { signalPassFailure(); } } @@ -737,7 +738,8 @@ mlir::ConversionTarget target(getContext()); target.addIllegalDialect(); - if (failed(applyPartialConversion(getFunction(), target, patterns))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; @@ -833,7 +835,8 @@ mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); - if (failed(applyPartialConversion(getOperation(), target, patterns))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; @@ -939,7 +942,7 @@ }); DenseSet unlegalizedOps; - (void)applyPartialConversion(getOperation(), target, patterns, + (void)applyPartialConversion(getOperation(), target, std::move(patterns), &unlegalizedOps); for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -32,7 +32,7 @@ namespace { struct TestTraitFolder : public PassWrapper { void runOnFunction() override { - applyPatternsAndFoldGreedily(getFunction(), {}); + applyPatternsAndFoldGreedily(getFunction(), OwningRewritePatternList()); } }; } // end anonymous namespace diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -232,7 +232,8 @@ OwningRewritePatternList patterns; populateTensorLinalgToBufferLinalgConversionPattern(&context, converter, patterns); - if (failed(applyFullConversion(this->getOperation(), target, patterns))) + if (failed(applyFullConversion(this->getOperation(), target, + std::move(patterns)))) this->signalPassFailure(); }; }; diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -60,6 +60,8 @@ SmallVector stage1Patterns; linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); + SmallVector frozenStage1Patterns; + llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); OwningRewritePatternList stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); @@ -78,8 +80,8 @@ return success(); }; - linalg::applyStagedPatterns(module, stage1Patterns, stage2Patterns, - stage3Transforms); + linalg::applyStagedPatterns(module, frozenStage1Patterns, + std::move(stage2Patterns), stage3Transforms); //===--------------------------------------------------------------------===// // Post staged patterns transforms diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp --- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp +++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp @@ -58,9 +58,8 @@ target.addIllegalDialect(); target.addIllegalDialect(); - if (failed(applyPartialConversion(m, target, patterns))) { + if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); - } } }; diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp --- a/mlir/test/lib/Transforms/TestExpandTanh.cpp +++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp @@ -26,7 +26,7 @@ void TestExpandTanhPass::runOnFunction() { OwningRewritePatternList patterns; populateExpandTanhPattern(patterns, &getContext()); - applyPatternsAndFoldGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Transforms/TestGpuRewrite.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp --- a/mlir/test/lib/Transforms/TestGpuRewrite.cpp +++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp @@ -26,7 +26,7 @@ void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); - applyPatternsAndFoldGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -97,7 +97,7 @@ LinalgDependenceGraph dependenceGraph = LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); fillFusionPatterns(context, dependenceGraph, fusionPatterns); - applyPatternsAndFoldGreedily(funcOp, fusionPatterns); + applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); } void TestLinalgFusionTransforms::runOnFunction() { diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -208,7 +208,7 @@ LinalgMarker(Identifier::get("_promote_views_aligned_", ctx), Identifier::get("_views_aligned_promoted_", ctx))); - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the marker. funcOp.walk([](LinalgOp op) { @@ -431,16 +431,18 @@ fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), stage1Patterns); } - OwningRewritePatternList stage2Patterns = + SmallVector frozenStage1Patterns; + llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); + FrozenRewritePatternList stage2Patterns = getLinalgTilingCanonicalizationPatterns(ctx); - applyStagedPatterns(funcOp, stage1Patterns, stage2Patterns); + applyStagedPatterns(funcOp, frozenStage1Patterns, std::move(stage2Patterns)); } static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { OwningRewritePatternList forwardPattern; forwardPattern.insert(funcOp.getContext()); forwardPattern.insert(funcOp.getContext()); - applyPatternsAndFoldGreedily(funcOp, forwardPattern); + applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); } static void applyContractionToVectorPatterns(FuncOp funcOp) { @@ -451,16 +453,18 @@ LinalgVectorizationPattern, LinalgVectorizationPattern, LinalgVectorizationPattern>(funcOp.getContext()); - applyPatternsAndFoldGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) { OwningRewritePatternList foldPattern; foldPattern.insert(funcOp.getContext()); + FrozenRewritePatternList frozenPatterns(std::move(foldPattern)); + // Explicitly walk and apply the pattern locally to avoid more general folding // on the rest of the IR. - funcOp.walk([&foldPattern](AffineMinOp minOp) { - applyOpPatternsAndFold(minOp, foldPattern); + funcOp.walk([&frozenPatterns](AffineMinOp minOp) { + applyOpPatternsAndFold(minOp, frozenPatterns); }); } /// Apply transformations specified as patterns. @@ -475,13 +479,13 @@ if (testPromotionOptions) { OwningRewritePatternList patterns; fillPromotionCallBackPatterns(&getContext(), patterns); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testTileAndDistributionOptions) { OwningRewritePatternList patterns; fillTileAndDistributePatterns(&getContext(), patterns); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testPatterns) diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -32,7 +32,7 @@ ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -41,7 +41,7 @@ void runOnFunction() override { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -78,7 +78,7 @@ VectorTransformsOptions options{lowering}; patterns.insert(options, &getContext()); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } @@ -94,7 +94,7 @@ return failure(); return success(); }); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } @@ -108,7 +108,7 @@ transposeLowering = VectorTransposeLowering::Flat; VectorTransformsOptions options{contractLowering, transposeLowering}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -145,7 +145,7 @@ } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } Option unrollBasedOnType{ @@ -181,7 +181,7 @@ }); patterns.insert(ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -199,7 +199,7 @@ ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -228,7 +228,7 @@ else options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); patterns.insert(ctx, options); - applyPatternsAndFoldGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };