diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -1266,6 +1266,22 @@ return CFI_attribute_other; } + mlir::Value getCharacterByteSize(mlir::Location loc, + mlir::ConversionPatternRewriter &rewriter, + fir::CharacterType charTy, + mlir::ValueRange lenParams) const { + auto i64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); + mlir::Value size = + genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(charTy)); + if (charTy.hasConstantLen()) + return size; // Length accounted for in the genTypeStrideInBytes GEP. + // Otherwise, multiply the single character size by the length. + assert(!lenParams.empty()); + auto len64 = FIROpConversion::integerCast(loc, rewriter, i64Ty, + lenParams.back()); + return rewriter.create(loc, i64Ty, size, len64); + } + // Get the element size and CFI type code of the boxed value. std::tuple getSizeAndTypeCode( mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, @@ -1286,18 +1302,9 @@ return {genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(boxEleTy)), typeCodeVal}; - if (auto charTy = boxEleTy.dyn_cast()) { - mlir::Value size = - genTypeStrideInBytes(loc, i64Ty, rewriter, this->convertType(charTy)); - if (charTy.getLen() == fir::CharacterType::unknownLen()) { - // Multiply the single character size by the length. - assert(!lenParams.empty()); - auto len64 = FIROpConversion::integerCast(loc, rewriter, i64Ty, - lenParams.back()); - size = rewriter.create(loc, i64Ty, size, len64); - } - return {size, typeCodeVal}; - }; + if (auto charTy = boxEleTy.dyn_cast()) + return {getCharacterByteSize(loc, rewriter, charTy, lenParams), + typeCodeVal}; if (fir::isa_ref_type(boxEleTy)) { auto ptrTy = mlir::LLVM::LLVMPointerType::get( mlir::LLVM::LLVMVoidType::get(rewriter.getContext())); @@ -1691,7 +1698,7 @@ sourceBox = operands[xbox.getSourceBoxOffset()]; sourceBoxType = xbox.getSourceBox().getType(); } - auto [boxTy, dest, eleSize] = consDescriptorPrefix( + auto [boxTy, dest, resultEleSize] = consDescriptorPrefix( xbox, fir::unwrapRefType(xbox.getMemref().getType()), rewriter, xbox.getOutRank(), adaptor.getSubstr(), adaptor.getLenParams(), sourceBox, sourceBoxType); @@ -1720,7 +1727,8 @@ // Adjust the element scaling factor if the element is a dependent type. if (fir::hasDynamicSize(seqEleTy)) { if (auto charTy = seqEleTy.dyn_cast()) { - prevPtrOff = eleSize; + prevPtrOff = + getCharacterByteSize(loc, rewriter, charTy, adaptor.getLenParams()); } else if (seqEleTy.isa()) { // prevPtrOff = ; TODO(loc, "generate call to calculate size of PDT"); @@ -1734,8 +1742,10 @@ const auto hasSubcomp = !xbox.getSubcomponent().empty(); const bool hasSubstr = !xbox.getSubstr().empty(); // Initial element stride that will be use to compute the step in - // each dimension. - mlir::Value prevDimByteStride = eleSize; + // each dimension. Initially, this is the size of the input element. + // Note that when there are no components/substring, the resultEleSize + // that was previously computed matches the input element size. + mlir::Value prevDimByteStride = resultEleSize; if (hasSubcomp) { // We have a subcomponent. The step value needs to be the number of // bytes per element (which is a derived type). diff --git a/flang/test/Fir/embox-substring.fir b/flang/test/Fir/embox-substring.fir --- a/flang/test/Fir/embox-substring.fir +++ b/flang/test/Fir/embox-substring.fir @@ -11,3 +11,28 @@ %3 = fir.embox %addr (%1) [%2] : (!fir.ref>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box>> return } + +// CHARACTER(*) :: C(2) +// CALL DUMP(C(:)(1:1)) +// Test that the resulting stride is based on the input length, not the substring one. +func.func @substring_dyn_base(%base_addr: !fir.ref>>, %base_len: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %2 = fircg.ext_embox %base_addr(%c2)[%c1, %c2, %c1] substr %c0, %c1 typeparams %base_len : (!fir.ref>>, index, index, index, index, index, index, index) -> !fir.box>> + fir.call @dump(%2) : (!fir.box>>) -> () + return +} +func.func private @dump(!fir.box>>) + +// CHECK-LABEL: llvm.func @substring_dyn_base( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) { +// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK: llvm.getelementptr +// CHECK: %[[VAL_28:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]][1] : (!llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_30:.*]] = llvm.ptrtoint %[[VAL_29]] : !llvm.ptr to i64 +// CHECK: %[[VAL_31:.*]] = llvm.mul %[[VAL_30]], %[[VAL_1]] : i64 +// CHECK: %[[VAL_42:.*]] = llvm.mul %[[VAL_31]], %[[VAL_5]] : i64 +// CHECK: %[[VAL_43:.*]] = llvm.insertvalue %[[VAL_42]], %{{.*}}[7, 0, 2] : !llvm.struct<(ptr>>, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>