diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -308,6 +308,11 @@ let summary = "Lower the operations from the vector dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertVectorToLLVMPass()"; + let options = [ + Option<"reassociateFPReductions", "reassociate-fp-reductions", + "bool", /*default=*/"false", + "Allows llvm to reassociate floating-point reductions for speed"> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,8 +23,9 @@ LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect a set of patterns to convert from the Vector dialect to LLVM. -void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns); +void populateVectorToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool reassociateFPReductions = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -220,4 +220,29 @@ [0], [1], []>, Arguments<(ins LLVM_Type, LLVM_Type)>; +// LLVM vector reduction over a single vector, with an initial value, +// and with permission to reassociate the reduction operations. +class LLVM_VectorReductionV2R + : LLVM_OpBase, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type, LLVM_Type)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, + llvm::Intrinsic::experimental_vector_reduce_v2_}] # mnem # [{, + { }] # StrJoin.lst, + ListIntSubst.lst)>.result # [{ + }); + auto operands = lookupValues(opInst.getOperands()); + llvm::FastMathFlags origFM = builder.getFastMathFlags(); + llvm::FastMathFlags tempFM = origFM; + tempFM.setAllowReassoc(true); + builder.setFastMathFlags(tempFM); + $res = builder.CreateCall(fn, operands); + builder.setFastMathFlags(origFM); + }]; +} + #endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -810,9 +810,14 @@ def LLVM_experimental_vector_reduce_umin : LLVM_VectorReduction<"umin">; def LLVM_experimental_vector_reduce_xor : LLVM_VectorReduction<"xor">; +// Floating-point add/mul reduction that may not be reassociated by backend. def LLVM_experimental_vector_reduce_v2_fadd : LLVM_VectorReductionV2<"fadd">; def LLVM_experimental_vector_reduce_v2_fmul : LLVM_VectorReductionV2<"fmul">; +// Floating-point add/mul reduction that may be reassociated by backend. +def LLVM_experimental_vector_reduce_v2r_fadd : LLVM_VectorReductionV2R<"fadd">; +def LLVM_experimental_vector_reduce_v2r_fmul : LLVM_VectorReductionV2R<"fmul">; + // // LLVM Matrix operations. // diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -255,9 +255,11 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) + LLVMTypeConverter &typeConverter, + bool reassociateFP) : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, - typeConverter) {} + typeConverter), + reassociateFPReductions(reassociateFP) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -301,8 +303,13 @@ : rewriter.create( op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); - rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0]); + if (reassociateFPReductions) + rewriter + .replaceOpWithNewOp( + op, llvmType, acc, operands[0]); + else + rewriter.replaceOpWithNewOp( + op, llvmType, acc, operands[0]); } else if (kind == "mul") { // Optional accumulator (or one). Value acc = operands.size() > 1 @@ -310,8 +317,13 @@ : rewriter.create( op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); - rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0]); + if (reassociateFPReductions) + rewriter + .replaceOpWithNewOp( + op, llvmType, acc, operands[0]); + else + rewriter.replaceOpWithNewOp( + op, llvmType, acc, operands[0]); } else if (kind == "min") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); @@ -324,6 +336,9 @@ } return failure(); } + +private: + const bool reassociateFPReductions; }; class VectorShuffleOpConversion : public ConvertToLLVMPattern { @@ -1139,16 +1154,18 @@ /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool reassociateFPReductions) { MLIRContext *ctx = converter.getDialect()->getContext(); // clang-format off patterns.insert(ctx); + patterns.insert( + ctx, converter, reassociateFPReductions); patterns - .insert">) +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.float +// +// REASSOC-LABEL: llvm.func @reduce_add_f32( +// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2r.fadd"(%[[C]], %[[A]]) +// REASSOC: llvm.return %[[V]] : !llvm.float +// +func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} + +// +// CHECK-LABEL: llvm.func @reduce_mul_f32( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fmul"(%[[C]], %[[A]]) +// CHECK: llvm.return %[[V]] : !llvm.float +// +// REASSOC-LABEL: llvm.func @reduce_mul_f32( +// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float +// REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2r.fmul"(%[[C]], %[[A]]) +// REASSOC: llvm.return %[[V]] : !llvm.float +// +func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "mul", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -161,6 +161,10 @@ "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float // CHECK: call float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32 "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fadd.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2r.fadd"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2r.fmul"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float // CHECK: call i32 @llvm.experimental.vector.reduce.xor.v8i32 "llvm.intr.experimental.vector.reduce.xor"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 llvm.return