diff --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h --- a/flang/include/flang/Optimizer/Transforms/Passes.h +++ b/flang/include/flang/Optimizer/Transforms/Passes.h @@ -42,6 +42,7 @@ #define GEN_PASS_DECL_MEMORYALLOCATIONOPT #define GEN_PASS_DECL_SIMPLIFYREGIONLITE #define GEN_PASS_DECL_ALGEBRAICSIMPLIFICATION +#define GEN_PASS_DECL_POLYMORPHICOPCONVERSION #include "flang/Optimizer/Transforms/Passes.h.inc" std::unique_ptr createAbstractResultOnFuncOptPass(); @@ -68,6 +69,7 @@ std::unique_ptr createAlgebraicSimplificationPass(); std::unique_ptr createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config); +std::unique_ptr createPolymorphicOpConversionPass(); // declarative passes #define GEN_PASS_REGISTRATION diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -271,4 +271,18 @@ let constructor = "::fir::createAlgebraicSimplificationPass()"; } +def PolymorphicOpConversion : Pass<"fir-polymorphic-op", "::mlir::func::FuncOp"> { + let summary = + "Simplify operations on polymorphic types"; + let description = [{ + This pass breaks up the lowering of operations on polymorphic types by + introducing an intermediate FIR level that simplifies code geneation. + }]; + let constructor = "::fir::createPolymorphicOpConversionPass()"; + let dependentDialects = [ + "fir::FIROpsDialect", "mlir::func::FuncDialect" + ]; +} + + #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc --- a/flang/include/flang/Tools/CLOptions.inc +++ b/flang/include/flang/Tools/CLOptions.inc @@ -199,6 +199,9 @@ pm.addPass(fir::createSimplifyRegionLitePass()); pm.addPass(mlir::createCSEPass()); + // Polymorphic types + pm.addPass(fir::createPolymorphicOpConversionPass()); + // convert control flow to CFG form fir::addCfgConversionPass(pm); pm.addPass(mlir::createConvertSCFToCFPass()); diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ AlgebraicSimplification.cpp SimplifyIntrinsics.cpp AddDebugFoundation.cpp + PolymorphicOpConversion.cpp DEPENDS FIRBuilder diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp --- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp +++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp @@ -22,7 +22,6 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Support/CommandLine.h" -#include namespace fir { #define GEN_PASS_DEF_CFGCONVERSION @@ -308,278 +307,20 @@ } }; -/// SelectTypeOp converted to an if-then-else chain -/// -/// This lowers the test conditions to calls into the runtime. -class CfgSelectTypeConv : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - CfgSelectTypeConv(mlir::MLIRContext *ctx, std::mutex *moduleMutex) - : mlir::OpConversionPattern(ctx), - moduleMutex(moduleMutex) {} - - mlir::LogicalResult - matchAndRewrite(fir::SelectTypeOp selectType, OpAdaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto operands = adaptor.getOperands(); - auto typeGuards = selectType.getCases(); - unsigned typeGuardNum = typeGuards.size(); - auto selector = selectType.getSelector(); - auto loc = selectType.getLoc(); - auto mod = selectType.getOperation()->getParentOfType(); - fir::KindMapping kindMap = fir::getKindMapping(mod); - - // Order type guards so the condition and branches are done to respect the - // Execution of SELECT TYPE construct as described in the Fortran 2018 - // standard 11.1.11.2 point 4. - // 1. If a TYPE IS type guard statement matches the selector, the block - // following that statement is executed. - // 2. Otherwise, if exactly one CLASS IS type guard statement matches the - // selector, the block following that statement is executed. - // 3. Otherwise, if several CLASS IS type guard statements match the - // selector, one of these statements will inevitably specify a type that - // is an extension of all the types specified in the others; the block - // following that statement is executed. - // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block - // following that statement is executed. - // 5. Otherwise, no block is executed. - - llvm::SmallVector orderedTypeGuards; - llvm::SmallVector orderedClassIsGuards; - unsigned defaultGuard = typeGuardNum - 1; - - // The following loop go through the type guards in the fir.select_type - // operation and sort them into two lists. - // - All the TYPE IS type guard are added in order to the orderedTypeGuards - // list. This list is used at the end to generate the if-then-else ladder. - // - CLASS IS type guard are added in a separate list. If a CLASS IS type - // guard type extends a type already present, the type guard is inserted - // before in the list to respect point 3. above. Otherwise it is just - // added in order at the end. - for (unsigned t = 0; t < typeGuardNum; ++t) { - if (auto a = typeGuards[t].dyn_cast()) { - orderedTypeGuards.push_back(t); - continue; - } - - if (auto a = typeGuards[t].dyn_cast()) { - if (auto recTy = a.getType().dyn_cast()) { - auto dt = mod.lookupSymbol(recTy.getName()); - assert(dt && "dispatch table not found"); - llvm::SmallSet ancestors = - collectAncestors(dt, mod); - if (!ancestors.empty()) { - auto it = orderedClassIsGuards.begin(); - while (it != orderedClassIsGuards.end()) { - fir::SubclassAttr sAttr = - typeGuards[*it].dyn_cast(); - if (auto ty = sAttr.getType().dyn_cast()) { - if (ancestors.contains(ty.getName())) - break; - } - ++it; - } - if (it != orderedClassIsGuards.end()) { - // Parent type is present so place it before. - orderedClassIsGuards.insert(it, t); - continue; - } - } - } - orderedClassIsGuards.push_back(t); - } - } - orderedTypeGuards.append(orderedClassIsGuards); - orderedTypeGuards.push_back(defaultGuard); - assert(orderedTypeGuards.size() == typeGuardNum && - "ordered type guard size doesn't match number of type guards"); - - for (unsigned idx : orderedTypeGuards) { - auto *dest = selectType.getSuccessor(idx); - std::optional destOps = - selectType.getSuccessorOperands(operands, idx); - if (typeGuards[idx].dyn_cast()) - rewriter.replaceOpWithNewOp(selectType, dest); - else if (mlir::failed(genTypeLadderStep(loc, selector, typeGuards[idx], - dest, destOps, mod, rewriter, - kindMap))) - return mlir::failure(); - } - return mlir::success(); - } - - llvm::SmallSet - collectAncestors(fir::DispatchTableOp dt, mlir::ModuleOp mod) const { - llvm::SmallSet ancestors; - if (!dt.getParent().has_value()) - return ancestors; - while (dt.getParent().has_value()) { - ancestors.insert(*dt.getParent()); - dt = mod.lookupSymbol(*dt.getParent()); - } - return ancestors; - } - - // Generate comparison of type descriptor addresses. - mlir::Value genTypeDescCompare(mlir::Location loc, mlir::Value selector, - mlir::Type ty, mlir::ModuleOp mod, - mlir::PatternRewriter &rewriter) const { - assert(ty.isa() && "expect fir.record type"); - fir::RecordType recTy = ty.dyn_cast(); - std::string typeDescName = - fir::NameUniquer::getTypeDescriptorName(recTy.getName()); - auto typeDescGlobal = mod.lookupSymbol(typeDescName); - if (!typeDescGlobal) - return {}; - auto typeDescAddr = rewriter.create( - loc, fir::ReferenceType::get(typeDescGlobal.getType()), - typeDescGlobal.getSymbol()); - auto intPtrTy = rewriter.getIndexType(); - mlir::Type tdescType = - fir::TypeDescType::get(mlir::NoneType::get(rewriter.getContext())); - mlir::Value selectorTdescAddr = - rewriter.create(loc, tdescType, selector); - auto typeDescInt = - rewriter.create(loc, intPtrTy, typeDescAddr); - auto selectorTdescInt = - rewriter.create(loc, intPtrTy, selectorTdescAddr); - return rewriter.create( - loc, mlir::arith::CmpIPredicate::eq, typeDescInt, selectorTdescInt); - } - - static int getTypeCode(mlir::Type ty, fir::KindMapping &kindMap) { - if (auto intTy = ty.dyn_cast()) - return fir::integerBitsToTypeCode(intTy.getWidth()); - if (auto floatTy = ty.dyn_cast()) - return fir::realBitsToTypeCode(floatTy.getWidth()); - if (auto logicalTy = ty.dyn_cast()) - return fir::logicalBitsToTypeCode( - kindMap.getLogicalBitsize(logicalTy.getFKind())); - if (fir::isa_complex(ty)) { - if (auto cmplxTy = ty.dyn_cast()) - return fir::complexBitsToTypeCode( - cmplxTy.getElementType().cast().getWidth()); - auto cmplxTy = ty.cast(); - return fir::complexBitsToTypeCode( - kindMap.getRealBitsize(cmplxTy.getFKind())); - } - if (auto charTy = ty.dyn_cast()) - return fir::characterBitsToTypeCode( - kindMap.getCharacterBitsize(charTy.getFKind())); - return 0; - } - - mlir::LogicalResult genTypeLadderStep(mlir::Location loc, - mlir::Value selector, - mlir::Attribute attr, mlir::Block *dest, - std::optional destOps, - mlir::ModuleOp mod, - mlir::PatternRewriter &rewriter, - fir::KindMapping &kindMap) const { - mlir::Value cmp; - // TYPE IS type guard comparison are all done inlined. - if (auto a = attr.dyn_cast()) { - if (fir::isa_trivial(a.getType()) || - a.getType().isa()) { - // For type guard statement with Intrinsic type spec the type code of - // the descriptor is compared. - int code = getTypeCode(a.getType(), kindMap); - if (code == 0) - return mlir::emitError(loc) - << "type code unavailable for " << a.getType(); - mlir::Value typeCode = rewriter.create( - loc, rewriter.getI8IntegerAttr(code)); - mlir::Value selectorTypeCode = rewriter.create( - loc, rewriter.getI8Type(), selector); - cmp = rewriter.create( - loc, mlir::arith::CmpIPredicate::eq, selectorTypeCode, typeCode); - } else { - // Flang inline the kind parameter in the type descriptor so we can - // directly check if the type descriptor addresses are identical for - // the TYPE IS type guard statement. - mlir::Value res = - genTypeDescCompare(loc, selector, a.getType(), mod, rewriter); - if (!res) - return mlir::failure(); - cmp = res; - } - // CLASS IS type guard statement is done with a runtime call. - } else if (auto a = attr.dyn_cast()) { - // Retrieve the type descriptor from the type guard statement record type. - assert(a.getType().isa() && "expect fir.record type"); - fir::RecordType recTy = a.getType().dyn_cast(); - std::string typeDescName = - fir::NameUniquer::getTypeDescriptorName(recTy.getName()); - auto typeDescGlobal = mod.lookupSymbol(typeDescName); - auto typeDescAddr = rewriter.create( - loc, fir::ReferenceType::get(typeDescGlobal.getType()), - typeDescGlobal.getSymbol()); - mlir::Type typeDescTy = ReferenceType::get(rewriter.getNoneType()); - mlir::Value typeDesc = - rewriter.create(loc, typeDescTy, typeDescAddr); - - // Prepare the selector descriptor for the runtime call. - mlir::Type descNoneTy = fir::BoxType::get(rewriter.getNoneType()); - mlir::Value descSelector = - rewriter.create(loc, descNoneTy, selector); - - // Generate runtime call. - llvm::StringRef fctName = RTNAME_STRING(ClassIs); - mlir::func::FuncOp callee; - { - // Since conversion is done in parallel for each fir.select_type - // operation, the runtime function insertion must be threadsafe. - std::lock_guard lock(*moduleMutex); - callee = - fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName, - rewriter.getFunctionType({descNoneTy, typeDescTy}, - rewriter.getI1Type())); - } - cmp = rewriter - .create(loc, callee, - mlir::ValueRange{descSelector, typeDesc}) - .getResult(0); - } - - auto *thisBlock = rewriter.getInsertionBlock(); - auto *newBlock = - rewriter.createBlock(dest->getParent(), mlir::Region::iterator(dest)); - rewriter.setInsertionPointToEnd(thisBlock); - if (destOps.has_value()) - rewriter.create(loc, cmp, dest, destOps.value(), - newBlock, std::nullopt); - else - rewriter.create(loc, cmp, dest, newBlock); - rewriter.setInsertionPointToEnd(newBlock); - return mlir::success(); - } - -private: - // Mutex used to guard insertion of mlir::func::FuncOp in the module. - std::mutex *moduleMutex; -}; - /// Convert FIR structured control flow ops to CFG ops. class CfgConversion : public fir::impl::CFGConversionBase { public: - mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override { - moduleMutex = new std::mutex(); - return mlir::success(); - } - void runOnOperation() override { auto *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.insert( context, forceLoopToExecuteOnce); - patterns.insert(context, moduleMutex); mlir::ConversionTarget target(*context); target.addLegalDialect(); // apply the patterns - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -588,9 +329,6 @@ signalPassFailure(); } } - -private: - std::mutex *moduleMutex; }; } // namespace diff --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp new file mode 100644 --- /dev/null +++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp @@ -0,0 +1,346 @@ +//===-- PolymorphicOpConversion.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" +#include "flang/Optimizer/Support/FIRContext.h" +#include "flang/Optimizer/Support/InternalNames.h" +#include "flang/Optimizer/Support/KindMapping.h" +#include "flang/Optimizer/Support/TypeCode.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "flang/Runtime/derived-api.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/CommandLine.h" +#include + +namespace fir { +#define GEN_PASS_DEF_POLYMORPHICOPCONVERSION +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace fir; +using namespace mlir; + +namespace { + +/// SelectTypeOp converted to an if-then-else chain +/// +/// This lowers the test conditions to calls into the runtime. +class SelectTypeConv : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + SelectTypeConv(mlir::MLIRContext *ctx, std::mutex *moduleMutex) + : mlir::OpConversionPattern(ctx), + moduleMutex(moduleMutex) {} + + mlir::LogicalResult + matchAndRewrite(fir::SelectTypeOp selectType, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; + +private: + // Generate comparison of type descriptor addresses. + mlir::Value genTypeDescCompare(mlir::Location loc, mlir::Value selector, + mlir::Type ty, mlir::ModuleOp mod, + mlir::PatternRewriter &rewriter) const; + + static int getTypeCode(mlir::Type ty, fir::KindMapping &kindMap); + + mlir::LogicalResult genTypeLadderStep(mlir::Location loc, + mlir::Value selector, + mlir::Attribute attr, mlir::Block *dest, + std::optional destOps, + mlir::ModuleOp mod, + mlir::PatternRewriter &rewriter, + fir::KindMapping &kindMap) const; + + llvm::SmallSet collectAncestors(fir::DispatchTableOp dt, + mlir::ModuleOp mod) const; + + // Mutex used to guard insertion of mlir::func::FuncOp in the module. + std::mutex *moduleMutex; +}; + +/// Convert FIR structured control flow ops to CFG ops. +class PolymorphicOpConversion + : public fir::impl::PolymorphicOpConversionBase { +public: + mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override { + moduleMutex = new std::mutex(); + return mlir::success(); + } + + void runOnOperation() override { + auto *context = &getContext(); + mlir::RewritePatternSet patterns(context); + patterns.insert(context, moduleMutex); + mlir::ConversionTarget target(*context); + target.addLegalDialect(); + + // apply the patterns + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + mlir::emitError(mlir::UnknownLoc::get(context), + "error in converting to CFG\n"); + signalPassFailure(); + } + } + +private: + std::mutex *moduleMutex; +}; +} // namespace + +mlir::LogicalResult SelectTypeConv::matchAndRewrite( + fir::SelectTypeOp selectType, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto operands = adaptor.getOperands(); + auto typeGuards = selectType.getCases(); + unsigned typeGuardNum = typeGuards.size(); + auto selector = selectType.getSelector(); + auto loc = selectType.getLoc(); + auto mod = selectType.getOperation()->getParentOfType(); + fir::KindMapping kindMap = fir::getKindMapping(mod); + + // Order type guards so the condition and branches are done to respect the + // Execution of SELECT TYPE construct as described in the Fortran 2018 + // standard 11.1.11.2 point 4. + // 1. If a TYPE IS type guard statement matches the selector, the block + // following that statement is executed. + // 2. Otherwise, if exactly one CLASS IS type guard statement matches the + // selector, the block following that statement is executed. + // 3. Otherwise, if several CLASS IS type guard statements match the + // selector, one of these statements will inevitably specify a type that + // is an extension of all the types specified in the others; the block + // following that statement is executed. + // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block + // following that statement is executed. + // 5. Otherwise, no block is executed. + + llvm::SmallVector orderedTypeGuards; + llvm::SmallVector orderedClassIsGuards; + unsigned defaultGuard = typeGuardNum - 1; + + // The following loop go through the type guards in the fir.select_type + // operation and sort them into two lists. + // - All the TYPE IS type guard are added in order to the orderedTypeGuards + // list. This list is used at the end to generate the if-then-else ladder. + // - CLASS IS type guard are added in a separate list. If a CLASS IS type + // guard type extends a type already present, the type guard is inserted + // before in the list to respect point 3. above. Otherwise it is just + // added in order at the end. + for (unsigned t = 0; t < typeGuardNum; ++t) { + if (auto a = typeGuards[t].dyn_cast()) { + orderedTypeGuards.push_back(t); + continue; + } + + if (auto a = typeGuards[t].dyn_cast()) { + if (auto recTy = a.getType().dyn_cast()) { + auto dt = mod.lookupSymbol(recTy.getName()); + assert(dt && "dispatch table not found"); + llvm::SmallSet ancestors = + collectAncestors(dt, mod); + if (!ancestors.empty()) { + auto it = orderedClassIsGuards.begin(); + while (it != orderedClassIsGuards.end()) { + fir::SubclassAttr sAttr = + typeGuards[*it].dyn_cast(); + if (auto ty = sAttr.getType().dyn_cast()) { + if (ancestors.contains(ty.getName())) + break; + } + ++it; + } + if (it != orderedClassIsGuards.end()) { + // Parent type is present so place it before. + orderedClassIsGuards.insert(it, t); + continue; + } + } + } + orderedClassIsGuards.push_back(t); + } + } + orderedTypeGuards.append(orderedClassIsGuards); + orderedTypeGuards.push_back(defaultGuard); + assert(orderedTypeGuards.size() == typeGuardNum && + "ordered type guard size doesn't match number of type guards"); + + for (unsigned idx : orderedTypeGuards) { + auto *dest = selectType.getSuccessor(idx); + std::optional destOps = + selectType.getSuccessorOperands(operands, idx); + if (typeGuards[idx].dyn_cast()) + rewriter.replaceOpWithNewOp(selectType, dest); + else if (mlir::failed(genTypeLadderStep(loc, selector, typeGuards[idx], + dest, destOps, mod, rewriter, + kindMap))) + return mlir::failure(); + } + return mlir::success(); +} + +mlir::LogicalResult SelectTypeConv::genTypeLadderStep( + mlir::Location loc, mlir::Value selector, mlir::Attribute attr, + mlir::Block *dest, std::optional destOps, + mlir::ModuleOp mod, mlir::PatternRewriter &rewriter, + fir::KindMapping &kindMap) const { + mlir::Value cmp; + // TYPE IS type guard comparison are all done inlined. + if (auto a = attr.dyn_cast()) { + if (fir::isa_trivial(a.getType()) || + a.getType().isa()) { + // For type guard statement with Intrinsic type spec the type code of + // the descriptor is compared. + int code = getTypeCode(a.getType(), kindMap); + if (code == 0) + return mlir::emitError(loc) + << "type code unavailable for " << a.getType(); + mlir::Value typeCode = rewriter.create( + loc, rewriter.getI8IntegerAttr(code)); + mlir::Value selectorTypeCode = rewriter.create( + loc, rewriter.getI8Type(), selector); + cmp = rewriter.create( + loc, mlir::arith::CmpIPredicate::eq, selectorTypeCode, typeCode); + } else { + // Flang inline the kind parameter in the type descriptor so we can + // directly check if the type descriptor addresses are identical for + // the TYPE IS type guard statement. + mlir::Value res = + genTypeDescCompare(loc, selector, a.getType(), mod, rewriter); + if (!res) + return mlir::failure(); + cmp = res; + } + // CLASS IS type guard statement is done with a runtime call. + } else if (auto a = attr.dyn_cast()) { + // Retrieve the type descriptor from the type guard statement record type. + assert(a.getType().isa() && "expect fir.record type"); + fir::RecordType recTy = a.getType().dyn_cast(); + std::string typeDescName = + fir::NameUniquer::getTypeDescriptorName(recTy.getName()); + auto typeDescGlobal = mod.lookupSymbol(typeDescName); + auto typeDescAddr = rewriter.create( + loc, fir::ReferenceType::get(typeDescGlobal.getType()), + typeDescGlobal.getSymbol()); + mlir::Type typeDescTy = ReferenceType::get(rewriter.getNoneType()); + mlir::Value typeDesc = + rewriter.create(loc, typeDescTy, typeDescAddr); + + // Prepare the selector descriptor for the runtime call. + mlir::Type descNoneTy = fir::BoxType::get(rewriter.getNoneType()); + mlir::Value descSelector = + rewriter.create(loc, descNoneTy, selector); + + // Generate runtime call. + llvm::StringRef fctName = RTNAME_STRING(ClassIs); + mlir::func::FuncOp callee; + { + // Since conversion is done in parallel for each fir.select_type + // operation, the runtime function insertion must be threadsafe. + std::lock_guard lock(*moduleMutex); + callee = + fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName, + rewriter.getFunctionType({descNoneTy, typeDescTy}, + rewriter.getI1Type())); + } + cmp = rewriter + .create(loc, callee, + mlir::ValueRange{descSelector, typeDesc}) + .getResult(0); + } + + auto *thisBlock = rewriter.getInsertionBlock(); + auto *newBlock = + rewriter.createBlock(dest->getParent(), mlir::Region::iterator(dest)); + rewriter.setInsertionPointToEnd(thisBlock); + if (destOps.has_value()) + rewriter.create(loc, cmp, dest, destOps.value(), + newBlock, std::nullopt); + else + rewriter.create(loc, cmp, dest, newBlock); + rewriter.setInsertionPointToEnd(newBlock); + return mlir::success(); +} + +// Generate comparison of type descriptor addresses. +mlir::Value +SelectTypeConv::genTypeDescCompare(mlir::Location loc, mlir::Value selector, + mlir::Type ty, mlir::ModuleOp mod, + mlir::PatternRewriter &rewriter) const { + assert(ty.isa() && "expect fir.record type"); + fir::RecordType recTy = ty.dyn_cast(); + std::string typeDescName = + fir::NameUniquer::getTypeDescriptorName(recTy.getName()); + auto typeDescGlobal = mod.lookupSymbol(typeDescName); + if (!typeDescGlobal) + return {}; + auto typeDescAddr = rewriter.create( + loc, fir::ReferenceType::get(typeDescGlobal.getType()), + typeDescGlobal.getSymbol()); + auto intPtrTy = rewriter.getIndexType(); + mlir::Type tdescType = + fir::TypeDescType::get(mlir::NoneType::get(rewriter.getContext())); + mlir::Value selectorTdescAddr = + rewriter.create(loc, tdescType, selector); + auto typeDescInt = + rewriter.create(loc, intPtrTy, typeDescAddr); + auto selectorTdescInt = + rewriter.create(loc, intPtrTy, selectorTdescAddr); + return rewriter.create( + loc, mlir::arith::CmpIPredicate::eq, typeDescInt, selectorTdescInt); +} + +int SelectTypeConv::getTypeCode(mlir::Type ty, fir::KindMapping &kindMap) { + if (auto intTy = ty.dyn_cast()) + return fir::integerBitsToTypeCode(intTy.getWidth()); + if (auto floatTy = ty.dyn_cast()) + return fir::realBitsToTypeCode(floatTy.getWidth()); + if (auto logicalTy = ty.dyn_cast()) + return fir::logicalBitsToTypeCode( + kindMap.getLogicalBitsize(logicalTy.getFKind())); + if (fir::isa_complex(ty)) { + if (auto cmplxTy = ty.dyn_cast()) + return fir::complexBitsToTypeCode( + cmplxTy.getElementType().cast().getWidth()); + auto cmplxTy = ty.cast(); + return fir::complexBitsToTypeCode( + kindMap.getRealBitsize(cmplxTy.getFKind())); + } + if (auto charTy = ty.dyn_cast()) + return fir::characterBitsToTypeCode( + kindMap.getCharacterBitsize(charTy.getFKind())); + return 0; +} + +llvm::SmallSet +SelectTypeConv::collectAncestors(fir::DispatchTableOp dt, + mlir::ModuleOp mod) const { + llvm::SmallSet ancestors; + if (!dt.getParent().has_value()) + return ancestors; + while (dt.getParent().has_value()) { + ancestors.insert(*dt.getParent()); + dt = mod.lookupSymbol(*dt.getParent()); + } + return ancestors; +} + +std::unique_ptr fir::createPolymorphicOpConversionPass() { + return std::make_unique(); +} diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 --- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 +++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 @@ -39,6 +39,7 @@ ! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd ! CHECK-NEXT: 'func.func' Pipeline +! CHECK-NEXT: PolymorphicOpConversion ! CHECK-NEXT: CFGConversion ! CHECK-NEXT: SCFToControlFlow diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -42,6 +42,7 @@ ! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd ! ALL-NEXT: 'func.func' Pipeline +! ALL-NEXT: PolymorphicOpConversion ! ALL-NEXT: CFGConversion ! ALL-NEXT: SCFToControlFlow diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -42,6 +42,7 @@ // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd // PASSES-NEXT: 'func.func' Pipeline +// PASSES-NEXT: PolymorphicOpConversion // PASSES-NEXT: CFGConversion // PASSES-NEXT: SCFToControlFlow diff --git a/flang/test/Lower/select-type.f90 b/flang/test/Lower/select-type.f90 --- a/flang/test/Lower/select-type.f90 +++ b/flang/test/Lower/select-type.f90 @@ -1,5 +1,5 @@ ! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s -! RUN: bbc -polymorphic-type -emit-fir %s -o - | fir-opt --cfg-conversion | FileCheck --check-prefix=CFG %s +! RUN: bbc -polymorphic-type -emit-fir %s -o - | fir-opt --fir-polymorphic-op | FileCheck --check-prefix=CFG %s module select_type_lower_test type p1 integer :: a