diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td --- a/mlir/include/mlir/Dialect/AVX512/AVX512.td +++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td @@ -100,14 +100,14 @@ AllTypesMatch<["a", "b"]>, TypesMatchWith<"k1 has the same number of bits as elements in a", "a", "k1", - "IntegerType::get($_self.getContext(), " - "($_self.cast().getShape()[0]))">, + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">, TypesMatchWith<"k2 has the same number of bits as elements in b", // Should use `b` instead of `a`, but that would require // adding `type($b)` to assemblyFormat. "a", "k2", - "IntegerType::get($_self.getContext(), " - "($_self.cast().getShape()[0]))">]> { + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">]> { let summary = "Vp2Intersect op"; let description = [{ The `vp2intersect` op is an AVX512 specific op that can lower to the proper @@ -126,8 +126,8 @@ let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a, VectorOfLengthAndType<[16, 8], [I32, I64]>:$b ); - let results = (outs AnyTypeOf<[I16, I8]>:$k1, - AnyTypeOf<[I16, I8]>:$k2 + let results = (outs VectorOfLengthAndType<[16, 8], [I1, I1]>:$k1, + VectorOfLengthAndType<[16, 8], [I1, I1]>:$k2 ); let assemblyFormat = "$a `,` $b attr-dict `:` type($a)"; diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir @@ -18,11 +18,11 @@ } func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) - -> (i16, i16, i8, i8) + -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) { // CHECK: llvm_avx512.vp2intersect.d.512 %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32> // CHECK: llvm_avx512.vp2intersect.q.512 %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64> - return %0, %1, %2, %3 : i16, i16, i8, i8 + return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> } diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir --- a/mlir/test/Dialect/AVX512/roundtrip.mlir +++ b/mlir/test/Dialect/AVX512/roundtrip.mlir @@ -21,11 +21,11 @@ } func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) - -> (i16, i16, i8, i8) + -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) { // CHECK: avx512.vp2intersect {{.*}} : vector<16xi32> %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32> // CHECK: avx512.vp2intersect {{.*}} : vector<8xi64> %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64> - return %0, %1, %2, %3 : i16, i16, i8, i8 + return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> }