diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -615,8 +615,9 @@ ``` %0 = "producer.op"() : () -> !type.A %1 = unrealized_conversion_cast %0 : !type.A to !type.B - %2 = unrealized_conversion_cast %1 : !type.B to !type.A - "consumer.op"(%2) : (!type.A) -> () + %2 = unrealized_conversion_cast %1 : !type.B to !type.C + %3 = unrealized_conversion_cast %2 : !type.C to !type.A + "consumer.op"(%3) : (!type.A) -> () ``` Such situations appear when the consumer operation is converted by one pass diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp --- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp +++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp @@ -17,37 +17,165 @@ namespace { -/// Removes `unrealized_conversion_cast`s whose results are only used by other -/// `unrealized_conversion_cast`s converting back to the original type. This -/// pattern is complementary to the folder and can be used to process operations -/// starting from the first, i.e. the usual traversal order in dialect -/// conversion. The folder, on the other hand, can only apply to the last -/// operation in a chain of conversions because it is not expected to walk -/// use-def chains. One would need to declare cast ops as dynamically illegal -/// with a complex condition in order to eliminate them using the folder alone -/// in the dialect conversion infra. +/// A chain of casts represent a list of UnrealizedConversionCast operations. +/// Apart from the root, each operation takes the result of its predecessor in +/// and cast it to a different type. +/// A chain is considered to be valid only if the last cast has at least one +/// usage within the IR. +class ChainOfCasts { +public: + ChainOfCasts(UnrealizedConversionCastOp root) { + ops.push_back(root.getOperation()); + } + + /// Append a cast to the chain. + void append(UnrealizedConversionCastOp op) { + ops.push_back(op.getOperation()); + } + + /// Append another chain of casts to this chain. + void append(const ChainOfCasts &chain) { + for (const auto &op : chain.ops) + ops.push_back(op); + } + + /// Get the first cast of the chain. + UnrealizedConversionCastOp front() const { + return mlir::cast(ops.front()); + } + + /// Get the last cast of the chain. + UnrealizedConversionCastOp back() const { + return mlir::cast(ops.back()); + } + + /// Get the length of the chain. + size_t size() const { return ops.size(); } + + /// Determine whether the chain of unrealized casts consist in a loop with + /// respect to the types it receives and returns. + bool hasLoop() const { + return front().getInputs().getTypes() == back().getResultTypes(); + } + + /// The method returns 'true' if any of the casts composing to the chain is + /// live within the IR, that is if there is exists an use from an operations + /// that is not another cast (with the exception of the last cast, which is + /// the expected exit point of the chain). Looking from the opposite + /// perspective, the method returns 'false' if the whole chain is just a noop. + bool isLive() const { + for (size_t i = 0; i < ops.size() - 1; ++i) + if (llvm::any_of(ops[i]->getUsers(), [&](Operation *user) { + return !isa(user); + })) + return true; + + return false; + } + +private: + std::vector ops; +}; + +/// Discover the chains of casts given a root. +/// During the discovery process, the dead parts of the chain (that is, the ones +/// ending with an unused cast), are ignored. +static std::vector getChains(UnrealizedConversionCastOp root) { + std::vector chains; + bool isRootLive = false; + + for (auto *user : root->getUsers()) { + if (auto castOp = dyn_cast(user)) { + if (castOp->use_empty()) + continue; + + for (const auto &subChain : getChains(castOp)) { + ChainOfCasts chain(root); + chain.append(subChain); + chains.push_back(std::move(chain)); + } + } else { + isRootLive = true; + } + } + + if (isRootLive && chains.empty()) + chains.emplace_back(root); + + return chains; +} + +/// Folds the chains of `unrealized_conversion_cast`s that have as exit types +/// the same as the input ones. +/// For example, the chains `A -> B -> C -> B -> A` and `A -> B -> C -> A` +/// represent a noop within the IR, and thus the initial input values can be +/// propagated. +/// The same does not hold for 'open' chains, such as `A -> B -> C`. +/// In this last case there is no loop among the types and thus the conversion +/// is incomplete. The same hold for 'closed' chains like `A -> B -> A`, but +/// with the result of type `B` being used by some non-cast operations. +/// Bifurcations (that is when a chain starts in between of another one) are +/// also taken into considerations, and all the above considerations remain +/// valid. +/// Special corner cases such as dead casts or single casts with same input and +/// output types are also covered. struct UnrealizedConversionCastPassthrough : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, PatternRewriter &rewriter) const override { - // Match the casts that are _only_ used by other casts, with the overall - // cast being a trivial noop: A->B->A. - auto users = op->getUsers(); - if (!llvm::all_of(users, [&](Operation *user) { - if (auto other = dyn_cast(user)) - return other.getResultTypes() == op.getInputs().getTypes() && - other.getInputs() == op.getOutputs(); - return false; + // In case of a dead cast, just erase it. + if (op->use_empty()) { + rewriter.eraseOp(op); + return success(); + } + + // If the cast is not the root of a chain, then it is just erased. The + // propagation of the root inputs, in fact, is executed when the root of + // the chain is encountered and the whole chain processed. + auto parent = + op.getInputs().front().getDefiningOp(); + bool isRoot = parent == nullptr; + + if (!isRoot) { + rewriter.eraseOp(op); + return success(); + } + + // Get the chains having the current cast operation as root. + std::vector chains = getChains(op); + + if (llvm::any_of(chains, [](const ChainOfCasts &chain) { + if (chain.isLive()) + return true; + + // Check if the chain ends with the same starting type. If it doesn't, + // then the overall chain of casts is live within the IR. + // For example, A -> B -> C -> B can't be simplified because A is not + // the final result type. + return !chain.hasLoop(); })) { return rewriter.notifyMatchFailure(op, "live unrealized conversion cast"); } - for (Operation *user : users) - rewriter.replaceOp(user, op.getInputs()); + bool rootReplaced = false; + + for (auto &chain : chains) { + if (chain.size() == 1 && rootReplaced) { + // If the chain is composed by just one cast, then we must avoid + // replacing it multiple times. + continue; + } + + // Propagate the initial input values. + rewriter.replaceOp(chain.back(), chain.front().getInputs()); + rootReplaced |= chain.size() == 1; + } + + if (!rootReplaced) + rewriter.eraseOp(op); - rewriter.eraseOp(op); return success(); } }; diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir @@ -0,0 +1,45 @@ +// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s + +// CHECK-LABEL: @liveSingleCast +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: return %[[liveCast]] : i32 + +func.func @liveSingleCast(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @liveChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1 +// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32 +// CHECK: return %[[cast1]] : i32 + +func.func @liveChain(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1 + %1 = builtin.unrealized_conversion_cast %0 : i1 to i32 + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @liveBifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64 +// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1 +// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64 +// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @liveBifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 + %2 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %3 = arith.extsi %2 : i1 to i64 + %4 = arith.addi %1, %3 : i64 + return %4 : i64 +} diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -split-input-file -reconcile-unrealized-casts | FileCheck %s + +// CHECK-LABEL: @unusedCast +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @unusedCast(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: @sameTypes +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @sameTypes(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i64 + return %0 : i64 +} + +// ----- + +// CHECK-LABEL: @pair +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @pair(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 + return %1 : i64 +} + +// ----- + +// CHECK-LABEL: @symmetricChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @symmetricChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i32 + %3 = builtin.unrealized_conversion_cast %2 : i32 to i64 + return %3 : i64 +} + +// ----- + +// CHECK-LABEL: @asymmetricChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @asymmetricChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + return %2 : i64 +} + +// ----- + +// CHECK-LABEL: @unusedChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @unusedChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: @bifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @bifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + %3 = builtin.unrealized_conversion_cast %1 : i1 to i32 + %4 = builtin.unrealized_conversion_cast %3 : i32 to i64 + %5 = arith.addi %2, %4 : i64 + return %5 : i64 +} + +// ----- + +// CHECK-LABEL: @unusedBifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @unusedBifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + %3 = builtin.unrealized_conversion_cast %0 : i32 to i64 + %4 = arith.addi %arg0, %3 : i64 + return %4 : i64 +}