diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td --- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td +++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td @@ -498,4 +498,29 @@ }]; } +def hlfir_DestroyOp : hlfir_Op<"destroy", []> { + let summary = "Mark the last use of an hlfir.expr"; + let description = [{ + Mark the last use of an hlfir.expr. This will be the point at which the + buffer of an hlfir.expr, if any, will be deallocated if it was heap + allocated. + It is not required to create an hlfir.destroy for and expression created + inside an hlfir.elemental an returned in the the hlfir.yield_element. + The last use of such expression is implicit and an hlfir.destroy could + not be emitted after the hlfir.yield_element since it is a terminator. + + Note that hlfir.destroy are currently generated by Fortran lowering that + has a good view of the expression use contexts, but this will need to be + revisited if any motion of hlfir.expr is done (like CSE) since + transformations should not introduce any hlfir.expr usages after an + hlfir.destroy. + The future will probably be to identify the last use points automatically + in bufferization instead. + }]; + + let arguments = (ins hlfir_ExprType:$expr); + + let assemblyFormat = "$expr attr-dict `:` qualified(type($expr))"; +} + #endif // FORTRAN_DIALECT_HLFIR_OPS diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -267,6 +267,21 @@ } }; +static bool allOtherUsesAreDestroys(mlir::Value value, + mlir::Operation *currentUse) { + for (mlir::Operation *useOp : value.getUsers()) + if (!mlir::isa(useOp) && useOp != currentUse) + return false; + return true; +} + +static void eraseAllUsesInDestroys(mlir::Value value, + mlir::ConversionPatternRewriter &rewriter) { + for (mlir::Operation *useOp : value.getUsers()) + if (mlir::isa(useOp)) + rewriter.eraseOp(useOp); +} + struct AssociateOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; @@ -290,10 +305,16 @@ rewriter.replaceOp(associate, {hlfirVar, firVar, flag}); }; - if (!isTrivialValue && associate.getSource().hasOneUse()) { + if (!isTrivialValue && allOtherUsesAreDestroys(associate.getSource(), + associate.getOperation())) { + // Re-use hlfir.expr buffer if this is the only use of the hlfir.expr + // outside of the hlfir.destroy. Take on the cleaning-up responsibility + // for the related hlfir.end_associate, and erase the hlfir.destroy (if + // any). mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getSource()); mlir::Value firBase = hlfir::Entity{bufferizedExpr}.getFirBase(); replaceWith(bufferizedExpr, firBase, mustFree); + eraseAllUsesInDestroys(associate.getSource(), rewriter); return mlir::success(); } if (isTrivialValue) { @@ -310,6 +331,26 @@ } }; +static void genFreeIfMustFree(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + mlir::Value var, mlir::Value mustFree) { + auto genFree = [&]() { + if (var.getType().isa()) + TODO(loc, "unbox"); + if (!var.getType().isa()) + var = rewriter.create( + loc, fir::HeapType::get(fir::unwrapRefType(var.getType())), var); + rewriter.create(loc, var); + }; + if (auto cstMustFree = fir::getIntIfConstant(mustFree)) { + if (*cstMustFree != 0) + genFree(); + // else, nothing to do. + return; + } + TODO(loc, "conditional free"); +} + struct EndAssociateOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; @@ -318,22 +359,31 @@ mlir::LogicalResult matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - mlir::Value mustFree = adaptor.getMustFree(); mlir::Location loc = endAssociate->getLoc(); + genFreeIfMustFree(loc, rewriter, adaptor.getVar(), adaptor.getMustFree()); rewriter.eraseOp(endAssociate); - auto genFree = [&]() { - mlir::Value var = adaptor.getVar(); - if (var.getType().isa()) - TODO(loc, "unbox"); - rewriter.create(loc, var); - }; - if (auto cstMustFree = fir::getIntIfConstant(mustFree)) { - if (*cstMustFree != 0) - genFree(); - // else, nothing to do. - return mlir::success(); + return mlir::success(); + } +}; + +struct DestroyOpConversion + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + explicit DestroyOpConversion(mlir::MLIRContext *ctx) + : mlir::OpConversionPattern{ctx} {} + mlir::LogicalResult + matchAndRewrite(hlfir::DestroyOp destroy, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // If expr was bufferized on the heap, now is time to deallocate the buffer. + mlir::Location loc = destroy->getLoc(); + mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getExpr()); + if (!fir::isa_trivial(bufferizedExpr.getType())) { + mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getExpr()); + mlir::Value firBase = hlfir::Entity(bufferizedExpr).getFirBase(); + genFreeIfMustFree(loc, rewriter, firBase, mustFree); } - TODO(endAssociate.getLoc(), "conditional free"); + rewriter.eraseOp(destroy); + return mlir::success(); } }; @@ -351,6 +401,14 @@ } }; +/// Was \p value created in the mlir block where \p builder is currently set ? +static bool wasCreatedInCurrentBlock(mlir::Value value, + fir::FirOpBuilder &builder) { + if (mlir::Operation *op = value.getDefiningOp()) + return op->getBlock() == builder.getBlock(); + return false; +} + /// This Listener allows setting both the builder and the rewriter as /// listeners. This is required when a pattern uses a firBuilder helper that /// may create illegal operations that will need to be translated and requires @@ -406,15 +464,26 @@ // the array temporary. An hlfir.as_expr may have been added if the // elemental is a "view" over a variable (e.g parentheses or transpose). if (auto asExpr = elementValue.getDefiningOp()) { - elementValue = hlfir::Entity{asExpr.getVar()}; - if (asExpr->hasOneUse()) + if (asExpr->hasOneUse() && !asExpr.isMove()) { + elementValue = hlfir::Entity{asExpr.getVar()}; rewriter.eraseOp(asExpr); + } } rewriter.eraseOp(yield); // Assign the element value to the temp element for this iteration. auto tempElement = hlfir::getElementAt(loc, builder, temp, oneBasedLoopIndices); builder.create(loc, elementValue, tempElement); + // hlfir.yield_element implicitly marks the end-of-life its operand if + // it is an expression created in the hlfir.elemental (since it is its + // last use and an hlfir.destroy could not be created afterwards) + // Now that this node has been removed and the expression has been used in + // the assign, insert an hlfir.destroy to mark the expression end-of-life. + // If the expression creation allocated a buffer on the heap inside the + // loop, this will ensure the buffer properly deallocated. + if (elementValue.getType().isa() && + wasCreatedInCurrentBlock(elementValue, builder)) + builder.create(loc, elementValue); builder.restoreInsertionPoint(insPt); mlir::Value bufferizedExpr = @@ -437,10 +506,11 @@ auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert(context); + patterns + .insert(context); mlir::ConversionTarget target(*context); target.addIllegalOp>) { + %must_free = arith.constant true + %expr = hlfir.as_expr %arg0 move %must_free: (!fir.ref>, i1) -> !hlfir.expr<100xi32> + hlfir.destroy %expr : !hlfir.expr<100xi32> + return +} +// CHECK-LABEL: func.func @test_move_with_cleanup( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>) { +// CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_0]] : (!fir.ref>) -> !fir.heap> +// CHECK: fir.freemem %[[VAL_5]] : !fir.heap> + +func.func @test_move_no_cleanup(%arg0 : !fir.ref>) { + %must_free = arith.constant false + %expr = hlfir.as_expr %arg0 move %must_free: (!fir.ref>, i1) -> !hlfir.expr<100xi32> + hlfir.destroy %expr : !hlfir.expr<100xi32> + return +} +// CHECK-LABEL: func.func @test_move_no_cleanup( +// CHECK-NOT: fir.freemem +// CHECK: return + +func.func @test_elemental() { + %c100 = arith.constant 100 : index + %c20 = arith.constant 20 : index + %0 = fir.shape %c100 : (index) -> !fir.shape<1> + %3 = hlfir.elemental %0 typeparams %c20 : (!fir.shape<1>, index) -> !hlfir.expr<100x!fir.char<1,20>> { + ^bb0(%i: index): + %buffer = fir.allocmem !fir.char<1,20> + %must_free = arith.constant true + %expr = hlfir.as_expr %buffer move %must_free: (!fir.heap>, i1) -> !hlfir.expr> + hlfir.yield_element %expr : !hlfir.expr> + } + return +} +// CHECK-LABEL: func.func @test_elemental( +// CHECK: fir.do_loop +// CHECK: %[[VAL_9:.*]] = fir.allocmem !fir.char<1,20> +// CHECK: hlfir.assign %[[VAL_9]] to %{{.*}} : !fir.heap>, !fir.ref> +// CHECK: fir.freemem %[[VAL_9]] : !fir.heap> +// CHECK: } +// CHECK: return + +func.func @test_elemental_expr_created_outside_of_loops() { + %buffer = fir.allocmem !fir.char<1,20> + %must_free = arith.constant true + %expr = hlfir.as_expr %buffer move %must_free: (!fir.heap>, i1) -> !hlfir.expr> + %c100 = arith.constant 100 : index + %c20 = arith.constant 20 : index + %0 = fir.shape %c100 : (index) -> !fir.shape<1> + %3 = hlfir.elemental %0 typeparams %c20 : (!fir.shape<1>, index) -> !hlfir.expr<100x!fir.char<1,20>> { + ^bb0(%i: index): + // No freemem should be inserted inside the loops. + hlfir.yield_element %expr : !hlfir.expr> + } + hlfir.destroy %expr : !hlfir.expr> + return +} +// CHECK-LABEL: func.func @test_elemental_expr_created_outside_of_loops() { +// CHECK: %[[VAL_9:.*]] = fir.allocmem !fir.char<1,20> +// CHECK: fir.do_loop +// CHECK: hlfir.assign %[[VAL_9]] to %{{.*}} : !fir.heap>, !fir.ref> +// CHECK-NOT: fir.freemem +// CHECK: } +// CHECK: fir.freemem %[[VAL_9]] : !fir.heap> +// CHECK: return diff --git a/flang/test/HLFIR/destroy.fir b/flang/test/HLFIR/destroy.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/destroy.fir @@ -0,0 +1,11 @@ +// Test hlfir.destroy operation parse, verify (no errors), and unparse. + +// RUN: fir-opt %s | fir-opt | FileCheck %s + +func.func @test(%expr : !hlfir.expr) { + hlfir.destroy %expr : !hlfir.expr + return +} +// CHECK-LABEL: func.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: !hlfir.expr) { +// CHECK: hlfir.destroy %[[VAL_0]] : !hlfir.expr