diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -331,24 +331,25 @@ } }; -static void genFreeIfMustFree(mlir::Location loc, - mlir::ConversionPatternRewriter &rewriter, +static void genFreeIfMustFree(mlir::Location loc, fir::FirOpBuilder &builder, mlir::Value var, mlir::Value mustFree) { auto genFree = [&]() { - if (var.getType().isa()) - TODO(loc, "unbox"); - if (!var.getType().isa()) - var = rewriter.create( - loc, fir::HeapType::get(fir::unwrapRefType(var.getType())), var); - rewriter.create(loc, var); + // fir::FreeMemOp operand type must be a fir::HeapType. + mlir::Type heapType = fir::HeapType::get( + hlfir::getFortranElementOrSequenceType(var.getType())); + if (var.getType().isa()) + var = builder.create(loc, heapType, var); + else if (!var.getType().isa()) + var = builder.create(loc, heapType, var); + builder.create(loc, var); }; if (auto cstMustFree = fir::getIntIfConstant(mustFree)) { if (*cstMustFree != 0) genFree(); - // else, nothing to do. + // else, mustFree is false, nothing to do. return; } - TODO(loc, "conditional free"); + builder.genIfThen(loc, mustFree).genThen(genFree).end(); } struct EndAssociateOpConversion @@ -360,7 +361,9 @@ matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { mlir::Location loc = endAssociate->getLoc(); - genFreeIfMustFree(loc, rewriter, adaptor.getVar(), adaptor.getMustFree()); + auto module = endAssociate->getParentOfType(); + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + genFreeIfMustFree(loc, builder, adaptor.getVar(), adaptor.getMustFree()); rewriter.eraseOp(endAssociate); return mlir::success(); } @@ -378,9 +381,11 @@ mlir::Location loc = destroy->getLoc(); mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr()); if (!fir::isa_trivial(bufferizedExpr.getType())) { + auto module = destroy->getParentOfType(); + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getExpr()); mlir::Value firBase = hlfir::Entity(bufferizedExpr).getFirBase(); - genFreeIfMustFree(loc, rewriter, firBase, mustFree); + genFreeIfMustFree(loc, builder, firBase, mustFree); } rewriter.eraseOp(destroy); return mlir::success(); diff --git a/flang/test/HLFIR/associate-codegen.fir b/flang/test/HLFIR/associate-codegen.fir --- a/flang/test/HLFIR/associate-codegen.fir +++ b/flang/test/HLFIR/associate-codegen.fir @@ -79,6 +79,43 @@ // CHECK-NOT: fir.freemem +func.func @test_end_associate_box(%var: !fir.box>) { + %true = arith.constant 1 : i1 + hlfir.end_associate %var, %true : !fir.box>, i1 + return +} +// CHECK-LABEL: func.func @test_end_associate_box( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>) { +// CHECK: %[[VAL_1:.*]] = arith.constant true +// CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_0]] : (!fir.box>) -> !fir.heap> +// CHECK: fir.freemem %[[VAL_2]] : !fir.heap> + + +func.func @test_end_associate_boxchar(%var: !fir.boxchar<2>) { + %true = arith.constant 1 : i1 + hlfir.end_associate %var, %true : !fir.boxchar<2>, i1 + return +} +// CHECK-LABEL: func.func @test_end_associate_boxchar( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.boxchar<2>) { +// CHECK: %[[VAL_1:.*]] = arith.constant true +// CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_0]] : (!fir.boxchar<2>) -> !fir.heap> +// CHECK: fir.freemem %[[VAL_2]] : !fir.heap> + + +func.func @test_end_associate_box_dynamic(%var: !fir.box>, %must_free: i1) { + hlfir.end_associate %var, %must_free : !fir.box>, i1 + return +} +// CHECK-LABEL: func.func @test_end_associate_box_dynamic( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.box>, +// CHECK-SAME: %[[VAL_1:.*]]: i1) { +// CHECK: fir.if %[[VAL_1]] { +// CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_0]] : (!fir.box>) -> !fir.heap> +// CHECK: fir.freemem %[[VAL_2]] : !fir.heap> +// CHECK: } + + func.func private @take_i4(!fir.ref) func.func private @take_r4(!fir.ref) func.func private @take_l4(!fir.ref>)