diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -603,7 +603,7 @@ }; //===----------------------------------------------------------------------===// -// PDLPatternModule +// PDL Patterns //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// @@ -796,6 +796,108 @@ SmallVector> allocatedValueRanges; }; +//===----------------------------------------------------------------------===// +// PDLPatternConfig + +/// An individual configuration for a pattern, which can be accessed by native +/// functions via the PDLPatternConfigSet. This allows for injecting additional +/// configuration into PDL patterns that is specific to certain compilation +/// flows. +class PDLPatternConfig { +public: + virtual ~PDLPatternConfig() = default; + + /// Hooks that are invoked at the beginning and end of a rewrite of a matched + /// pattern. These can be used to setup any specific state necessary for the + /// rewrite. + virtual void notifyRewriteBegin(PatternRewriter &rewriter) {} + virtual void notifyRewriteEnd(PatternRewriter &rewriter) {} + + /// Return the TypeID that represents this configuration. + TypeID getTypeID() const { return id; } + +protected: + PDLPatternConfig(TypeID id) : id(id) {} + +private: + TypeID id; +}; + +/// This class provides a base class for users implementing a type of pattern +/// configuration. +template +class PDLPatternConfigBase : public PDLPatternConfig { +public: + /// Support LLVM style casting. + static bool classof(const PDLPatternConfig *config) { + return config->getTypeID() == getConfigID(); + } + + /// Return the type id used for this configuration. + static TypeID getConfigID() { return TypeID::get(); } + +protected: + PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {} +}; + +/// This class contains a set of configurations for a specific pattern. +/// Configurations are uniqued by TypeID, meaning that only one configuration of +/// each type is allowed. +class PDLPatternConfigSet { +public: + PDLPatternConfigSet() = default; + + /// Construct a set with the given configurations. + template + PDLPatternConfigSet(ConfigsT &&...configs) { + (addConfig(std::forward(configs)), ...); + } + + /// Get the configuration defined by the given type. Asserts that the + /// configuration of the provided type exists. + template + const T &get() const { + const T *config = tryGet(); + assert(config && "configuration not found"); + return *config; + } + + /// Get the configuration defined by the given type, returns nullptr if the + /// configuration does not exist. + template + const T *tryGet() const { + for (const auto &configIt : configs) + if (const T *config = dyn_cast(configIt.get())) + return config; + return nullptr; + } + + /// Notify the configurations within this set at the beginning and end of a + /// rewrite of a matched pattern. + void notifyRewriteBegin(PatternRewriter &rewriter) { + for (const auto &config : configs) + config->notifyRewriteBegin(rewriter); + } + void notifyRewriteEnd(PatternRewriter &rewriter) { + for (const auto &config : configs) + config->notifyRewriteEnd(rewriter); + } + +protected: + /// Add a configuration to the set. + template + void addConfig(T &&config) { + assert(!tryGet>() && "configuration already exists"); + configs.emplace_back( + std::make_unique>(std::forward(config))); + } + + /// The set of configurations for this pattern. This uses a vector instead of + /// a map with the expectation that the number of configurations per set is + /// small (<= 1). + SmallVector> configs; +}; + //===----------------------------------------------------------------------===// // PDLPatternModule @@ -1034,6 +1136,13 @@ results.push_back(types); } }; +template +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + SmallVector values) { + results.push_back(TypeRange(values)); + } +}; //===----------------------------------------------------------------------===// // Value @@ -1061,6 +1170,13 @@ results.push_back(values); } }; +template +struct ProcessPDLValue> { + static void processAsResult(PatternRewriter &, PDLResultList &results, + SmallVector values) { + results.push_back(ValueRange(values)); + } +}; //===----------------------------------------------------------------------===// // PDL Function Builder: Argument Handling @@ -1246,6 +1362,9 @@ } // namespace pdl_function_builder } // namespace detail +//===----------------------------------------------------------------------===// +// PDLPatternModule + /// This class contains all of the necessary data for a set of PDL patterns, or /// pattern rewrites specified in the form of the PDL dialect. This PDL module /// contained by this pattern may contain any number of `pdl.pattern` @@ -1254,9 +1373,17 @@ public: PDLPatternModule() = default; - /// Construct a PDL pattern with the given module. - PDLPatternModule(OwningOpRef pdlModule) - : pdlModule(std::move(pdlModule)) {} + /// Construct a PDL pattern with the given module and configurations. + PDLPatternModule(OwningOpRef module) + : pdlModule(std::move(module)) {} + template + PDLPatternModule(OwningOpRef module, ConfigsT &&...patternConfigs) + : PDLPatternModule(std::move(module)) { + auto configSet = std::make_unique( + std::forward(patternConfigs)...); + attachConfigToPatterns(*pdlModule, *configSet); + configs.emplace_back(std::move(configSet)); + } /// Merge the state in `other` into this pattern module. void mergeIn(PDLPatternModule &&other); @@ -1342,6 +1469,11 @@ return rewriteFunctions; } + /// Return the set of the registered pattern configs. + SmallVector> takeConfigs() { + return std::move(configs); + } + /// Clear out the patterns and functions within this module. void clear() { pdlModule = nullptr; @@ -1350,9 +1482,16 @@ } private: + /// Attach the given pattern config set to the patterns defined within the + /// given module. + void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet); + /// The module containing the `pdl.pattern` operations. OwningOpRef pdlModule; + /// The set of configuration sets referenced by patterns within `pdlModule`. + SmallVector> configs; + /// The external functions referenced from within the PDL module. llvm::StringMap constraintFunctions; llvm::StringMap rewriteFunctions; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -888,6 +888,35 @@ MLIRContext &ctx; }; +//===----------------------------------------------------------------------===// +// PDL Configuration +//===----------------------------------------------------------------------===// + +/// A PDL configuration that is used to supported dialect conversion +/// functionality. +class PDLConversionConfig final + : public PDLPatternConfigBase { +public: + PDLConversionConfig(TypeConverter *converter) : converter(converter) {} + ~PDLConversionConfig() final = default; + + /// Return the type converter used by this configuration, which may be nullptr + /// if no type conversions are expected. + TypeConverter *getTypeConverter() const { return converter; } + + /// Hooks that are invoked at the beginning and end of a rewrite of a matched + /// pattern. + void notifyRewriteBegin(PatternRewriter &rewriter) final; + void notifyRewriteEnd(PatternRewriter &rewriter) final; + +private: + /// An optional type converter to use for the pattern. + TypeConverter *converter; +}; + +/// Register the dialect conversion PDL functions with the given pattern set. +void registerConversionPDLFunctions(RewritePatternSet &patterns); + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.pdll b/mlir/include/mlir/Transforms/DialectConversion.pdll new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/DialectConversion.pdll @@ -0,0 +1,30 @@ +//===- DialectConversion.pdll - DialectConversion PDLL Support --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various utilities for interacting with dialect conversion +// within PDLL. +// +//===----------------------------------------------------------------------===// + +/// This rewrite returns the converted value of `value`, whose type is defined +/// by the type converted specified in the `PDLConversionConfig` of the current +/// pattern. +Rewrite convertValue(value: Value) -> Value; + +/// This rewrite returns the converted values of `values`, whose type is defined +/// by the type converted specified in the `PDLConversionConfig` of the current +/// pattern. +Rewrite convertValues(values: ValueRange) -> ValueRange; + +/// This rewrite returns the converted type of `type` as defined by the type +/// converted specified in the `PDLConversionConfig` of the current pattern. +Rewrite convertType(type: Type) -> Type; + +/// This rewrite returns the converted types of `types` as defined by the type +/// converted specified in the `PDLConversionConfig` of the current pattern. +Rewrite convertTypes(types: TypeRange) -> TypeRange; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -158,11 +158,13 @@ if (!other.pdlModule) return; - // Steal the functions of the other module. + // Steal the functions and config of the other module. for (auto &it : other.constraintFunctions) registerConstraintFunction(it.first(), std::move(it.second)); for (auto &it : other.rewriteFunctions) registerRewriteFunction(it.first(), std::move(it.second)); + for (auto &it : other.configs) + configs.emplace_back(std::move(it)); // Steal the other state if we have no patterns. if (!pdlModule) { @@ -176,6 +178,21 @@ other.pdlModule->getBody()->getOperations()); } +void PDLPatternModule::attachConfigToPatterns(ModuleOp module, + PDLPatternConfigSet &configSet) { + // Attach the configuration as an opaque location to the symbols within the + // module. We use locations to avoid other more expensive and complex means of + // bookkeeping. + module->walk([&](Operation *op) { + // We only add to symbols to avoid hardcoding any specific operation names + // here (given that we don't depend on any PDL dialect). We can't use + // cast here because patterns may be optional symbols. + if (!op->hasTrait()) + return; + op->setLoc(OpaqueLoc::get(&configSet, op->getLoc())); + }); +} + //===----------------------------------------------------------------------===// // Function Registry diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -43,14 +43,21 @@ /// Return the bytecode address of the rewriter for this pattern. ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } + /// Return the configuration set for this pattern, or null if there is none. + PDLPatternConfigSet *getConfigSet() const { return configSet; } + private: template - PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) - : Pattern(std::forward(patternArgs)...), - rewriterAddr(rewriterAddr) {} + PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet, + Args &&...patternArgs) + : Pattern(std::forward(patternArgs)...), rewriterAddr(rewriterAddr), + configSet(configSet) {} /// The address of the rewriter for this pattern. ByteCodeAddr rewriterAddr; + + /// The optional config set for this pattern. + PDLPatternConfigSet *configSet; }; //===----------------------------------------------------------------------===// @@ -148,6 +155,7 @@ /// Create a ByteCode instance from the given module containing operations in /// the PDL interpreter dialect. PDLByteCode(ModuleOp module, + SmallVector> configs, llvm::StringMap constraintFns, llvm::StringMap rewriteFns); @@ -177,6 +185,9 @@ PDLByteCodeMutableState &state, SmallVectorImpl *matches) const; + /// The set of pattern configs referenced within the bytecode. + SmallVector> configs; + /// A vector containing pointers to uniqued data. The storage is intentionally /// opaque such that we can store a wide range of data types. The types of /// data stored here include: diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -35,20 +35,26 @@ PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, ByteCodeAddr rewriterAddr) { + PatternBenefit benefit = matchOp.getBenefit(); + MLIRContext *ctx = matchOp.getContext(); + + // Collect the set of generated operations. SmallVector generatedOps; if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) generatedOps = llvm::to_vector<8>(generatedOpsAttr.getAsValueRange()); - PatternBenefit benefit = matchOp.getBenefit(); - MLIRContext *ctx = matchOp.getContext(); + // Check to see if this pattern has a config set. + PDLPatternConfigSet *configSet = + OpaqueLoc::getUnderlyingLocationOrNull( + matchOp.getLoc()); // Check to see if this is pattern matches a specific operation type. if (Optional rootKind = matchOp.getRootKind()) - return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, + return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, generatedOps); - return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, - generatedOps); + return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), + benefit, ctx, generatedOps); } //===----------------------------------------------------------------------===// @@ -1014,9 +1020,11 @@ // PDLByteCode //===----------------------------------------------------------------------===// -PDLByteCode::PDLByteCode(ModuleOp module, - llvm::StringMap constraintFns, - llvm::StringMap rewriteFns) { +PDLByteCode::PDLByteCode( + ModuleOp module, SmallVector> configs, + llvm::StringMap constraintFns, + llvm::StringMap rewriteFns) + : configs(std::move(configs)) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, @@ -2193,6 +2201,10 @@ /// Run the rewriter of the given pattern on the root operation `op`. void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const { + auto *configSet = match.pattern->getConfigSet(); + if (configSet) + configSet->notifyRewriteBegin(rewriter); + // The arguments of the rewrite function are stored at the start of the // memory buffer. llvm::copy(match.values, state.memory.begin()); @@ -2205,4 +2217,7 @@ rewriterByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); + + if (configSet) + configSet->notifyRewriteEnd(rewriter); } diff --git a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp --- a/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternSet.cpp @@ -129,7 +129,8 @@ // Generate the pdl bytecode. impl->pdlByteCode = std::make_unique( - pdlModule, pdlPatterns.takeConstraintFunctions(), + pdlModule, pdlPatterns.takeConfigs(), + pdlPatterns.takeConstraintFunctions(), pdlPatterns.takeRewriteFunctions()); } diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -93,10 +93,12 @@ os << "} // end namespace\n\n"; // Emit function to add the generated matchers to the pattern list. - os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" - "::mlir::RewritePatternSet &patterns) {\n"; + os << "template \n" + "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" + "::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n"; for (const auto &name : patternNames) - os << " patterns.add<" << name << ">(patterns.getContext());\n"; + os << " patterns.add<" << name + << ">(patterns.getContext(), configs...);\n"; os << "}\n"; } @@ -104,14 +106,15 @@ StringSet<> &nativeFunctions) { const char *patternClassStartStr = R"( struct {0} : ::mlir::PDLPatternModule {{ - {0}(::mlir::MLIRContext *context) + template + {0}(::mlir::MLIRContext *context, ConfigsT &&...configs) : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( )"; os << llvm::formatv(patternClassStartStr, patternName); os << "R\"mlir("; pattern->print(os, OpPrintingFlags().enableDebugInfo()); - os << "\n )mlir\", context)) {\n"; + os << "\n )mlir\", context), std::forward(configs)...) {\n"; // Register any native functions used within the pattern. StringSet<> registeredNativeFunctions; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -904,6 +904,22 @@ // Type Conversion //===--------------------------------------------------------------------===// + /// Push a new type converter scope onto the rewritee. + void pushTypeConverter(TypeConverter *converter) { + typeConverterStack.push_back(converter); + } + + /// Pop the current type converter from the rewriter. + void popTypeConverter() { + assert(!typeConverterStack.empty() && "type converter stack is empty"); + typeConverterStack.pop_back(); + } + + /// Return the current type converter. + TypeConverter *getCurrentTypeConverter() const { + return typeConverterStack.empty() ? nullptr : typeConverterStack.back(); + } + /// Convert the signature of the given block. FailureOr convertBlockSignature( Block *block, TypeConverter *converter, @@ -1004,9 +1020,9 @@ /// 1->N conversion of some kind. SmallVector operationsWithChangedResults; - /// The current type converter, or nullptr if no type converter is currently - /// active. - TypeConverter *currentTypeConverter = nullptr; + /// The current stack of type converters, which may contain null entries if no + /// type converter is currently active for that level of the stack. + SmallVector typeConverterStack; /// This allows the user to collect the match failure message. function_ref notifyCallback; @@ -1256,6 +1272,7 @@ StringRef valueDiagTag, Optional inputLoc, PatternRewriter &rewriter, ValueRange values, SmallVectorImpl &remapped) { + TypeConverter *currentTypeConverter = getCurrentTypeConverter(); remapped.reserve(llvm::size(values)); SmallVector legalTypes; @@ -1420,7 +1437,8 @@ operationsWithChangedResults.push_back(replacements.size()); // Record the requested operation replacement. - replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter))); + replacements.insert( + std::make_pair(op, OpReplacement(getCurrentTypeConverter()))); // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. @@ -1690,16 +1708,19 @@ auto &rewriterImpl = dialectRewriter.getImpl(); // Track the current conversion pattern type converter in the rewriter. - llvm::SaveAndRestore currentConverterGuard( - rewriterImpl.currentTypeConverter, getTypeConverter()); + rewriterImpl.pushTypeConverter(getTypeConverter()); - // Remap the operands of the operation. + // Remap the operands of the operation and attempt to match this pattern. SmallVector operands; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, - op->getOperands(), operands))) { + op->getOperands(), operands)) || + failed(matchAndRewrite(op, operands, dialectRewriter))) { return failure(); } - return matchAndRewrite(op, operands, dialectRewriter); + + // Pop the current converter from the rewriter. + rewriterImpl.popTypeConverter(); + return success(); } //===----------------------------------------------------------------------===// @@ -3249,6 +3270,75 @@ return llvm::None; } +//===----------------------------------------------------------------------===// +// PDL Configuration +//===----------------------------------------------------------------------===// + +void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + rewriterImpl.pushTypeConverter(getTypeConverter()); +} + +void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + rewriterImpl.popTypeConverter(); +} + +/// Remap the given value using the rewriter and the type converter in the +/// provided config. +static SmallVector pdllConvertValues(ConversionPatternRewriter &rewriter, + ValueRange values) { + // If remapping fails, the best we can do for now is to error out. This + // is because PDL rewrites aren't failable, though maybe one day they + // could be (given that dialect conversion supports rollback). + SmallVector mappedValues; + if (failed(rewriter.getRemappedValues(values, mappedValues))) + llvm::report_fatal_error("dialect conversion failed to remap values"); + return mappedValues; +} + +void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { + patterns.getPDLPatterns().registerRewriteFunction( + "convertValue", [](PatternRewriter &rewriter, Value value) { + auto results = pdllConvertValues( + static_cast(rewriter), value); + return results.front(); + }); + patterns.getPDLPatterns().registerRewriteFunction( + "convertValues", [](PatternRewriter &rewriter, ValueRange values) { + return pdllConvertValues( + static_cast(rewriter), values); + }); + patterns.getPDLPatterns().registerRewriteFunction( + "convertType", [](PatternRewriter &rewriter, Type type) { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + if (TypeConverter *converter = rewriterImpl.getCurrentTypeConverter()) + return converter->convertType(type); + return type; + }); + patterns.getPDLPatterns().registerRewriteFunction( + "convertTypes", + [](PatternRewriter &rewriter, TypeRange types) -> SmallVector { + auto &rewriterImpl = + static_cast(rewriter).getImpl(); + TypeConverter *converter = rewriterImpl.getCurrentTypeConverter(); + if (!converter) + return types; + + // If remapping fails, the best we can do for now is to error out. This + // is because PDL rewrites aren't failable, though maybe one day they + // could be (given that dialect conversion supports rollback). + SmallVector remappedTypes; + if (failed(converter->convertTypes(types, remappedTypes))) + llvm::report_fatal_error( + "dialect conversion failed to convert types"); + return remappedTypes; + }); +} + //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/test-dialect-conversion-pdll.mlir b/mlir/test/Transforms/test-dialect-conversion-pdll.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-dialect-conversion-pdll.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -test-dialect-conversion-pdll | FileCheck %s + +// CHECK-LABEL: @TestSingleConversion +func.func @TestSingleConversion() { + // CHECK: %[[CAST:.*]] = "test.cast"() : () -> f64 + // CHECK-NEXT: "test.return"(%[[CAST]]) : (f64) -> () + %result = "test.cast"() : () -> (i64) + "test.return"(%result) : (i64) -> () +} + +// CHECK-LABEL: @TestLingeringConversion +func.func @TestLingeringConversion() -> i64 { + // CHECK: %[[ORIG_CAST:.*]] = "test.cast"() : () -> f64 + // CHECK: %[[MATERIALIZE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ORIG_CAST]] : f64 to i64 + // CHECK-NEXT: return %[[MATERIALIZE_CAST]] : i64 + %result = "test.cast"() : () -> (i64) + return %result : i64 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,8 +1,18 @@ +add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen + TestDialectConversion.pdll + TestDialectConversionPDLLPatterns.h.inc + + EXTRA_INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test + ) + # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp + TestDialectConversion.cpp TestInlining.cpp TestIntRangeInference.cpp TestTopologicalSort.cpp @@ -12,8 +22,12 @@ ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms + DEPENDS + MLIRTestDialectConversionPDLLPatternsIncGen + LINK_LIBS PUBLIC MLIRAnalysis + MLIRFuncDialect MLIRInferIntRangeInterface MLIRTestDialect MLIRTransforms diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp @@ -0,0 +1,96 @@ +//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace test; + +//===----------------------------------------------------------------------===// +// Test PDLL Support +//===----------------------------------------------------------------------===// + +#include "TestDialectConversionPDLLPatterns.h.inc" + +namespace { +struct PDLLTypeConverter : public TypeConverter { + PDLLTypeConverter() { + addConversion(convertType); + addArgumentMaterialization(materializeCast); + addSourceMaterialization(materializeCast); + } + + static LogicalResult convertType(Type t, SmallVectorImpl &results) { + // Convert I64 to F64. + if (t.isSignlessInteger(64)) { + results.push_back(FloatType::getF64(t.getContext())); + return success(); + } + + // Otherwise, convert the type directly. + results.push_back(t); + return success(); + } + /// Hook for materializing a conversion. + static Optional materializeCast(OpBuilder &builder, Type resultType, + ValueRange inputs, Location loc) { + return builder.create(loc, resultType, inputs) + .getResult(0); + } +}; + +struct TestDialectConversionPDLLPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass) + + StringRef getArgument() const final { return "test-dialect-conversion-pdll"; } + StringRef getDescription() const final { + return "Test DialectConversion PDLL functionality"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + LogicalResult initialize(MLIRContext *ctx) override { + // Build the pattern set within the `initialize` to avoid recompiling PDL + // patterns during each `runOnOperation` invocation. + RewritePatternSet patternList(ctx); + registerConversionPDLFunctions(patternList); + populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter)); + patterns = std::move(patternList); + return success(); + } + + void runOnOperation() final { + mlir::ConversionTarget target(getContext()); + target.addLegalOp(); + target.addDynamicallyLegalDialect( + [this](Operation *op) { return converter.isLegal(op); }); + + if (failed(mlir::applyFullConversion(getOperation(), target, patterns))) + signalPassFailure(); + } + + FrozenRewritePatternSet patterns; + PDLLTypeConverter converter; +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestDialectConversionPasses() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/lib/Transforms/TestDialectConversion.pdll b/mlir/test/lib/Transforms/TestDialectConversion.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDialectConversion.pdll @@ -0,0 +1,19 @@ +//===- TestPDLL.pdll - Test PDLL functionality ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestOps.td" +#include "mlir/Transforms/DialectConversion.pdll" + +/// Change the result type of a producer. +// FIXME: We shouldn't need to specify arguments for the result cast. +Pattern => replace op(args: ValueRange) -> (results: TypeRange) + with op(args) -> (convertTypes(results)); + +/// Pass through test.return conversion. +Pattern => replace op(args: ValueRange) + with op(convertValues(args)); diff --git a/mlir/test/lib/Transforms/lit.local.cfg b/mlir/test/lib/Transforms/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes.remove('.pdll') diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -5,18 +5,19 @@ // check that we handle overlap. // CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule { +// CHECK: template // CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( // CHECK: R"mlir( // CHECK: pdl.pattern // CHECK: operation "test.op" -// CHECK: )mlir", context)) +// CHECK: )mlir", context), std::forward(configs)...) // CHECK: struct NamedPattern : ::mlir::PDLPatternModule { // CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( // CHECK: R"mlir( // CHECK: pdl.pattern // CHECK: operation "test.op2" -// CHECK: )mlir", context)) +// CHECK: )mlir", context), std::forward(configs)...) // CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule { @@ -25,13 +26,13 @@ // CHECK: R"mlir( // CHECK: pdl.pattern // CHECK: operation "test.op3" -// CHECK: )mlir", context)) +// CHECK: )mlir", context), std::forward(configs)...) -// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns) { -// CHECK-NEXT: patterns.add(patterns.getContext()); -// CHECK-NEXT: patterns.add(patterns.getContext()); -// CHECK-NEXT: patterns.add(patterns.getContext()); -// CHECK-NEXT: patterns.add(patterns.getContext()); +// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) { +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); +// CHECK-NEXT: patterns.add(patterns.getContext(), configs...); // CHECK-NEXT: } Pattern => erase op; diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -75,6 +75,7 @@ void registerTestDeadCodeAnalysisPass(); void registerTestDecomposeCallGraphTypes(); void registerTestDiagnosticsPass(); +void registerTestDialectConversionPasses(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestExpandMathPass(); @@ -165,6 +166,7 @@ mlir::test::registerTestConstantFold(); mlir::test::registerTestControlFlowSink(); mlir::test::registerTestDiagnosticsPass(); + mlir::test::registerTestDialectConversionPasses(); #if MLIR_CUDA_CONVERSIONS_ENABLED mlir::test::registerTestGpuSerializeToCubinPass(); #endif