diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -308,6 +308,11 @@ let summary = "Lower the operations from the vector dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertVectorToLLVMPass()"; + let options = [ + Option<"reassociateFPReductions", "reassociate-fp-reductions", + "bool", /*default=*/"false", + "Allows llvm to reassociate floating-point reductions for speed"> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,8 +23,9 @@ LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect a set of patterns to convert from the Vector dialect to LLVM. -void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns); +void populateVectorToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool reassociateFPReductions = false); /// Create a pass to convert vector operations to the LLVMIR dialect. std::unique_ptr> createConvertVectorToLLVMPass(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -214,10 +214,30 @@ : LLVM_OneResultIntrOp<"experimental.vector.reduce." # mnem, [], [0], []>, Arguments<(ins LLVM_Type)>; -// LLVM vector reduction over a single vector, with an initial value. +// LLVM vector reduction over a single vector, with an initial value, +// and with permission to reassociate the reduction operations. class LLVM_VectorReductionV2 - : LLVM_OneResultIntrOp<"experimental.vector.reduce.v2." # mnem, - [0], [1], []>, - Arguments<(ins LLVM_Type, LLVM_Type)>; + : LLVM_OpBase, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type, LLVM_Type, + DefaultValuedAttr:$reassoc)> { + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn = llvm::Intrinsic::getDeclaration( + module, + llvm::Intrinsic::experimental_vector_reduce_v2_}] # mnem # [{, + { }] # StrJoin.lst, + ListIntSubst.lst)>.result # [{ + }); + auto operands = lookupValues(opInst.getOperands()); + llvm::FastMathFlags origFM = builder.getFastMathFlags(); + llvm::FastMathFlags tempFM = origFM; + tempFM.setAllowReassoc($reassoc); + builder.setFastMathFlags(tempFM); // set fastmath flag + $res = builder.CreateCall(fn, operands); + builder.setFastMathFlags(origFM); // restore fastmath flag + }]; +} #endif // LLVMIR_OP_BASE diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -255,9 +255,11 @@ class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) + LLVMTypeConverter &typeConverter, + bool reassociateFP) : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, - typeConverter) {} + typeConverter), + reassociateFPReductions(reassociateFP) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -302,7 +304,8 @@ op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0]); + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "mul") { // Optional accumulator (or one). Value acc = operands.size() > 1 @@ -311,7 +314,8 @@ op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); rewriter.replaceOpWithNewOp( - op, llvmType, acc, operands[0]); + op, llvmType, acc, operands[0], + rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "min") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); @@ -324,6 +328,9 @@ } return failure(); } + +private: + const bool reassociateFPReductions; }; class VectorShuffleOpConversion : public ConvertToLLVMPattern { @@ -1139,16 +1146,18 @@ /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool reassociateFPReductions) { MLIRContext *ctx = converter.getDialect()->getContext(); // clang-format off patterns.insert(ctx); + patterns.insert( + ctx, converter, reassociateFPReductions); patterns - .insert">) +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// CHECK: llvm.return %[[V]] : !llvm.float +// +// REASSOC-LABEL: llvm.func @reduce_add_f32( +// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float +// REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// REASSOC-SAME: {reassoc = true} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// REASSOC: llvm.return %[[V]] : !llvm.float +// +func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "add", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} + +// +// CHECK-LABEL: llvm.func @reduce_mul_f32( +// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float +// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fmul"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// CHECK: llvm.return %[[V]] : !llvm.float +// +// REASSOC-LABEL: llvm.func @reduce_mul_f32( +// REASSOC-SAME: %[[A:.*]]: !llvm<"<16 x float>">) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float +// REASSOC: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fmul"(%[[C]], %[[A]]) +// REASSOC-SAME: {reassoc = true} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float +// REASSOC: llvm.return %[[V]] : !llvm.float +// +func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction "mul", %arg0 : vector<16xf32> into f32 + return %0 : f32 +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -721,6 +721,7 @@ // CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.float, !llvm<"<16 x float>">) -> !llvm.float // CHECK: llvm.return %[[V]] : !llvm.float func @reduce_f64(%arg0: vector<16xf64>) -> f64 { @@ -731,6 +732,7 @@ // CHECK-SAME: %[[A:.*]]: !llvm<"<16 x double>">) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : !llvm.double // CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.v2.fadd"(%[[C]], %[[A]]) +// CHECK-SAME: {reassoc = false} : (!llvm.double, !llvm<"<16 x double>">) -> !llvm.double // CHECK: llvm.return %[[V]] : !llvm.double func @reduce_i32(%arg0: vector<16xi32>) -> i32 { diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -161,6 +161,10 @@ "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float // CHECK: call float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32 "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fadd.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2.fadd"(%arg0, %arg1) {reassoc = true} : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float + // CHECK: call reassoc float @llvm.experimental.vector.reduce.v2.fmul.f32.v8f32 + "llvm.intr.experimental.vector.reduce.v2.fmul"(%arg0, %arg1) {reassoc = true} : (!llvm.float, !llvm<"<8 x float>">) -> !llvm.float // CHECK: call i32 @llvm.experimental.vector.reduce.xor.v8i32 "llvm.intr.experimental.vector.reduce.xor"(%arg2) : (!llvm<"<8 x i32>">) -> !llvm.i32 llvm.return