diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -546,6 +546,11 @@ updateRootInPlace(op, [&]() { operand.set(to); }); } } + void replaceAllUsesWith(ValueRange from, ValueRange to) { + assert(from.size() == to.size() && "incorrect number of replacements"); + for (auto it : llvm::zip(from, to)) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + } /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1445,21 +1445,23 @@ failed(foldDynamicIndexList(rewriter, mixedStep))) return failure(); - SmallVector dynamicLowerBound, dynamicUpperBound, dynamicStep; - SmallVector staticLowerBound, staticUpperBound, staticStep; - dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound, - staticLowerBound); - op.getDynamicLowerBoundMutable().assign(dynamicLowerBound); - op.setStaticLowerBound(staticLowerBound); - - dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound, - staticUpperBound); - op.getDynamicUpperBoundMutable().assign(dynamicUpperBound); - op.setStaticUpperBound(staticUpperBound); - - dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep); - op.getDynamicStepMutable().assign(dynamicStep); - op.setStaticStep(staticStep); + rewriter.updateRootInPlace(op, [&]() { + SmallVector dynamicLowerBound, dynamicUpperBound, dynamicStep; + SmallVector staticLowerBound, staticUpperBound, staticStep; + dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound, + staticLowerBound); + op.getDynamicLowerBoundMutable().assign(dynamicLowerBound); + op.setStaticLowerBound(staticLowerBound); + + dispatchIndexOpFoldResults(mixedUpperBound, dynamicUpperBound, + staticUpperBound); + op.getDynamicUpperBoundMutable().assign(dynamicUpperBound); + op.setStaticUpperBound(staticUpperBound); + + dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep); + op.getDynamicStepMutable().assign(dynamicStep); + op.setStaticStep(staticStep); + }); return success(); } }; @@ -3077,7 +3079,8 @@ op.getLoc(), term.getCondition().getType(), rewriter.getBoolAttr(true)); - std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue); + rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs), + constantTrue); replaced = true; } } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -78,7 +78,8 @@ // Rewrite uses of the for-loop block arguments to the new while-loop // "after" arguments for (const auto &barg : enumerate(forOp.getBody(0)->getArguments())) - barg.value().replaceAllUsesWith(afterBlock->getArgument(barg.index())); + rewriter.replaceAllUsesWith(barg.value(), + afterBlock->getArgument(barg.index())); // Inline for-loop body operations into 'after' region. for (auto &arg : llvm::make_early_inc_range(*forOp.getBody())) @@ -88,7 +89,8 @@ for (auto yieldOp : afterBlock->getOps()) { SmallVector yieldOperands = yieldOp.getOperands(); yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult()); - yieldOp->setOperands(yieldOperands); + rewriter.updateRootInPlace( + yieldOp, [&]() { yieldOp->setOperands(yieldOperands); }); } // We cannot do a direct replacement of the forOp since the while op returns @@ -96,7 +98,8 @@ // carried in the set of iterargs). Instead, rewrite uses of the forOp // results. for (const auto &arg : llvm::enumerate(forOp.getResults())) - arg.value().replaceAllUsesWith(whileOp.getResult(arg.index() + 1)); + rewriter.replaceAllUsesWith(arg.value(), + whileOp.getResult(arg.index() + 1)); rewriter.eraseOp(forOp); return success(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -144,7 +144,7 @@ b.setInsertionPointAfter(forOp); partialIteration = cast(b.clone(*forOp.getOperation())); partialIteration.getLowerBoundMutable().assign(splitBound); - forOp.replaceAllUsesWith(partialIteration->getResults()); + b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults()); partialIteration.getInitArgsMutable().assign(forOp->getResults()); // Set new upper loop bound. @@ -221,11 +221,13 @@ if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration))) return failure(); // Apply label, so that the same loop is not rewritten a second time. - partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); + rewriter.updateRootInPlace(partialIteration, [&]() { + partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); + partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); + }); rewriter.updateRootInPlace(forOp, [&]() { forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); }); - partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); return success(); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -302,6 +302,7 @@ // CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] // CHECK: return [[V0]], [[C1]] : index, index +// ----- func.func @to_select_with_body(%cond: i1) -> index { %c0 = arith.constant 0 : index @@ -323,6 +324,7 @@ // CHECK: "test.op"() : () -> () // CHECK: } // CHECK: return [[V0]] : index + // ----- func.func @to_select2(%cond: i1) -> (index, index) { @@ -363,6 +365,10 @@ // CHECK-NEXT: %[[R:.*]] = call @make_i32() : () -> i32 // CHECK-NEXT: return %[[R]] : i32 +// ----- + +func.func private @make_i32() -> i32 + func.func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) { %a = call @make_i32() : () -> (i32) %b = call @make_i32() : () -> (i32) @@ -523,6 +529,8 @@ return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8 } +// ----- + // CHECK-LABEL: @merge_yielding_nested_if_nv1 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) func.func @merge_yielding_nested_if_nv1(%arg0: i1, %arg1: i1) { @@ -547,6 +555,8 @@ return } +// ----- + // CHECK-LABEL: @merge_yielding_nested_if_nv2 // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) func.func @merge_yielding_nested_if_nv2(%arg0: i1, %arg1: i1) -> i32 { @@ -571,6 +581,8 @@ return %r : i32 } +// ----- + // CHECK-LABEL: @merge_fail_yielding_nested_if // CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) func.func @merge_fail_yielding_nested_if(%arg0: i1, %arg1: i1) -> (i32, f32, i32, i8) { @@ -1128,6 +1140,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[res]] : i32 +// ----- + // CHECK-LABEL: @while_cmp_lhs func.func @while_cmp_lhs(%arg0 : i32) { %0 = scf.while () : () -> i32 { @@ -1155,6 +1169,8 @@ // CHECK-NEXT: scf.yield // CHECK-NEXT: } +// ----- + // CHECK-LABEL: @while_cmp_rhs func.func @while_cmp_rhs(%arg0 : i32) { %0 = scf.while () : () -> i32 { @@ -1213,6 +1229,7 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32 +// ----- // CHECK-LABEL: @combineIfs2 func.func @combineIfs2(%arg0 : i1, %arg2: i64) -> i32 { @@ -1239,6 +1256,7 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[res]] : i32 +// ----- // CHECK-LABEL: @combineIfs3 func.func @combineIfs3(%arg0 : i1, %arg2: i64) -> i32 { @@ -1265,6 +1283,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[res]] : i32 +// ----- + // CHECK-LABEL: @combineIfs4 func.func @combineIfs4(%arg0 : i1, %arg2: i64) { scf.if %arg0 { @@ -1283,6 +1303,8 @@ // CHECK-NEXT: "test.secondCodeTrue"() : () -> () // CHECK-NEXT: } +// ----- + // CHECK-LABEL: @combineIfsUsed // CHECK-SAME: %[[arg0:.+]]: i1 func.func @combineIfsUsed(%arg0 : i1, %arg2: i64) -> (i32, i32) { @@ -1313,6 +1335,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[res]]#0, %[[res]]#1 : i32, i32 +// ----- + // CHECK-LABEL: @combineIfsNot // CHECK-SAME: %[[arg0:.+]]: i1 func.func @combineIfsNot(%arg0 : i1, %arg2: i64) { @@ -1335,6 +1359,8 @@ // CHECK-NEXT: "test.secondCodeTrue"() : () -> () // CHECK-NEXT: } +// ----- + // CHECK-LABEL: @combineIfsNot2 // CHECK-SAME: %[[arg0:.+]]: i1 func.func @combineIfsNot2(%arg0 : i1, %arg2: i64) { @@ -1356,6 +1382,7 @@ // CHECK-NEXT: } else { // CHECK-NEXT: "test.firstCodeTrue"() : () -> () // CHECK-NEXT: } + // ----- // CHECK-LABEL: func @propagate_into_execute_region @@ -1406,7 +1433,6 @@ // CHECK-NEXT: "test.bar"(%[[VAL]]) : (i64) -> () // CHECK-NEXT: } - // ----- // CHECK-LABEL: func @func_execute_region_elim @@ -1442,7 +1468,6 @@ // CHECK: "test.bar"(%[[z]]) // CHECK: return - // ----- // CHECK-LABEL: func @func_execute_region_elim_multi_yield