diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -16,6 +16,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace transform { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -240,9 +240,13 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> { let description = [{ - This interface should be implemented by ops that select patterns of a - `transform.apply_patterns` op. It provides a method to populate a rewrite + This interface should be implemented by ops that select rewrite patterns of + a `transform.apply_patterns` op. It provides a method to populate a rewrite pattern set with patterns. + + Note: Conversion patterns are rewrite patterns in MLIR, but they should not + be populated with `PatternDescriptorOpInterface` because they cannot be + used in a greedy pattern rewrite. }]; let cppNamespace = "::mlir::transform"; @@ -250,11 +254,73 @@ let methods = [ InterfaceMethod< /*desc=*/[{ - Populate patterns into the given pattern set. + Populate rewrite patterns into the given pattern set. }], /*returnType=*/"void", /*name=*/"populatePatterns", - /*arguments=*/(ins "RewritePatternSet &":$patterns) + /*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns) + >, + ]; +} + +def ConversionPatternDescriptorOpInterface + : OpInterface<"ConversionPatternDescriptorOpInterface"> { + let description = [{ + This interface should be implemented by ops that select conversion patterns + of a `transform.apply_patterns` op. It provides a method to populate a + rewrite pattern set with conversion patterns. + + Note: Non-conversion rewrite patterns should not be populated with + `ConversionPatternDescriptorOpInterface` because it is not generally safe + to use non-conversion rewrite patterns as part of a dialect conversion. + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate conversion patterns into the given pattern set with the + given type converter. + }], + /*returnType=*/"void", + /*name=*/"populatePatterns", + /*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter, + "::mlir::RewritePatternSet &":$patterns) + >, + InterfaceMethod< + /*desc=*/[{ + Return the type converter to be used with this pattern set. If no + type converter is specified, the default type converter of the enclosing + "apply_conversion_patterns" op is used. + }], + /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>", + /*name=*/"getTypeConverter", + /*arguments=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"return nullptr;" + >, + ]; +} + +def TypeConverterBuilderOpInterface + : OpInterface<"TypeConverterBuilderOpInterface"> { + let description = [{ + This interface should be implemented by ops that specify a type converter + for a dialect conversion. Such ops can be used with + "apply_conversion_patterns". + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the type converter to be used with a dialect conversion. + }], + /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>", + /*name=*/"getTypeConverter", + /*arguments=*/(ins) >, ]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -156,6 +156,84 @@ }]; } +def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait] + # GraphRegionNoTerminator.traits> { + let summary = "Applies conversion patterns to the body of the targeted op"; + let description = [{ + This transform applies the specified conversion patterns to the targeted op + and all nested ops. By default, this transform applies a "full" dialect + conversion. If the `partial_conversion` unit attribute is present, this + transform applies a partial dialect conversion. + + The patterns that should be applied are specified in the first graph region + of this op. They must implement the + `ConversionPatternDescriptorOpInterface`. The order in which patterns are + applied is unspecified; i.e., the ordering of ops in the region of this op + is irrelevant. + + The second, optional graph region contains exactly one op that specifies + default type converter that should be used with this dialect conversion. If + provided, this op must implement the `TypeConverterBuilderOpInterface`. + Type converters are a property of conversion patterns: each conversion + pattern stores the type converter that should be used in its C++ class. Each + conversion pattern descriptor can optionally specify a type converter in its + `getTypeConverter` interface method. If no type converter is specified in + this method, the default type converter of the dialect conversion is used. + Default type converters are useful if the same type converter should be used + for multiple sets of conversion patterns. (Patterns that should not use this + default type converter specify their own type converter.) + + The `legal_ops`, `illegal_ops`, `legal_dialects`, `illegal_dialects` + attributes specify the conversion target. At least one of those four + attributes must be specified. + + This transform consumes the `target` handle and modifies the payload. It + does not produce any handles. + + This transform fails silently if the dialect conversion was unsuccessful. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + OptionalAttr:$legal_ops, + OptionalAttr:$illegal_ops, + OptionalAttr:$legal_dialects, + OptionalAttr:$illegal_dialects, + UnitAttr:$partialConversion); + let results = (outs); + let regions = (region VariadicRegion>:$regions); + + let assemblyFormat = [{ + `to` $target $regions attr-dict `:` type($target) + }]; + let hasVerifier = 1; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + "Value":$target, + CArg<"function_ref", "nullptr">: + $patternsBodyBuilder, + CArg<"function_ref", "nullptr">: + $typeConverterBodyBuilder)>, + ]; + + let extraClassDeclaration = [{ + ::mlir::Region &getPatterns() { + return getRegion(0); + } + + ::mlir::transform::TypeConverterBuilderOpInterface getDefaultTypeConverter() { + if (getNumRegions() < 2) + return {}; + return ::llvm::cast<::mlir::transform::TypeConverterBuilderOpInterface>( + &getRegion(1).front().front()); + } + }]; +} + def ApplyDeadCodeEliminationOp : TransformDialectOp<"apply_dce", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -28,10 +28,15 @@ *RegisteredOperationName::lookup(name, context); assert((opName.hasInterface() || opName.hasInterface() || + opName.hasInterface() || + opName.hasInterface() || opName.hasTrait()) && "non-terminator ops injected into the transform dialect must " - "implement TransformOpInterface or PatternDescriptorOpInterface"); - if (!opName.hasInterface()) { + "implement TransformOpInterface or PatternDescriptorOpInterface or " + "ConversionPatternDescriptorOpInterface"); + if (!opName.hasInterface() && + !opName.hasInterface() && + !opName.hasInterface()) { assert(opName.hasInterface() && "ops injected into the transform dialect must implement " "MemoryEffectsOpInterface"); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/CSE.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/STLExtras.h" @@ -478,6 +479,158 @@ op.getCanonicalizationPatterns(patterns, ctx); } +//===----------------------------------------------------------------------===// +// ApplyConversionPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + MLIRContext *ctx = getContext(); + + // Default type converter is built on demand. + std::unique_ptr defaultTypeConverter; + + // Configure conversion target. + ConversionTarget conversionTarget(*ctx); + if (getLegalOps()) + for (Attribute attr : cast(*getLegalOps())) + conversionTarget.addLegalOp( + OperationName(cast(attr).getValue(), ctx)); + if (getIllegalOps()) + for (Attribute attr : cast(*getIllegalOps())) + conversionTarget.addIllegalOp( + OperationName(cast(attr).getValue(), ctx)); + if (getLegalDialects()) + for (Attribute attr : cast(*getLegalDialects())) + conversionTarget.addLegalDialect(cast(attr).getValue()); + if (getIllegalDialects()) + for (Attribute attr : cast(*getIllegalDialects())) + conversionTarget.addIllegalDialect(cast(attr).getValue()); + + // Gather all specified patterns. + RewritePatternSet patterns(ctx); + if (!getPatterns().empty()) { + for (Operation &op : getPatterns().front()) { + auto descriptor = + cast(&op); + + // Check if this pattern set specifies a type converter. + std::unique_ptr typeConverter = + descriptor.getTypeConverter(); + TypeConverter *converter = nullptr; + if (typeConverter) { + converter = typeConverter.get(); + } else { + // No type converter specified: Use the default type converter. + if (!defaultTypeConverter) { + // Instantiate the default type converter. + transform::TypeConverterBuilderOpInterface typeConverterBuilder = + getDefaultTypeConverter(); + if (!typeConverterBuilder) { + auto diag = emitDefiniteFailure() + << "pattern descriptor does not specify type " + "converter and apply_conversion_patterns op has " + "no default type converter"; + diag.attachNote(op.getLoc()) << "pattern descriptor op"; + return diag; + } + defaultTypeConverter = typeConverterBuilder.getTypeConverter(); + assert(defaultTypeConverter && "expected type converter"); + } + converter = defaultTypeConverter.get(); + } + descriptor.populatePatterns(*converter, patterns); + } + } + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + for (Operation *target : state.getPayloadOps(getTarget())) { + // Make sure that this transform is not applied to itself. Modifying the + // transform IR while it is being interpreted is generally dangerous. + DiagnosedSilenceableFailure payloadCheck = + ensurePayloadIsSeparateFromTransform(*this, target); + if (!payloadCheck.succeeded()) + return payloadCheck; + + LogicalResult status = failure(); + if (getPartialConversion()) { + status = applyPartialConversion(target, conversionTarget, frozenPatterns); + } else { + status = applyFullConversion(target, conversionTarget, frozenPatterns); + } + + if (failed(status)) { + auto diag = emitSilenceableError() << "dialect conversion failed"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + } + + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::ApplyConversionPatternsOp::verify() { + if (getNumRegions() != 1 && getNumRegions() != 2) + return emitOpError() << "expected 1 or 2 regions"; + if (!getPatterns().empty()) { + for (Operation &op : getPatterns().front()) { + if (!isa(&op)) { + InFlightDiagnostic diag = + emitOpError() << "expected pattern children ops to implement " + "ConversionPatternDescriptorOpInterface"; + diag.attachNote(op.getLoc()) << "op without interface"; + return diag; + } + } + } + if (getNumRegions() == 2) { + Region &typeConverterRegion = getRegion(1); + if (!llvm::hasSingleElement(typeConverterRegion.front())) + return emitOpError() + << "expected exactly one op in default type converter region"; + Operation *typeConverterOp = &typeConverterRegion.front().front(); + if (!isa(typeConverterOp)) { + InFlightDiagnostic diag = emitOpError() + << "expected default converter child op to " + "implement TypeConverterBuilderOpInterface"; + diag.attachNote(typeConverterOp->getLoc()) << "op without interface"; + return diag; + } + } + if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() && + !getIllegalDialects()) + return emitOpError() << "conversion target is not specified"; + return success(); +} + +void transform::ApplyConversionPatternsOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + +void transform::ApplyConversionPatternsOp::build( + OpBuilder &builder, OperationState &result, Value target, + function_ref patternsBodyBuilder, + function_ref typeConverterBodyBuilder) { + result.addOperands(target); + + { + OpBuilder::InsertionGuard g(builder); + Region *region1 = result.addRegion(); + builder.createBlock(region1); + if (patternsBodyBuilder) + patternsBodyBuilder(builder, result.location); + } + { + Region *region2 = result.addRegion(); + builder.createBlock(region2); + if (typeConverterBodyBuilder) + typeConverterBodyBuilder(builder, result.location); + } +} + //===----------------------------------------------------------------------===// // ApplyLoopInvariantCodeMotionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -231,3 +231,51 @@ transform.apply_patterns.canonicalization } {apply_cse} : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @full_dialect_conversion +// CHECK-NEXT: %[[m:.*]] = "test.new_op"() : () -> memref<5xf32> +// CHECK-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast %0 : memref<5xf32> to tensor<5xf32> +// CHECK-NEXT: return %[[cast]] +func.func @full_dialect_conversion() -> tensor<5xf32> { + %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (tensor<5xf32>) + return %0 : tensor<5xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.transform.test_conversion_patterns + }, { + transform.apply_conversion_patterns.transform.test_type_converter + } {illegal_ops = ["test.foo"], + legal_ops = ["func.func", "func.return", "test.new_op"]} + : !transform.any_op +} + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below{{conversion target is not specified}} + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.transform.test_conversion_patterns + }, { + transform.apply_conversion_patterns.transform.test_type_converter + } : !transform.any_op +} + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below{{pattern descriptor does not specify type converter and apply_conversion_patterns op has no default type converter}} + transform.apply_conversion_patterns to %0 { + // expected-note @below{{pattern descriptor op}} + transform.apply_conversion_patterns.transform.test_conversion_patterns + } {illegal_ops = ["test.foo"]} : !transform.any_op +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -915,6 +915,65 @@ transform::modifiesPayload(effects); } +namespace { +/// Test conversion pattern that replaces ops with the "replace_with_new_op" +/// attribute with "test.new_op". +class ReplaceWithNewOpConversion : public ConversionPattern { +public: + ReplaceWithNewOpConversion(TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, RewritePattern::MatchAnyOpTypeTag(), + /*benefit=*/1, context) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!op->hasAttr("replace_with_new_op")) + return failure(); + SmallVector newResultTypes; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newResultTypes))) + return failure(); + Operation *newOp = rewriter.create( + op->getLoc(), + OperationName("test.new_op", op->getContext()).getIdentifier(), + operands, newResultTypes); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; +} // namespace + +void mlir::test::ApplyTestConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + patterns.insert(typeConverter, + patterns.getContext()); +} + +namespace { +/// Test type converter that converts tensor types to memref types. +class TestTypeConverter : public TypeConverter { +public: + TestTypeConverter() { + addConversion([](RankedTensorType type) -> Type { + return MemRefType::get(type.getShape(), type.getElementType()); + }); + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + if (inputs.size() != 1) + return std::nullopt; + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + } +}; +} // namespace + +std::unique_ptr<::mlir::TypeConverter> +mlir::test::TestTypeConverterOp::getTypeConverter() { + return std::make_unique(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -554,6 +554,24 @@ let cppNamespace = "::mlir::test"; } +def ApplyTestConversionPatternsOp + : Op]> { + let arguments = (ins); + let results = (outs); + let assemblyFormat = "attr-dict"; + let cppNamespace = "::mlir::test"; +} + +def TestTypeConverterOp + : Op]> { + let arguments = (ins); + let results = (outs); + let assemblyFormat = "attr-dict"; + let cppNamespace = "::mlir::test"; +} + def TestReEnterRegionOp : Op,