diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -73,7 +73,6 @@ class LoopVersioningPass : public fir::impl::LoopVersioningBase { - public: void runOnOperation() override; }; @@ -105,6 +104,7 @@ struct ArgInfo { mlir::Value *arg; size_t size; + unsigned rank; fir::BoxDimsOp dims[CFI_MAX_RANK]; }; @@ -114,13 +114,11 @@ mlir::Block::BlockArgListType args = func.getArguments(); mlir::ModuleOp module = func->getParentOfType(); fir::KindMapping kindMap = fir::getKindMapping(module); - mlir::SmallVector argsOfInterest; + mlir::SmallVector argsOfInterest; for (auto &arg : args) { if (auto seqTy = getAsSequenceType(&arg)) { unsigned rank = seqTy.getDimension(); - // Currently limited to 1D or 2D arrays as that seems to give good - // improvement without excessive increase in code-size, etc. - if (rank > 0 && rank < 3 && + if (rank > 0 && seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) { size_t typeSize = 0; mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType()); @@ -130,12 +128,9 @@ else if (auto cty = elementType.dyn_cast()) typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8; if (typeSize) - argsOfInterest.push_back({&arg, typeSize, {}}); + argsOfInterest.push_back({&arg, typeSize, rank, {}}); else LLVM_DEBUG(llvm::dbgs() << "Type not supported\n"); - - } else { - LLVM_DEBUG(llvm::dbgs() << "Too many dimensions\n"); } } } @@ -145,14 +140,14 @@ struct OpsWithArgs { mlir::Operation *op; - mlir::SmallVector argsAndDims; + mlir::SmallVector argsAndDims; }; // Now see if those arguments are used inside any loop. mlir::SmallVector loopsOfInterest; func.walk([&](fir::DoLoopOp loop) { mlir::Block &body = *loop.getBody(); - mlir::SmallVector argsInLoop; + mlir::SmallVector argsInLoop; body.walk([&](fir::CoordinateOp op) { // The current operation could be inside another loop than // the one we're currently processing. Skip it, we'll get @@ -199,16 +194,16 @@ mlir::Value allCompares = nullptr; // Ensure all of the arrays are unit-stride. for (auto &arg : op.argsAndDims) { - - fir::SequenceType seqTy = getAsSequenceType(arg.arg); - unsigned rank = seqTy.getDimension(); - - // We only care about lowest order dimension. - for (unsigned i = 0; i < rank; i++) { + // Fetch all the dimensions of the array, except the last dimension. + // Always fetch the first dimension, however, so set ndims = 1 if + // we have one dim + unsigned ndims = arg.rank; + for (unsigned i = 0; i < ndims; i++) { mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); arg.dims[i] = builder.create(loc, idxTy, idxTy, idxTy, *arg.arg, dimIdx); } + // We only care about lowest order dimension, here. mlir::Value elemSize = builder.createIntegerConstant(loc, idxTy, arg.size); mlir::Value cmp = builder.create( @@ -245,25 +240,41 @@ // Reduce the multi-dimensioned index to a single index. // This is required becase fir arrays do not support multiple dimensions // with unknown dimensions at compile time. + // We then calculate the multidimensional array like this: + // arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x) + // where stride is the distance between elements in the dimensions + // 0, 1 and 2 or x, y and z. if (coop->getOperand(0) == *arg.arg && coop->getOperands().size() >= 2) { builder.setInsertionPoint(coop); - mlir::Value totalIndex = builder.createIntegerConstant(loc, idxTy, 0); - // Operand(1) = array; Operand(2) = index1; Operand(3) = index2 - for (unsigned i = coop->getOperands().size() - 1; i > 1; i--) { + mlir::Value totalIndex; + for (unsigned i = arg.rank - 1; i > 0; i--) { + // Operand(1) = array; Operand(2) = index1; Operand(3) = index2 mlir::Value curIndex = - builder.createConvert(loc, idxTy, coop->getOperand(i)); - // First arg is Operand2, so dims[i-2] is 0-based i-1! + builder.createConvert(loc, idxTy, coop->getOperand(i + 1)); + // Multiply by the stride of this array. Later we'll divide by the + // element size. mlir::Value scale = - builder.createConvert(loc, idxTy, arg.dims[i - 2].getResult(1)); + builder.createConvert(loc, idxTy, arg.dims[i].getResult(2)); + curIndex = + builder.create(loc, scale, curIndex); + totalIndex = (totalIndex) ? builder.create( + loc, curIndex, totalIndex) + : curIndex; + } + mlir::Value elemSize = + builder.createIntegerConstant(loc, idxTy, arg.size); + // This is the lowest dimension - which doesn't need scaling + mlir::Value finalIndex = + builder.createConvert(loc, idxTy, coop->getOperand(1)); + if (totalIndex) { totalIndex = builder.create( - loc, totalIndex, - builder.create(loc, scale, curIndex)); + loc, + builder.create(loc, totalIndex, elemSize), + finalIndex); + } else { + totalIndex = finalIndex; } - totalIndex = builder.create( - loc, totalIndex, - builder.createConvert(loc, idxTy, coop->getOperand(1))); - auto newOp = builder.create( loc, builder.getRefType(elementType), caddr, mlir::ValueRange{totalIndex}); diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir --- a/flang/test/Transforms/loop-versioning.fir +++ b/flang/test/Transforms/loop-versioning.fir @@ -156,8 +156,7 @@ // CHECK: %[[CONV:.*]] = fir.convert %[[Y]] : {{.*}} // CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[CONV]] : {{.*}} // CHECK: fir.do_loop %[[INDEX:.*]] = {{.*}} -// CHECK: %[[IND_PLUS_1:.*]] = arith.addi %{{.*}}, %[[INDEX]] -// CHECK: %[[YADDR:.*]] = fir.coordinate_of %[[BOX_ADDR]], %[[IND_PLUS_1]] +// CHECK: %[[YADDR:.*]] = fir.coordinate_of %[[BOX_ADDR]], %[[INDEX]] // CHECK: %[[YINT:.*]] = fir.load %[[YADDR]] : {{.*}} // CHECK: %[[YINDEX:.*]] = fir.convert %[[YINT]] // CHECK: %[[XADDR:.*]] = fir.array_coor %[[X]] [%{{.*}}] %[[YINDEX]] @@ -269,7 +268,7 @@ // CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[CONV]] // CHECK: %[[RES:.*]] = fir.do_loop {{.*}} { // CHECK: %[[ADDR:.*]] = fir.coordinate_of %[[BOX_ADDR]], %{{.*}} -// CHECK: %45 = fir.load %[[ADDR]] : !fir.ref +// CHECK: %{{.*}} = fir.load %[[ADDR]] : !fir.ref // CHECK: } // CHECK: fir.result %[[RES]] : {{.*}} // CHECK: } else { @@ -355,19 +354,22 @@ // Only inner loop should be verisoned. // CHECK: fir.do_loop // CHECK: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}} +// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}} +// CHECK: %[[ONE:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0]], %[[ONE]] : {{.*}} // CHECK: %[[SIZE:.*]] = arith.constant 8 : index -// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS]]#2, %[[SIZE]] +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS0]]#2, %[[SIZE]] // CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}} // CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]] // CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref> // CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}} // Check the 2D -> 1D coordinate conversion, should have a multiply and a final add. // Some other operations are checked to synch the different parts. -// CHECK: arith.muli %[[DIMS]]#1, {{.*}} -// CHECK: %[[OUTER_IDX:.*]] = arith.addi {{.*}} +// CHECK: %[[OUTER_IDX:.*]] = arith.muli %[[DIMS1]]#2, {{.*}} +// CHECK: %[[ITEMSIZE:.*]] = arith.constant 8 : index // CHECK: %[[INNER_IDX:.*]] = fir.convert {{.*}} -// CHECK: %[[C2D:.*]] = arith.addi %[[OUTER_IDX]], %[[INNER_IDX]] +// CHECK: %[[OUTER_DIV:.*]] = arith.divsi %[[OUTER_IDX]], %[[ITEMSIZE]] +// CHECK: %[[C2D:.*]] = arith.addi %[[OUTER_DIV]], %[[INNER_IDX]] // CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %[[C2D]] : (!fir.ref>, index) -> !fir.ref // CHECK: %{{.*}} = fir.load %[[COORD]] : !fir.ref // CHECK: fir.result %{{.*}}, %{{.*}} @@ -384,4 +386,136 @@ // CHECK: fir.store %[[IF_RES]]#1 to %{{.*}} // CHECK: return +// ----- + +// subroutine sum3d(a, nx, ny, nz) +// real*8 :: a(:, :, :) +// integer :: nx, ny, nz +// real*8 :: sum +// integer :: i, j, k +// sum = 0 +// do k=1,nz +// do j=1,ny +// do i=0,nx +// sum = sum + a(i, j, k) +// end do +// end do +// end do +// end subroutine sum3d + + + func.func @sum3d(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "nx"}, %arg2: !fir.ref {fir.bindc_name = "ny"}, %arg3: !fir.ref {fir.bindc_name = "nz"}) { + %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum3dEi"} + %1 = fir.alloca i32 {bindc_name = "j", uniq_name = "_QMmoduleFsum3dEj"} + %2 = fir.alloca i32 {bindc_name = "k", uniq_name = "_QMmoduleFsum3dEk"} + %3 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum3dEsum"} + %cst = arith.constant 0.000000e+00 : f64 + fir.store %cst to %3 : !fir.ref + %c1_i32 = arith.constant 1 : i32 + %4 = fir.convert %c1_i32 : (i32) -> index + %5 = fir.load %arg3 : !fir.ref + %6 = fir.convert %5 : (i32) -> index + %c1 = arith.constant 1 : index + %7 = fir.convert %4 : (index) -> i32 + %8:2 = fir.do_loop %arg4 = %4 to %6 step %c1 iter_args(%arg5 = %7) -> (index, i32) { + fir.store %arg5 to %2 : !fir.ref + %c1_i32_0 = arith.constant 1 : i32 + %9 = fir.convert %c1_i32_0 : (i32) -> index + %10 = fir.load %arg2 : !fir.ref + %11 = fir.convert %10 : (i32) -> index + %c1_1 = arith.constant 1 : index + %12 = fir.convert %9 : (index) -> i32 + %13:2 = fir.do_loop %arg6 = %9 to %11 step %c1_1 iter_args(%arg7 = %12) -> (index, i32) { + fir.store %arg7 to %1 : !fir.ref + %c0_i32 = arith.constant 0 : i32 + %18 = fir.convert %c0_i32 : (i32) -> index + %19 = fir.load %arg1 : !fir.ref + %20 = fir.convert %19 : (i32) -> index + %c1_2 = arith.constant 1 : index + %21 = fir.convert %18 : (index) -> i32 + %22:2 = fir.do_loop %arg8 = %18 to %20 step %c1_2 iter_args(%arg9 = %21) -> (index, i32) { + fir.store %arg9 to %0 : !fir.ref + %27 = fir.load %3 : !fir.ref + %28 = fir.load %0 : !fir.ref + %29 = fir.convert %28 : (i32) -> i64 + %c1_i64 = arith.constant 1 : i64 + %30 = arith.subi %29, %c1_i64 : i64 + %31 = fir.load %1 : !fir.ref + %32 = fir.convert %31 : (i32) -> i64 + %c1_i64_3 = arith.constant 1 : i64 + %33 = arith.subi %32, %c1_i64_3 : i64 + %34 = fir.load %2 : !fir.ref + %35 = fir.convert %34 : (i32) -> i64 + %c1_i64_4 = arith.constant 1 : i64 + %36 = arith.subi %35, %c1_i64_4 : i64 + %37 = fir.coordinate_of %arg0, %30, %33, %36 : (!fir.box>, i64, i64, i64) -> !fir.ref + %38 = fir.load %37 : !fir.ref + %39 = arith.addf %27, %38 fastmath : f64 + fir.store %39 to %3 : !fir.ref + %40 = arith.addi %arg8, %c1_2 : index + %41 = fir.convert %c1_2 : (index) -> i32 + %42 = fir.load %0 : !fir.ref + %43 = arith.addi %42, %41 : i32 + fir.result %40, %43 : index, i32 + } + fir.store %22#1 to %0 : !fir.ref + %23 = arith.addi %arg6, %c1_1 : index + %24 = fir.convert %c1_1 : (index) -> i32 + %25 = fir.load %1 : !fir.ref + %26 = arith.addi %25, %24 : i32 + fir.result %23, %26 : index, i32 + } + fir.store %13#1 to %1 : !fir.ref + %14 = arith.addi %arg4, %c1 : index + %15 = fir.convert %c1 : (index) -> i32 + %16 = fir.load %2 : !fir.ref + %17 = arith.addi %16, %15 : i32 + fir.result %14, %17 : index, i32 + } + fir.store %8#1 to %2 : !fir.ref + return + } + +// Note this only checks the expected transformation, not the entire generated code: +// CHECK-LABEL: func.func @sum3d( +// CHECK-SAME: %[[ARG0:.*]]: !fir.box> {{.*}}) +// Only inner loop should be verisoned. +// CHECK: fir.do_loop +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}} +// CHECK: %[[ONE:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0]], %[[ONE]] : {{.*}} +// CHECK: %[[TWO:.*]] = arith.constant 2 : index +// CHECK: %[[DIMS2:.*]]:3 = fir.box_dims %[[ARG0]], %[[TWO]] : {{.*}} +// CHECK: %[[SIZE:.*]] = arith.constant 8 : index +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS0]]#2, %[[SIZE]] +// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}} +// CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]] +// CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref> +// CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}} +// Check the 3D -> 1D coordinate conversion, should have a multiply and a final add. +// Some other operations are checked to synch the different parts. +// CHECK: %[[OUTER_IDX:.*]] = arith.muli %[[DIMS2]]#2, {{.*}} +// CHECK: %[[MIDDLE_IDX:.*]] = arith.muli %[[DIMS1]]#2, {{.*}} +// CHECK: %[[MIDDLE_SUM:.*]] = arith.addi %[[MIDDLE_IDX]], %[[OUTER_IDX]] +// CHECK: %[[ITEMSIZE:.*]] = arith.constant 8 : index +// CHECK: %[[INNER_IDX:.*]] = fir.convert {{.*}} +// CHECK: %[[MIDDLE_DIV:.*]] = arith.divsi %[[MIDDLE_SUM]], %[[ITEMSIZE]] +// CHECK: %[[C3D:.*]] = arith.addi %[[MIDDLE_DIV]], %[[INNER_IDX]] +// CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %[[C3D]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %{{.*}} = fir.load %[[COORD]] : !fir.ref +// CHECK: fir.result %{{.*}}, %{{.*}} +// CHECK: } +// CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1 +// CHECK: } else { +// CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}} +// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[ARG0]], %{{.*}} : (!fir.box>, i64, i64, i64) -> !fir.ref +// CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref +// CHECK: fir.result %{{.*}}, %{{.*}} +// CHECK: } +// CHECK fir.result %[[LOOP_RES2]]#0, %[[LOOP_RES2]]#1 +// CHECK: } +// CHECK: fir.store %[[IF_RES]]#1 to %{{.*}} +// CHECK: return + } // End module