diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp --- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp @@ -7,6 +7,9 @@ //===----------------------------------------------------------------------===// // This file defines a pass that bufferize hlfir.expr. It translates operations // producing or consuming hlfir.expr into operations operating on memory. +// An hlfir.expr is translated to a tuple +// where cleanupflag is set to true if storage for the expression was allocated +// on the heap. //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/Character.h" @@ -33,6 +36,51 @@ namespace { +static mlir::Value packageBufferizedExpr(mlir::Location loc, + fir::FirOpBuilder &builder, + mlir::Value storage, + mlir::Value mustFree) { + auto tupleType = mlir::TupleType::get( + builder.getContext(), + mlir::TypeRange{storage.getType(), mustFree.getType()}); + auto undef = builder.create(loc, tupleType); + auto insert = builder.create( + loc, tupleType, undef, mustFree, + builder.getArrayAttr( + {builder.getIntegerAttr(builder.getIndexType(), 1)})); + return builder.create( + loc, tupleType, insert, storage, + builder.getArrayAttr( + {builder.getIntegerAttr(builder.getIndexType(), 0)})); +} + +static mlir::Value packageBufferizedExpr(mlir::Location loc, + fir::FirOpBuilder &builder, + mlir::Value storage, bool mustFree) { + mlir::Value mustFreeValue = builder.createBool(loc, mustFree); + return packageBufferizedExpr(loc, builder, storage, mustFreeValue); +} + +static mlir::Value getBufferizedExprStorage(mlir::Value bufferizedExpr) { + auto tupleType = bufferizedExpr.getType().dyn_cast(); + if (!tupleType) + return bufferizedExpr; + if (auto insert = bufferizedExpr.getDefiningOp()) + if (insert.getVal().getType() == tupleType.getType(0)) + return insert.getVal(); + TODO(bufferizedExpr.getLoc(), "general extract storage case"); +} +static mlir::Value getBufferizedExprMustFreeFlag(mlir::Value bufferizedExpr) { + auto tupleType = bufferizedExpr.getType().dyn_cast(); + if (!tupleType) + return bufferizedExpr; + if (auto insert = bufferizedExpr.getDefiningOp()) + if (auto insert0 = insert.getAdt().getDefiningOp()) + if (insert0.getVal().getType() == tupleType.getType(1)) + return insert0.getVal(); + TODO(bufferizedExpr.getLoc(), "general extract storage case"); +} + struct AssignOpConversion : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; explicit AssignOpConversion(mlir::MLIRContext *ctx) @@ -41,7 +89,8 @@ matchAndRewrite(hlfir::AssignOp assign, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - assign, adaptor.getOperands()[0], adaptor.getOperands()[1]); + assign, getBufferizedExprStorage(adaptor.getOperands()[0]), + getBufferizedExprStorage(adaptor.getOperands()[1])); return mlir::success(); } }; @@ -61,21 +110,89 @@ if (adaptor.getStrings().size() > 2) TODO(loc, "codegen of optimized chained concatenation of more than two " "strings"); - hlfir::Entity lhs{adaptor.getStrings()[0]}; - hlfir::Entity rhs{adaptor.getStrings()[1]}; + hlfir::Entity lhs{getBufferizedExprStorage(adaptor.getStrings()[0])}; + hlfir::Entity rhs{getBufferizedExprStorage(adaptor.getStrings()[1])}; auto [lhsExv, c1] = hlfir::translateToExtendedValue(loc, builder, lhs); auto [rhsExv, c2] = hlfir::translateToExtendedValue(loc, builder, rhs); assert(!c1 && !c2 && "expected variables"); fir::ExtendedValue res = fir::factory::CharacterExprHelper{builder, loc}.createConcatenate( *lhsExv.getCharBox(), *rhsExv.getCharBox()); + /// Ensure the memory type is the same as the result type. + mlir::Type addrType = fir::ReferenceType::get( + hlfir::getFortranElementType(concat.getResult().getType())); + mlir::Value cast = builder.createConvert(loc, addrType, fir::getBase(res)); + res = fir::substBase(res, cast); auto hlfirTempRes = hlfir::genDeclare(loc, builder, res, "tmp", fir::FortranVariableFlagsAttr{}); - rewriter.replaceOp(concat, hlfirTempRes); + mlir::Value bufferizedExpr = + packageBufferizedExpr(loc, builder, hlfirTempRes, false); + rewriter.replaceOp(concat, bufferizedExpr); return mlir::success(); } }; +struct AssociateOpConversion + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + explicit AssociateOpConversion(mlir::MLIRContext *ctx) + : mlir::OpConversionPattern{ctx} {} + mlir::LogicalResult + matchAndRewrite(hlfir::AssociateOp associate, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Location loc = associate->getLoc(); + // If this is the last use of the expression value and this is an hlfir.expr + // that was bufferized, re-use the storage. + // Otherwise, create a temp and assign the storage to it. + mlir::Value bufferizedExpr = getBufferizedExprStorage(adaptor.getSource()); + const bool isTrivialValue = fir::isa_trivial(bufferizedExpr.getType()); + + auto replaceWith = [&](mlir::Value hlfirVar, mlir::Value firVar, + mlir::Value flag) { + associate.getResult(0).replaceAllUsesWith(hlfirVar); + associate.getResult(1).replaceAllUsesWith(firVar); + associate.getResult(2).replaceAllUsesWith(flag); + rewriter.replaceOp(associate, {hlfirVar, firVar, flag}); + }; + + if (!isTrivialValue && associate.getSource().hasOneUse()) { + mlir::Value mustFree = getBufferizedExprMustFreeFlag(adaptor.getSource()); + mlir::Value firBase = hlfir::Entity{bufferizedExpr}.getFirBase(); + replaceWith(bufferizedExpr, firBase, mustFree); + return mlir::success(); + } + if (isTrivialValue) { + auto module = associate->getParentOfType(); + fir::FirOpBuilder builder(rewriter, fir::getKindMapping(module)); + auto temp = builder.createTemporary(loc, bufferizedExpr.getType(), + associate.getUniqName()); + builder.create(loc, bufferizedExpr, temp); + mlir::Value mustFree = builder.createBool(loc, false); + replaceWith(temp, temp, mustFree); + return mlir::success(); + } + TODO(loc, "hlfir.associate of hlfir.expr with more than one use"); + } +}; + +struct EndAssociateOpConversion + : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + explicit EndAssociateOpConversion(mlir::MLIRContext *ctx) + : mlir::OpConversionPattern{ctx} {} + mlir::LogicalResult + matchAndRewrite(hlfir::EndAssociateOp endAssociate, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::Value mustFree = adaptor.getMustFree(); + if (auto cstMustFree = fir::factory::getIntIfConstant(mustFree)) + if (*cstMustFree == 0) { + rewriter.eraseOp(endAssociate); + return mlir::success(); // nothing to do. + } + TODO(endAssociate.getLoc(), "conditional free"); + } +}; + class BufferizeHLFIR : public hlfir::impl::BufferizeHLFIRBase { public: void runOnOperation() override { @@ -89,8 +206,10 @@ auto module = this->getOperation(); auto *context = &getContext(); mlir::RewritePatternSet patterns(context); - patterns.insert(context); + patterns.insert(context); mlir::ConversionTarget target(*context); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](mlir::Operation *op) { return llvm::all_of( op->getResultTypes(), diff --git a/flang/test/HLFIR/associate-codegen.fir b/flang/test/HLFIR/associate-codegen.fir new file mode 100644 --- /dev/null +++ b/flang/test/HLFIR/associate-codegen.fir @@ -0,0 +1,85 @@ +// Test hlfir.associate/hlfir.end_associate operation code generation to FIR. + +// RUN: fir-opt %s -bufferize-hlfir | FileCheck %s + +func.func @associate_int() { + %c42_i32 = arith.constant 42 : i32 + %0:3 = hlfir.associate %c42_i32 {uniq_name = "x"} : (i32) -> (!fir.ref, !fir.ref, i1) + fir.call @take_i4(%0#0) : (!fir.ref) -> () + hlfir.end_associate %0#1, %0#2 : !fir.ref, i1 + return +} +// CHECK-LABEL: func.func @associate_int() { +// CHECK: %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "x"} +// CHECK: %[[VAL_1:.*]] = arith.constant 42 : i32 +// CHECK: fir.store %[[VAL_1]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = arith.constant false +// CHECK: fir.call @take_i4(%[[VAL_0]]) : (!fir.ref) -> () +// CHECK-NOT: fir.freemem + + +func.func @associate_real() { + %cst = arith.constant 4.200000e-01 : f32 + %0:3 = hlfir.associate %cst {uniq_name = "x"} : (f32) -> (!fir.ref, !fir.ref, i1) + fir.call @take_r4(%0#0) : (!fir.ref) -> () + hlfir.end_associate %0#1, %0#2 : !fir.ref, i1 + return +} +// CHECK-LABEL: func.func @associate_real() { +// CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "x"} +// CHECK: %[[VAL_1:.*]] = arith.constant 4.200000e-01 : f32 +// CHECK: fir.store %[[VAL_1]] to %[[VAL_0]] : !fir.ref +// CHECK: %[[VAL_2:.*]] = arith.constant false +// CHECK: fir.call @take_r4(%[[VAL_0]]) : (!fir.ref) -> () +// CHECK-NOT: fir.freemem + + +func.func @associate_logical() { + %true = arith.constant true + %0 = fir.convert %true : (i1) -> !fir.logical<4> + %1:3 = hlfir.associate %0 {uniq_name = "x"} : (!fir.logical<4>) -> (!fir.ref>, !fir.ref>, i1) + fir.call @take_l4(%1#0) : (!fir.ref>) -> () + hlfir.end_associate %1#1, %1#2 : !fir.ref>, i1 + return +} +// CHECK-LABEL: func.func @associate_logical() { +// CHECK: %[[VAL_0:.*]] = fir.alloca !fir.logical<4> {bindc_name = "x"} +// CHECK: %[[VAL_1:.*]] = arith.constant true +// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (i1) -> !fir.logical<4> +// CHECK: fir.store %[[VAL_2]] to %[[VAL_0]] : !fir.ref> +// CHECK: %[[VAL_3:.*]] = arith.constant false +// CHECK: fir.call @take_l4(%[[VAL_0]]) : (!fir.ref>) -> () +// CHECK-NOT: fir.freemem + + +func.func @associate_char(%arg0: !fir.boxchar<1> ) { + %0:2 = fir.unboxchar %arg0 : (!fir.boxchar<1>) -> (!fir.ref>, index) + %1:2 = hlfir.declare %0#0 typeparams %0#1 {uniq_name = "x"} : (!fir.ref>, index) -> (!fir.boxchar<1>, !fir.ref>) + %2 = arith.addi %0#1, %0#1 : index + %3 = hlfir.concat %1#0, %1#0 len %2 : (!fir.boxchar<1>, !fir.boxchar<1>, index) -> !hlfir.expr> + %4:3 = hlfir.associate %3 typeparams %2 {uniq_name = "x"} : (!hlfir.expr>, index) -> (!fir.boxchar<1>, !fir.ref>, i1) + fir.call @take_c(%4#0) : (!fir.boxchar<1>) -> () + hlfir.end_associate %4#1, %4#2 : !fir.ref>, i1 + return +} +// CHECK-LABEL: func.func @associate_char( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.boxchar<1>) { +// CHECK: %[[VAL_1:.*]]:2 = fir.unboxchar %[[VAL_0]] : (!fir.boxchar<1>) -> (!fir.ref>, index) +// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]]#0 typeparams %[[VAL_1]]#1 {uniq_name = "x"} : (!fir.ref>, index) -> (!fir.boxchar<1>, !fir.ref>) +// CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]]#1, %[[VAL_1]]#1 : index +// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]]#1, %[[VAL_1]]#1 : index +// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.char<1,?>(%[[VAL_4]] : index) {bindc_name = ".chrtmp"} +// CHECK: fir.call @llvm.memmove.p0.p0.i64 +// CHECK: %[[VAL_21:.*]]:2 = hlfir.declare %[[VAL_5]] typeparams %[[VAL_4]] {uniq_name = "tmp"} : (!fir.ref>, index) -> (!fir.boxchar<1>, !fir.ref>) +// CHECK: %[[VAL_22:.*]] = arith.constant false +// CHECK: %[[VAL_23:.*]] = fir.undefined tuple, i1> +// CHECK: %[[VAL_24:.*]] = fir.insert_value %[[VAL_23]], %[[VAL_22]], [1 : index] : (tuple, i1>, i1) -> tuple, i1> +// CHECK: %[[VAL_25:.*]] = fir.insert_value %[[VAL_24]], %[[VAL_21]]#0, [0 : index] : (tuple, i1>, !fir.boxchar<1>) -> tuple, i1> +// CHECK: fir.call @take_c(%[[VAL_21]]#0) : (!fir.boxchar<1>) -> () +// CHECK-NOT: fir.freemem + + +func.func private @take_i4(!fir.ref) +func.func private @take_r4(!fir.ref) +func.func private @take_l4(!fir.ref>) +func.func private @take_c(!fir.boxchar<1>)