diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -619,8 +619,9 @@ ``` %0 = "producer.op"() : () -> !type.A %1 = unrealized_conversion_cast %0 : !type.A to !type.B - %2 = unrealized_conversion_cast %1 : !type.B to !type.A - "consumer.op"(%2) : (!type.A) -> () + %2 = unrealized_conversion_cast %1 : !type.B to !type.C + %3 = unrealized_conversion_cast %2 : !type.C to !type.A + "consumer.op"(%3) : (!type.A) -> () ``` Such situations appear when the consumer operation is converted by one pass diff --git a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp --- a/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp +++ b/mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp @@ -17,37 +17,86 @@ namespace { -/// Removes `unrealized_conversion_cast`s whose results are only used by other -/// `unrealized_conversion_cast`s converting back to the original type. This -/// pattern is complementary to the folder and can be used to process operations -/// starting from the first, i.e. the usual traversal order in dialect -/// conversion. The folder, on the other hand, can only apply to the last -/// operation in a chain of conversions because it is not expected to walk -/// use-def chains. One would need to declare cast ops as dynamically illegal -/// with a complex condition in order to eliminate them using the folder alone -/// in the dialect conversion infra. +/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types +/// the same as the input ones. +/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A` +/// represent a noop within the IR, and thus the initial input values can be +/// propagated. +/// The same does not hold for 'open' chains chains of casts, such as +/// `A -> B -> C`. In this last case there is no cycle among the types and thus +/// the conversion is incomplete. The same hold for 'closed' chains like +/// `A -> B -> A`, but with the result of type `B` being used by some non-cast +/// operations. +/// Bifurcations (that is when a chain starts in between of another one) are +/// also taken into considerations, and all the above considerations remain +/// valid. +/// Special corner cases such as dead casts or single casts with same input and +/// output types are also covered. struct UnrealizedConversionCastPassthrough : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(UnrealizedConversionCastOp op, PatternRewriter &rewriter) const override { - // Match the casts that are _only_ used by other casts, with the overall - // cast being a trivial noop: A->B->A. - auto users = op->getUsers(); - if (!llvm::all_of(users, [&](Operation *user) { - if (auto other = dyn_cast(user)) - return other.getResultTypes() == op.getInputs().getTypes() && - other.getInputs() == op.getOutputs(); - return false; - })) { - return rewriter.notifyMatchFailure(op, "live unrealized conversion cast"); + // The nodes that either are not used by any operation or have at least + // one user that is not an unrealized cast. + DenseSet exitNodes; + + // The nodes whose users are all unrealized casts + DenseSet intermediateNodes; + + // Stack used for the depth-first traversal of the use-def DAG. + SmallVector visitStack; + visitStack.push_back(op); + + while (!visitStack.empty()) { + UnrealizedConversionCastOp current = visitStack.pop_back_val(); + auto users = current->getUsers(); + bool isLive = false; + + for (Operation *user : users) { + if (auto other = dyn_cast(user)) { + if (other.getInputs() != current.getOutputs()) + return rewriter.notifyMatchFailure( + op, "mismatching values propagation"); + } else { + isLive = true; + } + + // Continue traversing the DAG of unrealized casts + if (auto other = dyn_cast(user)) + visitStack.push_back(other); + } + + // If the cast is live, then we need to check if the results of the last + // cast have the same type of the root inputs. It this is the case (e.g. + // `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a + // no-op and the inputs can be forwarded. If it's not (e.g. + // `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete. + + bool isCycle = current.getResultTypes() == op.getInputs().getTypes(); + + if (isLive && !isCycle) + return rewriter.notifyMatchFailure(op, + "live unrealized conversion cast"); + + bool isExitNode = users.empty() || isLive; + + if (isExitNode) { + exitNodes.insert(current); + } else { + intermediateNodes.insert(current); + } } - for (Operation *user : users) - rewriter.replaceOp(user, op.getInputs()); + // Replace the sink nodes with the root input values + for (UnrealizedConversionCastOp exitNode : exitNodes) + rewriter.replaceOp(exitNode, op.getInputs()); + + // Erase all the other casts belonging to the DAG + for (UnrealizedConversionCastOp castOp : intermediateNodes) + rewriter.eraseOp(castOp); - rewriter.eraseOp(op); return success(); } }; diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts-failure.mlir @@ -0,0 +1,45 @@ +// RUN: not mlir-opt %s -split-input-file -mlir-print-ir-after-failure -reconcile-unrealized-casts 2>&1 | FileCheck %s + +// CHECK-LABEL: @liveSingleCast +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[liveCast:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: return %[[liveCast]] : i32 + +func.func @liveSingleCast(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @liveChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i32 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i1 +// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i1 to i32 +// CHECK: return %[[cast1]] : i32 + +func.func @liveChain(%arg0: i64) -> i32 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i1 + %1 = builtin.unrealized_conversion_cast %0 : i1 to i32 + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @liveBifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to i32 +// CHECK: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i64 +// CHECK: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i32 to i1 +// CHECK: %[[extsi:.*]] = arith.extsi %[[cast2]] : i1 to i64 +// CHECK: %[[result:.*]] = arith.addi %[[cast1]], %[[extsi]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @liveBifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 + %2 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %3 = arith.extsi %2 : i1 to i64 + %4 = arith.addi %1, %3 : i64 + return %4 : i64 +} diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -split-input-file -reconcile-unrealized-casts | FileCheck %s + +// CHECK-LABEL: @unusedCast +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @unusedCast(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: @sameTypes +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @sameTypes(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i64 + return %0 : i64 +} + +// ----- + +// CHECK-LABEL: @pair +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @pair(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i64 + return %1 : i64 +} + +// ----- + +// CHECK-LABEL: @symmetricChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @symmetricChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i32 + %3 = builtin.unrealized_conversion_cast %2 : i32 to i64 + return %3 : i64 +} + +// ----- + +// CHECK-LABEL: @asymmetricChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @asymmetricChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + return %2 : i64 +} + +// ----- + +// CHECK-LABEL: @unusedChain +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: return %[[arg0]] : i64 + +func.func @unusedChain(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + return %arg0 : i64 +} + +// ----- + +// CHECK-LABEL: @bifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @bifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + %3 = builtin.unrealized_conversion_cast %1 : i1 to i32 + %4 = builtin.unrealized_conversion_cast %3 : i32 to i64 + %5 = arith.addi %2, %4 : i64 + return %5 : i64 +} + +// ----- + +// CHECK-LABEL: @unusedBifurcation +// CHECK-SAME: (%[[arg0:.*]]: i64) -> i64 +// CHECK: %[[result:.*]] = arith.addi %[[arg0]], %[[arg0]] : i64 +// CHECK: return %[[result]] : i64 + +func.func @unusedBifurcation(%arg0: i64) -> i64 { + %0 = builtin.unrealized_conversion_cast %arg0 : i64 to i32 + %1 = builtin.unrealized_conversion_cast %0 : i32 to i1 + %2 = builtin.unrealized_conversion_cast %1 : i1 to i64 + %3 = builtin.unrealized_conversion_cast %0 : i32 to i64 + %4 = arith.addi %arg0, %3 : i64 + return %4 : i64 +}