diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -573,35 +573,31 @@ return result; } -/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum -/// with vector types. -static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, - Value rhs, bool isMin) { - auto floatType = cast(getElementTypeOrSelf(lhs.getType())); - Type i1Type = builder.getI1Type(); - if (auto vecType = dyn_cast(lhs.getType())) - i1Type = VectorType::get(vecType.getShape(), i1Type); - Value cmp = builder.create( - loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, - lhs, rhs); - Value sel = builder.create(loc, cmp, lhs, rhs); - Value isNan = builder.create( - loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); - Value nan = builder.create( - loc, lhs.getType(), - builder.getFloatAttr(floatType, - APFloat::getQNaN(floatType.getFloatSemantics()))); - return builder.create(loc, isNan, nan, sel); -} +namespace { +template +struct VectorToScalarMapper; +template <> +struct VectorToScalarMapper { + using Type = LLVM::MaximumOp; +}; +template <> +struct VectorToScalarMapper { + using Type = LLVM::MinimumOp; +}; +} // namespace template -static Value createFPReductionComparisonOpLowering( - ConversionPatternRewriter &rewriter, Location loc, Type llvmType, - Value vectorOperand, Value accumulator, bool isMin) { +static Value +createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, Value accumulator) { Value result = rewriter.create(loc, llvmType, vectorOperand); - if (accumulator) - result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin); + if (accumulator) { + result = + rewriter.create::Type>( + loc, result, accumulator); + } return result; } @@ -774,17 +770,13 @@ ReductionNeutralFPOne>( rewriter, loc, llvmType, operand, acc, reassociateFPReductions); } else if (kind == vector::CombiningKind::MINF) { - // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle - // NaNs/-0.0/+0.0 in the same way. - result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc, - /*isMin=*/true); + result = + createFPReductionComparisonOpLowering( + rewriter, loc, llvmType, operand, acc); } else if (kind == vector::CombiningKind::MAXF) { - // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle - // NaNs/-0.0/+0.0 in the same way. - result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc, - /*isMin=*/false); + result = + createFPReductionComparisonOpLowering( + rewriter, loc, llvmType, operand, acc); } else return failure(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1374,12 +1374,8 @@ } // CHECK-LABEL: @reduce_fmax_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32) -// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmax(%[[A]]) : (vector<16xf32>) -> f32 -// CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32 -// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32 -// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32 -// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32 -// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32 +// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmaximum(%[[A]]) : (vector<16xf32>) -> f32 +// CHECK: %[[R:.*]] = llvm.intr.maximum(%[[V]], %[[B]]) : (f32, f32) -> f32 // CHECK: return %[[R]] : f32 // ----- @@ -1390,12 +1386,8 @@ } // CHECK-LABEL: @reduce_fmin_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32) -// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fmin(%[[A]]) : (vector<16xf32>) -> f32 -// CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32 -// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32 -// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32 -// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32 -// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32 +// CHECK: %[[V:.*]] = llvm.intr.vector.reduce.fminimum(%[[A]]) : (vector<16xf32>) -> f32 +// CHECK: %[[R:.*]] = llvm.intr.minimum(%[[V]], %[[B]]) : (f32, f32) -> f32 // CHECK: return %[[R]] : f32 // -----