diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -104,6 +104,21 @@ return val; } +/// if a value comes from a fir.rebox, follow the rebox to the original source, +/// of the value, otherwise return the value +static mlir::Value unwrapReboxOp(mlir::Value val) { + // don't support reboxes of reboxes + if (fir::ReboxOp rebox = val.getDefiningOp()) + val = rebox.getBox(); + return val; +} + +/// normalize a value (removing fir.declare and fir.rebox) so that we can +/// more conveniently spot values which came from function arguments +static mlir::Value normaliseVal(mlir::Value val) { + return unwrapFirDeclare(unwrapReboxOp(val)); +} + void LoopVersioningPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); mlir::func::FuncOp func = getOperation(); @@ -112,7 +127,7 @@ /// A structure to hold an argument, the size of the argument and dimension /// information. struct ArgInfo { - mlir::Value *arg; + mlir::Value arg; size_t size; unsigned rank; fir::BoxDimsOp dims[CFI_MAX_RANK]; @@ -138,7 +153,7 @@ else if (auto cty = elementType.dyn_cast()) typeSize = 2 * cty.getEleType(kindMap).getIntOrFloatBitWidth() / 8; if (typeSize) - argsOfInterest.push_back({&arg, typeSize, rank, {}}); + argsOfInterest.push_back({arg, typeSize, rank, {}}); else LLVM_DEBUG(llvm::dbgs() << "Type not supported\n"); } @@ -166,7 +181,9 @@ return; mlir::Value operand = op->getOperand(0); for (auto a : argsOfInterest) { - if (*a.arg == unwrapFirDeclare(operand)) { + if (a.arg == normaliseVal(operand)) { + // use the reboxed value, not the block arg when re-creating the loop: + a.arg = operand; // Only add if it's not already in the list. if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) { return it.arg == a.arg; @@ -211,7 +228,7 @@ for (unsigned i = 0; i < ndims; i++) { mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i); arg.dims[i] = builder.create(loc, idxTy, idxTy, idxTy, - *arg.arg, dimIdx); + arg.arg, dimIdx); } // We only care about lowest order dimension, here. mlir::Value elemSize = @@ -238,11 +255,11 @@ for (auto &arg : op.argsAndDims) { fir::SequenceType::Shape newShape; newShape.push_back(fir::SequenceType::getUnknownExtent()); - auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg->getType()); + auto elementType = fir::unwrapSeqOrBoxedSeqType(arg.arg.getType()); mlir::Type arrTy = fir::SequenceType::get(newShape, elementType); mlir::Type boxArrTy = fir::BoxType::get(arrTy); mlir::Type refArrTy = builder.getRefType(arrTy); - auto carg = builder.create(loc, boxArrTy, *arg.arg); + auto carg = builder.create(loc, boxArrTy, arg.arg); auto caddr = builder.create(loc, refArrTy, carg); auto insPt = builder.saveInsertionPoint(); // Use caddr instead of arg. @@ -254,8 +271,7 @@ // arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x) // where stride is the distance between elements in the dimensions // 0, 1 and 2 or x, y and z. - if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg && - coop->getOperands().size() >= 2) { + if (coop->getOperand(0) == arg.arg && coop->getOperands().size() >= 2) { builder.setInsertionPoint(coop); mlir::Value totalIndex; for (unsigned i = arg.rank - 1; i > 0; i--) { diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir --- a/flang/test/Transforms/loop-versioning.fir +++ b/flang/test/Transforms/loop-versioning.fir @@ -14,6 +14,7 @@ module { func.func @sum1d(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "n"}) { %decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box>) -> !fir.box> + %rebox = fir.rebox %decl : (!fir.box>) -> !fir.box> %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"} %1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"} %cst = arith.constant 0.000000e+00 : f64 @@ -31,7 +32,7 @@ %9 = fir.convert %8 : (i32) -> i64 %c1_i64 = arith.constant 1 : i64 %10 = arith.subi %9, %c1_i64 : i64 - %11 = fir.coordinate_of %decl, %10 : (!fir.box>, i64) -> !fir.ref + %11 = fir.coordinate_of %rebox, %10 : (!fir.box>, i64) -> !fir.ref %12 = fir.load %11 : !fir.ref %13 = arith.addf %7, %12 fastmath : f64 fir.store %13 to %1 : !fir.ref @@ -49,12 +50,13 @@ // CHECK-LABEL: func.func @sum1d( // CHECK-SAME: %[[ARG0:.*]]: !fir.box> {{.*}}) // CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box>) -> !fir.box> +// CHECK: %[[REBOX:.*]] = fir.rebox %[[DECL]] // CHECK: %[[ZERO:.*]] = arith.constant 0 : index -// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}} +// CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[REBOX]], %[[ZERO]] : {{.*}} // CHECK: %[[SIZE:.*]] = arith.constant 8 : index // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[DIMS]]#2, %[[SIZE]] // CHECK: %[[IF_RES:.*]]:2 = fir.if %[[CMP]] -> {{.*}} -// CHECK: %[[NEWARR:.*]] = fir.convert %[[ARG0]] +// CHECK: %[[NEWARR:.*]] = fir.convert %[[REBOX]] // CHECK: %[[BOXADDR:.*]] = fir.box_addr %[[NEWARR]] : {{.*}} -> !fir.ref> // CHECK: %[[LOOP_RES:.*]]:2 = fir.do_loop {{.*}} // CHECK: %[[COORD:.*]] = fir.coordinate_of %[[BOXADDR]], %{{.*}} : (!fir.ref>, index) -> !fir.ref @@ -64,7 +66,7 @@ // CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1 // CHECK: } else { // CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}} -// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box>, i64) -> !fir.ref +// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[REBOX]], %{{.*}} : (!fir.box>, i64) -> !fir.ref // CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref // CHECK: fir.result %{{.*}}, %{{.*}} // CHECK: }