diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -15,6 +15,28 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyVectorToLLVMConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that convert vector dialect ops to LLVM dialect ops. These + patterns require an "LLVMTypeConverter". + + The patterns can be customized as follows: + - `reassociate_fp_reductions`: Allows LLVM to reassociate floating-point + reductions for speed. + - `force_32bit_vector_indices`: Allows the compiler to assume that vector + indices fit in 32-bit if that yields faster code. + }]; + + let arguments = (ins + DefaultValuedAttr:$reassociate_fp_reductions, + DefaultValuedAttr:$force_32bit_vector_indices); + let assemblyFormat = "attr-dict"; +} + + def ApplyCastAwayVectorLeadingOneDimPatternsOp : Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/TransformOps/CMakeLists.txt @@ -9,7 +9,10 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect MLIRVectorDialect + MLIRVectorToLLVM MLIRVectorTransforms MLIRSideEffectInterfaces MLIRTransformDialect diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -7,6 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -23,6 +26,25 @@ using namespace mlir::vector; using namespace mlir::transform; +//===----------------------------------------------------------------------===// +// Apply...ConversionPatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyVectorToLLVMConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + populateVectorToLLVMConversionPatterns( + static_cast(typeConverter), patterns, + getReassociateFpReductions(), getForce_32bitVectorIndices()); +} + +LogicalResult +transform::ApplyVectorToLLVMConversionPatternsOp::verifyTypeConverter( + transform::TypeConverterBuilderOpInterface builder) { + if (builder.getTypeConverterType() != "LLVMTypeConverter") + return emitOpError("expected LLVMTypeConverter"); + return success(); +} + //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/transform-op-vector-to-llvm.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s + +// CHECK-LABEL: func @lower_to_llvm +// CHECK-NOT: vector.bitcast +// CHECK: llvm.bitcast +func.func @lower_to_llvm(%input: vector) -> vector { + %0 = vector.bitcast %input : vector to vector + return %0 : vector +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.vector.vector_to_llvm + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {legal_dialects = ["func", "llvm"]} : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4088,12 +4088,14 @@ ":ArithDialect", ":AsmParser", ":IR", + ":LLVMCommonConversion", ":LLVMDialect", ":SideEffectInterfaces", ":TransformDialect", ":TransformUtils", ":VectorDialect", ":VectorEnumsIncGen", + ":VectorToLLVM", ":VectorToSCF", ":VectorTransformOpsIncGen", ":VectorTransforms",