diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -365,4 +365,14 @@ }]; } +// LLVM vector predication intrinsics. +class LLVM_VPBinaryBase + : LLVM_OneResultIntrOp<"vp." # mnem, [0], [], [NoSideEffect]>, + Arguments<(ins LLVM_VectorOf:$lhs, LLVM_VectorOf:$rhs, + LLVM_VectorOf:$mask, I32:$evl)>; + +class LLVM_VPBinaryI : LLVM_VPBinaryBase; + +class LLVM_VPBinaryF : LLVM_VPBinaryBase; + #endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1921,4 +1921,34 @@ } }]; } + +// +// LLVM Vector Predication operations. +// + +// Integer Binary +def LLVM_VPAddOp : LLVM_VPBinaryI<"add">; +def LLVM_VPSubOp : LLVM_VPBinaryI<"sub">; +def LLVM_VPMulOp : LLVM_VPBinaryI<"mul">; +def LLVM_VPSDivOp : LLVM_VPBinaryI<"sdiv">; +def LLVM_VPUDivOp : LLVM_VPBinaryI<"udiv">; +def LLVM_VPSRemOp : LLVM_VPBinaryI<"srem">; +def LLVM_VPURemOp : LLVM_VPBinaryI<"urem">; +def LLVM_VPAShrOp : LLVM_VPBinaryI<"ashr">; +def LLVM_VPLShrOp : LLVM_VPBinaryI<"lshr">; +def LLVM_VPShlOp : LLVM_VPBinaryI<"shl">; +def LLVM_VPOrOp : LLVM_VPBinaryI<"or">; +def LLVM_VPAndOp : LLVM_VPBinaryI<"and">; +def LLVM_VPXorOp : LLVM_VPBinaryI<"xor">; + +// Float Binary +def LLVM_VPFAddOp : LLVM_VPBinaryF<"fadd">; +def LLVM_VPFSubOp : LLVM_VPBinaryF<"fsub">; +def LLVM_VPFMulOp : LLVM_VPBinaryF<"fmul">; +def LLVM_VPFDivOp : LLVM_VPBinaryF<"fdiv">; +def LLVM_VPFRemOp : LLVM_VPBinaryF<"frem">; + + + + #endif // LLVMIR_OPS diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -515,6 +515,68 @@ llvm.return } +// CHECK-LABEL: @vector_predication_intrinsics +llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>, + %C: vector<8xf32>, %D: vector<8xf32>, + %mask: vector<8xi1>, %evl: i32) { + // CHECK: call <8 x i32> @llvm.vp.add.v8i32 + "llvm.intr.vp.add" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.sub.v8i32 + "llvm.intr.vp.sub" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.mul.v8i32 + "llvm.intr.vp.mul" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.sdiv.v8i32 + "llvm.intr.vp.sdiv" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.udiv.v8i32 + "llvm.intr.vp.udiv" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.srem.v8i32 + "llvm.intr.vp.srem" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.urem.v8i32 + "llvm.intr.vp.urem" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.ashr.v8i32 + "llvm.intr.vp.ashr" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.lshr.v8i32 + "llvm.intr.vp.lshr" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.shl.v8i32 + "llvm.intr.vp.shl" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.or.v8i32 + "llvm.intr.vp.or" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.and.v8i32 + "llvm.intr.vp.and" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + // CHECK: call <8 x i32> @llvm.vp.xor.v8i32 + "llvm.intr.vp.xor" (%A, %B, %mask, %evl) : + (vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32> + + // CHECK: call <8 x float> @llvm.vp.fadd.v8f32 + "llvm.intr.vp.fadd" (%C, %D, %mask, %evl) : + (vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.fsub.v8f32 + "llvm.intr.vp.fsub" (%C, %D, %mask, %evl) : + (vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.fmul.v8f32 + "llvm.intr.vp.fmul" (%C, %D, %mask, %evl) : + (vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.fdiv.v8f32 + "llvm.intr.vp.fdiv" (%C, %D, %mask, %evl) : + (vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.frem.v8f32 + "llvm.intr.vp.frem" (%C, %D, %mask, %evl) : + (vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + llvm.return +} + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -570,3 +632,21 @@ // CHECK-DAG: declare i1 @llvm.coro.end(i8*, i1) // CHECK-DAG: declare i8* @llvm.coro.free(token, i8* nocapture readonly) // CHECK-DAG: declare void @llvm.coro.resume(i8*) +// CHECK-DAG: declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2 +// CHECK-DAG: declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2 +// CHECK-DAG: declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2 +// CHECK-DAG: declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2 +// CHECK-DAG: declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fadd.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fsub.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fmul.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fdiv.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.frem.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0