diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h --- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h +++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h @@ -20,11 +20,13 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" +#include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/Dialect/Support/KindMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/DenseMap.h" #include +#include namespace fir { class AbstractArrayBox; @@ -40,11 +42,15 @@ /// patterns. class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener { public: - explicit FirOpBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) - : OpBuilder{op, /*listener=*/this}, kindMap{kindMap} {} - explicit FirOpBuilder(mlir::OpBuilder &builder, - const fir::KindMapping &kindMap) - : OpBuilder(builder), OpBuilder::Listener(), kindMap{kindMap} { + explicit FirOpBuilder(mlir::Operation *op, fir::KindMapping kindMap) + : OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)} {} + explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap) + : OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} { + setListener(this); + } + explicit FirOpBuilder(mlir::OpBuilder &builder, mlir::ModuleOp mod) + : OpBuilder(builder), OpBuilder::Listener(), + kindMap{getKindMapping(mod)} { setListener(this); } @@ -55,6 +61,12 @@ setListener(this); } + FirOpBuilder(FirOpBuilder &&other) + : OpBuilder(other), OpBuilder::Listener(), + kindMap{std::move(other.kindMap)}, fastMathFlags{other.fastMathFlags} { + setListener(this); + } + /// Get the current Region of the insertion point. mlir::Region &getRegion() { return *getBlock()->getParent(); } @@ -457,7 +469,7 @@ /// based on the current attributes setting. void setCommonAttributes(mlir::Operation *op) const; - const KindMapping &kindMap; + KindMapping kindMap; /// FastMathFlags that need to be set for operations that support /// mlir::arith::FastMathAttr. diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp --- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp +++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp @@ -228,8 +228,7 @@ if (embox.getHost()) { // Create the thunk. auto module = embox->getParentOfType(); - fir::KindMapping kindMap = getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); auto loc = embox.getLoc(); mlir::Type i8Ty = builder.getI8Type(); mlir::Type i8Ptr = builder.getRefType(i8Ty); diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -93,7 +93,7 @@ // We may need to call stacksave/stackrestore later, so // create the FuncOps beforehand. - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod)); + fir::FirOpBuilder builder(rewriter, mod); builder.setInsertionPointToStart(mod.getBody()); stackSaveFn = fir::factory::getLlvmStackSave(builder); stackRestoreFn = fir::factory::getLlvmStackRestore(builder); @@ -340,8 +340,7 @@ } mlir::Type funcPointerType = tuple.getType(0); mlir::Type lenType = tuple.getType(1); - fir::KindMapping kindMap = fir::getKindMapping(module); - fir::FirOpBuilder builder(*rewriter, kindMap); + fir::FirOpBuilder builder(*rewriter, module); auto [funcPointer, len] = fir::factory::extractCharacterProcedureTuple(builder, loc, oper); @@ -848,8 +847,7 @@ func.front().addArgument(trailingTys[fixup.second], loc); auto tupleType = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); - fir::KindMapping kindMap = fir::getKindMapping(getModule()); - fir::FirOpBuilder builder(*rewriter, kindMap); + fir::FirOpBuilder builder(*rewriter, getModule()); auto tuple = fir::factory::createCharacterProcedureTuple( builder, loc, tupleType, newProcPointerArg, newLenArg); func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -151,7 +151,7 @@ mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = asExpr->getLoc(); auto module = asExpr->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, module); if (asExpr.isMove()) { // Move variable storage for the hlfir.expr buffer. mlir::Value bufferizedExpr = packageBufferizedExpr( @@ -179,7 +179,7 @@ mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = shapeOf.getLoc(); mlir::ModuleOp mod = shapeOf->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(mod)); + fir::FirOpBuilder builder(rewriter, mod); mlir::Value shape; hlfir::Entity bufferizedExpr{getBufferizedExprStorage(adaptor.getExpr())}; diff --git a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/ConvertToFIR.cpp @@ -67,7 +67,7 @@ hlfir::Entity lhs(assignOp.getLhs()); hlfir::Entity rhs(assignOp.getRhs()); auto module = assignOp->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, module); if (rhs.getType().isa()) { mlir::emitError(loc, "hlfir must be bufferized with --bufferize-hlfir " diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp @@ -922,7 +922,7 @@ static void lower(hlfir::OrderedAssignmentTreeOpInterface root, mlir::PatternRewriter &rewriter, hlfir::Schedule &schedule) { auto module = root->getParentOfType(); - fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + fir::FirOpBuilder builder(rewriter, module); OrderedAssignmentRewriter assignmentRewriter(builder, root); for (auto &run : schedule) assignmentRewriter.lowerRun(run); diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -173,8 +173,7 @@ if (isResultBuiltinCPtr) { mlir::Value save = saveResult.getMemref(); auto module = op->template getParentOfType(); - fir::KindMapping kindMap = fir::getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, save, result.getType()); rewriter.create(loc, newOp->getResult(0), saveAddr); @@ -226,8 +225,7 @@ if (fir::isa_builtin_cptr_type(returnedValue.getType())) { rewriter.eraseOp(load); auto module = ret->getParentOfType(); - fir::KindMapping kindMap = fir::getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, resultStorage, returnedValue.getType()); mlir::Value retValue = rewriter.create( diff --git a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp --- a/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp +++ b/flang/lib/Optimizer/Transforms/ArrayValueCopy.cpp @@ -850,8 +850,7 @@ auto triples = sliceOp.getTriples(); const std::size_t tripleSize = triples.size(); auto module = arrLoad->getParentOfType(); - fir::KindMapping kindMap = getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); size = builder.genExtentFromTriplet(loc, triples[tripleSize - 3], triples[tripleSize - 2], triples[tripleSize - 1], idxTy); @@ -937,8 +936,7 @@ assert(seqTy && seqTy.isa()); const auto dimension = seqTy.cast().getDimension(); auto module = load->getParentOfType(); - fir::KindMapping kindMap = getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); auto typeparams = getTypeParamsIfRawData(loc, builder, load, alloc.getType()); mlir::Value result = rewriter.create( loc, eleTy, alloc, shape, slice, @@ -1002,8 +1000,7 @@ // Reverse the indices so they are in column-major order. std::reverse(indices.begin(), indices.end()); auto module = arrLoad->getParentOfType(); - fir::KindMapping kindMap = getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); auto fromAddr = rewriter.create( loc, getEleTy(src.getType()), src, shapeOp, CopyIn && copyUsingSlice ? sliceOp : mlir::Value{}, @@ -1041,8 +1038,7 @@ if (auto charTy = eleTy.dyn_cast()) { assert(load.getMemref().getType().isa()); auto module = load->getParentOfType(); - fir::KindMapping kindMap = getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); return {getCharacterLen(loc, builder, load, charTy)}; } TODO(loc, "unhandled dynamic type parameters"); @@ -1094,14 +1090,12 @@ loc, fir::BoxType::get(allocmem.getType()), allocmem, shape, /*slice=*/mlir::Value{}, typeParams); auto module = load->getParentOfType(); - fir::KindMapping kindMap = getKindMapping(module); - FirOpBuilder builder(rewriter, kindMap); + FirOpBuilder builder(rewriter, module); runtime::genDerivedTypeInitialize(builder, loc, box); // Any allocatable component that may have been allocated must be // deallocated during the clean-up. auto cleanup = [=](mlir::PatternRewriter &r) { - fir::KindMapping kindMap = getKindMapping(module); - FirOpBuilder builder(r, kindMap); + FirOpBuilder builder(r, module); runtime::genDerivedTypeDestroy(builder, loc, box); r.create(loc, allocmem); }; diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -178,7 +178,7 @@ return; // If we get here, there are loops to process. - fir::FirOpBuilder builder{module, kindMap}; + fir::FirOpBuilder builder{module, std::move(kindMap)}; mlir::Location loc = builder.getUnknownLoc(); mlir::IndexType idxTy = builder.getIndexType(); diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -680,8 +680,7 @@ fir::AllocMemOp &oldAlloc, mlir::PatternRewriter &rewriter) const { auto oldPoint = rewriter.saveInsertionPoint(); auto mod = oldAlloc->getParentOfType(); - fir::KindMapping kindMap = fir::getKindMapping(mod); - fir::FirOpBuilder builder{rewriter, kindMap}; + fir::FirOpBuilder builder{rewriter, mod}; mlir::func::FuncOp stackSaveFn = fir::factory::getLlvmStackSave(builder); mlir::SymbolRefAttr stackSaveSym =