diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -381,16 +381,27 @@ // // LLVM vector reduction over a single vector. -class LLVM_VecReductionBase +class LLVM_VecReductionBase : LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0], - [Pure, SameOperandsAndResultElementType]>, - Arguments<(ins LLVM_VectorOf)>; + [Pure, SameOperandsAndResultElementType], + requiresFastmath> { + dag commonArgs = (ins LLVM_VectorOf:$in); +} class LLVM_VecReductionF - : LLVM_VecReductionBase; + : LLVM_VecReductionBase { + dag fmfArg = ( + ins DefaultValuedAttr:$fastmathFlags); + let arguments = !con(commonArgs, fmfArg); + + let assemblyFormat = "`(` operands `)` custom(attr-dict) `:` " + "functional-type(operands, results)"; +} class LLVM_VecReductionI - : LLVM_VecReductionBase; + : LLVM_VecReductionBase { + let arguments = commonArgs; +} // LLVM vector reduction over a single vector, with an initial value, // and with permission to reassociate the reduction operations. diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1374,7 +1374,7 @@ } // CHECK-LABEL: @reduce_fmax_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32) -// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmax"(%[[A]]) : (vector<16xf32>) -> f32 +// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32 // CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32 // CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32 // CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32 @@ -1390,7 +1390,7 @@ } // CHECK-LABEL: @reduce_fmin_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32) -// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmin"(%[[A]]) : (vector<16xf32>) -> f32 +// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32 // CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32 // CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32 // CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32 diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -516,6 +516,11 @@ %11 = llvm.intr.sin(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 // CHECK: {{.*}} = llvm.intr.sin(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 %12 = llvm.intr.sin(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + +// CHECK: {{.*}} = llvm.intr.vector.reduce.fmin(%arg3) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 + %13 = llvm.intr.vector.reduce.fmin(%arg3) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 +// CHECK: {{.*}} = llvm.intr.vector.reduce.fmax(%arg3) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 + %14 = llvm.intr.vector.reduce.fmax(%arg3) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 return } diff --git a/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir b/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir --- a/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir +++ b/mlir/test/Integration/Dialect/LLVMIR/CPU/test-vector-reductions-fp.mlir @@ -23,13 +23,13 @@ %12 = llvm.mlir.constant(3 : i64) : i64 %v = llvm.insertelement %3, %11[%12 : i64] : vector<4xf32> - %max = "llvm.intr.vector.reduce.fmax"(%v) + %max = llvm.intr.vector.reduce.fmax(%v) : (vector<4xf32>) -> f32 llvm.call @printF32(%max) : (f32) -> () llvm.call @printNewline() : () -> () // CHECK: 4 - %min = "llvm.intr.vector.reduce.fmin"(%v) + %min = llvm.intr.vector.reduce.fmin(%v) : (vector<4xf32>) -> f32 llvm.call @printF32(%min) : (f32) -> () llvm.call @printNewline() : () -> () diff --git a/mlir/test/Target/LLVMIR/Import/fastmath.ll b/mlir/test/Target/LLVMIR/Import/fastmath.ll --- a/mlir/test/Target/LLVMIR/Import/fastmath.ll +++ b/mlir/test/Target/LLVMIR/Import/fastmath.ll @@ -41,9 +41,11 @@ declare float @llvm.powi.f32.i32(float, i32) declare float @llvm.pow.f32(float, float) declare float @llvm.fmuladd.f32(float, float, float) +declare float @llvm.vector.reduce.fmin.v2f32(<2 x float>) +declare float @llvm.vector.reduce.fmax.v2f32(<2 x float>) ; CHECK-LABEL: @fastmath_intr -define void @fastmath_intr(float %arg1, i32 %arg2) { +define void @fastmath_intr(float %arg1, i32 %arg2, <2 x float> %arg3) { ; CHECK: llvm.intr.exp(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 %1 = call nnan ninf float @llvm.exp.f32(float %arg1) ; CHECK: llvm.intr.powi(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, i32) -> f32 @@ -52,5 +54,10 @@ %3 = call fast float @llvm.pow.f32(float %arg1, float %arg1) ; CHECK: llvm.intr.fmuladd(%{{.*}}, %{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32, f32) -> f32 %4 = call fast float @llvm.fmuladd.f32(float %arg1, float %arg1, float %arg1) + ; CHECK: %{{.*}} = llvm.intr.vector.reduce.fmin({{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 + %5 = call nnan float @llvm.vector.reduce.fmin.v2f32(<2 x float> %arg3) + ; CHECK: %{{.*}} = llvm.intr.vector.reduce.fmax({{.*}}) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 + %6 = call nnan float @llvm.vector.reduce.fmax.v2f32(<2 x float> %arg3) + ret void } diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -231,9 +231,9 @@ %4 = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %2) ; CHECK: "llvm.intr.vector.reduce.and"(%{{.*}}) : (vector<8xi32>) -> i32 %5 = call i32 @llvm.vector.reduce.and.v8i32(<8 x i32> %2) - ; CHECK: "llvm.intr.vector.reduce.fmax"(%{{.*}}) : (vector<8xf32>) -> f32 + ; CHECK: llvm.intr.vector.reduce.fmax(%{{.*}}) : (vector<8xf32>) -> f32 %6 = call float @llvm.vector.reduce.fmax.v8f32(<8 x float> %1) - ; CHECK: "llvm.intr.vector.reduce.fmin"(%{{.*}}) : (vector<8xf32>) -> f32 + ; CHECK: llvm.intr.vector.reduce.fmin(%{{.*}}) : (vector<8xf32>) -> f32 %7 = call float @llvm.vector.reduce.fmin.v8f32(<8 x float> %1) ; CHECK: "llvm.intr.vector.reduce.mul"(%{{.*}}) : (vector<8xi32>) -> i32 %8 = call i32 @llvm.vector.reduce.mul.v8i32(<8 x i32> %2) diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -251,9 +251,9 @@ // CHECK: call i32 @llvm.vector.reduce.and.v8i32 "llvm.intr.vector.reduce.and"(%arg2) : (vector<8xi32>) -> i32 // CHECK: call float @llvm.vector.reduce.fmax.v8f32 - "llvm.intr.vector.reduce.fmax"(%arg1) : (vector<8xf32>) -> f32 + llvm.intr.vector.reduce.fmax(%arg1) : (vector<8xf32>) -> f32 // CHECK: call float @llvm.vector.reduce.fmin.v8f32 - "llvm.intr.vector.reduce.fmin"(%arg1) : (vector<8xf32>) -> f32 + llvm.intr.vector.reduce.fmin(%arg1) : (vector<8xf32>) -> f32 // CHECK: call i32 @llvm.vector.reduce.mul.v8i32 "llvm.intr.vector.reduce.mul"(%arg2) : (vector<8xi32>) -> i32 // CHECK: call i32 @llvm.vector.reduce.or.v8i32 diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir --- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir @@ -172,7 +172,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 { // expected-error @below{{op operand #0 must be LLVM dialect-compatible vector of floating-point}} - %0 = "llvm.intr.vector.reduce.fmax"(%arg0) : (vector<4xi32>) -> i32 + %0 = llvm.intr.vector.reduce.fmax(%arg0) : (vector<4xi32>) -> i32 llvm.return %0 : i32 } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1871,7 +1871,7 @@ llvm.func @fastmathFlagsFunc(f32) -> f32 // CHECK-LABEL: @fastmathFlags -llvm.func @fastmathFlags(%arg0: f32) { +llvm.func @fastmathFlags(%arg0: f32, %arg1 : vector<2xf32>) { // CHECK: {{.*}} = fadd nnan ninf float {{.*}}, {{.*}} // CHECK: {{.*}} = fsub nnan ninf float {{.*}}, {{.*}} // CHECK: {{.*}} = fmul nnan ninf float {{.*}}, {{.*}} @@ -1918,6 +1918,12 @@ %19 = "llvm.intr.powi"(%arg0, %exp) {fastmathFlags = #llvm.fastmath} : (f32, i32) -> f32 // CHECK: call afn float @llvm.powi.f32.i32(float {{.*}}, i32 {{.*}}) %20 = "llvm.intr.powi"(%arg0, %exp) {fastmathFlags = #llvm.fastmath} : (f32, i32) -> f32 + +// CHECK: call nnan float @llvm.vector.reduce.fmax.v2f32(<2 x float> {{.*}}) +// CHECK: call nnan float @llvm.vector.reduce.fmin.v2f32(<2 x float> {{.*}}) + %21 = llvm.intr.vector.reduce.fmax(%arg1) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 + %22 = llvm.intr.vector.reduce.fmin(%arg1) {fastmathFlags = #llvm.fastmath} : (vector<2xf32>) -> f32 + llvm.return }