diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -960,7 +960,7 @@ if (ty.getNumResults() == 0) return false; auto resultType = ty.getResult(0); - return resultType.isa(); + return resultType.isa(); } /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -28,6 +28,8 @@ #define DEBUG_TYPE "flang-abstract-result-opt" +using namespace mlir; + namespace fir { namespace { @@ -40,7 +42,7 @@ return fir::BoxType::get(type); return fir::ReferenceType::get(type); }) - .Case([](mlir::Type type) -> mlir::Type { + .Case([](mlir::Type type) -> mlir::Type { return fir::ReferenceType::get(type); }) .Default([](mlir::Type) -> mlir::Type { @@ -75,16 +77,18 @@ shouldBoxResult; } -class CallOpConversion : public mlir::OpRewritePattern { +template +class CallConversion : public mlir::OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) - : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} + using mlir::OpRewritePattern::OpRewritePattern; + + CallConversion(mlir::MLIRContext *context, bool shouldBoxResult) + : OpRewritePattern(context, 1), shouldBoxResult{shouldBoxResult} {} + mlir::LogicalResult - matchAndRewrite(fir::CallOp callOp, - mlir::PatternRewriter &rewriter) const override { - auto loc = callOp.getLoc(); - auto result = callOp->getResult(0); + matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto result = op->getResult(0); if (!result.hasOneUse()) { mlir::emitError(loc, "calls with abstract result must have exactly one user"); @@ -109,50 +113,74 @@ // TODO: This should be generalized for derived types, and it is // architecture and OS dependent. bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); - fir::CallOp newCallOp; + Op newOp; if (isResultBuiltinCPtr) { - auto recTy = result.getType().dyn_cast(); + auto recTy = result.getType().template dyn_cast(); newResultTypes.emplace_back(recTy.getTypeList()[0].second); } - if (callOp.getCallee()) { + + // fir::CallOp specific handling. + if constexpr (std::is_same_v) { + if (op.getCallee()) { + llvm::SmallVector newOperands; + if (!isResultBuiltinCPtr) + newOperands.emplace_back(arg); + newOperands.append(op.getOperands().begin(), op.getOperands().end()); + newOp = rewriter.create(loc, *op.getCallee(), + newResultTypes, newOperands); + } else { + // Indirect calls. + llvm::SmallVector newInputTypes; + if (!isResultBuiltinCPtr) + newInputTypes.emplace_back(argType); + for (auto operand : op.getOperands().drop_front()) + newInputTypes.push_back(operand.getType()); + auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes, + newResultTypes); + + llvm::SmallVector newOperands; + newOperands.push_back( + rewriter.create(loc, newFuncTy, op.getOperand(0))); + if (!isResultBuiltinCPtr) + newOperands.push_back(arg); + newOperands.append(op.getOperands().begin() + 1, + op.getOperands().end()); + newOp = rewriter.create(loc, mlir::SymbolRefAttr{}, + newResultTypes, newOperands); + } + } + + // fir::DispatchOp specific handling. + if constexpr (std::is_same_v) { llvm::SmallVector newOperands; if (!isResultBuiltinCPtr) newOperands.emplace_back(arg); - newOperands.append(callOp.getOperands().begin(), - callOp.getOperands().end()); - newCallOp = rewriter.create(loc, *callOp.getCallee(), - newResultTypes, newOperands); - } else { - // Indirect calls. - llvm::SmallVector newInputTypes; - if (!isResultBuiltinCPtr) - newInputTypes.emplace_back(argType); - for (auto operand : callOp.getOperands().drop_front()) - newInputTypes.push_back(operand.getType()); - auto newFuncTy = mlir::FunctionType::get(callOp.getContext(), - newInputTypes, newResultTypes); + unsigned passArgShift = newOperands.size(); + newOperands.append(op.getOperands().begin() + 1, op.getOperands().end()); - llvm::SmallVector newOperands; - newOperands.push_back(rewriter.create( - loc, newFuncTy, callOp.getOperand(0))); - if (!isResultBuiltinCPtr) - newOperands.push_back(arg); - newOperands.append(callOp.getOperands().begin() + 1, - callOp.getOperands().end()); - newCallOp = rewriter.create(loc, mlir::SymbolRefAttr{}, - newResultTypes, newOperands); + fir::DispatchOp newDispatchOp; + if (op.getPassArgPos()) + newOp = rewriter.create( + loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), + op.getOperands()[0], newOperands, + rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift)); + else + newOp = rewriter.create( + loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), + op.getOperands()[0], newOperands, nullptr); } + if (isResultBuiltinCPtr) { mlir::Value save = saveResult.getMemref(); - auto module = callOp->getParentOfType(); + auto module = op->template getParentOfType(); fir::KindMapping kindMap = fir::getKindMapping(module); FirOpBuilder builder(rewriter, kindMap); mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, save, result.getType()); - rewriter.create(loc, newCallOp->getResult(0), saveAddr); + rewriter.create(loc, newOp->getResult(0), saveAddr); } - callOp->dropAllReferences(); - rewriter.eraseOp(callOp); + op->dropAllReferences(); + rewriter.eraseOp(op); return mlir::success(); } @@ -289,17 +317,11 @@ return true; }); target.addDynamicallyLegalOp([](fir::DispatchOp dispatch) { - if (dispatch->getNumResults() != 1) - return true; - auto resultType = dispatch->getResult(0).getType(); - if (resultType.isa()) { - TODO(dispatch.getLoc(), "dispatchOp with abstract results"); - return false; - } - return true; + return !hasAbstractResult(dispatch.getFunctionType()); }); - patterns.insert(context, shouldBoxResult); + patterns.insert>(context, shouldBoxResult); + patterns.insert>(context, shouldBoxResult); patterns.insert(context); patterns.insert(context, shouldBoxResult); if (mlir::failed( diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir --- a/flang/test/Fir/abstract-results.fir +++ b/flang/test/Fir/abstract-results.fir @@ -244,6 +244,25 @@ // FUNC-BOX: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref } +// FUNC-REF-LABEL: func @dispatch( +// FUNC-REF-SAME: %[[ARG0:.*]]: !fir.class> {fir.bindc_name = "a"} +// FUNC-BOX-LABEL: func @dispatch( +// FUNC-BOX-SAME: %[[ARG0:.*]]: !fir.class> {fir.bindc_name = "a"} +func.func @dispatch(%arg0: !fir.class> {fir.bindc_name = "a"}) { + %buffer = fir.alloca !fir.type + %res = fir.dispatch "ret_array"(%arg0 : !fir.class>) (%arg0 : !fir.class>) -> !fir.type {pass_arg_pos = 0 : i32} + fir.save_result %res to %buffer : !fir.type, !fir.ref> + return + // FUNC-REF: %[[buffer:.*]] = fir.alloca !fir.type + // FUNC-REF: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class>) (%[[buffer]], %[[ARG0]] : !fir.ref>, !fir.class>) {pass_arg_pos = 1 : i32} + // FUNC-REF-NOT: fir.save_result + + // FUNC-BOX: %[[buffer:.*]] = fir.alloca !fir.type + // FUNC-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref>) -> !fir.box> + // FUNC-BOX: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class>) (%[[box]], %[[ARG0]] : !fir.box>, !fir.class>) {pass_arg_pos = 1 : i32} + // FUNC-BOX-NOT: fir.save_result +} + // ------------------------ Test fir.address_of rewrite ------------------------ func.func private @takesfuncarray((i32) -> !fir.array)