diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -94,6 +94,16 @@ return argTy.dyn_cast(); } +/// if a value comes from a fir.declare, follow it to the original source, +/// otherwise return the value +static mlir::Value unwrapFirDeclare(mlir::Value val) { + // fir.declare is for source code variables. We don't have declares of + // declares + if (fir::DeclareOp declare = val.getDefiningOp()) + return declare.getMemref(); + return val; +} + void LoopVersioningPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n"); mlir::func::FuncOp func = getOperation(); @@ -154,9 +164,9 @@ // to it later. if (op->getParentOfType() != loop) return; - const mlir::Value &operand = op->getOperand(0); + mlir::Value operand = op->getOperand(0); for (auto a : argsOfInterest) { - if (*a.arg == operand) { + if (*a.arg == unwrapFirDeclare(operand)) { // Only add if it's not already in the list. if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) { return it.arg == a.arg; @@ -244,7 +254,7 @@ // arr(x, y, z) bedcomes arr(z * stride(2) + y * stride(1) + x) // where stride is the distance between elements in the dimensions // 0, 1 and 2 or x, y and z. - if (coop->getOperand(0) == *arg.arg && + if (unwrapFirDeclare(coop->getOperand(0)) == *arg.arg && coop->getOperands().size() >= 2) { builder.setInsertionPoint(coop); mlir::Value totalIndex; diff --git a/flang/test/Transforms/loop-versioning.fir b/flang/test/Transforms/loop-versioning.fir --- a/flang/test/Transforms/loop-versioning.fir +++ b/flang/test/Transforms/loop-versioning.fir @@ -13,6 +13,7 @@ // end subroutine sum1d module { func.func @sum1d(%arg0: !fir.box> {fir.bindc_name = "a"}, %arg1: !fir.ref {fir.bindc_name = "n"}) { + %decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box>) -> !fir.box> %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"} %1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"} %cst = arith.constant 0.000000e+00 : f64 @@ -30,7 +31,7 @@ %9 = fir.convert %8 : (i32) -> i64 %c1_i64 = arith.constant 1 : i64 %10 = arith.subi %9, %c1_i64 : i64 - %11 = fir.coordinate_of %arg0, %10 : (!fir.box>, i64) -> !fir.ref + %11 = fir.coordinate_of %decl, %10 : (!fir.box>, i64) -> !fir.ref %12 = fir.load %11 : !fir.ref %13 = arith.addf %7, %12 fastmath : f64 fir.store %13 to %1 : !fir.ref @@ -47,6 +48,7 @@ // Note this only checks the expected transformation, not the entire generated code: // CHECK-LABEL: func.func @sum1d( // CHECK-SAME: %[[ARG0:.*]]: !fir.box> {{.*}}) +// CHECK: %[[DECL:.*]] = fir.declare %arg0 {uniq_name = "a"} : (!fir.box>) -> !fir.box> // CHECK: %[[ZERO:.*]] = arith.constant 0 : index // CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[ARG0]], %[[ZERO]] : {{.*}} // CHECK: %[[SIZE:.*]] = arith.constant 8 : index @@ -62,7 +64,7 @@ // CHECK fir.result %[[LOOP_RES]]#0, %[[LOOP_RES]]#1 // CHECK: } else { // CHECK: %[[LOOP_RES2:.*]]:2 = fir.do_loop {{.*}} -// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[ARG0]], %{{.*}} : (!fir.box>, i64) -> !fir.ref +// CHECK: %[[COORD2:.*]] = fir.coordinate_of %[[DECL]], %{{.*}} : (!fir.box>, i64) -> !fir.ref // CHECK: %{{.*}}= fir.load %[[COORD2]] : !fir.ref // CHECK: fir.result %{{.*}}, %{{.*}} // CHECK: }