diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -267,4 +267,32 @@ VectorOfLengthAndType<[8], [I64]>:$b); } +//===----------------------------------------------------------------------===// +// AVX op definitions +//===----------------------------------------------------------------------===// + +class AVX_Op traits = []> : + Op {} + +class AVX_IntrOp traits = []> : + LLVM_IntrOpBase; + +//----------------------------------------------------------------------------// +// AVX Rsqrt +//----------------------------------------------------------------------------// + +def RsqrtOp : AVX_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Rsqrt"; + let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); + let results = (outs VectorOfLengthAndType<[8], [F32]>:$b); + let assemblyFormat = "$a attr-dict `:` type($a)"; +} + +def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [NoSideEffect, + SameOperandsAndResultType]> { + let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); +} + #endif // X86VECTOR_OPS diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -90,6 +90,20 @@ } }; +struct RsqrtOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RsqrtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RsqrtOp::Adaptor adaptor(operands); + + auto opType = adaptor.a().getType(); + rewriter.replaceOpWithNewOp(op, opType, adaptor.a()); + return success(); + } +}; + /// An entry associating the "main" AVX512 op with its instantiations for /// vectors of 32-bit and 64-bit elements. template @@ -131,7 +145,7 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { Registry::registerPatterns(converter, patterns); - patterns.add(converter); + patterns.add(converter); } void mlir::configureX86VectorLegalizeForExportTarget( @@ -139,4 +153,6 @@ Registry::configureTarget(target); target.addLegalOp(); target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); } diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir --- a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -42,3 +42,11 @@ %2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64> return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> } + +// CHECK-LABEL: func @avx_rsqrt +func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) +{ + // CHECK: x86vector.avx.intr.rsqrt.ps.256 + %0 = x86vector.avx.rsqrt %a : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir --- a/mlir/test/Dialect/X86Vector/roundtrip.mlir +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -46,3 +46,11 @@ %2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> } + +// CHECK-LABEL: func @avx_rsqrt +func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) +{ + // CHECK: x86vector.avx.rsqrt {{.*}} : vector<8xf32> + %0 = x86vector.avx.rsqrt %a : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-rsqrt.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --jit-kind=mcjit --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s +// TODO: drop lli's --jit-kind flag once PR#49906 (https://bugs.llvm.org/show_bug.cgi?id=49906) is fixed. + +func @entry() -> i32 { + %i0 = constant 0 : i32 + + %v = std.constant dense<[0.125, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]> : vector<8xf32> + %r = x86vector.avx.rsqrt %v : vector<8xf32> + // CHECK: ( 2.82764, 1.99951, 1.41382, 0.999756, 0.706909, 0.499878, 0.353455, 0.249939 ) + vector.print %r : vector<8xf32> + + return %i0 : i32 +} diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir --- a/mlir/test/Target/LLVMIR/x86vector.mlir +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -59,3 +59,11 @@ (vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> } + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avx_rsqrt_ps_256 +llvm.func @LLVM_x86_avx_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float> + %0 = "x86vector.avx.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>) + llvm.return %0 : vector<8xf32> +}