diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -375,4 +375,28 @@ class LLVM_VPBinaryF : LLVM_VPBinaryBase; +class LLVM_VPUnaryBase + : LLVM_OneResultIntrOp<"vp." # mnem, [0], [], [NoSideEffect]>, + Arguments<(ins LLVM_VectorOf:$op, + LLVM_VectorOf:$mask, I32:$evl)>; + +class LLVM_VPUnaryF : LLVM_VPUnaryBase; + +class LLVM_VPTernaryBase + : LLVM_OneResultIntrOp<"vp." # mnem, [0], [], [NoSideEffect]>, + Arguments<(ins LLVM_VectorOf:$op1, LLVM_VectorOf:$op2, + LLVM_VectorOf:$op3, LLVM_VectorOf:$mask, + I32:$evl)>; + +class LLVM_VPTernaryF : LLVM_VPTernaryBase; + +class LLVM_VPReductionBase + : LLVM_OneResultIntrOp<"vp.reduce." # mnem, [], [1], [NoSideEffect]>, + Arguments<(ins element:$satrt_value, LLVM_VectorOf:$val, + LLVM_VectorOf:$mask, I32:$evl)>; + +class LLVM_VPReductionI : LLVM_VPReductionBase; + +class LLVM_VPReductionF : LLVM_VPReductionBase; + #endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1948,7 +1948,28 @@ def LLVM_VPFDivOp : LLVM_VPBinaryF<"fdiv">; def LLVM_VPFRemOp : LLVM_VPBinaryF<"frem">; - +// Float Unary +def LLVM_VPFNegOp : LLVM_VPUnaryF<"fneg">; + +// Float Ternary +def LLVM_VPFmaOp : LLVM_VPTernaryF<"fma">; + +// Integer Reduction +def LLVM_VPReduceAddOp : LLVM_VPReductionI<"add">; +def LLVM_VPReduceMulOp : LLVM_VPReductionI<"mul">; +def LLVM_VPReduceAndOp : LLVM_VPReductionI<"and">; +def LLVM_VPReduceOrOp : LLVM_VPReductionI<"or">; +def LLVM_VPReduceXorOp : LLVM_VPReductionI<"xor">; +def LLVM_VPReduceSMaxOp : LLVM_VPReductionI<"smax">; +def LLVM_VPReduceSMinOp : LLVM_VPReductionI<"smin">; +def LLVM_VPReduceUMaxOp : LLVM_VPReductionI<"umax">; +def LLVM_VPReduceUMinOp : LLVM_VPReductionI<"umin">; + +// Float Reduction +def LLVM_VPReduceFAddOp : LLVM_VPReductionF<"fadd">; +def LLVM_VPReduceFMulOp : LLVM_VPReductionF<"fmul">; +def LLVM_VPReduceFMaxOp : LLVM_VPReductionF<"fmax">; +def LLVM_VPReduceFMinOp : LLVM_VPReductionF<"fmin">; #endif // LLVMIR_OPS diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -518,6 +518,7 @@ // CHECK-LABEL: @vector_predication_intrinsics llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>, %C: vector<8xf32>, %D: vector<8xf32>, + %i: i32, %f: f32, %mask: vector<8xi1>, %evl: i32) { // CHECK: call <8 x i32> @llvm.vp.add.v8i32 "llvm.intr.vp.add" (%A, %B, %mask, %evl) : @@ -574,6 +575,55 @@ // CHECK: call <8 x float> @llvm.vp.frem.v8f32 "llvm.intr.vp.frem" (%C, %D, %mask, %evl) : (vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.fneg.v8f32 + "llvm.intr.vp.fneg" (%C, %mask, %evl) : + (vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + // CHECK: call <8 x float> @llvm.vp.fma.v8f32 + "llvm.intr.vp.fma" (%C, %D, %D, %mask, %evl) : + (vector<8xf32>, vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32> + + // CHECK: call i32 @llvm.vp.reduce.add.v8i32 + "llvm.intr.vp.reduce.add" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.mul.v8i32 + "llvm.intr.vp.reduce.mul" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.and.v8i32 + "llvm.intr.vp.reduce.and" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.or.v8i32 + "llvm.intr.vp.reduce.or" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.xor.v8i32 + "llvm.intr.vp.reduce.xor" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.smax.v8i32 + "llvm.intr.vp.reduce.smax" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.smin.v8i32 + "llvm.intr.vp.reduce.smin" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.umax.v8i32 + "llvm.intr.vp.reduce.umax" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + // CHECK: call i32 @llvm.vp.reduce.umin.v8i32 + "llvm.intr.vp.reduce.umin" (%i, %A, %mask, %evl) : + (i32, vector<8xi32>, vector<8xi1>, i32) -> i32 + + // CHECK: call float @llvm.vp.reduce.fadd.v8f32 + "llvm.intr.vp.reduce.fadd" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + // CHECK: call float @llvm.vp.reduce.fmul.v8f32 + "llvm.intr.vp.reduce.fmul" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + // CHECK: call float @llvm.vp.reduce.fmax.v8f32 + "llvm.intr.vp.reduce.fmax" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + // CHECK: call float @llvm.vp.reduce.fmin.v8f32 + "llvm.intr.vp.reduce.fmin" (%f, %C, %mask, %evl) : + (f32, vector<8xf32>, vector<8xi1>, i32) -> f32 + + llvm.return } @@ -650,3 +700,18 @@ // CHECK-DAG: declare <8 x float> @llvm.vp.fmul.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 // CHECK-DAG: declare <8 x float> @llvm.vp.fdiv.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 // CHECK-DAG: declare <8 x float> @llvm.vp.frem.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fneg.v8f32(<8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare <8 x float> @llvm.vp.fma.v8f32(<8 x float>, <8 x float>, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.add.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.mul.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.and.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.or.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.xor.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.smax.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.smin.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.umax.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare i32 @llvm.vp.reduce.umin.v8i32(i32, <8 x i32>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fadd.v8f32(float, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fmul.v8f32(float, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fmax.v8f32(float, <8 x float>, <8 x i1>, i32) #0 +// CHECK-DAG: declare float @llvm.vp.reduce.fmin.v8f32(float, <8 x float>, <8 x i1>, i32) #0