diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -416,12 +416,9 @@ void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); }; -//===----------------------------------------------------------------------===// -// Pattern-driven rewriters -//===----------------------------------------------------------------------===// - //===----------------------------------------------------------------------===// // OwningRewritePatternList +//===----------------------------------------------------------------------===// class OwningRewritePatternList { using PatternListT = std::vector>; @@ -481,98 +478,6 @@ PatternListT patterns; }; -//===----------------------------------------------------------------------===// -// PatternApplicator - -/// This class manages the application of a group of rewrite patterns, with a -/// user-provided cost model. -class PatternApplicator { -public: - /// The cost model dynamically assigns a PatternBenefit to a particular - /// pattern. Users can query contained patterns and pass analysis results to - /// applyCostModel. Patterns to be discarded should have a benefit of - /// `impossibleToMatch`. - using CostModel = function_ref; - - explicit PatternApplicator(const OwningRewritePatternList &owningPatternList) - : owningPatternList(owningPatternList) {} - - /// Attempt to match and rewrite the given op with any pattern, allowing a - /// predicate to decide if a pattern can be applied or not, and hooks for if - /// the pattern match was a success or failure. - /// - /// canApply: called before each match and rewrite attempt; return false to - /// skip pattern. - /// onFailure: called when a pattern fails to match to perform cleanup. - /// onSuccess: called when a pattern match succeeds; return failure() to - /// invalidate the match and try another pattern. - LogicalResult - matchAndRewrite(Operation *op, PatternRewriter &rewriter, - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); - - /// Apply a cost model to the patterns within this applicator. - void applyCostModel(CostModel model); - - /// Apply the default cost model that solely uses the pattern's static - /// benefit. - void applyDefaultCostModel() { - applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); }); - } - - /// Walk all of the patterns within the applicator. - void walkAllPatterns(function_ref walk); - -private: - /// Attempt to match and rewrite the given op with the given pattern, allowing - /// a predicate to decide if a pattern can be applied or not, and hooks for if - /// the pattern match was a success or failure. - LogicalResult - matchAndRewrite(Operation *op, const RewritePattern &pattern, - PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess); - - /// The list that owns the patterns used within this applicator. - const OwningRewritePatternList &owningPatternList; - - /// The set of patterns to match for each operation, stable sorted by benefit. - DenseMap> patterns; - /// The set of patterns that may match against any operation type, stable - /// sorted by benefit. - SmallVector anyOpPatterns; -}; - -//===----------------------------------------------------------------------===// -// applyPatternsGreedily -//===----------------------------------------------------------------------===// - -/// Rewrite the regions of the specified operation, which must be isolated from -/// above, by repeatedly applying the highest benefit patterns in a greedy -/// work-list driven manner. Return success if no more patterns can be matched -/// in the result operation regions. -/// Note: This does not apply patterns to the top-level operation itself. Note: -/// These methods also perform folding and simple dead-code elimination -/// before attempting to match any of the provided patterns. -/// -LogicalResult -applyPatternsAndFoldGreedily(Operation *op, - const OwningRewritePatternList &patterns); -/// Rewrite the given regions, which must be isolated from above. -LogicalResult -applyPatternsAndFoldGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns); - -/// Applies the specified patterns on `op` alone while also trying to fold it, -/// by selecting the highest benefits patterns in a greedy manner. Returns -/// success if no more patterns can be matched. `erased` is set to true if `op` -/// was folded away or erased as a result of becoming dead. Note: This does not -/// apply any patterns recursively to the regions of `op`. -LogicalResult applyOpPatternsAndFold(Operation *op, - const OwningRewritePatternList &patterns, - bool *erased = nullptr); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -0,0 +1,85 @@ +//===- PatternApplicator.h - PatternApplicator -------==---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an applicator that applies pattern rewrites based upon a +// user defined cost model. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H +#define MLIR_REWRITE_PATTERNAPPLICATOR_H + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class PatternRewriter; + +/// This class manages the application of a group of rewrite patterns, with a +/// user-provided cost model. +class PatternApplicator { +public: + /// The cost model dynamically assigns a PatternBenefit to a particular + /// pattern. Users can query contained patterns and pass analysis results to + /// applyCostModel. Patterns to be discarded should have a benefit of + /// `impossibleToMatch`. + using CostModel = function_ref; + + explicit PatternApplicator(const OwningRewritePatternList &owningPatternList) + : owningPatternList(owningPatternList) {} + + /// Attempt to match and rewrite the given op with any pattern, allowing a + /// predicate to decide if a pattern can be applied or not, and hooks for if + /// the pattern match was a success or failure. + /// + /// canApply: called before each match and rewrite attempt; return false to + /// skip pattern. + /// onFailure: called when a pattern fails to match to perform cleanup. + /// onSuccess: called when a pattern match succeeds; return failure() to + /// invalidate the match and try another pattern. + LogicalResult + matchAndRewrite(Operation *op, PatternRewriter &rewriter, + function_ref canApply = {}, + function_ref onFailure = {}, + function_ref onSuccess = {}); + + /// Apply a cost model to the patterns within this applicator. + void applyCostModel(CostModel model); + + /// Apply the default cost model that solely uses the pattern's static + /// benefit. + void applyDefaultCostModel() { + applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); }); + } + + /// Walk all of the patterns within the applicator. + void walkAllPatterns(function_ref walk); + +private: + /// Attempt to match and rewrite the given op with the given pattern, allowing + /// a predicate to decide if a pattern can be applied or not, and hooks for if + /// the pattern match was a success or failure. + LogicalResult + matchAndRewrite(Operation *op, const RewritePattern &pattern, + PatternRewriter &rewriter, + function_ref canApply, + function_ref onFailure, + function_ref onSuccess); + + /// The list that owns the patterns used within this applicator. + const OwningRewritePatternList &owningPatternList; + + /// The set of patterns to match for each operation, stable sorted by benefit. + DenseMap> patterns; + /// The set of patterns that may match against any operation type, stable + /// sorted by benefit. + SmallVector anyOpPatterns; +}; + +} // end namespace mlir + +#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -0,0 +1,52 @@ +//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares methods for applying a set of patterns greedily, choosing +// the patterns with the highest local benefit, until a fixed point is reached. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ +#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// applyPatternsGreedily +//===----------------------------------------------------------------------===// + +/// Rewrite the regions of the specified operation, which must be isolated from +/// above, by repeatedly applying the highest benefit patterns in a greedy +/// work-list driven manner. Return success if no more patterns can be matched +/// in the result operation regions. +/// Note: This does not apply patterns to the top-level operation itself. Note: +/// These methods also perform folding and simple dead-code elimination +/// before attempting to match any of the provided patterns. +/// +LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const OwningRewritePatternList &patterns); +/// Rewrite the given regions, which must be isolated from above. +LogicalResult +applyPatternsAndFoldGreedily(MutableArrayRef regions, + const OwningRewritePatternList &patterns); + +/// Applies the specified patterns on `op` alone while also trying to fold it, +/// by selecting the highest benefits patterns in a greedy manner. Returns +/// success if no more patterns can be matched. `erased` is set to true if `op` +/// was folded away or erased as a result of becoming dead. Note: This does not +/// apply any patterns recursively to the regions of `op`. +LogicalResult applyOpPatternsAndFold(Operation *op, + const OwningRewritePatternList &patterns, + bool *erased = nullptr); + +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(Parser) add_subdirectory(Pass) add_subdirectory(Reducer) +add_subdirectory(Rewrite) add_subdirectory(Support) add_subdirectory(TableGen) add_subdirectory(Target) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/FormatVariadic.h" #include "../GPUCommon/GPUOpsLowering.h" diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/FormatVariadic.h" #include "../GPUCommon/GPUOpsLowering.h" diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -123,7 +124,7 @@ OwningRewritePatternList patterns; populateConvertShapeConstraintsConversionPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(func, patterns))) + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -17,8 +17,8 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -15,16 +15,12 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -24,14 +24,10 @@ #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -24,7 +24,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -15,7 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Utils.h" #define DEBUG_TYPE "simplify-affine-structure" diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -16,7 +16,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -20,9 +20,8 @@ #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -23,9 +23,9 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::linalg; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -22,6 +22,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -22,8 +22,8 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -21,9 +21,9 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -12,10 +12,9 @@ #include "mlir/Dialect/Quant/QuantizeUtils.h" #include "mlir/Dialect/Quant/UniformSupport.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::quant; diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -11,9 +11,8 @@ #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::quant; diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp --- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -8,14 +8,9 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "llvm/Support/Debug.h" using namespace mlir; -#define DEBUG_TYPE "pattern-match" - //===----------------------------------------------------------------------===// // PatternBenefit //===----------------------------------------------------------------------===// @@ -205,135 +200,3 @@ cloneRegionBefore(region, *before->getParent(), before->getIterator()); } -//===----------------------------------------------------------------------===// -// PatternApplicator -//===----------------------------------------------------------------------===// - -void PatternApplicator::applyCostModel(CostModel model) { - // Separate patterns by root kind to simplify lookup later on. - patterns.clear(); - anyOpPatterns.clear(); - for (const auto &pat : owningPatternList) { - // If the pattern is always impossible to match, just ignore it. - if (pat->getBenefit().isImpossibleToMatch()) { - LLVM_DEBUG({ - llvm::dbgs() - << "Ignoring pattern '" << pat->getRootKind() - << "' because it is impossible to match (by pattern benefit)\n"; - }); - continue; - } - if (Optional opName = pat->getRootKind()) - patterns[*opName].push_back(pat.get()); - else - anyOpPatterns.push_back(pat.get()); - } - - // Sort the patterns using the provided cost model. - llvm::SmallDenseMap benefits; - auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) { - return benefits[lhs] > benefits[rhs]; - }; - auto processPatternList = [&](SmallVectorImpl &list) { - // Special case for one pattern in the list, which is the most common case. - if (list.size() == 1) { - if (model(*list.front()).isImpossibleToMatch()) { - LLVM_DEBUG({ - llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind() - << "' because it is impossible to match or cannot lead " - "to legal IR (by cost model)\n"; - }); - list.clear(); - } - return; - } - - // Collect the dynamic benefits for the current pattern list. - benefits.clear(); - for (RewritePattern *pat : list) - benefits.try_emplace(pat, model(*pat)); - - // Sort patterns with highest benefit first, and remove those that are - // impossible to match. - std::stable_sort(list.begin(), list.end(), cmp); - while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { - LLVM_DEBUG({ - llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind() - << "' because it is impossible to match or cannot lead to " - "legal IR (by cost model)\n"; - }); - list.pop_back(); - } - }; - for (auto &it : patterns) - processPatternList(it.second); - processPatternList(anyOpPatterns); -} - -void PatternApplicator::walkAllPatterns( - function_ref walk) { - for (auto &it : owningPatternList) - walk(*it); -} - -LogicalResult PatternApplicator::matchAndRewrite( - Operation *op, PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess) { - // Check to see if there are patterns matching this specific operation type. - MutableArrayRef opPatterns; - auto patternIt = patterns.find(op->getName()); - if (patternIt != patterns.end()) - opPatterns = patternIt->second; - - // Process the patterns for that match the specific operation type, and any - // operation type in an interleaved fashion. - // FIXME: It'd be nice to just write an llvm::make_merge_range utility - // and pass in a comparison function. That would make this code trivial. - auto opIt = opPatterns.begin(), opE = opPatterns.end(); - auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); - while (opIt != opE && anyIt != anyE) { - // Try to match the pattern providing the most benefit. - RewritePattern *pattern; - if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) - pattern = *(opIt++); - else - pattern = *(anyIt++); - - // Otherwise, try to match the generic pattern. - if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, - onSuccess))) - return success(); - } - // If we break from the loop, then only one of the ranges can still have - // elements. Loop over both without checking given that we don't need to - // interleave anymore. - for (RewritePattern *pattern : llvm::concat( - llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { - if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, - onSuccess))) - return success(); - } - return failure(); -} - -LogicalResult PatternApplicator::matchAndRewrite( - Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess) { - // Check that the pattern can be applied. - if (canApply && !canApply(pattern)) - return failure(); - - // Try to match and rewrite this pattern. The patterns are sorted by - // benefit, so if we match we can immediately rewrite. - rewriter.setInsertionPoint(op); - if (succeeded(pattern.matchAndRewrite(op, rewriter))) - return success(!onSuccess || succeeded(onSuccess(pattern))); - - if (onFailure) - onFailure(pattern); - return failure(); -} diff --git a/mlir/lib/Rewrite/CMakeLists.txt b/mlir/lib/Rewrite/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Rewrite/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_library(MLIRRewrite + PatternApplicator.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite + + DEPENDS + mlir-generic-headers + + LINK_LIBS PUBLIC + MLIRIR + ) diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -0,0 +1,148 @@ +//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an applicator that applies pattern rewrites based upon a +// user defined cost model. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +#define DEBUG_TYPE "pattern-match" + +void PatternApplicator::applyCostModel(CostModel model) { + // Separate patterns by root kind to simplify lookup later on. + patterns.clear(); + anyOpPatterns.clear(); + for (const auto &pat : owningPatternList) { + // If the pattern is always impossible to match, just ignore it. + if (pat->getBenefit().isImpossibleToMatch()) { + LLVM_DEBUG({ + llvm::dbgs() + << "Ignoring pattern '" << pat->getRootKind() + << "' because it is impossible to match (by pattern benefit)\n"; + }); + continue; + } + if (Optional opName = pat->getRootKind()) + patterns[*opName].push_back(pat.get()); + else + anyOpPatterns.push_back(pat.get()); + } + + // Sort the patterns using the provided cost model. + llvm::SmallDenseMap benefits; + auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) { + return benefits[lhs] > benefits[rhs]; + }; + auto processPatternList = [&](SmallVectorImpl &list) { + // Special case for one pattern in the list, which is the most common case. + if (list.size() == 1) { + if (model(*list.front()).isImpossibleToMatch()) { + LLVM_DEBUG({ + llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind() + << "' because it is impossible to match or cannot lead " + "to legal IR (by cost model)\n"; + }); + list.clear(); + } + return; + } + + // Collect the dynamic benefits for the current pattern list. + benefits.clear(); + for (RewritePattern *pat : list) + benefits.try_emplace(pat, model(*pat)); + + // Sort patterns with highest benefit first, and remove those that are + // impossible to match. + std::stable_sort(list.begin(), list.end(), cmp); + while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { + LLVM_DEBUG({ + llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind() + << "' because it is impossible to match or cannot lead to " + "legal IR (by cost model)\n"; + }); + list.pop_back(); + } + }; + for (auto &it : patterns) + processPatternList(it.second); + processPatternList(anyOpPatterns); +} + +void PatternApplicator::walkAllPatterns( + function_ref walk) { + for (auto &it : owningPatternList) + walk(*it); +} + +LogicalResult PatternApplicator::matchAndRewrite( + Operation *op, PatternRewriter &rewriter, + function_ref canApply, + function_ref onFailure, + function_ref onSuccess) { + // Check to see if there are patterns matching this specific operation type. + MutableArrayRef opPatterns; + auto patternIt = patterns.find(op->getName()); + if (patternIt != patterns.end()) + opPatterns = patternIt->second; + + // Process the patterns for that match the specific operation type, and any + // operation type in an interleaved fashion. + // FIXME: It'd be nice to just write an llvm::make_merge_range utility + // and pass in a comparison function. That would make this code trivial. + auto opIt = opPatterns.begin(), opE = opPatterns.end(); + auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); + while (opIt != opE && anyIt != anyE) { + // Try to match the pattern providing the most benefit. + RewritePattern *pattern; + if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) + pattern = *(opIt++); + else + pattern = *(anyIt++); + + // Otherwise, try to match the generic pattern. + if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, + onSuccess))) + return success(); + } + // If we break from the loop, then only one of the ranges can still have + // elements. Loop over both without checking given that we don't need to + // interleave anymore. + for (RewritePattern *pattern : llvm::concat( + llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { + if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, + onSuccess))) + return success(); + } + return failure(); +} + +LogicalResult PatternApplicator::matchAndRewrite( + Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, + function_ref canApply, + function_ref onFailure, + function_ref onSuccess) { + // Check that the pattern can be applied. + if (canApply && !canApply(pattern)) + return failure(); + + // Try to match and rewrite this pattern. The patterns are sorted by + // benefit, so if we match we can immediately rewrite. + rewriter.setInsertionPoint(op); + if (succeeded(pattern.matchAndRewrite(op, rewriter))) + return success(!onSuccess || succeeded(onSuccess(pattern))); + + if (onFailure) + onFailure(pattern); + return failure(); +} diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -7,7 +7,6 @@ Canonicalizer.cpp CopyRemoval.cpp CSE.cpp - DialectConversion.cpp Inliner.cpp LocationSnapshot.cpp LoopCoalescing.cpp diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" using namespace mlir; diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -17,6 +17,7 @@ #include "mlir/Analysis/CallGraph.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SCCIterator.h" diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + DialectConversion.cpp FoldUtils.cpp GreedyPatternRewriteDriver.cpp InliningUtils.cpp @@ -19,5 +20,6 @@ MLIRLoopAnalysis MLIRSCF MLIRPass + MLIRRewrite MLIRStandard ) diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp rename from mlir/lib/Transforms/DialectConversion.cpp rename to mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" +#include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" @@ -74,8 +75,7 @@ /// A utility function to log a successful result for the given reason. template -static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, - Args &&... args) { +static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { LLVM_DEBUG({ os.unindent(); os.startLine() << "} -> SUCCESS"; @@ -88,8 +88,7 @@ /// A utility function to log a failure result for the given reason. template -static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, - Args &&... args) { +static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { LLVM_DEBUG({ os.unindent(); os.startLine() << "} -> FAILURE : " @@ -2033,21 +2032,21 @@ return minDepth; // Sort the patterns by those likely to be the most beneficial. - llvm::array_pod_sort( - patternsByDepth.begin(), patternsByDepth.end(), - [](const std::pair *lhs, - const std::pair *rhs) { - // First sort by the smaller pattern legalization depth. - if (lhs->second != rhs->second) - return llvm::array_pod_sort_comparator(&lhs->second, - &rhs->second); - - // Then sort by the larger pattern benefit. - auto lhsBenefit = lhs->first->getBenefit(); - auto rhsBenefit = rhs->first->getBenefit(); - return llvm::array_pod_sort_comparator(&rhsBenefit, - &lhsBenefit); - }); + llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(), + [](const std::pair *lhs, + const std::pair *rhs) { + // First sort by the smaller pattern legalization + // depth. + if (lhs->second != rhs->second) + return llvm::array_pod_sort_comparator( + &lhs->second, &rhs->second); + + // Then sort by the larger pattern benefit. + auto lhsBenefit = lhs->first->getBenefit(); + auto rhsBenefit = rhs->first->getBenefit(); + return llvm::array_pod_sort_comparator( + &rhsBenefit, &lhsBenefit); + }); // Update the legalization pattern to use the new sorted list. patterns.clear(); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -10,8 +10,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Rewrite/PatternApplicator.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -23,8 +23,8 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -13,8 +13,8 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -10,10 +10,10 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -7,9 +7,8 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -25,9 +24,9 @@ OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold( ArrayRef operands) { - auto argument_op = getOperand(); + auto argumentOp = getOperand(); // The success case should cause the trait fold to be supressed. - return argument_op.getDefiningOp() ? argument_op : OpFoldResult{}; + return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{}; } namespace { diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -14,6 +14,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" @@ -93,7 +94,7 @@ // VectorTransforms.cpp vectorTransferPatterns.insert( context, vectorTransformsOptions); - applyPatternsAndFoldGreedily(module, vectorTransferPatterns); + applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns)); // Programmatic controlled lowering of linalg.copy and linalg.fill. PassManager pm(context); @@ -105,13 +106,14 @@ OwningRewritePatternList vectorContractLoweringPatterns; populateVectorContractLoweringPatterns(vectorContractLoweringPatterns, context, vectorTransformsOptions); - applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns); + applyPatternsAndFoldGreedily(module, + std::move(vectorContractLoweringPatterns)); // Programmatic controlled lowering of vector.transfer only. OwningRewritePatternList vectorToLoopsPatterns; populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, VectorTransferToSCFOptions()); - applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns); + applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns)); // Ensure we drop the marker in the end. module.walk([](linalg::LinalgOp op) { diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp --- a/mlir/test/lib/Transforms/TestExpandTanh.cpp +++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp @@ -11,8 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/Transforms/Passes.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/test/lib/Transforms/TestGpuRewrite.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp --- a/mlir/test/lib/Transforms/TestGpuRewrite.cpp +++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp @@ -12,8 +12,8 @@ #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::linalg; diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -17,8 +17,8 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -14,9 +14,8 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::vector;