diff --git a/flang/include/flang/Lower/StatementContext.h b/flang/include/flang/Lower/StatementContext.h --- a/flang/include/flang/Lower/StatementContext.h +++ b/flang/include/flang/Lower/StatementContext.h @@ -35,7 +35,7 @@ ~StatementContext() { if (!cufs.empty()) - finalize(/*popScope=*/true); + finalizeAndPop(); assert(cufs.empty() && "invalid StatementContext destructor call"); } @@ -61,15 +61,29 @@ } } - /// Make cleanup calls. Pop or reset the stack top list. - void finalize(bool popScope = false) { + /// Make cleanup calls. Retain the stack top list for a repeat call. + void finalizeAndKeep() { assert(!cufs.empty() && "invalid finalize statement context"); if (cufs.back()) (*cufs.back())(); - if (popScope) - cufs.pop_back(); - else - cufs.back().reset(); + } + + /// Make cleanup calls. Pop the stack top list. + void finalizeAndPop() { + finalizeAndKeep(); + cufs.pop_back(); + } + + /// Make cleanup calls. Clear the stack top list. + void finalize() { + finalizeAndKeep(); + cufs.back().reset(); + } + + bool workListIsEmpty() const { + return cufs.empty() || llvm::all_of(cufs, [](auto &opt) -> bool { + return !opt.hasValue(); + }); } private: diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1749,8 +1749,11 @@ // Generate a sequence of case value comparisons and branches. auto caseValue = valueList.begin(); auto caseBlock = blockList.begin(); - for (mlir::Attribute attr : attrList) { - if (attr.isa()) { + bool skipFinalization = false; + for (const auto attr : llvm::enumerate(attrList)) { + if (attr.value().isa()) { + if (attrList.size() == 1) + stmtCtx.finalize(); genFIRBranch(*caseBlock++); break; } @@ -1767,16 +1770,30 @@ charHelper.createUnboxChar(rhs); mlir::Value &rhsAddr = rhsVal.first; mlir::Value &rhsLen = rhsVal.second; - return fir::runtime::genCharCompare(*builder, loc, pred, lhsAddr, - lhsLen, rhsAddr, rhsLen); + mlir::Value result = fir::runtime::genCharCompare( + *builder, loc, pred, lhsAddr, lhsLen, rhsAddr, rhsLen); + if (stmtCtx.workListIsEmpty() || skipFinalization) + return result; + if (attr.index() == attrList.size() - 2) { + stmtCtx.finalize(); + return result; + } + fir::IfOp ifOp = builder->create(loc, result, + /*withElseRegion=*/false); + builder->setInsertionPointToStart(&ifOp.getThenRegion().front()); + stmtCtx.finalizeAndKeep(); + builder->setInsertionPointAfter(ifOp); + return result; }; mlir::Block *newBlock = insertBlock(*caseBlock); - if (attr.isa()) { + if (attr.value().isa()) { mlir::Block *newBlock2 = insertBlock(*caseBlock); + skipFinalization = true; mlir::Value cond = genCond(*caseValue++, mlir::arith::CmpIPredicate::sge); genFIRConditionalBranch(cond, newBlock, newBlock2); builder->setInsertionPointToEnd(newBlock); + skipFinalization = false; mlir::Value cond2 = genCond(*caseValue++, mlir::arith::CmpIPredicate::sle); genFIRConditionalBranch(cond2, *caseBlock++, newBlock2); @@ -1784,12 +1801,13 @@ continue; } mlir::arith::CmpIPredicate pred; - if (attr.isa()) { + if (attr.value().isa()) { pred = mlir::arith::CmpIPredicate::eq; - } else if (attr.isa()) { + } else if (attr.value().isa()) { pred = mlir::arith::CmpIPredicate::sge; } else { - assert(attr.isa() && "unexpected predicate"); + assert(attr.value().isa() && + "unexpected predicate"); pred = mlir::arith::CmpIPredicate::sle; } mlir::Value cond = genCond(*caseValue++, pred); @@ -1798,12 +1816,7 @@ } assert(caseValue == valueList.end() && caseBlock == blockList.end() && "select case list mismatch"); - // Clean-up the selector at the end of the construct if it is a temporary - // (which is possible with characters). - mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); - builder->setInsertionPointToEnd(eval.parentConstruct->constructExit->block); - stmtCtx.finalize(); - builder->restoreInsertionPoint(insertPt); + assert(stmtCtx.workListIsEmpty() && "statement context must be empty"); } fir::ExtendedValue diff --git a/flang/lib/Lower/ConvertExpr.cpp b/flang/lib/Lower/ConvertExpr.cpp --- a/flang/lib/Lower/ConvertExpr.cpp +++ b/flang/lib/Lower/ConvertExpr.cpp @@ -3813,7 +3813,7 @@ // be needed afterwards. stmtCtx.pushScope(); [[maybe_unused]] ExtValue loopRes = lowerArrayExpression(expr); - stmtCtx.finalize(/*popScope=*/true); + stmtCtx.finalizeAndPop(); assert(fir::getBase(loopRes)); } @@ -4719,7 +4719,7 @@ /// fir::ResultOp at the end of the innermost loop. void finalizeElementCtx() { if (elementCtx) { - stmtCtx.finalize(/*popScope=*/true); + stmtCtx.finalizeAndPop(); elementCtx = false; } } @@ -6433,7 +6433,7 @@ builder.create(loc, castLen, charLen.value()); } } - stmtCtx.finalize(/*popScope=*/true); + stmtCtx.finalizeAndPop(); builder.create(loc, mem); builder.restoreInsertionPoint(insPt); diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -196,7 +196,7 @@ mlir::ValueRange{cookie}); mlir::Value iostat = call.getResult(0); if (csi.bigUnitIfOp) { - stmtCtx.finalize(/*popScope=*/true); + stmtCtx.finalizeAndPop(); builder.create(loc, iostat); builder.setInsertionPointAfter(csi.bigUnitIfOp); iostat = csi.bigUnitIfOp.getResult(0); diff --git a/flang/test/Lower/select-case-statement.f90 b/flang/test/Lower/select-case-statement.f90 --- a/flang/test/Lower/select-case-statement.f90 +++ b/flang/test/Lower/select-case-statement.f90 @@ -158,54 +158,188 @@ print*, nn end - ! CHECK-LABEL: func @_QPtest_char_temp_selector - subroutine test_char_temp_selector() - ! Test that character selector that are temps are deallocated - ! only after they have been used in the select case comparisons. - interface - function gen_char_temp_selector() - character(:), allocatable :: gen_char_temp_selector - end function - end interface - select case (gen_char_temp_selector()) - case ('case1') - call foo1() - case ('case2') - call foo2() - case ('case3') - call foo3() + ! CHECK-LABEL: func @_QPscharacter1 + subroutine scharacter1(s) + ! CHECK-DAG: %[[V_0:[0-9]+]] = fir.alloca !fir.box>> + character(len=3) :: s + ! CHECK-DAG: %[[V_1:[0-9]+]] = fir.alloca i32 {bindc_name = "n", uniq_name = "_QFscharacter1En"} + ! CHECK: fir.store %c0{{.*}} to %[[V_1]] : !fir.ref + n = 0 + + ! CHECK: %[[V_8:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_9:[0-9]+]] = arith.cmpi sge, %[[V_8]], %c0{{.*}} : i32 + ! CHECK: cond_br %[[V_9]], ^bb1, ^bb15 + ! CHECK: ^bb1: // pred: ^bb0 + if (lge(s,'00')) then + + ! CHECK: %[[V_18:[0-9]+]] = fir.load %[[V_0]] : !fir.ref>>> + ! CHECK: %[[V_20:[0-9]+]] = fir.box_addr %[[V_18]] : (!fir.box>>) -> !fir.heap> + ! CHECK: %[[V_42:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_43:[0-9]+]] = arith.cmpi eq, %[[V_42]], %c0{{.*}} : i32 + ! CHECK: fir.if %[[V_43]] { + ! CHECK: fir.freemem %[[V_20]] : !fir.heap> + ! CHECK: } + ! CHECK: cond_br %[[V_43]], ^bb3, ^bb2 + ! CHECK: ^bb2: // pred: ^bb1 + select case(trim(s)) + case('11') + n = 1 + + case default + continue + + ! CHECK: %[[V_48:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_49:[0-9]+]] = arith.cmpi eq, %[[V_48]], %c0{{.*}} : i32 + ! CHECK: fir.if %[[V_49]] { + ! CHECK: fir.freemem %[[V_20]] : !fir.heap> + ! CHECK: } + ! CHECK: cond_br %[[V_49]], ^bb6, ^bb5 + ! CHECK: ^bb3: // pred: ^bb1 + ! CHECK: fir.store %c1{{.*}} to %[[V_1]] : !fir.ref + ! CHECK: ^bb4: // pred: ^bb13 + ! CHECK: ^bb5: // pred: ^bb2 + case('22') + n = 2 + + ! CHECK: %[[V_54:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_55:[0-9]+]] = arith.cmpi eq, %[[V_54]], %c0{{.*}} : i32 + ! CHECK: fir.if %[[V_55]] { + ! CHECK: fir.freemem %[[V_20]] : !fir.heap> + ! CHECK: } + ! CHECK: cond_br %[[V_55]], ^bb8, ^bb7 + ! CHECK: ^bb6: // pred: ^bb2 + ! CHECK: fir.store %c2{{.*}} to %[[V_1]] : !fir.ref + ! CHECK: ^bb7: // pred: ^bb5 + case('33') + n = 3 + + case('44':'55','66':'77','88':) + n = 4 + ! CHECK: %[[V_60:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_61:[0-9]+]] = arith.cmpi sge, %[[V_60]], %c0{{.*}} : i32 + ! CHECK: cond_br %[[V_61]], ^bb9, ^bb10 + ! CHECK: ^bb8: // pred: ^bb5 + ! CHECK: fir.store %c3{{.*}} to %[[V_1]] : !fir.ref + ! CHECK: ^bb9: // pred: ^bb7 + ! CHECK: %[[V_66:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_67:[0-9]+]] = arith.cmpi sle, %[[V_66]], %c0{{.*}} : i32 + ! CHECK: fir.if %[[V_67]] { + ! CHECK: fir.freemem %[[V_20]] : !fir.heap> + ! CHECK: } + ! CHECK: cond_br %[[V_67]], ^bb14, ^bb10 + ! CHECK: ^bb10: // 2 preds: ^bb7, ^bb9 + ! CHECK: %[[V_72:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_73:[0-9]+]] = arith.cmpi sge, %[[V_72]], %c0{{.*}} : i32 + ! CHECK: cond_br %[[V_73]], ^bb11, ^bb12 + ! CHECK: ^bb11: // pred: ^bb10 + ! CHECK: %[[V_78:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_79:[0-9]+]] = arith.cmpi sle, %[[V_78]], %c0{{.*}} : i32 + ! CHECK: fir.if %[[V_79]] { + ! CHECK: fir.freemem %[[V_20]] : !fir.heap> + ! CHECK: } + ! CHECK: ^bb12: // 2 preds: ^bb10, ^bb11 + ! CHECK: %[[V_84:[0-9]+]] = fir.call @_FortranACharacterCompareScalar1 + ! CHECK: %[[V_85:[0-9]+]] = arith.cmpi sge, %[[V_84]], %c0{{.*}} : i32 + ! CHECK: fir.freemem %[[V_20]] : !fir.heap> + ! CHECK: cond_br %[[V_85]], ^bb14, ^bb13 + ! CHECK: ^bb13: // pred: ^bb12 + ! CHECK: ^bb14: // 3 preds: ^bb9, ^bb11, ^bb12 + ! CHECK: fir.store %c4{{.*}} to %[[V_1]] : !fir.ref + ! CHECK: ^bb15: // 6 preds: ^bb0, ^bb3, ^bb4, ^bb6, ^bb8, ^bb14 + end select + end if + ! CHECK: %[[V_89:[0-9]+]] = fir.load %[[V_1]] : !fir.ref + print*, n + end subroutine + + + ! CHECK-LABEL: func @_QPscharacter2 + subroutine scharacter2(s) + ! CHECK-DAG: %[[V_0:[0-9]+]] = fir.alloca !fir.box>> + ! CHECK: %[[V_1:[0-9]+]] = fir.alloca !fir.box>> + character(len=3) :: s + n = 0 + + ! CHECK: %[[V_12:[0-9]+]] = fir.load %[[V_1]] : !fir.ref>>> + ! CHECK: %[[V_13:[0-9]+]] = fir.box_addr %[[V_12]] : (!fir.box>>) -> !fir.heap> + ! CHECK: fir.freemem %[[V_13]] : !fir.heap> + ! CHECK: br ^bb1 + ! CHECK: ^bb1: // pred: ^bb0 + ! CHECK: br ^bb2 + n = -10 + select case(trim(s)) case default - call foo_default() + n = 9 + end select + print*, n + + ! CHECK: ^bb2: // pred: ^bb1 + ! CHECK: %[[V_28:[0-9]+]] = fir.load %[[V_0]] : !fir.ref>>> + ! CHECK: %[[V_29:[0-9]+]] = fir.box_addr %[[V_28]] : (!fir.box>>) -> !fir.heap> + ! CHECK: fir.freemem %[[V_29]] : !fir.heap> + ! CHECK: br ^bb3 + ! CHECK: ^bb3: // pred: ^bb2 + n = -2 + select case(trim(s)) end select - ! CHECK: %[[VAL_0:.*]] = fir.alloca !fir.box>> {bindc_name = ".result"} - ! CHECK: %[[VAL_1:.*]] = fir.call @_QPgen_char_temp_selector() : () -> !fir.box>> - ! CHECK: fir.save_result %[[VAL_1]] to %[[VAL_0]] : !fir.box>>, !fir.ref>>> - ! CHECK: cond_br %{{.*}}, ^bb2, ^bb1 - ! CHECK: ^bb1: - ! CHECK: cond_br %{{.*}}, ^bb4, ^bb3 - ! CHECK: ^bb2: - ! CHECK: fir.call @_QPfoo1() : () -> () - ! CHECK: br ^bb8 - ! CHECK: ^bb3: - ! CHECK: cond_br %{{.*}}, ^bb6, ^bb5 - ! CHECK: ^bb4: - ! CHECK: fir.call @_QPfoo2() : () -> () - ! CHECK: br ^bb8 - ! CHECK: ^bb5: - ! CHECK: br ^bb7 - ! CHECK: ^bb6: - ! CHECK: fir.call @_QPfoo3() : () -> () - ! CHECK: br ^bb8 - ! CHECK: ^bb7: - ! CHECK: fir.call @_QPfoo_default() : () -> () - ! CHECK: br ^bb8 - ! CHECK: ^bb8: - ! CHECK: %[[VAL_36:.*]] = fir.load %[[VAL_0]] : !fir.ref>>> - ! CHECK: %[[VAL_37:.*]] = fir.box_addr %[[VAL_36]] : (!fir.box>>) -> !fir.heap> - ! CHECK: %[[VAL_38:.*]] = fir.convert %[[VAL_37]] : (!fir.heap>) -> i64 - ! CHECK: %[[VAL_39:.*]] = arith.constant 0 : i64 - ! CHECK: %[[VAL_40:.*]] = arith.cmpi ne, %[[VAL_38]], %[[VAL_39]] : i64 - ! CHECK: fir.if %[[VAL_40]] { - ! CHECK: fir.freemem %[[VAL_37]] - ! CHECK: } + print*, n end subroutine + + ! CHECK-LABEL: main + program p + integer sinteger, v(10) + + n = -10 + do j = 1, 4 + do k = 1, 10 + n = n + 1 + v(k) = sinteger(n) + enddo + ! expected output: 1 1 1 1 1 1 1 1 1 1 + ! 1 2 3 4 4 6 7 7 7 7 + ! 7 7 7 7 7 0 0 0 0 0 + ! 7 7 7 7 7 7 7 7 7 7 + print*, v + enddo + + print* + call slogical(.false.) ! expected output: 0 1 0 3 1 1 3 1 + call slogical(.true.) ! expected output: 0 0 2 3 2 3 2 2 + + print* + call scharacter('aa') ! expected output: 10 + call scharacter('d') ! expected output: 10 + call scharacter('f') ! expected output: -1 + call scharacter('ff') ! expected output: 20 + call scharacter('fff') ! expected output: 20 + call scharacter('ffff') ! expected output: 20 + call scharacter('fffff') ! expected output: -1 + call scharacter('jj') ! expected output: -1 + call scharacter('m') ! expected output: 30 + call scharacter('q') ! expected output: -1 + call scharacter('qq') ! expected output: 40 + call scharacter('qqq') ! expected output: -1 + call scharacter('vv') ! expected output: -1 + call scharacter('xx') ! expected output: 50 + call scharacter('zz') ! expected output: 50 + + print* + call scharacter1('99 ') ! expected output: 4 + call scharacter1('88 ') ! expected output: 4 + call scharacter1('77 ') ! expected output: 4 + call scharacter1('66 ') ! expected output: 4 + call scharacter1('55 ') ! expected output: 4 + call scharacter1('44 ') ! expected output: 4 + call scharacter1('33 ') ! expected output: 3 + call scharacter1('22 ') ! expected output: 2 + call scharacter1('11 ') ! expected output: 1 + call scharacter1('00 ') ! expected output: 0 + call scharacter1('. ') ! expected output: 0 + call scharacter1(' ') ! expected output: 0 + + print* + call scharacter2('99 ') ! expected output: 9 -2 + call scharacter2('22 ') ! expected output: 9 -2 + call scharacter2('. ') ! expected output: 9 -2 + call scharacter2(' ') ! expected output: 9 -2 + end