diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -234,6 +234,25 @@ }]; } +def ApplyToLLVMConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that convert ops from the specified dialect to LLVM + dialect ops. These patterns require an "LLVMTypeConverter". + + Note: Only dialects that implement the `ConvertToLLVMPatternInterface` are + supported. Any conversion target modifications by interface implementations + are currently ignored. The conversion target is fully specified by the + enclosing "apply_conversion_patterns" op. + }]; + + let arguments = (ins StrAttr:$dialect_name); + let assemblyFormat = "$dialect_name attr-dict"; + let hasVerifier = 1; +} + def ApplyDeadCodeEliminationOp : TransformDialectOp<"apply_dce", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -14,6 +14,8 @@ LINK_LIBS PUBLIC MLIRCastInterfaces MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect MLIRLoopLikeInterface MLIRParser MLIRPass diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -8,6 +8,8 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -642,6 +644,43 @@ } } +//===----------------------------------------------------------------------===// +// ApplyToLLVMConversionPatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyToLLVMConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); + assert(dialect && "expected that dialect is loaded"); + auto iface = cast(dialect); + // ConversionTarget is currently ignored because the enclosing + // apply_conversion_patterns op sets up its own ConversionTarget. + ConversionTarget target(*getContext()); + iface->populateConvertToLLVMConversionPatterns( + target, static_cast(typeConverter), patterns); +} + +LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter( + transform::TypeConverterBuilderOpInterface builder) { + if (builder.getTypeConverterType() != "LLVMTypeConverter") + return emitOpError("expected LLVMTypeConverter"); + return success(); +} + +LogicalResult transform::ApplyToLLVMConversionPatternsOp::verify() { + Dialect *dialect = getContext()->getLoadedDialect(getDialectName()); + if (!dialect) + return emitOpError("unknown dialect or dialect not loaded: ") + << getDialectName(); + auto iface = dyn_cast(dialect); + if (!iface) + return emitOpError( + "dialect does not implement ConvertToLLVMPatternInterface or " + "extension was not loaded: ") + << getDialectName(); + return success(); +} + //===----------------------------------------------------------------------===// // ApplyLoopInvariantCodeMotionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -256,3 +256,23 @@ // Verify that the returned handle is usable. transform.test_print_remark_at_operand %1, "transformed" : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @lower_to_llvm +// CHECK-NOT: memref.alloc +// CHECK: llvm.call @malloc +func.func @lower_to_llvm() { + %0 = memref.alloc() : memref<2048xi8> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.dialect_to_llvm "memref" + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {legal_dialects = ["func", "llvm"]} : !transform.any_op +} diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -279,3 +279,48 @@ transform.apply_conversion_patterns.transform.test_conversion_patterns } {illegal_ops = ["test.foo"]} : !transform.any_op } + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + // expected-error @below{{expected LLVMTypeConverter}} + transform.apply_conversion_patterns.dialect_to_llvm "test" + } with type_converter { + transform.apply_conversion_patterns.transform.test_type_converter + } {illegal_ops = ["test.foo"], + legal_ops = ["func.func", "func.return", "test.new_op"]} + : !transform.any_op +} + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + // expected-error @below{{unknown dialect or dialect not loaded: this_dialect_does_not_exist}} + transform.apply_conversion_patterns.dialect_to_llvm "this_dialect_does_not_exist" + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {illegal_ops = ["test.foo"], + legal_ops = ["func.func", "func.return", "test.new_op"]} + : !transform.any_op +} + +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + // expected-error @below{{dialect does not implement ConvertToLLVMPatternInterface or extension was not loaded: transform}} + transform.apply_conversion_patterns.dialect_to_llvm "transform" + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {illegal_ops = ["test.foo"], + legal_ops = ["func.func", "func.return", "test.new_op"]} + : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10594,6 +10594,8 @@ ":ControlFlowInterfaces", ":IR", ":LoopLikeInterface", + ":LLVMCommonConversion", + ":LLVMDialect", ":Pass", ":Rewrite", ":SideEffectInterfaces",