diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -27,6 +27,17 @@ let assemblyFormat = "attr-dict"; } +def ApplySCFStructuralConversionPatternsOp : Op]> { + let description = [{ + Collects patterns for performing structural conversions of SCF operations. + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">; def GetParentForOp : Op { let description = [{ Given an scf.if conditional, inject user-defined information that it is - always safe to execute only the if or else branch. - + always safe to execute only the if or else branch. + This is achieved by just replacing the scf.if by the content of one of its branches. diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h @@ -53,6 +53,16 @@ TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); +/// Similar to `populateSCFStructuralTypeConversionsAndLegality` but does not +/// populate the conversion target. +void populateSCFStructuralTypeConversions(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +/// Updates the ConversionTarget with dynamic legality of SCF operations based +/// on the provided type converter. +void populateSCFStructuralTypeConversionTarget(TypeConverter &typeConverter, + ConversionTarget &target); + /// Populates the provided pattern set with patterns that do 1:N type /// conversions on (some) SCF ops. This is intended to be used with /// applyPartialOneToNConversion. diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -321,6 +321,18 @@ /*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter, "::mlir::RewritePatternSet &":$patterns) >, + InterfaceMethod< + /*desc=*/[{ + Update the ConversionTarget using the final TypeConverter. The default + implementation is to do nothing. + }], + /*returnType=*/"void", + /*name=*/"updateConversionTarget", + /*arguments=*/(ins "::mlir::TypeConverter &":$typeConverter, + "::mlir::ConversionTarget &":$conversionTarget), + /*methodBody=*/"", + /*defaultImplementation=*/"return;" + >, InterfaceMethod< /*desc=*/[{ Return the type converter to be used with this pattern set. If no diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -32,6 +32,17 @@ scf::populateSCFForLoopCanonicalizationPatterns(patterns); } +void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + scf::populateSCFStructuralTypeConversions(typeConverter, patterns); +} + +void transform::ApplySCFStructuralConversionPatternsOp::updateConversionTarget( + TypeConverter &typeConverter, ConversionTarget &conversionTarget) { + scf::populateSCFStructuralTypeConversionTarget(typeConverter, + conversionTarget); +} + //===----------------------------------------------------------------------===// // GetParentForOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -247,12 +247,15 @@ }; } // namespace -void mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { +void mlir::scf::populateSCFStructuralTypeConversions( + TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( typeConverter, patterns.getContext()); +} + +void mlir::scf::populateSCFStructuralTypeConversionTarget( + TypeConverter &typeConverter, ConversionTarget &target) { target.addDynamicallyLegalOp([&](Operation *op) { return typeConverter.isLegal(op->getResultTypes()); }); @@ -266,3 +269,10 @@ target.addDynamicallyLegalOp( [&](Operation *op) { return typeConverter.isLegal(op); }); } + +void mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target) { + populateSCFStructuralTypeConversions(typeConverter, patterns); + populateSCFStructuralTypeConversionTarget(typeConverter, target); +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -545,6 +545,12 @@ } converter = defaultTypeConverter.get(); } + + // Add descriptor-specific updates to the conversion target, which may + // depend on the final type converter. In structural converters, the + // legality of types dictates the dynamic legality of an operation. + descriptor.updateConversionTarget(*converter, conversionTarget); + descriptor.populatePatterns(*converter, patterns); } } diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -280,3 +280,27 @@ %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.loop.promote_if_one_iteration %0 : !transform.any_op } + + +// ----- + +func.func @test_structural_conversion_patterns(%a: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = scf.for %j = %c0 to %c10 step %c1 iter_args(%arg0 = %a) -> tensor { + %1 = "test.foo"(%arg0) : (tensor) -> (tensor) + scf.yield %1 : tensor + } + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.scf.structural_conversions + } with type_converter { + transform.apply_conversion_patterns.transform.test_type_converter + } { partial_conversion } : !transform.any_op +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -954,17 +954,20 @@ class TestTypeConverter : public TypeConverter { public: TestTypeConverter() { + addConversion([](Type t) { return t; }); addConversion([](RankedTensorType type) -> Type { return MemRefType::get(type.getShape(), type.getElementType()); }); - addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> std::optional { + auto unrealizedCastConverter = [&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { if (inputs.size() != 1) return std::nullopt; return builder.create(loc, resultType, inputs) .getResult(0); - }); + }; + addSourceMaterialization(unrealizedCastConverter); + addTargetMaterialization(unrealizedCastConverter); } }; } // namespace