diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -35,13 +35,22 @@ /// Extends the MLIR OpBuilder to provide methods for building common FIR /// patterns. -class FirOpBuilder : public mlir::OpBuilder { +class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener { public: explicit FirOpBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) - : OpBuilder{op}, kindMap{kindMap} {} + : OpBuilder{op, /*listener=*/this}, kindMap{kindMap} {} explicit FirOpBuilder(mlir::OpBuilder &builder, const fir::KindMapping &kindMap) - : OpBuilder{builder}, kindMap{kindMap} {} + : OpBuilder{builder}, kindMap{kindMap} { + setListener(this); + } + + // The listener self-reference has to be updated in case of copy-construction. + FirOpBuilder(const FirOpBuilder &other) + : OpBuilder{other}, kindMap{other.kindMap}, fastMathFlags{ + other.fastMathFlags} { + setListener(this); + } /// Get the current Region of the insertion point. mlir::Region &getRegion() { return *getBlock()->getParent(); } @@ -393,11 +402,31 @@ mlir::Value ub, mlir::Value step, mlir::Type type); + /// Set default FastMathFlags value for all operations + /// supporting mlir::arith::FastMathAttr that will be created + /// by this builder. + void setFastMathFlags(mlir::arith::FastMathFlags flags) { + fastMathFlags = flags; + } + /// Dump the current function. (debug) LLVM_DUMP_METHOD void dumpFunc(); private: + /// Set attributes (e.g. FastMathAttr) to \p op operation + /// based on the current attributes setting. + void setCommonAttributes(mlir::Operation *op) const; + + /// FirOpBuilder hook for creating new operation. + void notifyOperationInserted(mlir::Operation *op) override { + setCommonAttributes(op); + } + const KindMapping &kindMap; + + /// FastMathFlags that need to be set for operations that support + /// mlir::arith::FastMathAttr. + mlir::arith::FastMathFlags fastMathFlags{}; }; } // namespace fir diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -571,6 +571,18 @@ return create(loc, cmp, div, zero); } +void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const { + auto fmi = mlir::dyn_cast(*op); + if (!fmi) + return; + // TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged. + // For now set the attribute by the name. + llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName(); + if (fastMathFlags != mlir::arith::FastMathFlags::none) + op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get( + op->getContext(), fastMathFlags)); +} + //===--------------------------------------------------------------------===// // ExtendedValue inquiry helper implementation //===--------------------------------------------------------------------===// diff --git a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp --- a/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp +++ b/flang/unittests/Optimizer/Builder/FIRBuilderTest.cpp @@ -528,3 +528,58 @@ EXPECT_TRUE(fir::isDerivedWithLenParameters(array)); } } + +TEST_F(FIRBuilderTest, genArithFastMath) { + auto builder = getBuilder(); + auto ctx = builder.getContext(); + auto loc = builder.getUnknownLoc(); + + auto realTy = mlir::FloatType::getF32(ctx); + auto arg = builder.create(loc, realTy); + + // Test that FastMathFlags is 'none' by default. + mlir::Operation *op1 = builder.create(loc, arg, arg); + auto op1_fmi = + mlir::dyn_cast_or_null(op1); + EXPECT_TRUE(op1_fmi); + auto op1_fmf = op1_fmi.getFastMathFlagsAttr().getValue(); + EXPECT_EQ(op1_fmf, arith::FastMathFlags::none); + + // Test that the builder is copied properly. + fir::FirOpBuilder builder_copy(builder); + + arith::FastMathFlags FMF1 = + arith::FastMathFlags::contract | arith::FastMathFlags::reassoc; + builder.setFastMathFlags(FMF1); + arith::FastMathFlags FMF2 = + arith::FastMathFlags::nnan | arith::FastMathFlags::ninf; + builder_copy.setFastMathFlags(FMF2); + + // Modifying FastMathFlags for the copy must not affect the original builder. + mlir::Operation *op2 = builder.create(loc, arg, arg); + auto op2_fmi = + mlir::dyn_cast_or_null(op2); + EXPECT_TRUE(op2_fmi); + auto op2_fmf = op2_fmi.getFastMathFlagsAttr().getValue(); + EXPECT_EQ(op2_fmf, FMF1); + + // Modifying FastMathFlags for the original builder must not affect the copy. + mlir::Operation *op3 = + builder_copy.create(loc, arg, arg); + auto op3_fmi = + mlir::dyn_cast_or_null(op3); + EXPECT_TRUE(op3_fmi); + auto op3_fmf = op3_fmi.getFastMathFlagsAttr().getValue(); + EXPECT_EQ(op3_fmf, FMF2); + + // Test that the builder copy inherits FastMathFlags from the original. + fir::FirOpBuilder builder_copy2(builder); + + mlir::Operation *op4 = + builder_copy2.create(loc, arg, arg); + auto op4_fmi = + mlir::dyn_cast_or_null(op4); + EXPECT_TRUE(op4_fmi); + auto op4_fmf = op4_fmi.getFastMathFlagsAttr().getValue(); + EXPECT_EQ(op4_fmf, FMF1); +}