diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -105,6 +105,7 @@ struct ArgInfo { mlir::Value *arg; size_t size; + unsigned rank; fir::BoxDimsOp dims[CFI_MAX_RANK]; }; @@ -120,7 +121,7 @@ unsigned rank = seqTy.getDimension(); // Currently limited to 1D or 2D arrays as that seems to give good // improvement without excessive increase in code-size, etc. - if (rank > 0 && rank < 3 && + if (rank > 0 && rank < 4 && seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) { size_t typeSize = 0; mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType()); @@ -130,7 +131,7 @@ else if (auto cty = elementType.dyn_cast()) typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8; if (typeSize) - argsOfInterest.push_back({&arg, typeSize, {}}); + argsOfInterest.push_back({&arg, typeSize, rank, {}}); else LLVM_DEBUG(llvm::dbgs() << "Type not supported\n"); @@ -199,16 +200,13 @@ mlir::Value allCompares = nullptr; // Ensure all of the arrays are unit-stride. for (auto &arg : op.argsAndDims) { - - fir::SequenceType seqTy = getAsSequenceType(arg.arg); - unsigned rank = seqTy.getDimension(); - - // We only care about lowest order dimension. - for (unsigned i = 0; i < rank; i++) { + // Fetch all the dimensions of the array. + for (unsigned i = 0; i < arg.rank; i++) { mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); arg.dims[i] = builder.create(loc, idxTy, idxTy, idxTy, *arg.arg, dimIdx); } + // We only care about lowest order dimension, here. mlir::Value elemSize = builder.createIntegerConstant(loc, idxTy, arg.size); mlir::Value cmp = builder.create( @@ -249,16 +247,19 @@ coop->getOperands().size() >= 2) { builder.setInsertionPoint(coop); mlir::Value totalIndex = builder.createIntegerConstant(loc, idxTy, 0); - // Operand(1) = array; Operand(2) = index1; Operand(3) = index2 - for (unsigned i = coop->getOperands().size() - 1; i > 1; i--) { + for (unsigned i = arg.rank - 1; i > 0; i--) { + // Operand(1) = array; Operand(2) = index1; Operand(3) = index2 mlir::Value curIndex = - builder.createConvert(loc, idxTy, coop->getOperand(i)); - // First arg is Operand2, so dims[i-2] is 0-based i-1! + builder.createConvert(loc, idxTy, coop->getOperand(i + 1)); + // Note we scale by the extent, not the stride, because + // the stride is fixed up later when we index into the overall + // array (so there's an implied multiply by the base elementsize) mlir::Value scale = - builder.createConvert(loc, idxTy, arg.dims[i - 2].getResult(1)); - totalIndex = builder.create( - loc, totalIndex, - builder.create(loc, scale, curIndex)); + builder.createConvert(loc, idxTy, arg.dims[i - 1].getResult(1)); + totalIndex = + builder.create(loc, totalIndex, curIndex); + totalIndex = + builder.create(loc, scale, totalIndex); } totalIndex = builder.create( loc, totalIndex, diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir --- a/flang/test/Transforms/loop-versioning.fir +++ b/flang/test/Transforms/loop-versioning.fir @@ -364,8 +364,7 @@ // CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}} // Check the 2D -> 1D coordinate conversion, should have a multiply and a final add. // Some other operations are checked to synch the different parts. -// CHECK: arith.muli %[[DIMS]]#1, {{.*}} -// CHECK: %[[OUTER_IDX:.*]] = arith.addi {{.*}} +// CHECK: %[[OUTER_IDX:.*]] = arith.muli %[[DIMS]]#1, {{.*}} // CHECK: %[[INNER_IDX:.*]] = fir.convert {{.*}} // CHECK: %[[C2D:.*]] = arith.addi %[[OUTER_IDX]], %[[INNER_IDX]] // CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %[[C2D]] : (!fir.ref>, index) -> !fir.ref @@ -384,4 +383,181 @@ // CHECK: fir.store %[[IF_RES]]#1 to %{{.*}} // CHECK: return +// ----- + +// subroutine sum3d(a, nx, ny, nz) +// real*8 :: a(:, :, :) +// integer :: nx, ny, nz +// real*8 :: sum +// integer :: i, j, k +// sum = 0 +// do k=1,nz +// do j=1,ny +// do i=0,nx +// sum = sum + a(i, j, k) +// end do +// end do +// end do +// end subroutine sum3d + + func.func @sum3d(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "nx"}, %arg2: !fir.ref {fir.bindc_name = "ny"}, %arg3: !fir.ref {fir.bindc_name = "nz"}) { + %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum3dEi"} + %1 = fir.alloca i32 {bindc_name = "j", uniq_name = "_QMmoduleFsum3dEj"} + %2 = fir.alloca i32 {bindc_name = "k", uniq_name = "_QMmoduleFsum3dEk"} + %3 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum3dEsum"} + %cst = arith.constant 0.000000e+00 : f64 + fir.store %cst to %3 : !fir.ref + %c1_i32 = arith.constant 1 : i32 + %4 = fir.convert %c1_i32 : (i32) -> index + %5 = fir.load %arg3 : !fir.ref + %6 = fir.convert %5 : (i32) -> index + %c1 = arith.constant 1 : index + %7 = fir.convert %4 : (index) -> i32 + %8:2 = fir.do_loop %arg4 = %4 to %6 step %c1 iter_args(%arg5 = %7) -> (index, i32) { + fir.store %arg5 to %2 : !fir.ref + %c1_i32_0 = arith.constant 1 : i32 + %9 = fir.convert %c1_i32_0 : (i32) -> index + %10 = fir.load %arg2 : !fir.ref + %11 = fir.convert %10 : (i32) -> index + %c1_1 = arith.constant 1 : index + %12 = fir.convert %9 : (index) -> i32 + %13:2 = fir.do_loop %arg6 = %9 to %11 step %c1_1 iter_args(%arg7 = %12) -> (index, i32) { + fir.store %arg7 to %1 : !fir.ref + %c0_i32 = arith.constant 0 : i32 + %18 = fir.convert %c0_i32 : (i32) -> index + %19 = fir.load %arg1 : !fir.ref + %20 = fir.convert %19 : (i32) -> index + %c1_2 = arith.constant 1 : index + %21 = fir.convert %18 : (index) -> i32 + %c0 = arith.constant 0 : index + %22:3 = fir.box_dims %arg0, %c0 : (!fir.box>, index) -> (index, index, index) + %c1_3 = arith.constant 1 : index + %23:3 = fir.box_dims %arg0, %c1_3 : (!fir.box>, index) -> (index, index, index) + %c2 = arith.constant 2 : index + %24:3 = fir.box_dims %arg0, %c2 : (!fir.box>, index) -> (index, index, index) + %c8 = arith.constant 8 : index + %25 = arith.cmpi eq, %22#2, %c8 : index + %26:2 = fir.if %25 -> (index, i32) { + %31 = fir.convert %arg0 : (!fir.box>) -> !fir.box> + %32 = fir.box_addr %31 : (!fir.box>) -> !fir.ref> + %33:2 = fir.do_loop %arg8 = %18 to %20 step %c1_2 iter_args(%arg9 = %21) -> (index, i32) { + fir.store %arg9 to %0 : !fir.ref + %34 = fir.load %3 : !fir.ref + %35 = fir.load %0 : !fir.ref + %36 = fir.convert %35 : (i32) -> i64 + %c1_i64 = arith.constant 1 : i64 + %37 = arith.subi %36, %c1_i64 : i64 + %38 = fir.load %1 : !fir.ref + %39 = fir.convert %38 : (i32) -> i64 + %c1_i64_4 = arith.constant 1 : i64 + %40 = arith.subi %39, %c1_i64_4 : i64 + %41 = fir.load %2 : !fir.ref + %42 = fir.convert %41 : (i32) -> i64 + %c1_i64_5 = arith.constant 1 : i64 + %43 = arith.subi %42, %c1_i64_5 : i64 + %c0_6 = arith.constant 0 : index + %44 = fir.convert %43 : (i64) -> index + %45 = arith.addi %c0_6, %44 : index + %46 = arith.muli %23#1, %45 : index + %47 = fir.convert %40 : (i64) -> index + %48 = arith.addi %46, %47 : index + %49 = arith.muli %22#1, %48 : index + %50 = fir.convert %37 : (i64) -> index + %51 = arith.addi %49, %50 : index + %52 = fir.coordinate_of %32, %51 : (!fir.ref>, index) -> !fir.ref + %53 = fir.load %52 : !fir.ref + %54 = arith.addf %34, %53 fastmath : f64 + fir.store %54 to %3 : !fir.ref + %55 = arith.addi %arg8, %c1_2 : index + %56 = fir.convert %c1_2 : (index) -> i32 + %57 = fir.load %0 : !fir.ref + %58 = arith.addi %57, %56 : i32 + fir.result %55, %58 : index, i32 + } + fir.result %33#0, %33#1 : index, i32 + } else { + %31:2 = fir.do_loop %arg8 = %18 to %20 step %c1_2 iter_args(%arg9 = %21) -> (index, i32) { + fir.store %arg9 to %0 : !fir.ref + %32 = fir.load %3 : !fir.ref + %33 = fir.load %0 : !fir.ref + %34 = fir.convert %33 : (i32) -> i64 + %c1_i64 = arith.constant 1 : i64 + %35 = arith.subi %34, %c1_i64 : i64 + %36 = fir.load %1 : !fir.ref + %37 = fir.convert %36 : (i32) -> i64 + %c1_i64_4 = arith.constant 1 : i64 + %38 = arith.subi %37, %c1_i64_4 : i64 + %39 = fir.load %2 : !fir.ref + %40 = fir.convert %39 : (i32) -> i64 + %c1_i64_5 = arith.constant 1 : i64 + %41 = arith.subi %40, %c1_i64_5 : i64 + %42 = fir.coordinate_of %arg0, %35, %38, %41 : (!fir.box>, i64, i64, i64) -> !fir.ref + %43 = fir.load %42 : !fir.ref + %44 = arith.addf %32, %43 fastmath : f64 + fir.store %44 to %3 : !fir.ref + %45 = arith.addi %arg8, %c1_2 : index + %46 = fir.convert %c1_2 : (index) -> i32 + %47 = fir.load %0 : !fir.ref + %48 = arith.addi %47, %46 : i32 + fir.result %45, %48 : index, i32 + } + fir.result %31#0, %31#1 : index, i32 + } + fir.store %26#1 to %0 : !fir.ref + %27 = arith.addi %arg6, %c1_1 : index + %28 = fir.convert %c1_1 : (index) -> i32 + %29 = fir.load %1 : !fir.ref + %30 = arith.addi %29, %28 : i32 + fir.result %27, %30 : index, i32 + } + fir.store %13#1 to %1 : !fir.ref + %14 = arith.addi %arg4, %c1 : index + %15 = fir.convert %c1 : (index) -> i32 + %16 = fir.load %2 : !fir.ref + %17 = arith.addi %16, %15 : i32 + fir.result %14, %17 : index, i32 + } + fir.store %8#1 to %2 : !fir.ref + return + } + +// Note this only checks the expected transformation, not the entire generated code: +// CHECK-LABEL: func.func @sum3d( +// CHECK-SAME: %[[ARG0:.*]]: !fir.box> {{.*}}) +// Only inner loop should be verisoned. +// CHECK: fir.do_loop +// CHECK: %[[ZERO:.*]] = arith.constant 0 : index +// CHECK: %[[DIMS0:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}} +// CHECK: %[[ONE:.*]] = arith.constant 1 : index +// CHECK: %[[DIMS1:.*]]:3 = fir.box_dims %[[ARG0]], %[[ONE]] : {{.*}} +// CHECK: %[[TWO:.*]] = arith.constant 2 : index +// CHECK: %[[DIMS2:.*]]:3 = fir.box_dims %[[ARG0]], %[[TWO]] : {{.*}} +// CHECK: %[[SIZE:.*]] = arith.constant 8 : index +// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS0]]#2, %[[SIZE]] +// CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}} +// CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]] +// CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref> +// CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}} +// Check the 3D -> 1D coordinate conversion, should have a multiply and a final add. +// Some other operations are checked to synch the different parts. +// CHECK: %[[OUTER_IDX:.*]] = arith.muli %[[DIMS1]]#1, {{.*}} +// CHECK: %[[MIDDLE_IDX:.*]] = arith.muli %[[DIMS0]]#1, {{.*}} +// CHECK: %[[INNER_IDX:.*]] = fir.convert {{.*}} +// CHECK: %[[C3D:.*]] = arith.addi %[[MIDDLE_IDX]], %[[INNER_IDX]] +// CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %[[C3D]] : (!fir.ref>, index) -> !fir.ref +// CHECK: %{{.*}} = fir.load %[[COORD]] : !fir.ref +// CHECK: fir.result %{{.*}}, %{{.*}} +// CHECK: } +// CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1 +// CHECK: } else { +// CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}} +// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[ARG0]], %{{.*}} : (!fir.box>, i64, i64, i64) -> !fir.ref +// CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref +// CHECK: fir.result %{{.*}}, %{{.*}} +// CHECK: } +// CHECK fir.result %[[LOOP_RES2]]#0, %[[LOOP_RES2]]#1 +// CHECK: } +// CHECK: fir.store %[[IF_RES]]#1 to %{{.*}} +// CHECK: return + } // End module