diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -145,6 +145,11 @@ replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl &exceptions) const; + /// Replace all uses of 'this' value with 'newValue' if the given callback + /// returns true. + void replaceUsesWithIf(Value newValue, + function_ref shouldReplace); + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -125,6 +125,15 @@ } } +/// Replace all uses of 'this' value with 'newValue' if the given callback +/// returns true. +void Value::replaceUsesWithIf(Value newValue, + function_ref shouldReplace) { + for (OpOperand &use : llvm::make_early_inc_range(getUses())) + if (shouldReplace(use)) + use.set(newValue); +} + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -197,8 +197,6 @@ /// Fully replace uses of the old arguments with the new, materializing cast /// operations as necessary. - // FIXME(riverriddle) The 'mapping' parameter is only necessary because the - // implementation of replaceUsesOfBlockArgument is buggy. void applyRewrites(ConversionValueMapping &mapping); //===--------------------------------------------------------------------===// @@ -436,9 +434,10 @@ /// This is useful when saving and undoing a set of rewrites. struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numReplacements, - unsigned numBlockActions, unsigned numIgnoredOperations, - unsigned numRootUpdates) + unsigned numArgReplacements, unsigned numBlockActions, + unsigned numIgnoredOperations, unsigned numRootUpdates) : numCreatedOps(numCreatedOps), numReplacements(numReplacements), + numArgReplacements(numArgReplacements), numBlockActions(numBlockActions), numIgnoredOperations(numIgnoredOperations), numRootUpdates(numRootUpdates) {} @@ -449,6 +448,9 @@ /// The current number of replacements queued. unsigned numReplacements; + /// The current number of argument replacements queued. + unsigned numArgReplacements; + /// The current number of block actions performed. unsigned numBlockActions; @@ -624,6 +626,9 @@ /// Ordered vector of any requested operation replacements. SmallVector replacements; + /// Ordered vector of any requested block argument replacements. + SmallVector argReplacements; + /// Ordered list of block operations (creations, splits, motions). SmallVector blockActions; @@ -654,8 +659,8 @@ RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), replacements.size(), - blockActions.size(), ignoredOps.size(), - rootUpdates.size()); + argReplacements.size(), blockActions.size(), + ignoredOps.size(), rootUpdates.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -664,6 +669,12 @@ rootUpdates[i].resetOperation(); rootUpdates.resize(state.numRootUpdates); + // Reset any replaced arguments. + for (BlockArgument replacedArg : + llvm::drop_begin(argReplacements, state.numArgReplacements)) + mapping.erase(replacedArg); + argReplacements.resize(state.numArgReplacements); + // Undo any block actions. undoBlockActions(state.numBlockActions); @@ -753,6 +764,25 @@ argConverter.notifyOpRemoved(repl.op); } + // Apply all of the requested argument replacements. + for (BlockArgument arg : argReplacements) { + Value repl = mapping.lookupOrDefault(arg); + if (repl.isa()) { + arg.replaceAllUsesWith(repl); + continue; + } + + // If the replacement value is an operation, we check to make sure that we + // don't replace uses that are within the parent operation of the + // replacement value. + Operation *replOp = repl.cast().getOwner(); + Block *replBlock = replOp->getBlock(); + arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { + Operation *user = operand.getOwner(); + return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + }); + } + // In a second pass, erase all of the replaced operations in reverse. This // allows processing nested operations before their parent region is // destroyed. @@ -907,11 +937,13 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { - for (auto &u : from.getUses()) { - if (u.getOwner() == to.getDefiningOp()) - continue; - u.getOwner()->replaceUsesOfWith(from, to); - } + LLVM_DEBUG({ + Operation *parentOp = from.getOwner()->getParentOp(); + impl->logger.startLine() << "** Replace Argument : '" << from + << "'(in region of '" << parentOp->getName() + << "'(" << from.getOwner()->getParentOp() << ")\n"; + }); + impl->argReplacements.push_back(from); impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -197,3 +197,17 @@ }) : () -> () return } + +// ----- + +// CHECK-LABEL: @undo_block_arg_replace +func @undo_block_arg_replace() { + "test.undo_block_arg_replace"() ({ + ^bb0(%arg0: i32): + // CHECK: ^bb0(%[[ARG:.*]]: i32): + // CHECK-NEXT: "test.return"(%[[ARG]]) : (i32) + + "test.return"(%arg0) : (i32) -> () + }) : () -> () + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -238,6 +238,24 @@ } }; +/// A simple pattern that tests the undo mechanism when replacing the uses of a +/// block argument. +struct TestUndoBlockArgReplace : public ConversionPattern { + TestUndoBlockArgReplace(MLIRContext *ctx) + : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto illegalOp = + rewriter.create(op->getLoc(), rewriter.getF32Type()); + rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0), + illegalOp); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing @@ -449,12 +467,14 @@ TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - patterns.insert< - TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, - TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType, - TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, - TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, - TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext()); + patterns.insert( + &getContext()); patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);