diff --git a/flang/include/flang/Optimizer/Builder/Character.h b/flang/include/flang/Optimizer/Builder/Character.h --- a/flang/include/flang/Optimizer/Builder/Character.h +++ b/flang/include/flang/Optimizer/Builder/Character.h @@ -187,6 +187,39 @@ mlir::FuncOp getLlvmMemset(FirOpBuilder &builder); mlir::FuncOp getRealloc(FirOpBuilder &builder); +//===----------------------------------------------------------------------===// +// Tools to work with Character dummy procedures +//===----------------------------------------------------------------------===// + +/// Create a tuple type to pass character functions +/// as arguments along their length. The function type set in the tuple is the +/// one provided by \p funcPointerType. +mlir::Type getCharacterProcedureTupleType(mlir::Type funcPointerType); + +/// Is this tuple type holding a character function and its result length ? +bool isCharacterProcedureTuple(mlir::Type type); + +/// Is \p tuple a value holding a character function address and its result +/// length ? +inline bool isCharacterProcedureTuple(mlir::Value tuple) { + return isCharacterProcedureTuple(tuple.getType()); +} + +/// Create a tuple given \p addr and \p len as well as the tuple +/// type \p argTy. \p addr must be any function address, and \p len must be +/// any integer. Converts will be inserted if needed if \addr and \p len +/// types are not the same as the one inside the tuple type \p tupleType. +mlir::Value createCharacterProcedureTuple(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Type tupleType, + mlir::Value addr, mlir::Value len); + +/// Given a tuple containing a character function address and its result length, +/// extract the tuple into a pair of value . +std::pair +extractCharacterProcedureTuple(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Value tuple); + } // namespace fir::factory #endif // FORTRAN_OPTIMIZER_BUILDER_CHARACTER_H diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h --- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h +++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h @@ -68,6 +68,12 @@ /// Attribute to mark Fortran entities with the TARGET attribute. static constexpr llvm::StringRef getTargetAttrName() { return "fir.target"; } +/// Attribute to mark that a function argument is a character dummy procedure. +/// Character dummy procedure have special ABI constraints. +static constexpr llvm::StringRef getCharacterProcedureDummyAttrName() { + return "fir.char_proc"; +} + /// Tell if \p value is: /// - a function argument that has attribute \p attributeName /// - or, the result of fir.alloca/fir.allocamem op that has attribute \p diff --git a/flang/lib/Optimizer/Builder/Character.cpp b/flang/lib/Optimizer/Builder/Character.cpp --- a/flang/lib/Optimizer/Builder/Character.cpp +++ b/flang/lib/Optimizer/Builder/Character.cpp @@ -725,3 +725,51 @@ // Length cannot be deduced from memref. return {}; } + +std::pair +fir::factory::extractCharacterProcedureTuple(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value tuple) { + mlir::TupleType tupleType = tuple.getType().cast(); + mlir::Value addr = builder.create( + loc, tupleType.getType(0), tuple, + builder.getArrayAttr( + {builder.getIntegerAttr(builder.getIndexType(), 0)})); + mlir::Value len = builder.create( + loc, tupleType.getType(1), tuple, + builder.getArrayAttr( + {builder.getIntegerAttr(builder.getIndexType(), 1)})); + return {addr, len}; +} + +mlir::Value fir::factory::createCharacterProcedureTuple( + fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type argTy, + mlir::Value addr, mlir::Value len) { + mlir::TupleType tupleType = argTy.cast(); + addr = builder.createConvert(loc, tupleType.getType(0), addr); + len = builder.createConvert(loc, tupleType.getType(1), len); + mlir::Value tuple = builder.create(loc, tupleType); + tuple = builder.create( + loc, tupleType, tuple, addr, + builder.getArrayAttr( + {builder.getIntegerAttr(builder.getIndexType(), 0)})); + tuple = builder.create( + loc, tupleType, tuple, len, + builder.getArrayAttr( + {builder.getIntegerAttr(builder.getIndexType(), 1)})); + return tuple; +} + +bool fir::factory::isCharacterProcedureTuple(mlir::Type ty) { + mlir::TupleType tuple = ty.dyn_cast(); + return tuple && tuple.size() == 2 && + tuple.getType(0).isa() && + fir::isa_integer(tuple.getType(1)); +} + +mlir::Type +fir::factory::getCharacterProcedureTupleType(mlir::Type funcPointerType) { + mlir::MLIRContext *context = funcPointerType.getContext(); + mlir::Type lenType = mlir::IntegerType::get(context, 64); + return mlir::TupleType::get(context, {funcPointerType, lenType}); +} diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt --- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt +++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt @@ -6,12 +6,14 @@ TargetRewrite.cpp DEPENDS + FIRBuilder FIRDialect FIRSupport FIROptCodeGenPassIncGen CGOpsIncGen LINK_LIBS + FIRBuilder FIRDialect FIRSupport MLIROpenMPToLLVM diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -17,9 +17,11 @@ #include "PassDetail.h" #include "Target.h" #include "flang/Lower/Todo.h" +#include "flang/Optimizer/Builder/Character.h" #include "flang/Optimizer/CodeGen/CodeGen.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" +#include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Support/FIRContext.h" #include "mlir/Transforms/DialectConversion.h" @@ -42,7 +44,8 @@ ReturnAsStore, ReturnType, Split, - Trailing + Trailing, + TrailingCharProc }; FixupTy(Codes code, std::size_t index, std::size_t second = 0) @@ -266,6 +269,41 @@ .template Case([&](mlir::ComplexType cmplx) { rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); }) + .template Case([&](mlir::TupleType tuple) { + if (factory::isCharacterProcedureTuple(tuple)) { + mlir::ModuleOp module = getModule(); + if constexpr (std::is_same_v, fir::CallOp>) { + if (callOp.callee()) { + llvm::StringRef charProcAttr = + fir::getCharacterProcedureDummyAttrName(); + // The charProcAttr attribute is only used as a safety to + // confirm that this is a dummy procedure and should be split. + // It cannot be used to match because attributes are not + // available in case of indirect calls. + auto funcOp = + module.lookupSymbol(*callOp.callee()); + if (funcOp && + !funcOp.template getArgAttrOfType( + index, charProcAttr)) + mlir::emitError(loc, "tuple argument will be split even " + "though it does not have the `" + + charProcAttr + "` attribute"); + } + } + mlir::Type funcPointerType = tuple.getType(0); + mlir::Type lenType = tuple.getType(1); + FirOpBuilder builder(*rewriter, getKindMapping(module)); + auto [funcPointer, len] = + factory::extractCharacterProcedureTuple(builder, loc, oper); + newInTys.push_back(funcPointerType); + newOpers.push_back(funcPointer); + trailingInTys.push_back(lenType); + trailingOpers.push_back(len); + } else { + newInTys.push_back(tuple); + newOpers.push_back(oper); + } + }) .Default([&](mlir::Type ty) { newInTys.push_back(ty); newOpers.push_back(oper); @@ -360,6 +398,14 @@ .Case([&](mlir::ComplexType ty) { lowerComplexSignatureArg(ty, newInTys); }) + .Case([&](mlir::TupleType tuple) { + if (factory::isCharacterProcedureTuple(tuple)) { + newInTys.push_back(tuple.getType(0)); + trailingInTys.push_back(tuple.getType(1)); + } else { + newInTys.push_back(ty); + } + }) .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); } // append trailing input types @@ -394,7 +440,8 @@ return false; } for (auto ty : func.getInputs()) - if ((ty.isa() && !noCharacterConversion) || + if (((ty.isa() || factory::isCharacterProcedureTuple(ty)) && + !noCharacterConversion) || (isa_complex(ty) && !noComplexConversion)) { LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); return false; @@ -476,6 +523,16 @@ else doComplexArg(func, cmplx, newInTys, fixups); }) + .Case([&](mlir::TupleType tuple) { + if (factory::isCharacterProcedureTuple(tuple)) { + fixups.emplace_back(FixupTy::Codes::TrailingCharProc, + newInTys.size(), trailingTys.size()); + newInTys.push_back(tuple.getType(0)); + trailingTys.push_back(tuple.getType(1)); + } else { + newInTys.push_back(ty); + } + }) .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); } @@ -604,6 +661,23 @@ func.getArgument(fixup.index + 1).replaceAllUsesWith(box); func.front().eraseArgument(fixup.index + 1); } break; + case FixupTy::Codes::TrailingCharProc: { + // The FIR character procedure argument tuple has been split into a + // pair of distinct arguments. The first part of the pair appears in + // the original argument position. The second part of the pair is + // appended after all the original arguments. + auto newProcPointerArg = func.front().insertArgument( + fixup.index, newInTys[fixup.index], loc); + auto newLenArg = + func.front().addArgument(trailingTys[fixup.second], loc); + auto tupleType = oldArgTys[fixup.index - offset]; + rewriter->setInsertionPointToStart(&func.front()); + FirOpBuilder builder(*rewriter, getKindMapping(getModule())); + auto tuple = factory::createCharacterProcedureTuple( + builder, loc, tupleType, newProcPointerArg, newLenArg); + func.getArgument(fixup.index + 1).replaceAllUsesWith(tuple); + func.front().eraseArgument(fixup.index + 1); + } break; } } } diff --git a/flang/test/Fir/target-rewrite-char-proc.fir b/flang/test/Fir/target-rewrite-char-proc.fir new file mode 100644 --- /dev/null +++ b/flang/test/Fir/target-rewrite-char-proc.fir @@ -0,0 +1,69 @@ +// Test rewrite of character procedure pointer tuple argument to two different +// arguments: one for the function address, and one for the length. The length +// argument is added after other characters. +// RUN: fir-opt --target-rewrite="target=x86_64-unknown-linux-gnu" %s | FileCheck %s + +// CHECK: func private @takes_char_proc(() -> () {fir.char_proc}, i64) +func private @takes_char_proc(tuple<() -> (), i64> {fir.char_proc}) + +func private @takes_char(!fir.boxchar<1>) +func private @char_proc(!fir.ref>, index) -> !fir.boxchar<1> + +func @_QPcst_len() { + %0 = fir.address_of(@char_proc) : (!fir.ref>, index) -> !fir.boxchar<1> + %c7_i64 = arith.constant 7 : i64 + %1 = fir.convert %0 : ((!fir.ref>, index) -> !fir.boxchar<1>) -> (() -> ()) + %2 = fir.undefined tuple<() -> (), i64> + %3 = fir.insert_value %2, %1, [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64> + %4 = fir.insert_value %3, %c7_i64, [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64> + + // CHECK: %[[PROC_ADDR:.*]] = fir.extract_value %{{.*}}, [0 : index] : (tuple<() -> (), i64>) -> (() -> ()) + // CHECK: %[[LEN:.*]] = fir.extract_value %{{.*}}, [1 : index] : (tuple<() -> (), i64>) -> i64 + // CHECK: fir.call @takes_char_proc(%[[PROC_ADDR]], %[[LEN]]) : (() -> (), i64) -> () + fir.call @takes_char_proc(%4) : (tuple<() -> (), i64>) -> () + return +} + +// CHECK: func @test_dummy_proc_that_takes_dummy_char_proc( +// CHECK-SAME: %[[ARG0:.*]]: () -> ()) { +func @test_dummy_proc_that_takes_dummy_char_proc(%arg0: () -> ()) { + %0 = fir.address_of(@char_proc) : (!fir.ref>, index) -> !fir.boxchar<1> + %c7_i64 = arith.constant 7 : i64 + %1 = fir.convert %0 : ((!fir.ref>, index) -> !fir.boxchar<1>) -> (() -> ()) + %2 = fir.undefined tuple<() -> (), i64> + %3 = fir.insert_value %2, %1, [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64> + %4 = fir.insert_value %3, %c7_i64, [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64> + %5 = fir.convert %arg0 : (() -> ()) -> ((tuple<() -> (), i64>) -> ()) + + // CHECK: %[[ARG_CAST:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> ((() -> (), i64) -> ()) + // CHECK: %[[PROC_ADDR:.*]] = fir.extract_value %4, [0 : index] : (tuple<() -> (), i64>) -> (() -> ()) + // CHECK: %[[PROC_LEN:.*]] = fir.extract_value %4, [1 : index] : (tuple<() -> (), i64>) -> i64 + // CHECK: fir.call %[[ARG_CAST]](%[[PROC_ADDR]], %[[PROC_LEN]]) : (() -> (), i64) -> () + fir.call %5(%4) : (tuple<() -> (), i64>) -> () + return +} + +// CHECK: func @takes_dummy_char_proc_impl( +// CHECK-SAME: %[[PROC_ADDR:.*]]: () -> () {fir.char_proc}, +// CHECK-SAME: %[[C_ADDR:.*]]: !fir.ref>, +// CHECK-SAME: %[[PROC_LEN:.*]]: i64, +// CHECK-SAME: %[[C_LEN:.*]]: i64) { +func @takes_dummy_char_proc_impl(%arg0: tuple<() -> (), i64> {fir.char_proc}, %arg1: !fir.boxchar<1>) { + // CHECK: %[[UNDEF:.*]] = fir.undefined tuple<() -> (), i64> + // CHECK: %[[TUPLE0:.*]] = fir.insert_value %[[UNDEF]], %[[PROC_ADDR]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64> + // CHECK: %[[TUPLE1:.*]] = fir.insert_value %[[TUPLE0]], %[[PROC_LEN]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64> + %0 = fir.alloca !fir.char<1,7> {bindc_name = ".result"} + %1:2 = fir.unboxchar %arg1 : (!fir.boxchar<1>) -> (!fir.ref>, index) + %c5 = arith.constant 5 : index + %2 = fir.emboxchar %1#0, %c5 : (!fir.ref>, index) -> !fir.boxchar<1> + %3 = fir.extract_value %arg0, [0 : index] : (tuple<() -> (), i64>) -> (() -> ()) + %c7_i64 = arith.constant 7 : i64 + %4 = fir.convert %c7_i64 : (i64) -> index + %6 = fir.convert %3 : (() -> ()) -> ((!fir.ref>, index, !fir.boxchar<1>) -> !fir.boxchar<1>) + %7 = fir.call %6(%0, %4, %2) : (!fir.ref>, index, !fir.boxchar<1>) -> !fir.boxchar<1> + %8 = fir.convert %0 : (!fir.ref>) -> !fir.ref> + %9 = fir.emboxchar %8, %4 : (!fir.ref>, index) -> !fir.boxchar<1> + fir.call @takes_char(%9) : (!fir.boxchar<1>) -> () + return +} +