diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -15,6 +15,45 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyMemrefToLLVMConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that convert memref dialect ops to LLVM dialect ops. These + patterns require an "LLVMTypeConverter". + }]; + + let assemblyFormat = "attr-dict"; +} + +def MemrefToLLVMTypeConverterOp : Op]> { + let description = [{ + This operation provides an "LLVMTypeConverter" that lowers memref types to + LLVM types. + + The type converter can be customized as follows: + - `use_aligned_alloc`: Use aligned_alloc in place of malloc for heap + allocations. + - `index_bitwidth`: Bitwidth of the index type, "0" indicates the size of a + machine word. + - `use_generic_functions`: Use generic allocation and deallocation functions + instead of the classic "malloc", "aligned_alloc" and "free" functions. + - `use_opaque_pointers`: Generate LLVM IR using opaque pointers instead of + typed pointers. + }]; + + let arguments = (ins + DefaultValuedAttr:$use_aligned_alloc, + DefaultValuedAttr:$index_bitwidth, + DefaultValuedAttr:$use_generic_functions, + DefaultValuedAttr:$use_opaque_pointers); + let assemblyFormat = "attr-dict"; +} + def ApplyExpandOpsPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -263,6 +263,39 @@ ]; } +def TypeConverterBuilderOpInterface + : OpInterface<"TypeConverterBuilderOpInterface"> { + let description = [{ + This interface should be implemented by ops that specify a type converter + for a dialect conversion. Such ops can be used with + "apply_conversion_patterns". + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the type converter to be used with a dialect conversion. + }], + /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>", + /*name=*/"getTypeConverter", + /*arguments=*/(ins) + >, + StaticInterfaceMethod< + /*desc=*/[{ + Return the type of type converter that this `getTypeConverter` returns. + This function is used for op verification. + }], + /*returnType=*/"StringRef", + /*name=*/"getTypeConverterType", + /*arguments=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return "TypeConverter"; }] + >, + ]; +} + def ConversionPatternDescriptorOpInterface : OpInterface<"ConversionPatternDescriptorOpInterface"> { let description = [{ @@ -300,27 +333,16 @@ /*methodBody=*/"", /*defaultImplementation=*/"return nullptr;" >, - ]; -} - -def TypeConverterBuilderOpInterface - : OpInterface<"TypeConverterBuilderOpInterface"> { - let description = [{ - This interface should be implemented by ops that specify a type converter - for a dialect conversion. Such ops can be used with - "apply_conversion_patterns". - }]; - - let cppNamespace = "::mlir::transform"; - - let methods = [ InterfaceMethod< /*desc=*/[{ - Return the type converter to be used with a dialect conversion. + Verify the default type converter that is provided by the enclosing + "apply_conversion_patterns" op. }], - /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>", - /*name=*/"getTypeConverter", - /*arguments=*/(ins) + /*returnType=*/"::mlir::LogicalResult", + /*name=*/"verifyTypeConverter", + /*arguments=*/(ins "TypeConverterBuilderOpInterface":$builder), + /*methodBody=*/"", + /*defaultImplementation=*/"return success();" >, ]; } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -37,6 +37,8 @@ /// registered using addConversion and addMaterialization, respectively. class TypeConverter { public: + virtual ~TypeConverter() = default; + /// This class provides all of the information necessary to convert a type /// signature. class SignatureConversion { diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt @@ -11,8 +11,11 @@ MLIRAffineDialect MLIRArithDialect MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect MLIRLoopLikeInterface MLIRMemRefDialect + MLIRMemRefToLLVM MLIRMemRefTransforms MLIRNVGPUDialect MLIRTransformDialect diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -7,8 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" @@ -26,6 +30,43 @@ #define DEBUG_TYPE "memref-transforms" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +//===----------------------------------------------------------------------===// +// Apply...ConversionPatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyMemrefToLLVMConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + populateFinalizeMemRefToLLVMConversionPatterns( + static_cast(typeConverter), patterns); +} + +LogicalResult +transform::ApplyMemrefToLLVMConversionPatternsOp::verifyTypeConverter( + transform::TypeConverterBuilderOpInterface builder) { + if (builder.getTypeConverterType() != "LLVMTypeConverter") + return emitOpError("expected LLVMTypeConverter"); + return success(); +} + +std::unique_ptr +transform::MemrefToLLVMTypeConverterOp::getTypeConverter() { + LowerToLLVMOptions options(getContext()); + options.allocLowering = + (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc + : LowerToLLVMOptions::AllocLowering::Malloc); + options.useGenericFunctions = getUseGenericFunctions(); + options.useOpaquePointers = getUseOpaquePointers(); + + if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout) + options.overrideIndexBitwidth(getIndexBitwidth()); + + return std::make_unique(getContext(), options); +} + +StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() { + return "LLVMTypeConverter"; +} + //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// @@ -157,6 +198,7 @@ void init() { declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -589,14 +589,24 @@ if (!llvm::hasSingleElement(typeConverterRegion.front())) return emitOpError() << "expected exactly one op in default type converter region"; - Operation *typeConverterOp = &typeConverterRegion.front().front(); - if (!isa(typeConverterOp)) { + auto typeConverterOp = dyn_cast( + &typeConverterRegion.front().front()); + if (!typeConverterOp) { InFlightDiagnostic diag = emitOpError() << "expected default converter child op to " "implement TypeConverterBuilderOpInterface"; diag.attachNote(typeConverterOp->getLoc()) << "op without interface"; return diag; } + // Check default type converter type. + if (!getPatterns().empty()) { + for (Operation &op : getPatterns().front()) { + auto descriptor = + cast(&op); + if (failed(descriptor.verifyTypeConverter(typeConverterOp))) + return failure(); + } + } } if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() && !getIllegalDialects()) diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -256,3 +256,23 @@ // Verify that the returned handle is usable. transform.test_print_remark_at_operand %1, "transformed" : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @lower_to_llvm +// CHECK-NOT: memref.alloc +// CHECK: llvm.call @malloc +func.func @lower_to_llvm() { + %0 = memref.alloc() : memref<2048xi8> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.memref.memref_to_llvm + }, { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {legal_dialects = ["func", "llvm"]} : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -11355,8 +11355,11 @@ ":AffineDialect", ":ArithDialect", ":IR", + ":LLVMCommonConversion", + ":LLVMDialect", ":LoopLikeInterface", ":MemRefDialect", + ":MemRefToLLVM", ":MemRefTransformOpsIncGen", ":MemRefTransforms", ":NVGPUDialect",