diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp --- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp +++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp @@ -12,108 +12,22 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include using namespace mlir; namespace { -/// A chain of casts represent a list of UnrealizedConversionCast operations. -/// Apart from the root, each operation takes the result of its predecessor in -/// and cast it to a different type. -/// A chain is considered to be valid only if the last cast has at least one -/// usage within the IR. -class ChainOfCasts { -public: - ChainOfCasts(UnrealizedConversionCastOp root) { - ops.push_back(root.getOperation()); - } - - /// Append a cast to the chain. - void append(UnrealizedConversionCastOp op) { - ops.push_back(op.getOperation()); - } - - /// Append another chain of casts to this chain. - void append(const ChainOfCasts &chain) { - for (const auto &op : chain.ops) - ops.push_back(op); - } - - /// Get the first cast of the chain. - UnrealizedConversionCastOp front() const { - return mlir::cast(ops.front()); - } - - /// Get the last cast of the chain. - UnrealizedConversionCastOp back() const { - return mlir::cast(ops.back()); - } - - /// Get the length of the chain. - size_t size() const { return ops.size(); } - - /// Determine whether the chain of unrealized casts consist in a loop with - /// respect to the types it receives and returns. - bool hasLoop() const { - return front().getInputs().getTypes() == back().getResultTypes(); - } - - /// The method returns 'true' if any of the casts composing to the chain is - /// live within the IR, that is if there is exists an use from an operations - /// that is not another cast (with the exception of the last cast, which is - /// the expected exit point of the chain). Looking from the opposite - /// perspective, the method returns 'false' if the whole chain is just a noop. - bool isLive() const { - for (size_t i = 0; i < ops.size() - 1; ++i) - if (llvm::any_of(ops[i]->getUsers(), [&](Operation *user) { - return !isa(user); - })) - return true; - - return false; - } - -private: - std::vector ops; -}; - -/// Discover the chains of casts given a root. -/// During the discovery process, the dead parts of the chain (that is, the ones -/// ending with an unused cast), are ignored. -static std::vector getChains(UnrealizedConversionCastOp root) { - std::vector chains; - bool isRootLive = false; - - for (auto *user : root->getUsers()) { - if (auto castOp = dyn_cast(user)) { - if (castOp->use_empty()) - continue; - - for (const auto &subChain : getChains(castOp)) { - ChainOfCasts chain(root); - chain.append(subChain); - chains.push_back(std::move(chain)); - } - } else { - isRootLive = true; - } - } - - if (isRootLive && chains.empty()) - chains.emplace_back(root); - - return chains; -} - -/// Folds the chains of `unrealized_conversion_cast`s that have as exit types +/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types /// the same as the input ones. -/// For example, the chains `A -> B -> C -> B -> A` and `A -> B -> C -> A` +/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A` /// represent a noop within the IR, and thus the initial input values can be /// propagated. -/// The same does not hold for 'open' chains, such as `A -> B -> C`. -/// In this last case there is no loop among the types and thus the conversion -/// is incomplete. The same hold for 'closed' chains like `A -> B -> A`, but -/// with the result of type `B` being used by some non-cast operations. +/// The same does not hold for 'open' chains chains of casts, such as +/// `A -> B -> C`. In this last case there is no cycle among the types and thus +/// the conversion is incomplete. The same hold for 'closed' chains like +/// `A -> B -> A`, but with the result of type `B` being used by some non-cast +/// operations. /// Bifurcations (that is when a chain starts in between of another one) are /// also taken into considerations, and all the above considerations remain /// valid. @@ -125,56 +39,63 @@ LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, PatternRewriter &rewriter) const override { - // In case of a dead cast, just erase it. - if (op->use_empty()) { - rewriter.eraseOp(op); - return success(); - } - - // If the cast is not the root of a chain, then it is just erased. The - // propagation of the root inputs, in fact, is executed when the root of - // the chain is encountered and the whole chain processed. - auto parent = - op.getInputs().front().getDefiningOp(); - bool isRoot = parent == nullptr; - - if (!isRoot) { - rewriter.eraseOp(op); - return success(); - } - - // Get the chains having the current cast operation as root. - std::vector chains = getChains(op); - - if (llvm::any_of(chains, [](const ChainOfCasts &chain) { - if (chain.isLive()) - return true; - - // Check if the chain ends with the same starting type. If it doesn't, - // then the overall chain of casts is live within the IR. - // For example, A -> B -> C -> B can't be simplified because A is not - // the final result type. - return !chain.hasLoop(); - })) { - return rewriter.notifyMatchFailure(op, "live unrealized conversion cast"); - } + // The nodes that either are not used by any operation or have at least + // one user that is not an unrealized cast. + DenseSet sinkNodes; + + // The nodes whose users are all unrealized casts + DenseSet intermediateNodes; + + // Stack used for the depth-first traversal of the use-def DAG. + std::stack visitStack; + visitStack.push(op); + + while (!visitStack.empty()) { + UnrealizedConversionCastOp current = visitStack.top(); + visitStack.pop(); + + auto users = current->getUsers(); + + if (!llvm::all_of(users, [&](Operation *user) { + if (auto other = dyn_cast(user)) + return other.getInputs() == current.getOutputs(); + + // The cast is live, so we need to check if the results of the last + // cast have the same type of the root inputs. It this is the case + // (e.g. `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is + // just a no-op and the inputs can be forwarded. If it's not (e.g. + // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is + // incomplete. + return current.getResultTypes() == op.getInputs().getTypes(); + })) { + return rewriter.notifyMatchFailure(op, + "live unrealized conversion cast"); + } - bool rootReplaced = false; + bool isSink = users.empty() || llvm::any_of(users, [](Operation *user) { + return !isa(user); + }); - for (auto &chain : chains) { - if (chain.size() == 1 && rootReplaced) { - // If the chain is composed by just one cast, then we must avoid - // replacing it multiple times. - continue; + if (isSink) { + sinkNodes.insert(current); + } else { + intermediateNodes.insert(current); } - // Propagate the initial input values. - rewriter.replaceOp(chain.back(), chain.front().getInputs()); - rootReplaced |= chain.size() == 1; + // Continue traversing the DAG of unrealized casts + for (Operation *user : users) { + if (auto other = dyn_cast(user)) + visitStack.push(other); + } } - if (!rootReplaced) - rewriter.eraseOp(op); + // Replace the sink nodes with the root input values + for (UnrealizedConversionCastOp sink : sinkNodes) + rewriter.replaceOp(sink, op.getInputs()); + + // Erase all the other casts belonging to the DAG + for (UnrealizedConversionCastOp castOp : intermediateNodes) + rewriter.eraseOp(castOp); return success(); }