diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Support/MathExtras.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" @@ -393,6 +394,27 @@ return result; } +/// Create lowering of minf/maxf op. We cannot use llvm.maximum/llvm.minimum +/// with vector types. +static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, + Value rhs, bool isMin) { + auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + Type i1Type = builder.getI1Type(); + if (auto vecType = lhs.getType().dyn_cast()) + i1Type = VectorType::get(vecType.getShape(), i1Type); + Value cmp = builder.create( + loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, + lhs, rhs); + Value sel = builder.create(loc, cmp, lhs, rhs); + Value isNan = builder.create( + loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); + Value nan = builder.create( + loc, lhs.getType(), + builder.getFloatAttr(floatType, + APFloat::getQNaN(floatType.getFloatSemantics()))); + return builder.create(loc, isNan, nan, sel); +} + /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertOpToLLVMPattern { @@ -497,18 +519,25 @@ rewriter.replaceOpWithNewOp( reductionOp, llvmType, acc, operand, rewriter.getBoolAttr(reassociateFPReductions)); - } else if (kind == vector::CombiningKind::MINF) + } else if (kind == vector::CombiningKind::MINF) { // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle // NaNs/-0.0/+0.0 in the same way. - rewriter.replaceOpWithNewOp(reductionOp, - llvmType, operand); - else if (kind == vector::CombiningKind::MAXF) + Value result = + rewriter.create(loc, llvmType, operand); + if (acc) + result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); + rewriter.replaceOp(reductionOp, result); + } else if (kind == vector::CombiningKind::MAXF) { // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle // NaNs/-0.0/+0.0 in the same way. - rewriter.replaceOpWithNewOp(reductionOp, - llvmType, operand); - else + Value result = + rewriter.create(loc, llvmType, operand); + if (acc) + result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); + rewriter.replaceOp(reductionOp, result); + } else return failure(); + return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1213,6 +1213,38 @@ // ----- +func.func @reduce_fmax_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xf32> into f32 + return %0 : f32 +} +// CHECK-LABEL: @reduce_fmax_f32( +// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmax"(%[[A]]) : (vector<16xf32>) -> f32 +// CHECK: %[[C0:.*]] = llvm.fcmp "ogt" %[[V]], %[[B]] : f32 +// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32 +// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32 +// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32 +// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32 +// CHECK: return %[[R]] : f32 + +// ----- + +func.func @reduce_fmin_f32(%arg0: vector<16xf32>, %arg1: f32) -> f32 { + %0 = vector.reduction , %arg0, %arg1 : vector<16xf32> into f32 + return %0 : f32 +} +// CHECK-LABEL: @reduce_fmin_f32( +// CHECK-SAME: %[[A:.*]]: vector<16xf32>, %[[B:.*]]: f32) +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmin"(%[[A]]) : (vector<16xf32>) -> f32 +// CHECK: %[[C0:.*]] = llvm.fcmp "olt" %[[V]], %[[B]] : f32 +// CHECK: %[[S0:.*]] = llvm.select %[[C0]], %[[V]], %[[B]] : i1, f32 +// CHECK: %[[C1:.*]] = llvm.fcmp "uno" %[[V]], %[[B]] : f32 +// CHECK: %[[NAN:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32 +// CHECK: %[[R:.*]] = llvm.select %[[C1]], %[[NAN]], %[[S0]] : i1, f32 +// CHECK: return %[[R]] : f32 + +// ----- + func.func @reduce_minui_i32(%arg0: vector<16xi32>) -> i32 { %0 = vector.reduction , %arg0 : vector<16xi32> into i32 return %0 : i32