diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -15,13 +15,17 @@ #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APFloat.h" #include "llvm/Support/Casting.h" #include @@ -603,6 +607,51 @@ return result; } +/// Reduction neutral classes for overloading +class MaskNeutralFMaximum {}; +class MaskNeutralFMinimum {}; + +/// Get the mask neutral floating point maximum value +static llvm::APFloat +getMaskNeutralValue(MaskNeutralFMaximum, + const llvm::fltSemantics &floatSemantics) { + return llvm::APFloat::getSmallest(floatSemantics, /*Negative=*/true); +} +/// Get the mask neutral floating point minimum value +static llvm::APFloat +getMaskNeutralValue(MaskNeutralFMinimum, + const llvm::fltSemantics &floatSemantics) { + return llvm::APFloat::getLargest(floatSemantics, /*Negative=*/false); +} + +/// Create the mask neutral floating point MLIR vector constant +template +static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Type vectorType) { + const auto &floatSemantics = cast(llvmType).getFloatSemantics(); + auto value = getMaskNeutralValue(MaskNeutral{}, floatSemantics); + auto denseValue = + DenseElementsAttr::get(vectorType.cast(), value); + return rewriter.create(loc, vectorType, denseValue); +} + +/// Lowers masked `fmaximum` and `fminimum` reductions using the non-masked +/// intrinsics. It is a workaround to overcome the lack of masked intrinsics for +/// `fmaximum`/`fminimum`. +/// More information: https://github.com/llvm/llvm-project/issues/64940 +template +static Value lowerMaskedReductionWithRegular( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, Value mask) { + const Value vectorMaskNeutral = createMaskNeutralValue( + rewriter, loc, llvmType, vectorOperand.getType()); + const Value selectedVectorByMask = rewriter.create( + loc, mask, vectorOperand, vectorMaskNeutral); + return createFPReductionComparisonOpLowering( + rewriter, loc, llvmType, selectedVectorByMask, accumulator); +} + /// Overloaded methods to lower a reduction to an llvm instrinsic that requires /// a start value. This start value format spans across fp reductions without /// mask and all the masked reduction intrinsics. @@ -903,10 +952,16 @@ ReductionNeutralFPMin>( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; - default: - return rewriter.notifyMatchFailure( - maskOp, - "lowering to LLVM is not implemented for this masked operation"); + case CombiningKind::MAXIMUMF: + result = lowerMaskedReductionWithRegular( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case CombiningKind::MINIMUMF: + result = lowerMaskedReductionWithRegular( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; } // Replace `vector.mask` operation altogether. diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir @@ -101,6 +101,36 @@ // ----- +func.func @masked_reduce_maximumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_maximumf_f32( +// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 { +// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<-1.401300e-45> : vector<16xf32>) : vector<16xf32> +// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32> +// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fmaximum(%[[MASKED]]) : (vector<16xf32>) -> f32 +// CHECK: return %[[RESULT]] + +// ----- + +func.func @masked_reduce_minimumf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_minimumf_f32( +// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 { +// CHECK: %[[MASK_NEUTRAL:.*]] = llvm.mlir.constant(dense<3.40282347E+38> : vector<16xf32>) : vector<16xf32> +// CHECK: %[[MASKED:.*]] = llvm.select %[[MASK]], %[[INPUT]], %[[MASK_NEUTRAL]] : vector<16xi1>, vector<16xf32> +// CHECK: %[[RESULT:.*]] = llvm.intr.vector.reduce.fminimum(%[[MASKED]]) : (vector<16xf32>) -> f32 +// CHECK: return %[[RESULT]] + +// ----- + func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 return %0 : i8