diff --git a/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h b/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h --- a/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h +++ b/mlir/include/mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h @@ -13,12 +13,14 @@ #ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H #define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H -#include +#include "mlir/Support/LLVM.h" namespace mlir { class ModuleOp; +class Operation; template class OperationPass; +class PDLPatternConfigSet; #define GEN_PASS_DECL_CONVERTPDLTOPDLINTERP #include "mlir/Conversion/Passes.h.inc" @@ -26,6 +28,12 @@ /// Creates and returns a pass to convert PDL ops to PDL interpreter ops. std::unique_ptr> createPDLToPDLInterpPass(); +/// Creates and returns a pass to convert PDL ops to PDL interpreter ops. +/// `configMap` holds a map of the configurations for each pattern being +/// compiled. +std::unique_ptr> createPDLToPDLInterpPass( + DenseMap &configMap); + } // namespace mlir #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -600,10 +600,16 @@ class PatternRewriter : public RewriterBase { public: using RewriterBase::RewriterBase; + + /// A hook used to indicate if the pattern rewriter can recover from failure + /// during the rewrite stage of a pattern. For example, if the pattern + /// rewriter supports rollback, it may progress smoothly even if IR was + /// changed during the rewrite. + virtual bool canRecoverFromRewriteFailure() const { return false; } }; //===----------------------------------------------------------------------===// -// PDLPatternModule +// PDL Patterns //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// @@ -796,6 +802,108 @@ SmallVector> allocatedValueRanges; }; +//===----------------------------------------------------------------------===// +// PDLPatternConfig + +/// An individual configuration for a pattern, which can be accessed by native +/// functions via the PDLPatternConfigSet. This allows for injecting additional +/// configuration into PDL patterns that is specific to certain compilation +/// flows. +class PDLPatternConfig { +public: + virtual ~PDLPatternConfig() = default; + + /// Hooks that are invoked at the beginning and end of a rewrite of a matched + /// pattern. These can be used to setup any specific state necessary for the + /// rewrite. + virtual void notifyRewriteBegin(PatternRewriter &rewriter) {} + virtual void notifyRewriteEnd(PatternRewriter &rewriter) {} + + /// Return the TypeID that represents this configuration. + TypeID getTypeID() const { return id; } + +protected: + PDLPatternConfig(TypeID id) : id(id) {} + +private: + TypeID id; +}; + +/// This class provides a base class for users implementing a type of pattern +/// configuration. +template +class PDLPatternConfigBase : public PDLPatternConfig { +public: + /// Support LLVM style casting. + static bool classof(const PDLPatternConfig *config) { + return config->getTypeID() == getConfigID(); + } + + /// Return the type id used for this configuration. + static TypeID getConfigID() { return TypeID::get(); } + +protected: + PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {} +}; + +/// This class contains a set of configurations for a specific pattern. +/// Configurations are uniqued by TypeID, meaning that only one configuration of +/// each type is allowed. +class PDLPatternConfigSet { +public: + PDLPatternConfigSet() = default; + + /// Construct a set with the given configurations. + template + PDLPatternConfigSet(ConfigsT &&...configs) { + (addConfig(std::forward(configs)), ...); + } + + /// Get the configuration defined by the given type. Asserts that the + /// configuration of the provided type exists. + template + const T &get() const { + const T *config = tryGet(); + assert(config && "configuration not found"); + return *config; + } + + /// Get the configuration defined by the given type, returns nullptr if the + /// configuration does not exist. + template + const T *tryGet() const { + for (const auto &configIt : configs) + if (const T *config = dyn_cast(configIt.get())) + return config; + return nullptr; + } + + /// Notify the configurations within this set at the beginning or end of a + /// rewrite of a matched pattern. + void notifyRewriteBegin(PatternRewriter &rewriter) { + for (const auto &config : configs) + config->notifyRewriteBegin(rewriter); + } + void notifyRewriteEnd(PatternRewriter &rewriter) { + for (const auto &config : configs) + config->notifyRewriteEnd(rewriter); + } + +protected: + /// Add a configuration to the set. + template + void addConfig(T &&config) { + assert(!tryGet>() && "configuration already exists"); + configs.emplace_back( + std::make_unique>(std::forward(config))); + } + + /// The set of configurations for this pattern. This uses a vector instead of + /// a map with the expectation that the number of configurations per set is + /// small (<= 1). + SmallVector> configs; +}; + //===----------------------------------------------------------------------===// // PDLPatternModule @@ -807,9 +915,11 @@ /// A native PDL rewrite function. This function performs a rewrite on the /// given set of values. Any results from this rewrite that should be passed /// back to PDL should be added to the provided result list. This method is only -/// invoked when the corresponding match was successful. -using PDLRewriteFunction = - std::function)>; +/// invoked when the corresponding match was successful. Returns failure if an +/// invariant of the rewrite was broken (certain rewriters may recover from +/// partial pattern application). +using PDLRewriteFunction = std::function)>; namespace detail { namespace pdl_function_builder { @@ -1034,6 +1144,13 @@ results.push_back(types); } }; +template +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + SmallVector values) { + results.push_back(TypeRange(values)); + } +}; //===----------------------------------------------------------------------===// // Value @@ -1061,6 +1178,13 @@ results.push_back(values); } }; +template +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + SmallVector values) { + results.push_back(ValueRange(values)); + } +}; //===----------------------------------------------------------------------===// // PDL Function Builder: Argument Handling @@ -1111,28 +1235,49 @@ /// Store a single result within the result list. template -static void processResults(PatternRewriter &rewriter, PDLResultList &results, - T &&value) { +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, T &&value) { ProcessPDLValue::processAsResult(rewriter, results, std::forward(value)); + return success(); } /// Store a std::pair<> as individual results within the result list. template -static void processResults(PatternRewriter &rewriter, PDLResultList &results, - std::pair &&pair) { - processResults(rewriter, results, std::move(pair.first)); - processResults(rewriter, results, std::move(pair.second)); +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + std::pair &&pair) { + if (failed(processResults(rewriter, results, std::move(pair.first))) || + failed(processResults(rewriter, results, std::move(pair.second)))) + return failure(); + return success(); } /// Store a std::tuple<> as individual results within the result list. template -static void processResults(PatternRewriter &rewriter, PDLResultList &results, - std::tuple &&tuple) { +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + std::tuple &&tuple) { auto applyFn = [&](auto &&...args) { - (processResults(rewriter, results, std::move(args)), ...); + return (succeeded(processResults(rewriter, results, std::move(args))) && + ...); }; - std::apply(applyFn, std::move(tuple)); + return success(std::apply(applyFn, std::move(tuple))); +} + +/// Handle LogicalResult propagation. +inline LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + LogicalResult &&result) { + return result; +} +template +static LogicalResult processResults(PatternRewriter &rewriter, + PDLResultList &results, + FailureOr &&result) { + if (failed(result)) + return failure(); + return processResults(rewriter, results, std::move(*result)); } //===----------------------------------------------------------------------===// @@ -1192,23 +1337,26 @@ /// This overload handles the case of no return values. template > -std::enable_if_t::value> +std::enable_if_t::value, + LogicalResult> processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &, ArrayRef values, std::index_sequence) { fn(rewriter, (ProcessPDLValue>::processAsArg( values[I]))...); + return success(); } /// This overload handles the case of return values, which need to be packaged /// into the result list. template > -std::enable_if_t::value> +std::enable_if_t::value, + LogicalResult> processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, PDLResultList &results, ArrayRef values, std::index_sequence) { - processResults( + return processResults( rewriter, results, fn(rewriter, (ProcessPDLValue>:: processAsArg(values[I]))...)); @@ -1240,14 +1388,17 @@ std::make_index_sequence::num_args - 1>(); assertArgs(rewriter, values, argIndices); - processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, - argIndices); + return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, + argIndices); }; } } // namespace pdl_function_builder } // namespace detail +//===----------------------------------------------------------------------===// +// PDLPatternModule + /// This class contains all of the necessary data for a set of PDL patterns, or /// pattern rewrites specified in the form of the PDL dialect. This PDL module /// contained by this pattern may contain any number of `pdl.pattern` @@ -1256,9 +1407,17 @@ public: PDLPatternModule() = default; - /// Construct a PDL pattern with the given module. - PDLPatternModule(OwningOpRef pdlModule) - : pdlModule(std::move(pdlModule)) {} + /// Construct a PDL pattern with the given module and configurations. + PDLPatternModule(OwningOpRef module) + : pdlModule(std::move(module)) {} + template + PDLPatternModule(OwningOpRef module, ConfigsT &&...patternConfigs) + : PDLPatternModule(std::move(module)) { + auto configSet = std::make_unique( + std::forward(patternConfigs)...); + attachConfigToPatterns(*pdlModule, *configSet); + configs.emplace_back(std::move(configSet)); + } /// Merge the state in `other` into this pattern module. void mergeIn(PDLPatternModule &&other); @@ -1344,6 +1503,14 @@ return rewriteFunctions; } + /// Return the set of the registered pattern configs. + SmallVector> takeConfigs() { + return std::move(configs); + } + DenseMap takeConfigMap() { + return std::move(configMap); + } + /// Clear out the patterns and functions within this module. void clear() { pdlModule = nullptr; @@ -1352,9 +1519,17 @@ } private: + /// Attach the given pattern config set to the patterns defined within the + /// given module. + void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet); + /// The module containing the `pdl.pattern` operations. OwningOpRef pdlModule; + /// The set of configuration sets referenced by patterns within `pdlModule`. + SmallVector> configs; + DenseMap configMap; + /// The external functions referenced from within the PDL module. llvm::StringMap constraintFunctions; llvm::StringMap rewriteFunctions; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -574,6 +574,11 @@ // PatternRewriter Hooks //===--------------------------------------------------------------------===// + /// Indicate that the conversion rewriter can recover from rewrite failure. + /// Recovery is supported via rollback, allowing for continued processing of + /// patterns even if a failure is encountered during the rewrite step. + bool canRecoverFromRewriteFailure() const override { return true; } + /// PatternRewriter hook for replacing the results of an operation when the /// given functor returns true. void replaceOpWithIf( @@ -891,6 +896,35 @@ MLIRContext &ctx; }; +//===----------------------------------------------------------------------===// +// PDL Configuration +//===----------------------------------------------------------------------===// + +/// A PDL configuration that is used to supported dialect conversion +/// functionality. +class PDLConversionConfig final + : public PDLPatternConfigBase { +public: + PDLConversionConfig(TypeConverter *converter) : converter(converter) {} + ~PDLConversionConfig() final = default; + + /// Return the type converter used by this configuration, which may be nullptr + /// if no type conversions are expected. + TypeConverter *getTypeConverter() const { return converter; } + + /// Hooks that are invoked at the beginning and end of a rewrite of a matched + /// pattern. + void notifyRewriteBegin(PatternRewriter &rewriter) final; + void notifyRewriteEnd(PatternRewriter &rewriter) final; + +private: + /// An optional type converter to use for the pattern. + TypeConverter *converter; +}; + +/// Register the dialect conversion PDL functions with the given pattern set. +void registerConversionPDLFunctions(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.pdll b/mlir/include/mlir/Transforms/DialectConversion.pdll new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/DialectConversion.pdll @@ -0,0 +1,30 @@ +//===- DialectConversion.pdll - DialectConversion PDLL Support -*- PDLL -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various utilities for interacting with dialect conversion +// within PDLL. +// +//===----------------------------------------------------------------------===// + +/// This rewrite returns the converted value of `value`, whose type is defined +/// by the type converted specified in the `PDLConversionConfig` of the current +/// pattern. +Rewrite convertValue(value: Value) -> Value; + +/// This rewrite returns the converted values of `values`, whose type is defined +/// by the type converted specified in the `PDLConversionConfig` of the current +/// pattern. +Rewrite convertValues(values: ValueRange) -> ValueRange; + +/// This rewrite returns the converted type of `type` as defined by the type +/// converted specified in the `PDLConversionConfig` of the current pattern. +Rewrite convertType(type: Type) -> Type; + +/// This rewrite returns the converted types of `types` as defined by the type +/// converted specified in the `PDLConversionConfig` of the current pattern. +Rewrite convertTypes(types: TypeRange) -> TypeRange; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -37,7 +37,8 @@ /// given module containing PDL pattern operations. struct PatternLowering { public: - PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule); + PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, + DenseMap *configMap); /// Generate code for matching and rewriting based on the pattern operations /// within the module. @@ -140,13 +141,19 @@ /// The set of operation values whose whose location will be used for newly /// generated operations. SetVector locOps; + + /// A mapping between pattern operations and the corresponding configuration + /// set. + DenseMap *configMap; }; } // namespace -PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc, - ModuleOp rewriterModule) +PatternLowering::PatternLowering( + pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule, + DenseMap *configMap) : builder(matcherFunc.getContext()), matcherFunc(matcherFunc), - rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} + rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule), + configMap(configMap) {} void PatternLowering::lower(ModuleOp module) { PredicateUniquer predicateUniquer; @@ -589,10 +596,14 @@ rootKindAttr = builder.getStringAttr(*rootKind); builder.setInsertionPointToEnd(currentBlock); - builder.create( + auto matchOp = builder.create( pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), failureBlockStack.back()); + + // Set the config of the lowered match to the parent pattern. + if (configMap) + configMap->try_emplace(matchOp, configMap->lookup(pattern)); } SymbolRefAttr PatternLowering::generateRewriter( @@ -922,7 +933,14 @@ namespace { struct PDLToPDLInterpPass : public impl::ConvertPDLToPDLInterpBase { + PDLToPDLInterpPass() = default; + PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default; + PDLToPDLInterpPass(DenseMap &configMap) + : configMap(&configMap) {} void runOnOperation() final; + + /// A map containing the configuration for each pattern. + DenseMap *configMap = nullptr; }; } // namespace @@ -946,15 +964,24 @@ module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); // Generate the code for the patterns within the module. - PatternLowering generator(matcherFunc, rewriterModule); + PatternLowering generator(matcherFunc, rewriterModule, configMap); generator.lower(module); // After generation, delete all of the pattern operations. for (pdl::PatternOp pattern : - llvm::make_early_inc_range(module.getOps())) + llvm::make_early_inc_range(module.getOps())) { + // Drop the now dead config mappings. + if (configMap) + configMap->erase(pattern); + pattern.erase(); + } } std::unique_ptr> mlir::createPDLToPDLInterpPass() { return std::make_unique(); } +std::unique_ptr> mlir::createPDLToPDLInterpPass( + DenseMap &configMap) { + return std::make_unique(configMap); +} diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -158,11 +158,15 @@ if (!other.pdlModule) return; - // Steal the functions of the other module. + // Steal the functions and config of the other module. for (auto &it : other.constraintFunctions) registerConstraintFunction(it.first(), std::move(it.second)); for (auto &it : other.rewriteFunctions) registerRewriteFunction(it.first(), std::move(it.second)); + for (auto &it : other.configs) + configs.emplace_back(std::move(it)); + for (auto &it : other.configMap) + configMap.insert(it); // Steal the other state if we have no patterns. if (!pdlModule) { @@ -176,6 +180,18 @@ other.pdlModule->getBody()->getOperations()); } +void PDLPatternModule::attachConfigToPatterns(ModuleOp module, + PDLPatternConfigSet &configSet) { + // Attach the configuration to the symbols within the module. We only add + // to symbols to avoid hardcoding any specific operation names here (given + // that we don't depend on any PDL dialect). We can't use + // cast here because patterns may be optional symbols. + module->walk([&](Operation *op) { + if (op->hasTrait()) + configMap[op] = &configSet; + }); +} + //===----------------------------------------------------------------------===// // Function Registry diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -38,19 +38,27 @@ class PDLByteCodePattern : public Pattern { public: static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, + PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr); /// Return the bytecode address of the rewriter for this pattern. ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } + /// Return the configuration set for this pattern, or null if there is none. + PDLPatternConfigSet *getConfigSet() const { return configSet; } + private: template - PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) - : Pattern(std::forward(patternArgs)...), - rewriterAddr(rewriterAddr) {} + PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet, + Args &&...patternArgs) + : Pattern(std::forward(patternArgs)...), rewriterAddr(rewriterAddr), + configSet(configSet) {} /// The address of the rewriter for this pattern. ByteCodeAddr rewriterAddr; + + /// The optional config set for this pattern. + PDLPatternConfigSet *configSet; }; //===----------------------------------------------------------------------===// @@ -148,6 +156,8 @@ /// Create a ByteCode instance from the given module containing operations in /// the PDL interpreter dialect. PDLByteCode(ModuleOp module, + SmallVector> configs, + const DenseMap &configMap, llvm::StringMap constraintFns, llvm::StringMap rewriteFns); @@ -165,9 +175,9 @@ PDLByteCodeMutableState &state) const; /// Run the rewriter of the given pattern that was previously matched in - /// `match`. - void rewrite(PatternRewriter &rewriter, const MatchResult &match, - PDLByteCodeMutableState &state) const; + /// `match`. Returns if a failure was encountered during the rewrite. + LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, + PDLByteCodeMutableState &state) const; private: /// Execute the given byte code starting at the provided instruction `inst`. @@ -177,6 +187,9 @@ PDLByteCodeMutableState &state, SmallVectorImpl *matches) const; + /// The set of pattern configs referenced within the bytecode. + SmallVector> configs; + /// A vector containing pointers to uniqued data. The storage is intentionally /// opaque such that we can store a wide range of data types. The types of /// data stored here include: diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -34,21 +34,23 @@ //===----------------------------------------------------------------------===// PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, + PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr) { + PatternBenefit benefit = matchOp.getBenefit(); + MLIRContext *ctx = matchOp.getContext(); + + // Collect the set of generated operations. SmallVector generatedOps; if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) generatedOps = llvm::to_vector<8>(generatedOpsAttr.getAsValueRange()); - PatternBenefit benefit = matchOp.getBenefit(); - MLIRContext *ctx = matchOp.getContext(); - // Check to see if this is pattern matches a specific operation type. if (Optional rootKind = matchOp.getRootKind()) - return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, + return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, generatedOps); - return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, - generatedOps); + return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), + benefit, ctx, generatedOps); } //===----------------------------------------------------------------------===// @@ -194,14 +196,15 @@ ByteCodeField &maxValueRangeMemoryIndex, ByteCodeField &maxLoopLevel, llvm::StringMap &constraintFns, - llvm::StringMap &rewriteFns) + llvm::StringMap &rewriteFns, + const DenseMap &configMap) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), - maxLoopLevel(maxLoopLevel) { + maxLoopLevel(maxLoopLevel), configMap(configMap) { for (const auto &it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (const auto &it : llvm::enumerate(rewriteFns)) @@ -328,6 +331,9 @@ ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; ByteCodeField &maxLoopLevel; + + /// A map of pattern configurations. + const DenseMap &configMap; }; /// This class provides utilities for writing a bytecode stream. @@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( - op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); + op, configMap.lookup(op), + rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op.getOperation()), op.getMatchedOps()); writer.appendPDLValueList(op.getInputs()); @@ -1014,13 +1021,16 @@ // PDLByteCode //===----------------------------------------------------------------------===// -PDLByteCode::PDLByteCode(ModuleOp module, - llvm::StringMap constraintFns, - llvm::StringMap rewriteFns) { +PDLByteCode::PDLByteCode( + ModuleOp module, SmallVector> configs, + const DenseMap &configMap, + llvm::StringMap constraintFns, + llvm::StringMap rewriteFns) + : configs(std::move(configs)) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, - maxLoopLevel, constraintFns, rewriteFns); + maxLoopLevel, constraintFns, rewriteFns, configMap); generator.generate(module); // Initialize the external functions. @@ -1076,14 +1086,15 @@ /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching /// context. - void execute(PatternRewriter &rewriter, - SmallVectorImpl *matches = nullptr, - Optional mainRewriteLoc = {}); + LogicalResult + execute(PatternRewriter &rewriter, + SmallVectorImpl *matches = nullptr, + Optional mainRewriteLoc = {}); private: /// Internal implementation of executing each of the bytecode commands. void executeApplyConstraint(PatternRewriter &rewriter); - void executeApplyRewrite(PatternRewriter &rewriter); + LogicalResult executeApplyRewrite(PatternRewriter &rewriter); void executeAreEqual(); void executeAreRangesEqual(); void executeBranch(); @@ -1345,7 +1356,7 @@ selectJump(succeeded(constraintFn(rewriter, args))); } -void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { +LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; SmallVector args; @@ -1359,7 +1370,7 @@ // Execute the rewrite function. ByteCodeField numResults = read(); ByteCodeRewriteResultList results(numResults); - rewriteFn(rewriter, results, args); + LogicalResult rewriteResult = rewriteFn(rewriter, results, args); assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); @@ -1395,6 +1406,13 @@ allocatedTypeRangeMemory.push_back(std::move(it)); for (auto &it : results.getAllocatedValueRanges()) allocatedValueRangeMemory.push_back(std::move(it)); + + // Process the result of the rewrite. + if (failed(rewriteResult)) { + LLVM_DEBUG(llvm::dbgs() << " - Failed"); + return failure(); + } + return success(); } void ByteCodeExecutor::executeAreEqual() { @@ -2017,10 +2035,10 @@ }); } -void ByteCodeExecutor::execute( - PatternRewriter &rewriter, - SmallVectorImpl *matches, - Optional mainRewriteLoc) { +LogicalResult +ByteCodeExecutor::execute(PatternRewriter &rewriter, + SmallVectorImpl *matches, + Optional mainRewriteLoc) { while (true) { // Print the location of the operation being executed. LLVM_DEBUG(llvm::dbgs() << readInline() << "\n"); @@ -2031,7 +2049,8 @@ executeApplyConstraint(rewriter); break; case ApplyRewrite: - executeApplyRewrite(rewriter); + if (failed(executeApplyRewrite(rewriter))) + return failure(); break; case AreEqual: executeAreEqual(); @@ -2078,7 +2097,7 @@ case Finalize: executeFinalize(); LLVM_DEBUG(llvm::dbgs() << "\n"); - return; + return success(); case ForEach: executeForEach(); break; @@ -2166,8 +2185,6 @@ } } -/// Run the pattern matcher on the given root operation, collecting the matched -/// patterns in `matches`. void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl &matches, PDLByteCodeMutableState &state) const { @@ -2181,7 +2198,8 @@ state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); - executor.execute(rewriter, &matches); + LogicalResult executeResult = executor.execute(rewriter, &matches); + assert(succeeded(executeResult) && "unexpected matcher execution failure"); // Order the found matches by benefit. std::stable_sort(matches.begin(), matches.end(), @@ -2190,9 +2208,13 @@ }); } -/// Run the rewriter of the given pattern on the root operation `op`. -void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, - PDLByteCodeMutableState &state) const { +LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter, + const MatchResult &match, + PDLByteCodeMutableState &state) const { + auto *configSet = match.pattern->getConfigSet(); + if (configSet) + configSet->notifyRewriteBegin(rewriter); + // The arguments of the rewrite function are stored at the start of the // memory buffer. llvm::copy(match.values, state.memory.begin()); @@ -2204,5 +2226,24 @@ state.allocatedValueRangeMemory, state.loopIndex, uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); - executor.execute(rewriter, /*matches=*/nullptr, match.location); + LogicalResult result = + executor.execute(rewriter, /*matches=*/nullptr, match.location); + + if (configSet) + configSet->notifyRewriteEnd(rewriter); + + // If the rewrite failed, check if the pattern rewriter can recover. If it + // can, we can signal to the pattern applicator to keep trying patterns. If it + // doesn't, we need to bail. Bailing here should be fine, given that we have + // no means to propagate such a failure to the user, and it also indicates a + // bug in the user code (i.e. failable rewrites should not be used with + // pattern rewriters that don't support it). + if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) { + LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting"); + llvm::report_fatal_error( + "Native PDL Rewrite failed, but the pattern " + "rewriter doesn't support recovery. Failable pattern rewrites should " + "not be used with pattern rewriters that do not support them."); + } + return result; } diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp --- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp @@ -16,7 +16,9 @@ using namespace mlir; -static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { +static LogicalResult +convertPDLToPDLInterp(ModuleOp pdlModule, + DenseMap &configMap) { // Skip the conversion if the module doesn't contain pdl. if (pdlModule.getOps().empty()) return success(); @@ -37,7 +39,7 @@ // mode. pdlPipeline.enableVerifier(false); #endif - pdlPipeline.addPass(createPDLToPDLInterpPass()); + pdlPipeline.addPass(createPDLToPDLInterpPass(configMap)); if (failed(pdlPipeline.run(pdlModule))) return failure(); @@ -123,13 +125,16 @@ ModuleOp pdlModule = pdlPatterns.getModule(); if (!pdlModule) return; - if (failed(convertPDLToPDLInterp(pdlModule))) + DenseMap configMap = + pdlPatterns.takeConfigMap(); + if (failed(convertPDLToPDLInterp(pdlModule, configMap))) llvm::report_fatal_error( "failed to lower PDL pattern module to the PDL Interpreter"); // Generate the pdl bytecode. impl->pdlByteCode = std::make_unique( - pdlModule, pdlPatterns.takeConstraintFunctions(), + pdlModule, pdlPatterns.takeConfigs(), configMap, + pdlPatterns.takeConstraintFunctions(), pdlPatterns.takeRewriteFunctions()); } diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -191,20 +191,21 @@ Operation *dumpRootOp = getDumpRootOp(op); #endif if (pdlMatch) { - bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); - result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); + result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); } else { - const auto *pattern = static_cast(bestPattern); + LLVM_DEBUG(llvm::dbgs() << "Trying to match \"" + << bestPattern->getDebugName() << "\"\n"); - LLVM_DEBUG(llvm::dbgs() - << "Trying to match \"" << pattern->getDebugName() << "\"\n"); + const auto *pattern = static_cast(bestPattern); result = pattern->matchAndRewrite(op, rewriter); - LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result " - << succeeded(result) << "\n"); - if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) - result = failure(); + LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName() + << "\" result " << succeeded(result) << "\n"); } + + // Process the result of the pattern application. + if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern))) + result = failure(); if (succeeded(result)) { LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); break; diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -93,10 +93,12 @@ os << "} // end namespace\n\n"; // Emit function to add the generated matchers to the pattern list. - os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" - "::mlir::RewritePatternSet &patterns) {\n"; + os << "template \n" + "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" + "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n"; for (const auto &name : patternNames) - os << " patterns.add<" << name << ">(patterns.getContext());\n"; + os << " patterns.add<" << name + << ">(patterns.getContext(), configs...);\n"; os << "}\n"; } @@ -104,14 +106,15 @@ StringSet<> &nativeFunctions) { const char *patternClassStartStr = R"( struct {0} : ::mlir::PDLPatternModule {{ - {0}(::mlir::MLIRContext *context) + template + {0}(::mlir::MLIRContext *context, ConfigsT &&...configs) : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( )"; os << llvm::formatv(patternClassStartStr, patternName); os << "R\"mlir("; pattern->print(os, OpPrintingFlags().enableDebugInfo()); - os << "\n )mlir\", context)) {\n"; + os << "\n )mlir\", context), std::forward(configs)...) {\n"; // Register any native functions used within the pattern. StringSet<> registeredNativeFunctions; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3272,6 +3272,76 @@ return llvm::None; } +//===----------------------------------------------------------------------===// +// PDL Configuration +//===----------------------------------------------------------------------===// + +void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + rewriterImpl.currentTypeConverter = getTypeConverter(); +} + +void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + rewriterImpl.currentTypeConverter = nullptr; +} + +/// Remap the given value using the rewriter and the type converter in the +/// provided config. +static FailureOr> +pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) { + SmallVector mappedValues; + if (failed(rewriter.getRemappedValues(values, mappedValues))) + return failure(); + return std::move(mappedValues); +} + +void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { + patterns.getPDLPatterns().registerRewriteFunction( + "convertValue", + [](PatternRewriter &rewriter, Value value) -> FailureOr { + auto results = pdllConvertValues( + static_cast(rewriter), value); + if (failed(results)) + return failure(); + return results->front(); + }); + patterns.getPDLPatterns().registerRewriteFunction( + "convertValues", [](PatternRewriter &rewriter, ValueRange values) { + return pdllConvertValues( + static_cast(rewriter), values); + }); + patterns.getPDLPatterns().registerRewriteFunction( + "convertType", + [](PatternRewriter &rewriter, Type type) -> FailureOr { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + if (TypeConverter *converter = rewriterImpl.currentTypeConverter) { + if (Type newType = converter->convertType(type)) + return newType; + return failure(); + } + return type; + }); + patterns.getPDLPatterns().registerRewriteFunction( + "convertTypes", + [](PatternRewriter &rewriter, + TypeRange types) -> FailureOr> { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + TypeConverter *converter = rewriterImpl.currentTypeConverter; + if (!converter) + return SmallVector(types); + + SmallVector remappedTypes; + if (failed(converter->convertTypes(types, remappedTypes))) + return failure(); + return std::move(remappedTypes); + }); +} + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/test-dialect-conversion-pdll.mlir b/mlir/test/Transforms/test-dialect-conversion-pdll.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-dialect-conversion-pdll.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -test-dialect-conversion-pdll | FileCheck %s + +// CHECK-LABEL: @TestSingleConversion +func.func @TestSingleConversion() { + // CHECK: %[[CAST:.*]] = "test.cast"() : () -> f64 + // CHECK-NEXT: "test.return"(%[[CAST]]) : (f64) -> () + %result = "test.cast"() : () -> (i64) + "test.return"(%result) : (i64) -> () +} + +// CHECK-LABEL: @TestLingeringConversion +func.func @TestLingeringConversion() -> i64 { + // CHECK: %[[ORIG_CAST:.*]] = "test.cast"() : () -> f64 + // CHECK: %[[MATERIALIZE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ORIG_CAST]] : f64 to i64 + // CHECK-NEXT: return %[[MATERIALIZE_CAST]] : i64 + %result = "test.cast"() : () -> (i64) + return %result : i64 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,8 +1,18 @@ +add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen + TestDialectConversion.pdll + TestDialectConversionPDLLPatterns.h.inc + + EXTRA_INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test + ) + # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp + TestDialectConversion.cpp TestInlining.cpp TestIntRangeInference.cpp TestTopologicalSort.cpp @@ -12,8 +22,12 @@ ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms + DEPENDS + MLIRTestDialectConversionPDLLPatternsIncGen + LINK_LIBS PUBLIC MLIRAnalysis + MLIRFuncDialect MLIRInferIntRangeInterface MLIRTestDialect MLIRTransforms diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp @@ -0,0 +1,96 @@ +//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace test; + +//===----------------------------------------------------------------------===// +// Test PDLL Support +//===----------------------------------------------------------------------===// + +#include "TestDialectConversionPDLLPatterns.h.inc" + +namespace { +struct PDLLTypeConverter : public TypeConverter { + PDLLTypeConverter() { + addConversion(convertType); + addArgumentMaterialization(materializeCast); + addSourceMaterialization(materializeCast); + } + + static LogicalResult convertType(Type t, SmallVectorImpl &results) { + // Convert I64 to F64. + if (t.isSignlessInteger(64)) { + results.push_back(FloatType::getF64(t.getContext())); + return success(); + } + + // Otherwise, convert the type directly. + results.push_back(t); + return success(); + } + /// Hook for materializing a conversion. + static Optional materializeCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create(loc, resultType, inputs) + .getResult(0); + } +}; + +struct TestDialectConversionPDLLPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass) + + StringRef getArgument() const final { return "test-dialect-conversion-pdll"; } + StringRef getDescription() const final { + return "Test DialectConversion PDLL functionality"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + LogicalResult initialize(MLIRContext *ctx) override { + // Build the pattern set within the `initialize` to avoid recompiling PDL + // patterns during each `runOnOperation` invocation. + RewritePatternSet patternList(ctx); + registerConversionPDLFunctions(patternList); + populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter)); + patterns = std::move(patternList); + return success(); + } + + void runOnOperation() final { + mlir::ConversionTarget target(getContext()); + target.addLegalOp(); + target.addDynamicallyLegalDialect( + [this](Operation *op) { return converter.isLegal(op); }); + + if (failed(mlir::applyFullConversion(getOperation(), target, patterns))) + signalPassFailure(); + } + + FrozenRewritePatternSet patterns; + PDLLTypeConverter converter; +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestDialectConversionPasses() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Transforms/TestDialectConversion.pdll b/mlir/test/lib/Transforms/TestDialectConversion.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDialectConversion.pdll @@ -0,0 +1,19 @@ +//===- TestPDLL.pdll - Test PDLL functionality ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestOps.td" +#include "mlir/Transforms/DialectConversion.pdll" + +/// Change the result type of a producer. +// FIXME: We shouldn't need to specify arguments for the result cast. +Pattern => replace op(args: ValueRange) -> (results: TypeRange) + with op(args) -> (convertTypes(results)); + +/// Pass through test.return conversion. +Pattern => replace op(args: ValueRange) + with op(convertValues(args)); diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes.remove('.pdll') diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -5,18 +5,19 @@ // check that we handle overlap. // CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule { +// CHECK: template // CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( // CHECK: R"mlir( // CHECK: pdl.pattern // CHECK: operation "test.op" -// CHECK: )mlir", context)) +// CHECK: )mlir", context), std::forward(configs)...) // CHECK: struct NamedPattern : ::mlir::PDLPatternModule { // CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( // CHECK: R"mlir( // CHECK: pdl.pattern // CHECK: operation "test.op2" -// CHECK: )mlir", context)) +// CHECK: )mlir", context), std::forward(configs)...) // CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule { @@ -25,13 +26,13 @@ // CHECK: R"mlir( // CHECK: pdl.pattern // CHECK: operation "test.op3" -// CHECK: )mlir", context)) +// CHECK: )mlir", context), std::forward(configs)...) -// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns) { -// CHECK-NEXT: patterns.add(patterns.getContext()); -// CHECK-NEXT: patterns.add(patterns.getContext()); -// CHECK-NEXT: patterns.add(patterns.getContext()); -// CHECK-NEXT: patterns.add(patterns.getContext()); +// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) { +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); // CHECK-NEXT: } Pattern => erase op; diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -76,6 +76,7 @@ void registerTestDeadCodeAnalysisPass(); void registerTestDecomposeCallGraphTypes(); void registerTestDiagnosticsPass(); +void registerTestDialectConversionPasses(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestExpandMathPass(); @@ -170,6 +171,7 @@ mlir::test::registerTestConstantFold(); mlir::test::registerTestControlFlowSink(); mlir::test::registerTestDiagnosticsPass(); + mlir::test::registerTestDialectConversionPasses(); #if MLIR_CUDA_CONVERSIONS_ENABLED mlir::test::registerTestGpuSerializeToCubinPass(); #endif