diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td --- a/mlir/include/mlir/Dialect/AVX512/AVX512.td +++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td @@ -37,6 +37,11 @@ "x86_avx512_" # !subst(".", "_", mnemonic), [], [], traits, numResults>; +class AVX_IntrOp traits = []> : + LLVM_IntrOpBase; + // Defined by first result overload. May have to be extended for other // instructions in the future. class AVX512_IntrOverloadedOp:$b); } +// Rsqrt. +def RsqrtOp : AVX512_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Rsqrt"; + let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); + let results = (outs VectorOfLengthAndType<[8], [F32]>:$b); + let assemblyFormat = "$a attr-dict `:` type($a)"; +} + +def RsqrtIntrOp : AVX_IntrOp<"rsqrt.ps.256", 1, [NoSideEffect, + SameOperandsAndResultType]> { + let arguments = (ins VectorOfLengthAndType<[8], [F32]>:$a); +} + #endif // AVX512_OPS diff --git a/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp @@ -89,6 +89,20 @@ } }; +struct RsqrtOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RsqrtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RsqrtOp::Adaptor adaptor(operands); + + auto opType = adaptor.a().getType(); + rewriter.replaceOpWithNewOp(op, opType, adaptor.a()); + return success(); + } +}; + /// An entry associating the "main" AVX512 op with its instantiations for /// vectors of 32-bit and 64-bit elements. template @@ -130,7 +144,7 @@ void mlir::populateAVX512LegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { Registry::registerPatterns(converter, patterns); - patterns.add(converter); + patterns.add(converter); } void mlir::configureAVX512LegalizeForExportTarget( @@ -138,4 +152,6 @@ Registry::configureTarget(target); target.addLegalOp(); target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); } diff --git a/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir b/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir --- a/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir @@ -39,3 +39,10 @@ %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64> return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> } + +func @axv512_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) +{ + // CHECK: avx512.intr.rsqrt.ps.256 + %0 = avx512.rsqrt %a : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir --- a/mlir/test/Dialect/AVX512/roundtrip.mlir +++ b/mlir/test/Dialect/AVX512/roundtrip.mlir @@ -42,3 +42,10 @@ %2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> } + +func @avx512_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>) +{ + // CHECK: avx512.rsqrt {{.*}} : vector<8xf32> + %0 = avx512.rsqrt %a : vector<8xf32> + return %0 : vector<8xf32> +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-rsqrt.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-rsqrt.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-rsqrt.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() -> i32 { + %i0 = constant 0 : i32 + + %f0 = constant 0.125 : f32 + %f1 = constant 0.25 : f32 + %f2 = constant 0.5 : f32 + %f3 = constant 1.0 : f32 + %f4 = constant 2.0: f32 + %f5 = constant 4.0: f32 + %f6 = constant 8.0: f32 + %f7 = constant 16.0: f32 + + %v0 = vector.broadcast %f0 : f32 to vector<8xf32> + %v1 = vector.insert %f1, %v0[1] : f32 into vector<8xf32> + %v2 = vector.insert %f2, %v1[2] : f32 into vector<8xf32> + %v3 = vector.insert %f3, %v2[3] : f32 into vector<8xf32> + %v4 = vector.insert %f4, %v3[4] : f32 into vector<8xf32> + %v5 = vector.insert %f5, %v4[5] : f32 into vector<8xf32> + %v6 = vector.insert %f6, %v5[6] : f32 into vector<8xf32> + %v7 = vector.insert %f7, %v6[7] : f32 into vector<8xf32> + vector.print %v7 : vector<8xf32> + // CHECK: ( 0.125, 0.25, 0.5, 1, 2, 4, 8, 16 ) + + %r = avx512.rsqrt %v7 : vector<8xf32> + vector.print %r : vector<8xf32> + // CHECK: ( 2.82764, 1.99951, 1.41382, 0.999756, 0.706909, 0.499878, 0.353455, 0.249939 ) + + return %i0 : i32 +} diff --git a/mlir/test/Target/LLVMIR/avx512.mlir b/mlir/test/Target/LLVMIR/avx512.mlir --- a/mlir/test/Target/LLVMIR/avx512.mlir +++ b/mlir/test/Target/LLVMIR/avx512.mlir @@ -59,3 +59,11 @@ (vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> } + +// CHECK-LABEL: define <8 x float> @LLVM_x86_avx512_rsqrt_ps_256 +llvm.func @LLVM_x86_avx512_rsqrt_ps_256(%a: vector <8xf32>) -> vector<8xf32> +{ + // CHECK: call <8 x float> @llvm.x86.avx.rsqrt.ps.256(<8 x float> + %0 = "avx512.intr.rsqrt.ps.256"(%a) : (vector<8xf32>) -> (vector<8xf32>) + llvm.return %0 : vector<8xf32> +}