diff --git a/flang/include/flang/Optimizer/CodeGen/CGOps.td b/flang/include/flang/Optimizer/CodeGen/CGOps.td --- a/flang/include/flang/Optimizer/CodeGen/CGOps.td +++ b/flang/include/flang/Optimizer/CodeGen/CGOps.td @@ -176,6 +176,16 @@ let extraClassDeclaration = [{ unsigned getRank(); + + // Shape is optional, but if it exists, it will be at offset 1. + unsigned shapeOffset() { return 1; } + unsigned shiftOffset() { return shapeOffset() + shape().size(); } + unsigned sliceOffset() { return shiftOffset() + shift().size(); } + unsigned subcomponentOffset() { return sliceOffset() + slice().size(); } + unsigned indicesOffset() { + return subcomponentOffset() + subcomponent().size(); + } + unsigned lenParamsOffset() { return indicesOffset() + indices().size(); } }]; } diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -66,6 +66,11 @@ return lowerTy().convertType(ty); } + mlir::Type getVoidPtrType() const { + return mlir::LLVM::LLVMPointerType::get( + mlir::IntegerType::get(&lowerTy().getContext(), 8)); + } + mlir::LLVM::ConstantOp genI32Constant(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter, int value) const { @@ -126,6 +131,17 @@ return rewriter.create(loc, ty, p); } + mlir::Value + loadStrideFromBox(mlir::Location loc, mlir::Value box, unsigned dim, + mlir::ConversionPatternRewriter &rewriter) const { + auto idxTy = lowerTy().indexType(); + auto c0 = genConstantOffset(loc, rewriter, 0); + auto cDims = genConstantOffset(loc, rewriter, kDimsPosInBox); + auto dimValue = genConstantIndex(loc, idxTy, rewriter, dim); + return loadFromOffset(loc, box, c0, cDims, dimValue, kDimStridePos, idxTy, + rewriter); + } + /// Read base address from a fir.box. Returned address has type ty. mlir::Value loadBaseAddrFromBox(mlir::Location loc, mlir::Type ty, mlir::Value box, @@ -188,6 +204,12 @@ return type; } + // Return LLVM type of the base address given the LLVM type + // of the related descriptor (lowered fir.box type). + static mlir::Type getBaseAddrTypeFromBox(mlir::Type type) { + return getBoxEleTy(type, {kAddrPosInBox}); + } + template mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty, mlir::ConversionPatternRewriter &rewriter, @@ -2004,6 +2026,168 @@ } }; +/// XArrayCoor is the address arithmetic on a dynamically shaped, sliced, +/// shifted etc. array. +/// (See the static restriction on coordinate_of.) array_coor determines the +/// coordinate (location) of a specific element. +struct XArrayCoorOpConversion + : public FIROpAndTypeConversion { + using FIROpAndTypeConversion::FIROpAndTypeConversion; + + mlir::LogicalResult + doRewrite(fir::cg::XArrayCoorOp coor, mlir::Type ty, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto loc = coor.getLoc(); + mlir::ValueRange operands = adaptor.getOperands(); + unsigned rank = coor.getRank(); + assert(coor.indices().size() == rank); + assert(coor.shape().empty() || coor.shape().size() == rank); + assert(coor.shift().empty() || coor.shift().size() == rank); + assert(coor.slice().empty() || coor.slice().size() == 3 * rank); + mlir::Type idxTy = lowerTy().indexType(); + mlir::Value one = genConstantIndex(loc, idxTy, rewriter, 1); + mlir::Value prevExt = one; + mlir::Value zero = genConstantIndex(loc, idxTy, rewriter, 0); + mlir::Value offset = zero; + const bool isShifted = !coor.shift().empty(); + const bool isSliced = !coor.slice().empty(); + const bool baseIsBoxed = coor.memref().getType().isa(); + + auto indexOps = coor.indices().begin(); + auto shapeOps = coor.shape().begin(); + auto shiftOps = coor.shift().begin(); + auto sliceOps = coor.slice().begin(); + // For each dimension of the array, generate the offset calculation. + for (unsigned i = 0; i < rank; + ++i, ++indexOps, ++shapeOps, ++shiftOps, sliceOps += 3) { + mlir::Value index = + integerCast(loc, rewriter, idxTy, operands[coor.indicesOffset() + i]); + mlir::Value lb = isShifted ? integerCast(loc, rewriter, idxTy, + operands[coor.shiftOffset() + i]) + : one; + mlir::Value step = one; + bool normalSlice = isSliced; + // Compute zero based index in dimension i of the element, applying + // potential triplets and lower bounds. + if (isSliced) { + mlir::Value ub = *(sliceOps + 1); + normalSlice = !mlir::isa_and_nonnull(ub.getDefiningOp()); + if (normalSlice) + step = integerCast(loc, rewriter, idxTy, *(sliceOps + 2)); + } + auto idx = rewriter.create(loc, idxTy, index, lb); + mlir::Value diff = + rewriter.create(loc, idxTy, idx, step); + if (normalSlice) { + mlir::Value sliceLb = + integerCast(loc, rewriter, idxTy, operands[coor.sliceOffset() + i]); + auto adj = rewriter.create(loc, idxTy, sliceLb, lb); + diff = rewriter.create(loc, idxTy, diff, adj); + } + // Update the offset given the stride and the zero based index `diff` + // that was just computed. + if (baseIsBoxed) { + // Use stride in bytes from the descriptor. + mlir::Value stride = + loadStrideFromBox(loc, adaptor.getOperands()[0], i, rewriter); + auto sc = rewriter.create(loc, idxTy, diff, stride); + offset = rewriter.create(loc, idxTy, sc, offset); + } else { + // Use stride computed at last iteration. + auto sc = rewriter.create(loc, idxTy, diff, prevExt); + offset = rewriter.create(loc, idxTy, sc, offset); + // Compute next stride assuming contiguity of the base array + // (in element number). + auto nextExt = + integerCast(loc, rewriter, idxTy, operands[coor.shapeOffset() + i]); + prevExt = + rewriter.create(loc, idxTy, prevExt, nextExt); + } + } + + // Add computed offset to the base address. + if (baseIsBoxed) { + // Working with byte offsets. The base address is read from the fir.box. + // and need to be casted to i8* to do the pointer arithmetic. + mlir::Type baseTy = + getBaseAddrTypeFromBox(adaptor.getOperands()[0].getType()); + mlir::Value base = + loadBaseAddrFromBox(loc, baseTy, adaptor.getOperands()[0], rewriter); + mlir::Type voidPtrTy = getVoidPtrType(); + base = rewriter.create(loc, voidPtrTy, base); + llvm::SmallVector args{base, offset}; + auto addr = rewriter.create(loc, voidPtrTy, args); + if (coor.subcomponent().empty()) { + rewriter.replaceOpWithNewOp(coor, baseTy, addr); + return success(); + } + auto casted = rewriter.create(loc, baseTy, addr); + args.clear(); + args.push_back(casted); + args.push_back(zero); + if (!coor.lenParams().empty()) { + // If type parameters are present, then we don't want to use a GEPOp + // as below, as the LLVM struct type cannot be statically defined. + TODO(loc, "derived type with type parameters"); + } + // TODO: array offset subcomponents must be converted to LLVM's + // row-major layout here. + for (auto i = coor.subcomponentOffset(); i != coor.indicesOffset(); ++i) + args.push_back(operands[i]); + rewriter.replaceOpWithNewOp(coor, baseTy, args); + return success(); + } + + // The array was not boxed, so it must be contiguous. offset is therefore an + // element offset and the base type is kept in the GEP unless the element + // type size is itself dynamic. + mlir::Value base; + if (coor.subcomponent().empty()) { + // No subcomponent. + if (!coor.lenParams().empty()) { + // Type parameters. Adjust element size explicitly. + auto eleTy = fir::dyn_cast_ptrEleTy(coor.getType()); + assert(eleTy && "result must be a reference-like type"); + if (fir::characterWithDynamicLen(eleTy)) { + assert(coor.lenParams().size() == 1); + auto bitsInChar = lowerTy().getKindMap().getCharacterBitsize( + eleTy.cast().getFKind()); + auto scaling = genConstantIndex(loc, idxTy, rewriter, bitsInChar / 8); + auto scaledBySize = + rewriter.create(loc, idxTy, offset, scaling); + auto length = + integerCast(loc, rewriter, idxTy, + adaptor.getOperands()[coor.lenParamsOffset()]); + offset = rewriter.create(loc, idxTy, scaledBySize, + length); + } else { + TODO(loc, "compute size of derived type with type parameters"); + } + } + // Cast the base address to a pointer to T. + base = rewriter.create(loc, ty, + adaptor.getOperands()[0]); + } else { + // Operand #0 must have a pointer type. For subcomponent slicing, we + // want to cast away the array type and have a plain struct type. + mlir::Type ty0 = adaptor.getOperands()[0].getType(); + auto ptrTy = ty0.dyn_cast(); + assert(ptrTy && "expected pointer type"); + mlir::Type eleTy = ptrTy.getElementType(); + while (auto arrTy = eleTy.dyn_cast()) + eleTy = arrTy.getElementType(); + auto newTy = mlir::LLVM::LLVMPointerType::get(eleTy); + base = rewriter.create(loc, newTy, + adaptor.getOperands()[0]); + } + SmallVector args = {base, offset}; + for (auto i = coor.subcomponentOffset(); i != coor.indicesOffset(); ++i) + args.push_back(operands[i]); + rewriter.replaceOpWithNewOp(coor, ty, args); + return success(); + } +}; + // // Primitive operations on Complex types // @@ -2431,8 +2615,8 @@ ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion, StoreOpConversion, StringLitOpConversion, SubcOpConversion, UnboxCharOpConversion, UnboxProcOpConversion, - UndefOpConversion, UnreachableOpConversion, XEmboxOpConversion, - ZeroOpConversion>(typeConverter); + UndefOpConversion, UnreachableOpConversion, XArrayCoorOpConversion, + XEmboxOpConversion, ZeroOpConversion>(typeConverter); mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -1828,3 +1828,128 @@ // CHECK: %[[BOX10:.*]] = llvm.insertvalue %[[ADDR_BITCAST]], %[[BOX9]][0 : i32] : !llvm.struct<(ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)> // CHECK: llvm.store %[[BOX10]], %[[ALLOCA]] : !llvm.ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>> // CHECK: llvm.call @_QPtest_dt_callee(%1) : (!llvm.ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>>) -> () + +// ----- + +// Test `fircg.ext_array_coor` conversion. + +// Conversion with only shape and indice. + +func @ext_array_coor0(%arg0: !fir.ref>) { + %c0 = arith.constant 0 : i64 + %1 = fircg.ext_array_coor %arg0(%c0) <%c0> : (!fir.ref>, i64, i64) -> !fir.ref + return +} + +// CHECK-LABEL: llvm.func @ext_array_coor0( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[IDX:.*]] = llvm.sub %[[C0]], %[[C1]] : i64 +// CHECK: %[[DIFF0:.*]] = llvm.mul %[[IDX]], %[[C1]] : i64 +// CHECK: %[[SC:.*]] = llvm.mul %[[DIFF0]], %[[C1]] : i64 +// CHECK: %[[OFFSET:.*]] = llvm.add %[[SC]], %[[C0_1]] : i64 +// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : !llvm.ptr to !llvm.ptr +// CHECK: %{{.*}} = llvm.getelementptr %[[BITCAST]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr + +// Conversion with shift and slice. + +func @ext_array_coor1(%arg0: !fir.ref>) { + %c0 = arith.constant 0 : i64 + %1 = fircg.ext_array_coor %arg0(%c0) origin %c0[%c0, %c0, %c0]<%c0> : (!fir.ref>, i64, i64, i64, i64, i64, i64) -> !fir.ref + return +} + +// CHECK-LABEL: llvm.func @ext_array_coor1( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[IDX:.*]] = llvm.sub %[[C0]], %[[C0]] : i64 +// CHECK: %[[DIFF0:.*]] = llvm.mul %[[IDX]], %[[C0]] : i64 +// CHECK: %[[ADJ:.*]] = llvm.sub %[[C0]], %[[C0]] : i64 +// CHECK: %[[DIFF1:.*]] = llvm.add %[[DIFF0]], %[[ADJ]] : i64 +// CHECK: %[[STRIDE:.*]] = llvm.mul %[[DIFF1]], %[[C1]] : i64 +// CHECK: %[[OFFSET:.*]] = llvm.add %[[STRIDE]], %[[C0_1]] : i64 +// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : !llvm.ptr to !llvm.ptr +// CHECK: %{{.*}} = llvm.getelementptr %[[BITCAST]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr + +// Conversion for a dynamic length char. + +func @ext_array_coor2(%arg0: !fir.ref>>) { + %c0 = arith.constant 0 : i64 + %1 = fircg.ext_array_coor %arg0(%c0) <%c0> : (!fir.ref>>, i64, i64) -> !fir.ref + return +} + +// CHECK-LABEL: llvm.func @ext_array_coor2( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr) +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[IDX:.*]] = llvm.sub %[[C0]], %[[C1]] : i64 +// CHECK: %[[DIFF0:.*]] = llvm.mul %[[IDX]], %[[C1]] : i64 +// CHECK: %[[SC:.*]] = llvm.mul %[[DIFF0]], %[[C1]] : i64 +// CHECK: %[[OFFSET:.*]] = llvm.add %[[SC]], %[[C0_1]] : i64 +// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : !llvm.ptr to !llvm.ptr +// CHECK: %{{.*}} = llvm.getelementptr %[[BITCAST]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr + +// Conversion for a `fir.box`. + +func @ext_array_coor3(%arg0: !fir.box>) { + %c0 = arith.constant 0 : i64 + %1 = fircg.ext_array_coor %arg0(%c0) <%c0> : (!fir.box>, i64, i64) -> !fir.ref + return +} + +// CHECK-LABEL: llvm.func @ext_array_coor3( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>>) { +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[IDX:.*]] = llvm.sub %[[C0]], %[[C1]] : i64 +// CHECK: %[[DIFF0:.*]] = llvm.mul %[[IDX]], %[[C1]] : i64 +// CHECK: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[DIMPOSINBOX:.*]] = llvm.mlir.constant(7 : i32) : i32 +// CHECK: %[[DIMOFFSET:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[STRIDPOS:.*]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[GEPSTRIDE:.*]] = llvm.getelementptr %[[ARG0]][%[[C0_2]], %[[DIMPOSINBOX]], %[[DIMOFFSET]], %[[STRIDPOS]]] : (!llvm.ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>>, i32, i32, i64, i32) -> !llvm.ptr +// CHECK: %[[LOADEDSTRIDE:.*]] = llvm.load %[[GEPSTRIDE]] : !llvm.ptr +// CHECK: %[[SC:.*]] = llvm.mul %[[DIFF0]], %[[LOADEDSTRIDE]] : i64 +// CHECK: %[[OFFSET:.*]] = llvm.add %[[SC]], %[[C0_1]] : i64 +// CHECK: %[[C0_3:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[ADDRPOSINBOX:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[GEPADDR:.*]] = llvm.getelementptr %[[ARG0]][%[[C0_3]], %[[ADDRPOSINBOX]]] : (!llvm.ptr, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, i{{.*}}, array<1 x array<3 x i64>>)>>, i32, i32) -> !llvm.ptr> +// CHECK: %[[LOADEDADDR:.*]] = llvm.load %[[GEPADDR]] : !llvm.ptr> +// CHECK: %[[LOADEDADDRBITCAST:.*]] = llvm.bitcast %[[LOADEDADDR]] : !llvm.ptr to !llvm.ptr +// CHECK: %[[GEPADDROFFSET:.*]] = llvm.getelementptr %[[LOADEDADDRBITCAST]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: %{{.*}} = llvm.bitcast %[[GEPADDROFFSET]] : !llvm.ptr to !llvm.ptr + +// Conversion with non zero shift and slice. + +func @ext_array_coor4(%arg0: !fir.ref>) { + %c0 = arith.constant 0 : i64 + %c10 = arith.constant 10 : i64 + %c20 = arith.constant 20 : i64 + %c1 = arith.constant 1 : i64 + %1 = fircg.ext_array_coor %arg0(%c0) origin %c0[%c10, %c20, %c1]<%c1> : (!fir.ref>, i64, i64, i64, i64, i64, i64) -> !fir.ref + return +} + +// CHECK-LABEL: llvm.func @ext_array_coor4( +// CHECK: %[[ARG0:.*]]: !llvm.ptr>) { +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : i64) : i64 +// CHECK: %[[C20:.*]] = llvm.mlir.constant(20 : i64) : i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[IDX:.*]] = llvm.sub %[[C1]], %[[C0]] : i64 +// CHECK: %[[DIFF0:.*]] = llvm.mul %[[IDX]], %[[C1]] : i64 +// CHECK: %[[ADJ:.*]] = llvm.sub %[[C10]], %[[C0]] : i64 +// CHECK: %[[DIFF1:.*]] = llvm.add %[[DIFF0]], %[[ADJ]] : i64 +// CHECK: %[[STRIDE:.*]] = llvm.mul %[[DIFF1]], %[[C1_1]] : i64 +// CHECK: %[[OFFSET:.*]] = llvm.add %[[STRIDE]], %[[C0_1]] : i64 +// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG0]] : !llvm.ptr> to !llvm.ptr +// CHECK: %{{.*}} = llvm.getelementptr %[[BITCAST]][%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr