diff --git a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp --- a/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp +++ b/flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp @@ -1309,19 +1309,20 @@ mlir::LogicalResult hlfir::ForallIndexOp::canonicalize(hlfir::ForallIndexOp indexOp, mlir::PatternRewriter &rewriter) { + for (mlir::Operation *user : indexOp->getResult(0).getUsers()) + if (!mlir::isa(user)) + return mlir::failure(); + auto insertPt = rewriter.saveInsertionPoint(); for (mlir::Operation *user : indexOp->getResult(0).getUsers()) - if (auto loadOp = mlir::dyn_cast_or_null(user)) { + if (auto loadOp = mlir::dyn_cast(user)) { rewriter.setInsertionPoint(loadOp); rewriter.replaceOpWithNewOp( user, loadOp.getResult().getType(), indexOp.getIndex()); } rewriter.restoreInsertionPoint(insertPt); - if (indexOp.use_empty()) { - rewriter.eraseOp(indexOp); - return mlir::success(); - } - return mlir::failure(); + rewriter.eraseOp(indexOp); + return mlir::success(); } #include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc" diff --git a/flang/test/HLFIR/forall-index.fir b/flang/test/HLFIR/forall-index.fir --- a/flang/test/HLFIR/forall-index.fir +++ b/flang/test/HLFIR/forall-index.fir @@ -21,6 +21,32 @@ %xi = hlfir.designate %x(%ival) : (!fir.ref>, i64) -> !fir.ref hlfir.yield %xi : !fir.ref } + } + return +} +// CHECK-LABEL: func.func @forall_index( +// CHECK: hlfir.forall lb { +// CHECK: } ub { +// CHECK: } (%[[VAL_4:.*]]: i64) { +// CHECK: hlfir.forall_index "i" %[[VAL_4]] : (i64) -> !fir.ref + +// CANONICALIZATION-LABEL: func.func @forall_index( +// CANONICALIZATION: hlfir.forall lb { +// CANONICALIZATION: } ub { +// CANONICALIZATION: } (%[[VAL_4:.*]]: i64) { +// CANONICALIZATION-NOT: hlfir.forall_index +// CANONICALIZATION: hlfir.designate %{{.*}} (%[[VAL_4]]) : (!fir.ref>, i64) -> !fir.ref +// CANONICALIZATION: hlfir.designate %{{.*}} (%[[VAL_4]]) : (!fir.ref>, i64) -> !fir.ref + +func.func @forall_index_do_not_canonicalize(%x: !fir.ref>, %y: !fir.ref>) { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + hlfir.forall lb { + hlfir.yield %c1 : index + } ub { + hlfir.yield %c10 : index + } (%arg0: i64) { + %i = hlfir.forall_index "i" %arg0 : (i64) -> !fir.ref hlfir.region_assign { %res = fir.call @taking_address(%i) : (!fir.ref) -> f32 hlfir.yield %res : f32 @@ -32,13 +58,14 @@ } return } +// CHECK-LABEL: func.func @forall_index_do_not_canonicalize( // CHECK: hlfir.forall lb { // CHECK: } ub { // CHECK: } (%[[VAL_4:.*]]: i64) { // CHECK: hlfir.forall_index "i" %[[VAL_4]] : (i64) -> !fir.ref +// CANONICALIZATION-LABEL: func.func @forall_index_do_not_canonicalize( // CANONICALIZATION: %[[VAL_5:.*]] = hlfir.forall_index "i" %[[VAL_4:.*]] : (i64) -> !fir.ref -// CANONICALIZATION: hlfir.designate %{{.*}} (%[[VAL_4]]) : (!fir.ref>, i64) -> !fir.ref -// CANONICALIZATION: hlfir.designate %{{.*}} (%[[VAL_4]]) : (!fir.ref>, i64) -> !fir.ref // CANONICALIZATION: fir.call @taking_address(%[[VAL_5]]) : (!fir.ref) -> f32 -// CANONICALIZATION: hlfir.designate %{{.*}} (%[[VAL_4]]) : (!fir.ref>, i64) -> !fir.ref +// CANONICALIZATION: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref +// CANONICALIZATION: hlfir.designate %{{.*}} (%[[VAL_6]]) : (!fir.ref>, i64) -> !fir.ref diff --git a/flang/test/Lower/HLFIR/forall.f90 b/flang/test/Lower/HLFIR/forall.f90 --- a/flang/test/Lower/HLFIR/forall.f90 +++ b/flang/test/Lower/HLFIR/forall.f90 @@ -91,11 +91,13 @@ ! CHECK: hlfir.yield %[[VAL_12]] : i1 ! CHECK: } do { ! CHECK: hlfir.region_assign { -! CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[VAL_9]]) : (!fir.ref>, i64) -> !fir.ref +! CHECK: %[[I_LOAD:.*]] = fir.load %[[VAL_10]] : !fir.ref +! CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_8]]#0 (%[[I_LOAD]]) : (!fir.ref>, i64) -> !fir.ref ! CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref ! CHECK: hlfir.yield %[[VAL_14]] : i32 ! CHECK: } to { -! CHECK: %[[VAL_15:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[VAL_9]], %[[VAL_9]]) : (!fir.ref>, i64, i64) -> !fir.ref +! CHECK: %[[I_LOAD:.*]] = fir.load %[[VAL_10]] : !fir.ref +! CHECK: %[[VAL_15:.*]] = hlfir.designate %[[VAL_5]]#0 (%[[I_LOAD]], %[[I_LOAD]]) : (!fir.ref>, i64, i64) -> !fir.ref ! CHECK: hlfir.yield %[[VAL_15]] : !fir.ref ! CHECK: } ! CHECK: }