diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -15,6 +15,33 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def MemrefToLLVMTypeConverterOp : Op]> { + let description = [{ + This operation provides an "LLVMTypeConverter" that lowers memref types to + LLVM types. + + The type converter can be customized as follows: + - `use_aligned_alloc`: Use aligned_alloc in place of malloc for heap + allocations. + - `index_bitwidth`: Bitwidth of the index type, "0" indicates the size of a + machine word. + - `use_generic_functions`: Use generic allocation and deallocation functions + instead of the classic "malloc", "aligned_alloc" and "free" functions. + - `use_opaque_pointers`: Generate LLVM IR using opaque pointers instead of + typed pointers. + }]; + + let arguments = (ins + DefaultValuedAttr:$use_aligned_alloc, + DefaultValuedAttr:$index_bitwidth, + DefaultValuedAttr:$use_generic_functions, + DefaultValuedAttr:$use_opaque_pointers); + let assemblyFormat = "attr-dict"; +} + def ApplyExpandOpsPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -263,6 +263,39 @@ ]; } +def TypeConverterBuilderOpInterface + : OpInterface<"TypeConverterBuilderOpInterface"> { + let description = [{ + This interface should be implemented by ops that specify a type converter + for a dialect conversion. Such ops can be used with + "apply_conversion_patterns". + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the type converter to be used with a dialect conversion. + }], + /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>", + /*name=*/"getTypeConverter", + /*arguments=*/(ins) + >, + StaticInterfaceMethod< + /*desc=*/[{ + Return the type of type converter that this `getTypeConverter` returns. + This function is used for op verification. + }], + /*returnType=*/"StringRef", + /*name=*/"getTypeConverterType", + /*arguments=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return "TypeConverter"; }] + >, + ]; +} + def ConversionPatternDescriptorOpInterface : OpInterface<"ConversionPatternDescriptorOpInterface"> { let description = [{ @@ -300,27 +333,16 @@ /*methodBody=*/"", /*defaultImplementation=*/"return nullptr;" >, - ]; -} - -def TypeConverterBuilderOpInterface - : OpInterface<"TypeConverterBuilderOpInterface"> { - let description = [{ - This interface should be implemented by ops that specify a type converter - for a dialect conversion. Such ops can be used with - "apply_conversion_patterns". - }]; - - let cppNamespace = "::mlir::transform"; - - let methods = [ InterfaceMethod< /*desc=*/[{ - Return the type converter to be used with a dialect conversion. + Verify the default type converter that is provided by the enclosing + "apply_conversion_patterns" op. }], - /*returnType=*/"std::unique_ptr<::mlir::TypeConverter>", - /*name=*/"getTypeConverter", - /*arguments=*/(ins) + /*returnType=*/"::mlir::LogicalResult", + /*name=*/"verifyTypeConverter", + /*arguments=*/(ins "TypeConverterBuilderOpInterface":$builder), + /*methodBody=*/"", + /*defaultImplementation=*/"return success();" >, ]; } diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -15,6 +15,28 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyVectorToLLVMConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that convert vector dialect ops to LLVM dialect ops. These + patterns require an "LLVMTypeConverter". + + The patterns can be customized as follows: + - `reassociate_fp_reductions`: Allows LLVM to reassociate floating-point + reductions for speed. + - `force_32bit_vector_indices`: Allows the compiler to assume that vector + indices fit in 32-bit if that yields faster code. + }]; + + let arguments = (ins + DefaultValuedAttr:$reassociate_fp_reductions, + DefaultValuedAttr:$force_32bit_vector_indices); + let assemblyFormat = "attr-dict"; +} + + def ApplyCastAwayVectorLeadingOneDimPatternsOp : Op]> { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -37,6 +37,8 @@ /// registered using addConversion and addMaterialization, respectively. class TypeConverter { public: + virtual ~TypeConverter() = default; + /// This class provides all of the information necessary to convert a type /// signature. class SignatureConversion { diff --git a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/TransformOps/CMakeLists.txt @@ -11,6 +11,8 @@ MLIRAffineDialect MLIRArithDialect MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect MLIRLoopLikeInterface MLIRMemRefDialect MLIRMemRefTransforms diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -26,6 +28,29 @@ #define DEBUG_TYPE "memref-transforms" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +//===----------------------------------------------------------------------===// +// Apply...ConversionPatternsOp +//===----------------------------------------------------------------------===// + +std::unique_ptr +transform::MemrefToLLVMTypeConverterOp::getTypeConverter() { + LowerToLLVMOptions options(getContext()); + options.allocLowering = + (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc + : LowerToLLVMOptions::AllocLowering::Malloc); + options.useGenericFunctions = getUseGenericFunctions(); + options.useOpaquePointers = getUseOpaquePointers(); + + if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout) + options.overrideIndexBitwidth(getIndexBitwidth()); + + return std::make_unique(getContext(), options); +} + +StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() { + return "LLVMTypeConverter"; +} + //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -589,14 +589,24 @@ if (!llvm::hasSingleElement(typeConverterRegion.front())) return emitOpError() << "expected exactly one op in default type converter region"; - Operation *typeConverterOp = &typeConverterRegion.front().front(); - if (!isa(typeConverterOp)) { + auto typeConverterOp = dyn_cast( + &typeConverterRegion.front().front()); + if (!typeConverterOp) { InFlightDiagnostic diag = emitOpError() << "expected default converter child op to " "implement TypeConverterBuilderOpInterface"; diag.attachNote(typeConverterOp->getLoc()) << "op without interface"; return diag; } + // Check default type converter type. + if (!getPatterns().empty()) { + for (Operation &op : getPatterns().front()) { + auto descriptor = + cast(&op); + if (failed(descriptor.verifyTypeConverter(typeConverterOp))) + return failure(); + } + } } if (!getLegalOps() && !getIllegalOps() && !getLegalDialects() && !getIllegalDialects()) diff --git a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt @@ -9,7 +9,10 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect MLIRVectorDialect + MLIRVectorToLLVM MLIRVectorTransforms MLIRSideEffectInterfaces MLIRTransformDialect diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -7,6 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -23,6 +26,25 @@ using namespace mlir::vector; using namespace mlir::transform; +//===----------------------------------------------------------------------===// +// Apply...ConversionPatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + populateVectorToLLVMConversionPatterns( + static_cast(typeConverter), patterns, + getReassociateFpReductions(), getForce_32bitVectorIndices()); +} + +LogicalResult +transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter( + transform::TypeConverterBuilderOpInterface builder) { + if (builder.getTypeConverterType() != "LLVMTypeConverter") + return emitOpError("expected LLVMTypeConverter"); + return success(); +} + //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s + +// CHECK-LABEL: func @lower_to_llvm +// CHECK-NOT: vector.bitcast +// CHECK: llvm.bitcast +func.func @lower_to_llvm(%input: vector) -> vector { + %0 = vector.bitcast %input : vector to vector + return %0 : vector +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.vector.vector_to_llvm + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {legal_dialects = ["func", "llvm"]} : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4146,12 +4146,14 @@ ":ArithDialect", ":AsmParser", ":IR", + ":LLVMCommonConversion", ":LLVMDialect", ":SideEffectInterfaces", ":TransformDialect", ":TransformUtils", ":VectorDialect", ":VectorEnumsIncGen", + ":VectorToLLVM", ":VectorToSCF", ":VectorTransformOpsIncGen", ":VectorTransforms", @@ -11510,6 +11512,7 @@ ":AffineDialect", ":ArithDialect", ":IR", + ":LLVMCommonConversion", ":LoopLikeInterface", ":MemRefDialect", ":MemRefTransformOpsIncGen",