diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -519,10 +519,9 @@ /// This class represents one requested operation replacement via 'replaceOp'. struct OpReplacement { OpReplacement() = default; - OpReplacement(Operation *op, ValueRange newValues) - : op(op), newValues(newValues.begin(), newValues.end()) {} + OpReplacement(ValueRange newValues) + : newValues(newValues.begin(), newValues.end()) {} - Operation *op; SmallVector newValues; }; @@ -681,8 +680,8 @@ /// Ordered vector of all of the newly created operations during conversion. std::vector createdOps; - /// Ordered vector of any requested operation replacements. - SmallVector replacements; + /// Ordered map of requested operation replacements. + llvm::MapVector replacements; /// Ordered vector of any requested block argument replacements. SmallVector argReplacements; @@ -690,18 +689,29 @@ /// Ordered list of block operations (creations, splits, motions). SmallVector blockActions; - /// A set of operations that have been erased/replaced/etc that should no - /// longer be considered for legalization. This is not meant to be an - /// exhaustive list of all operations, but the minimal set that can be used to - /// detect if a given operation should be `ignored`. For example, we may add - /// the operations that define non-empty regions to the set, but not any of - /// the others. This simplifies the amount of memory needed as we can query if - /// the parent operation was ignored. + /// A set of operations that should no longer be considered for legalization, + /// but were not directly replace/erased/etc. by a pattern. These are + /// generally child operations of other operations who were + /// replaced/erased/etc. This is not meant to be an exhaustive list of all + /// operations, but the minimal set that can be used to detect if a given + /// operation should be `ignored`. For example, we may add the operations that + /// define non-empty regions to the set, but not any of the others. This + /// simplifies the amount of memory needed as we can query if the parent + /// operation was ignored. llvm::SetVector ignoredOps; /// A transaction state for each of operations that were updated in-place. SmallVector rootUpdates; + /// A vector of indices to operations that were replaced with values with + /// different result types than the original operation, e.g. 1->N conversion + /// of some kind. + SmallVector operationsWithChangedResults; + + /// A default type converter, used when block conversions do not have one + /// explicitly provided. + TypeConverter defaultTypeConverter; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -711,10 +721,6 @@ /// A logger used to emit diagnostics during the conversion process. llvm::ScopedPrinter logger{llvm::dbgs()}; #endif - - /// A default type converter, used when block conversions do not have one - /// explicitly provided. - TypeConverter defaultTypeConverter; }; } // end namespace detail } // end namespace mlir @@ -750,16 +756,16 @@ void ConversionPatternRewriterImpl::applyRewrites() { // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { - for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) { - if (auto newValue = repl.newValues[i]) - repl.op->getResult(i).replaceAllUsesWith( + for (unsigned i = 0, e = repl.second.newValues.size(); i != e; ++i) { + if (auto newValue = repl.second.newValues[i]) + repl.first->getResult(i).replaceAllUsesWith( mapping.lookupOrDefault(newValue)); } // If this operation defines any regions, drop any pending argument // rewrites. - if (repl.op->getNumRegions()) - argConverter.notifyOpRemoved(repl.op); + if (repl.first->getNumRegions()) + argConverter.notifyOpRemoved(repl.first); } // Apply all of the requested argument replacements. @@ -785,7 +791,7 @@ // allows processing nested operations before their parent region is // destroyed. for (auto &repl : llvm::reverse(replacements)) - repl.op->erase(); + repl.first->erase(); argConverter.applyRewrites(mapping); @@ -819,9 +825,10 @@ // Reset any replaced operations and undo any saved mappings. for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) - for (auto result : repl.op->getResults()) + for (auto result : repl.first->getResults()) mapping.erase(result); - replacements.resize(state.numReplacements); + while (replacements.size() != state.numReplacements) + replacements.pop_back(); // Pop all of the newly created operations. while (createdOps.size() != state.numCreatedOps) { @@ -832,6 +839,11 @@ // Pop all of the recorded ignored operations that are no longer valid. while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); + + // Reset operations with changed results. + while (!operationsWithChangedResults.empty() && + operationsWithChangedResults.back() >= state.numReplacements) + operationsWithChangedResults.pop_back(); } void ConversionPatternRewriterImpl::eraseDanglingBlocks() { @@ -898,8 +910,8 @@ } bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { - // Check to see if this operation or its parent were ignored. - return ignoredOps.count(op) || ignoredOps.count(op->getParentOp()); + // Check to see if this operation was replaced or its parent ignored. + return replacements.count(op) || ignoredOps.count(op->getParentOp()); } void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { @@ -963,14 +975,25 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, ValueRange newValues) { assert(newValues.size() == op->getNumResults()); + assert(!replacements.count(op) && "operation was already replaced"); + + // Track if any of the results changed, e.g. erased and replaced with null. + bool resultChanged = false; // Create mappings for each of the new result values. - for (unsigned i = 0, e = newValues.size(); i < e; ++i) - if (auto repl = newValues[i]) - mapping.map(op->getResult(i), repl); + Value newValue, result; + for (auto it : llvm::zip(newValues, op->getResults())) { + std::tie(newValue, result) = it; + if (!newValue) + resultChanged = true; + else + mapping.map(result, newValue); + } + if (resultChanged) + operationsWithChangedResults.push_back(replacements.size()); // Record the requested operation replacement. - replacements.emplace_back(op, newValues); + replacements.insert(std::make_pair(op, OpReplacement(newValues))); // Mark this operation as recursively ignored so that we don't need to // convert any nested operations. @@ -1511,20 +1534,12 @@ assert(impl.pendingRootUpdates.empty() && "dangling root updates"); #endif - // Check all of the replacements to ensure that the pattern actually replaced - // the root operation. We also mark any other replaced ops as 'dead' so that - // we don't try to legalize them later. - bool replacedRoot = false; - for (unsigned i = curState.numReplacements, e = impl.replacements.size(); - i != e; ++i) { - Operation *replacedOp = impl.replacements[i].op; - if (replacedOp == op) - replacedRoot = true; - else - impl.ignoredOps.insert(replacedOp); - } - - // Check that the root was either updated or replace. + // Check that the root was either replaced or updated in place. + auto replacedRoot = [&] { + return llvm::any_of( + llvm::drop_begin(impl.replacements, curState.numReplacements), + [op](auto &it) { return it.first == op; }); + }; auto updatedRootInPlace = [&] { return llvm::any_of( llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates), @@ -1532,7 +1547,7 @@ }; (void)replacedRoot; (void)updatedRootInPlace; - assert((replacedRoot || updatedRootInPlace()) && + assert((replacedRoot() || updatedRootInPlace()) && "expected pattern to replace the root operation"); // Legalize each of the actions registered during application. @@ -1856,6 +1871,10 @@ /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); + /// This method is called after the conversion process to legalize any + /// remaining legalization artifacts and complete the conversion. + LogicalResult finalize(ConversionPatternRewriter &rewriter); + /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -1916,16 +1935,56 @@ // Convert each operation and discard rewrites on failure. ConversionPatternRewriter rewriter(ops.front()->getContext()); + ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); for (auto *op : toConvert) if (failed(convert(rewriter, op))) - return rewriter.getImpl().discardRewrites(), failure(); + return rewriterImpl.discardRewrites(), failure(); - // Otherwise, the body conversion succeeded. Apply rewrites if this is not an - // analysis conversion. + // Now that all of the operations have been converted, finalize the conversion + // process to ensure any lingering conversion artifacts are cleaned up and + // legalized. + if (failed(finalize(rewriter))) + return rewriterImpl.discardRewrites(), failure(); + + // After a successful conversion, apply rewrites if this is not an analysis + // conversion. if (mode == OpConversionMode::Analysis) - rewriter.getImpl().discardRewrites(); + rewriterImpl.discardRewrites(); else - rewriter.getImpl().applyRewrites(); + rewriterImpl.applyRewrites(); + return success(); +} + +LogicalResult +OperationConverter::finalize(ConversionPatternRewriter &rewriter) { + ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); + auto isOpDead = [&](Operation *op) { return rewriterImpl.isOpIgnored(op); }; + + // Process the operations with changed results. + for (unsigned replIdx : rewriterImpl.operationsWithChangedResults) { + auto &repl = *(rewriterImpl.replacements.begin() + replIdx); + for (auto it : llvm::zip(repl.first->getResults(), repl.second.newValues)) { + Value result = std::get<0>(it), newValue = std::get<1>(it); + + // If the operation result was replaced with null, all of the uses of this + // value should be replaced. + if (!newValue) { + auto liveUserIt = llvm::find_if_not(result.getUsers(), isOpDead); + if (liveUserIt != result.user_end()) { + InFlightDiagnostic diag = repl.first->emitError() + << "failed to legalize operation '" + << repl.first->getName() + << "' marked as erased"; + diag.attachNote(liveUserIt->getLoc()) + << "found live user of result #" + << result.cast().getResultNumber() << ": " + << *liveUserIt; + return failure(); + } + } + } + } + return success(); } diff --git a/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -test-legalize-unknown-root-patterns -verify-diagnostics + +// Test that an error is emitted when an operation is marked as "erased", but +// has users that live across the conversion. +func @remove_all_ops(%arg0: i32) -> i32 { + // expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}} + %0 = "test.illegal_op_a"() : () -> i32 + // expected-note@below {{found live user of result #0: return %0 : i32}} + return %0 : i32 +}