diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td --- a/mlir/include/mlir/Dialect/AVX512/AVX512.td +++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td @@ -96,4 +96,41 @@ "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)"; } +def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect, + AllTypesMatch<["a", "b"]>, + TypesMatchWith<"k1 has the same number of bits as elements in a", + "a", "k1", + "IntegerType::get($_self.getContext(), " + "($_self.cast().getShape()[0]))">, + TypesMatchWith<"k2 has the same number of bits as elements in b", + // Should use `b` instead of `a`, but that would require + // adding `type($b)` to assemblyFormat. + "a", "k2", + "IntegerType::get($_self.getContext(), " + "($_self.cast().getShape()[0]))">]> { + let summary = "Vp2Intersect op"; + let description = [{ + The `vp2intersect` op is an AVX512 specific op that can lower to the proper + LLVMAVX512 operation: `llvm.vp2intersect.d.512` or + `llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is + applied to. + + #### From the Intel Intrinsics Guide: + + Compute intersection of packed integer vectors `a` and `b`, and store + indication of match in the corresponding bit of two mask registers + specified by `k1` and `k2`. A match in corresponding elements of `a` and + `b` is indicated by a set bit in the corresponding bit of the mask + registers. + }]; + let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a, + VectorOfLengthAndType<[16, 8], [I32, I64]>:$b + ); + let results = (outs AnyTypeOf<[I16, I8]>:$k1, + AnyTypeOf<[I16, I8]>:$k2 + ); + let assemblyFormat = + "$a `,` $b attr-dict `:` type($a)"; +} + #endif // AVX512_OPS diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h --- a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h +++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h.inc" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td @@ -28,25 +28,33 @@ // MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system //----------------------------------------------------------------------------// -class LLVMAVX512_IntrOp traits = []> : +class LLVMAVX512_IntrOp traits = []> : LLVM_IntrOpBase; + [], [], traits, numResults>; def LLVM_x86_avx512_mask_rndscale_ps_512 : - LLVMAVX512_IntrOp<"mask.rndscale.ps.512">, + LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>, Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; def LLVM_x86_avx512_mask_rndscale_pd_512 : - LLVMAVX512_IntrOp<"mask.rndscale.pd.512">, + LLVMAVX512_IntrOp<"mask.rndscale.pd.512", 1>, Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; def LLVM_x86_avx512_mask_scalef_ps_512 : - LLVMAVX512_IntrOp<"mask.scalef.ps.512">, + LLVMAVX512_IntrOp<"mask.scalef.ps.512", 1>, Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; def LLVM_x86_avx512_mask_scalef_pd_512 : - LLVMAVX512_IntrOp<"mask.scalef.pd.512">, + LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>, Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; +def LLVM_x86_avx512_vp2intersect_d_512 : + LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>, + Arguments<(ins LLVM_Type, LLVM_Type)>; + +def LLVM_x86_avx512_vp2intersect_q_512 : + LLVMAVX512_IntrOp<"vp2intersect.q.512", 2>, + Arguments<(ins LLVM_Type, LLVM_Type)>; + #endif // AVX512_OPS diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -77,6 +77,29 @@ return failure(); } }; + +struct Vp2IntersectOp512Conversion + : public ConvertOpToLLVMPattern { + explicit Vp2IntersectOp512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter) {} + + LogicalResult + matchAndRewrite(Vp2IntersectOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type elementType = + op.a().getType().template cast().getElementType(); + if (elementType.isInteger(32)) + return LLVM::detail::oneToOneRewrite( + op, LLVM::x86_avx512_vp2intersect_d_512::getOperationName(), operands, + *getTypeConverter(), rewriter); + if (elementType.isInteger(64)) + return LLVM::detail::oneToOneRewrite( + op, LLVM::x86_avx512_vp2intersect_q_512::getOperationName(), operands, + *getTypeConverter(), rewriter); + return failure(); + } +}; } // namespace /// Populate the given list with patterns that convert from AVX512 to LLVM. @@ -84,6 +107,8 @@ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.insert(&converter.getContext(), converter); + ScaleFOp512Conversion, + Vp2IntersectOp512Conversion>(&converter.getContext(), + converter); // clang-format on } diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir @@ -16,3 +16,13 @@ // Keep results alive. return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64> } + +func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) + -> (i16, i16, i8, i8) +{ + // CHECK: llvm_avx512.vp2intersect.d.512 + %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32> + // CHECK: llvm_avx512.vp2intersect.q.512 + %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64> + return %0, %1, %2, %3 : i16, i16, i8, i8 +} diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir --- a/mlir/test/Dialect/AVX512/roundtrip.mlir +++ b/mlir/test/Dialect/AVX512/roundtrip.mlir @@ -19,3 +19,13 @@ %1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64> return %0, %1: vector<16xf32>, vector<8xf64> } + +func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) + -> (i16, i16, i8, i8) +{ + // CHECK: avx512.vp2intersect {{.*}} : vector<16xi32> + %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32> + // CHECK: avx512.vp2intersect {{.*}} : vector<8xi64> + %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64> + return %0, %1, %2, %3 : i16, i16, i8, i8 +} diff --git a/mlir/test/Target/avx512.mlir b/mlir/test/Target/avx512.mlir --- a/mlir/test/Target/avx512.mlir +++ b/mlir/test/Target/avx512.mlir @@ -29,3 +29,23 @@ (vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64> llvm.return %1: vector<8xf64> } + +// CHECK-LABEL: define <{ i16, i16 }> @LLVM_x86_vp2intersect_d_512 +llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>) + -> !llvm.struct +{ + // CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32> + %0 = "llvm_avx512.vp2intersect.d.512"(%a, %b) : + (vector<16xi32>, vector<16xi32>) -> !llvm.struct + llvm.return %0 : !llvm.struct +} + +// CHECK-LABEL: define <{ i8, i8 }> @LLVM_x86_vp2intersect_q_512 +llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>) + -> !llvm.struct +{ + // CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64> + %0 = "llvm_avx512.vp2intersect.q.512"(%a, %b) : + (vector<8xi64>, vector<8xi64>) -> !llvm.struct + llvm.return %0 : !llvm.struct +}