diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -105,6 +105,7 @@ struct ArgInfo { mlir::Value *arg; size_t size; + unsigned rank; fir::BoxDimsOp dims[CFI_MAX_RANK]; }; @@ -114,13 +115,13 @@ mlir::Block::BlockArgListType args = func.getArguments(); mlir::ModuleOp module = func->getParentOfType(); fir::KindMapping kindMap = fir::getKindMapping(module); - mlir::SmallVector argsOfInterest; + mlir::SmallVector argsOfInterest; for (auto &arg : args) { if (auto seqTy = getAsSequenceType(&arg)) { unsigned rank = seqTy.getDimension(); // Currently limited to 1D or 2D arrays as that seems to give good // improvement without excessive increase in code-size, etc. - if (rank > 0 && rank < 3 && + if (rank > 0 && rank < 4 && seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) { size_t typeSize = 0; mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType()); @@ -130,7 +131,7 @@ else if (auto cty = elementType.dyn_cast()) typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8; if (typeSize) - argsOfInterest.push_back({&arg, typeSize, {}}); + argsOfInterest.push_back({&arg, typeSize, rank, {}}); else LLVM_DEBUG(llvm::dbgs() << "Type not supported\n"); @@ -145,14 +146,14 @@ struct OpsWithArgs { mlir::Operation *op; - mlir::SmallVector argsAndDims; + mlir::SmallVector argsAndDims; }; // Now see if those arguments are used inside any loop. mlir::SmallVector loopsOfInterest; func.walk([&](fir::DoLoopOp loop) { mlir::Block &body = *loop.getBody(); - mlir::SmallVector argsInLoop; + mlir::SmallVector argsInLoop; body.walk([&](fir::CoordinateOp op) { // The current operation could be inside another loop than // the one we're currently processing. Skip it, we'll get @@ -199,16 +200,16 @@ mlir::Value allCompares = nullptr; // Ensure all of the arrays are unit-stride. for (auto &arg : op.argsAndDims) { - - fir::SequenceType seqTy = getAsSequenceType(arg.arg); - unsigned rank = seqTy.getDimension(); - - // We only care about lowest order dimension. - for (unsigned i = 0; i < rank; i++) { + // Fetch all the dimensions of the array, except the last dimension. + // Always fetch the first dimension, however, so set ndims = 1 if + // we have one dim + unsigned ndims = (arg.rank == 1) ? 1 : arg.rank - 1; + for (unsigned i = 0; i < ndims; i++) { mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); arg.dims[i] = builder.create(loc, idxTy, idxTy, idxTy, *arg.arg, dimIdx); } + // We only care about lowest order dimension, here. mlir::Value elemSize = builder.createIntegerConstant(loc, idxTy, arg.size); mlir::Value cmp = builder.create( @@ -248,21 +249,27 @@ if (coop->getOperand(0) == *arg.arg && coop->getOperands().size() >= 2) { builder.setInsertionPoint(coop); - mlir::Value totalIndex = builder.createIntegerConstant(loc, idxTy, 0); - // Operand(1) = array; Operand(2) = index1; Operand(3) = index2 - for (unsigned i = coop->getOperands().size() - 1; i > 1; i--) { + mlir::Value totalIndex; + for (unsigned i = arg.rank - 1; i > 0; i--) { + // Operand(1) = array; Operand(2) = index1; Operand(3) = index2 mlir::Value curIndex = - builder.createConvert(loc, idxTy, coop->getOperand(i)); - // First arg is Operand2, so dims[i-2] is 0-based i-1! + builder.createConvert(loc, idxTy, coop->getOperand(i + 1)); + // Note we scale by the extent, not the stride, because + // the stride is fixed up later when we index into the overall + // array (so there's an implied multiply by the base elementsize) mlir::Value scale = - builder.createConvert(loc, idxTy, arg.dims[i - 2].getResult(1)); - totalIndex = builder.create( - loc, totalIndex, - builder.create(loc, scale, curIndex)); + builder.createConvert(loc, idxTy, arg.dims[i - 1].getResult(1)); + totalIndex = (totalIndex) ? builder.create( + loc, totalIndex, curIndex) + : curIndex; + totalIndex = + builder.create(loc, scale, totalIndex); } - totalIndex = builder.create( - loc, totalIndex, - builder.createConvert(loc, idxTy, coop->getOperand(1))); + mlir::Value finalIndex = + builder.createConvert(loc, idxTy, coop->getOperand(1)); + totalIndex = (totalIndex) ? builder.create( + loc, totalIndex, finalIndex) + : finalIndex; auto newOp = builder.create( loc, builder.getRefType(elementType), caddr, diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir --- a/flang/test/Transforms/loop-versioning.fir +++ b/flang/test/Transforms/loop-versioning.fir @@ -156,8 +156,7 @@ // CHECK: %[[CONV:.*]] = fir.convert %[[Y]] : {{.*}} // CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[CONV]] : {{.*}} // CHECK: fir.do_loop %[[INDEX:.*]] = {{.*}} -// CHECK: %[[IND_PLUS_1:.*]] = arith.addi %{{.*}}, %[[INDEX]] -// CHECK: %[[YADDR:.*]] = fir.coordinate_of %[[BOX_ADDR]], %[[IND_PLUS_1]] +// CHECK: %[[YADDR:.*]] = fir.coordinate_of %[[BOX_ADDR]], %[[INDEX]] // CHECK: %[[YINT:.*]] = fir.load %[[YADDR]] : {{.*}} // CHECK: %[[YINDEX:.*]] = fir.convert %[[YINT]] // CHECK: %[[XADDR:.*]] = fir.array_coor %[[X]] [%{{.*}}] %[[YINDEX]] @@ -269,7 +268,7 @@ // CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[CONV]] // CHECK: %[[RES:.*]] = fir.do_loop {{.*}} { // CHECK: %[[ADDR:.*]] = fir.coordinate_of %[[BOX_ADDR]], %{{.*}} -// CHECK: %45 = fir.load %[[ADDR]] : !fir.ref +// CHECK: %{{.*}} = fir.load %[[ADDR]] : !fir.ref // CHECK: } // CHECK: fir.result %[[RES]] : {{.*}} // CHECK: } else { @@ -364,8 +363,7 @@ // CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}} // Check the 2D -> 1D coordinate conversion, should have a multiply and a final add. // Some other operations are checked to synch the different parts. -// CHECK: arith.muli %[[DIMS]]#1, {{.*}} -// CHECK: %[[OUTER_IDX:.*]] = arith.addi {{.*}} +// CHECK: %[[OUTER_IDX:.*]] = arith.muli %[[DIMS]]#1, {{.*}} // CHECK: %[[INNER_IDX:.*]] = fir.convert {{.*}} // CHECK: %[[C2D:.*]] = arith.addi %[[OUTER_IDX]], %[[INNER_IDX]] // CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %[[C2D]] : (!fir.ref>, index) -> !fir.ref @@ -384,4 +382,179 @@ // CHECK: fir.store %[[IF_RES]]#1 to %{{.*}} // CHECK: return +// ----- + +// subroutine sum3d(a, nx, ny, nz) +// real*8 :: a(:, :, :) +// integer :: nx, ny, nz +// real*8 :: sum +// integer :: i, j, k +// sum = 0 +// do k=1,nz +// do j=1,ny +// do i=0,nx +// sum = sum + a(i, j, k) +// end do +// end do +// end do +// end subroutine sum3d + + func.func @sum3d(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "nx"}, %arg2: !fir.ref {fir.bindc_name = "ny"}, %arg3: !fir.ref {fir.bindc_name = "nz"}) { + %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum3dEi"} + %1 = fir.alloca i32 {bindc_name = "j", uniq_name = "_QMmoduleFsum3dEj"} + %2 = fir.alloca i32 {bindc_name = "k", uniq_name = "_QMmoduleFsum3dEk"} + %3 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum3dEsum"} + %cst = arith.constant 0.000000e+00 : f64 + fir.store %cst to %3 : !fir.ref + %c1_i32 = arith.constant 1 : i32 + %4 = fir.convert %c1_i32 : (i32) -> index + %5 = fir.load %arg3 : !fir.ref + %6 = fir.convert %5 : (i32) -> index + %c1 = arith.constant 1 : index + %7 = fir.convert %4 : (index) -> i32 + %8:2 = fir.do_loop %arg4 = %4 to %6 step %c1 iter_args(%arg5 = %7) -> (index, i32) { + fir.store %arg5 to %2 : !fir.ref + %c1_i32_0 = arith.constant 1 : i32 + %9 = fir.convert %c1_i32_0 : (i32) -> index + %10 = fir.load %arg2 : !fir.ref + %11 = fir.convert %10 : (i32) -> index + %c1_1 = arith.constant 1 : index + %12 = fir.convert %9 : (index) -> i32 + %13:2 = fir.do_loop %arg6 = %9 to %11 step %c1_1 iter_args(%arg7 = %12) -> (index, i32) { + fir.store %arg7 to %1 : !fir.ref + %c0_i32 = arith.constant 0 : i32 + %18 = fir.convert %c0_i32 : (i32) -> index + %19 = fir.load %arg1 : !fir.ref + %20 = fir.convert %19 : (i32) -> index + %c1_2 = arith.constant 1 : index + %21 = fir.convert %18 : (index) -> i32 + %c0 = arith.constant 0 : index + %22:3 = fir.box_dims %arg0, %c0 : (!fir.box>, index) -> (index, index, index) + %c1_3 = arith.constant 1 : index + %23:3 = fir.box_dims %arg0, %c1_3 : (!fir.box>, index) -> (index, index, index) + %c2 = arith.constant 2 : index + %24:3 = fir.box_dims %arg0, %c2 : (!fir.box>, index) -> (index, index, index) + %c8 = arith.constant 8 : index + %25 = arith.cmpi eq, %22#2, %c8 : index + %26:2 = fir.if %25 -> (index, i32) { + %31 = fir.convert %arg0 : (!fir.box>) -> !fir.box> + %32 = fir.box_addr %31 : (!fir.box>) -> !fir.ref> + %33:2 = fir.do_loop %arg8 = %18 to %20 step %c1_2 iter_args(%arg9 = %21) -> (index, i32) { + fir.store %arg9 to %0 : !fir.ref + %34 = fir.load %3 : !fir.ref + %35 = fir.load %0 : !fir.ref + %36 = fir.convert %35 : (i32) -> i64 + %c1_i64 = arith.constant 1 : i64 + %37 = arith.subi %36, %c1_i64 : i64 + %38 = fir.load %1 : !fir.ref + %39 = fir.convert %38 : (i32) -> i64 + %c1_i64_4 = arith.constant 1 : i64 + %40 = arith.subi %39, %c1_i64_4 : i64 + %41 = fir.load %2 : !fir.ref + %42 = fir.convert %41 : (i32) -> i64 + %c1_i64_5 = arith.constant 1 : i64 + %43 = arith.subi %42, %c1_i64_5 : i64 + %c0_6 = arith.constant 0 : index + %44 = fir.convert %43 : (i64) -> index + %45 = arith.addi %c0_6, %44 : index + %46 = arith.muli %23#1, %45 : index + %47 = fir.convert %40 : (i64) -> index + %48 = arith.addi %46, %47 : index + %49 = arith.muli %22#1, %48 : index + %50 = fir.convert %37 : (i64) -> index + %51 = arith.addi %49, %50 : index + %52 = fir.coordinate_of %32, %51 : (!fir.ref>, index) -> !fir.ref + %53 = fir.load %52 : !fir.ref + %54 = arith.addf %34, %53 fastmath : f64 + fir.store %54 to %3 : !fir.ref + %55 = arith.addi %arg8, %c1_2 : index + %56 = fir.convert %c1_2 : (index) -> i32 + %57 = fir.load %0 : !fir.ref + %58 = arith.addi %57, %56 : i32 + fir.result %55, %58 : index, i32 + } + fir.result %33#0, %33#1 : index, i32 + } else { + %31:2 = fir.do_loop %arg8 = %18 to %20 step %c1_2 iter_args(%arg9 = %21) -> (index, i32) { + fir.store %arg9 to %0 : !fir.ref + %32 = fir.load %3 : !fir.ref + %33 = fir.load %0 : !fir.ref + %34 = fir.convert %33 : (i32) -> i64 + %c1_i64 = arith.constant 1 : i64 + %35 = arith.subi %34, %c1_i64 : i64 + %36 = fir.load %1 : !fir.ref + %37 = fir.convert %36 : (i32) -> i64 + %c1_i64_4 = arith.constant 1 : i64 + %38 = arith.subi %37, %c1_i64_4 : i64 + %39 = fir.load %2 : !fir.ref + %40 = fir.convert %39 : (i32) -> i64 + %c1_i64_5 = arith.constant 1 : i64 + %41 = arith.subi %40, %c1_i64_5 : i64 + %42 = fir.coordinate_of %arg0, %35, %38, %41 : (!fir.box>, i64, i64, i64) -> !fir.ref + %43 = fir.load %42 : !fir.ref + %44 = arith.addf %32, %43 fastmath : f64 + fir.store %44 to %3 : !fir.ref + %45 = arith.addi %arg8, %c1_2 : index + %46 = fir.convert %c1_2 : (index) -> i32 + %47 = fir.load %0 : !fir.ref + %48 = arith.addi %47, %46 : i32 + fir.result %45, %48 : index, i32 + } + fir.result %31#0, %31#1 : index, i32 + } + fir.store %26#1 to %0 : !fir.ref + %27 = arith.addi %arg6, %c1_1 : index + %28 = fir.convert %c1_1 : (index) -> i32 + %29 = fir.load %1 : !fir.ref + %30 = arith.addi %29, %28 : i32 + fir.result %27, %30 : index, i32 + } + fir.store %13#1 to %1 : !fir.ref + %14 = arith.addi %arg4, %c1 : index + %15 = fir.convert %c1 : (index) -> i32 + %16 = fir.load %2 : !fir.ref + %17 = arith.addi %16, %15 : i32 + fir.result %14, %17 : index, i32 + } + fir.store %8#1 to %2 : !fir.ref + return + } + +// Note this only checks the expected transformation, not the entire generated code: +// CHECK-LABEL: func.func @sum3d( +// CHECK-SAME: %[[ARG0:.*]]: !fir.box> {{.*}}) +// Only inner loop should be verisoned. +// CHECK: fir.do_loop +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}} +// CHECK: %[[ONE:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0]], %[[ONE]] : {{.*}} +// CHECK: %[[SIZE:.*]] = arith.constant 8 : index +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS0]]#2, %[[SIZE]] +// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}} +// CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]] +// CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref> +// CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}} +// Check the 3D -> 1D coordinate conversion, should have a multiply and a final add. +// Some other operations are checked to synch the different parts. +// CHECK: %[[OUTER_IDX:.*]] = arith.muli %[[DIMS1]]#1, {{.*}} +// CHECK: %[[MIDDLE_IDX:.*]] = arith.muli %[[DIMS0]]#1, {{.*}} +// CHECK: %[[INNER_IDX:.*]] = fir.convert {{.*}} +// CHECK: %[[C3D:.*]] = arith.addi %[[MIDDLE_IDX]], %[[INNER_IDX]] +// CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %[[C3D]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %{{.*}} = fir.load %[[COORD]] : !fir.ref +// CHECK: fir.result %{{.*}}, %{{.*}} +// CHECK: } +// CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1 +// CHECK: } else { +// CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}} +// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[ARG0]], %{{.*}} : (!fir.box>, i64, i64, i64) -> !fir.ref +// CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref +// CHECK: fir.result %{{.*}}, %{{.*}} +// CHECK: } +// CHECK fir.result %[[LOOP_RES2]]#0, %[[LOOP_RES2]]#1 +// CHECK: } +// CHECK: fir.store %[[IF_RES]]#1 to %{{.*}} +// CHECK: return + } // End module