The reconciliation pass has been improved to introduce the support for chains of casts, thus not limiting anymore the reconciliation to just consider pairs of unrealized casts.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
This looks broken in multiple ways and unnecessarily complex.
I can see several approaches to this:
- modify this to find a DAG of casts starting from the "top" such that all "bottom" casts have the same result types as the "top" has operand types, and intermediate casts have no other users than casts that belong to the DAG, then drop the entire DAG; note that some of the bifurcation tests will have several such DAGs where the user of the "bottom" is another cast.
- add a separate pattern that iteratively propagates operands through casts: {A->B, B->C, C->A} becomes {A->B, A->C, C->A} that can be removed by DCE + the current pattern, or even {A->B, A->C, A->A} that can be removed by DCE + folding away the cast to itself.
The second approach looks significantly less complex.
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp | ||
---|---|---|
38–39 | Nit: llvm::append_range is slightly less verbose and slightly more efficient (it will reserve space). | |
40–41 | Is this even necessary? Can't we just rely on the infra to delete dead ops? | |
44 | No need to prefix with mlir:: inside mlir codebase. | |
49–50 | There is no guarantee that the cast has only one input. | |
53–56 | So this will just indiscriminately erase intermediate casts even when we don't know yet if the entire chain can be erased. And if the chain is not, we will have just dropped the casts on the floor without necessarily changing their users, living the IR in the very broken state with dangling pointers. While the passes are allowed to leave the IR in invalid state on failure, I would argue that invalid means "doesn't pass the verifier" not "has corrupted memory". Furthermore, the pattern is used outside this specific pass in the wild. | |
57 | Nit: this is usually called a cycle rather than a loop. | |
61–63 | ||
67 | Do not evaluate .size() on every iteration: https://llvm.org/docs/CodingStandards.html#don-t-evaluate-end-every-time-through-a-loop | |
67 | Please add braces to non-trivial loops https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements | |
69 | This will also ignore unrealized casts that are not part of the chain. So if you have a cast in the middle of a chain that is used by two other casts, you will still consider the entire chain dead. I suppose the correct check would be that the cast has only one user, and this user is contained in ops. Otherwise, if you really want to check if the entire DAG of casts is dead, you likely need to do so by following use-def chains and checking that each op. | |
77 | Prefer SmallVector https://llvm.org/docs/CodingStandards.html#c-standard-library | |
87 | Please expand auto unless the deduced type is obvious from context, e.g., there is a cast on the RHS. | |
92 | Prefer iteration to recursion when traversing IR, specifically use-def chains - https://mlir.llvm.org/getting_started/DeveloperGuide/. | |
102–103 | I can't fully follow this live/dead reasoning and am not convinced it is at all necessary. What would happen if we always added the root here regardless of it being live? Is it a problem if we clean up a chain even if we know it's dead? |
This is the current approach.
- add a separate pattern that iteratively propagates operands through casts: {A->B, B->C, C->A} becomes {A->B, A->C, C->A} that can be removed by DCE + the current pattern, or even {A->B, A->C, A->A} that can be removed by DCE + folding away the cast to itself.
Please tell if I'm wrong, but being this a conversion pass the following would happen, the operations are not really modified until the end of the conversion, and thus the values at the beginning of the chain would not be propagated down to the end, but only to the next cast. For example, {A -> B, B -> C, C -> D, D -> A} would become {A -> B, A -> C, B -> D, C -> A}.
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp | ||
---|---|---|
40–41 | Sure, we can assume that the CSE pass is applied before the casts reconciliation | |
49–50 | The same approach is followed in the UnrealizedCastOp folder though | |
53–56 | Maybe I've misunderstood how to pattern rewriter works, but erasing the intra-chain nodes should not "drop anything on the floor". They would be actually erased only if the conversion succeeds, and this means that the entire chain would have been processed. | |
69 | I don't get this. If a cast in the middle of a chain is used by other casts, then the chain is still possibly valid. |-> D -> E leads to two chains: A -> B -> C -> A and A -> B -> D -> E. The first one is valid, while the second one would make the conversion fail. | |
102–103 | This was for the corner case {A -> A}, but can be removed if we rely on folding |
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp | ||
---|---|---|
69 | Sorry, bad formatting. The D path was meant to start from B. A -> B -> C -> A |-> D -> E |
I can't follow the code to recover this. Could you please refactor it in a way that makes the concept visible in the code and erases all the ops simultaneously?
- add a separate pattern that iteratively propagates operands through casts: {A->B, B->C, C->A} becomes {A->B, A->C, C->A} that can be removed by DCE + the current pattern, or even {A->B, A->C, A->A} that can be removed by DCE + folding away the cast to itself.
Please tell if I'm wrong, but being this a conversion pass the following would happen, the operations are not really modified until the end of the conversion, and thus the values at the beginning of the chain would not be propagated down to the end, but only to the next cast. For example, {A -> B, B -> C, C -> D, D -> A} would become {A -> B, A -> C, B -> D, C -> A}.
I incorrectly assumed that the pass was using a regular rewrite driver because the pattern is _not_ a conversion pattern. In the dialect conversion infrastructure, your interpretation is indeed correct. The pass doesn't have to use it though, but it may be preferable to keep the pattern compatible for users that may be using it.
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp | ||
---|---|---|
49–50 | Then the folder has a bug. | |
53–56 | That assumes the dialect conversion driver (the pass does use it), but this is a plain rewrite pattern which may be used with other drivers that do not necessarily postpone deletion. | |
69 | Maybe the entire chain concept is not the right word/abstraction here. |
They are erased, but not simultaneously. I've been thinking a few days about this but I'm having a hard time figuring how to perform it. In case of "bifurcations", the casts that are shared between two or more chains would be erased multiple times. Using the conversion driver allows to delay the erase execution, and thus allows to perform it exactly once on all the casts that are not the roots (see code).
If your pattern is set up to first collect (aka match) all cast ops to erase in, e.g., a DenseSet, and then just go over it and erase them, there shouldn't be a double erase. Note that I am expecting DenseSet to contain a DAG of casts, and the code currently works on chains, where multiple _overlapping_ chains form the DAG, hence the issue with double erasing. Chains aren't the right model IMO.
I am really worried about the current approach relying on the knowledge of how the specific rewriter works internally since there is absolutely no guarantee it will keep working that way. Such a change is unlikely to happen immediately, but when it does, it will be extremely hard to debug the breakage here.
Wouldn't the rewrite pattern just become an erase of the matched casts?
And wouldn't this require the DenseSet to be populated (and passed to the pattern) before the driver is executed? If someone would like to reuse the rewrite pattern to eliminate the casts, then it would also have to copy the logic to first discover the DAGs.
I understand your worries and I agree with you about them, but as I was saying I am missing how to implement it in a different self-contained way.
The pattern would also discover the DAG and populate the DenseSet, that would be its "match" part, and the "rewrite" part is to erase everything and replace the sink nodes of the DAG with operands of its root. Each cast belongs to at most one such DAG under the single-operand cast assumption (which we should check and bail out when it doesn't hold), so we should not be able to accidentally find it by use-def chains after it has been erased.
I have implemented your suggestions. It turned out to be simpler than what I expected: I didn’t know that in case of dialect conversion the driver skips the operation that have been marked as erased (and thus my previous doubts). Thanks for the ideas and the small dive into the different driver opportunities (I knew the conversion driver was not the only one, but I never had the chance to reason about the others and the sharing of patterns).
One note though: I have opted for keeping the dead casts management, for two reasons: the first (a bit meaningless) is its simplicity (just a users.empty() check); the second, more important, is that a CSE pass would lead to other IR changes that the user may not want to perform (for whatever reason). In you feel that we should completely avoid this please let me know.
P.S: sorry for the double patch, I forgot to specify the base commit wrt diff
I hoped it would be simpler. Indeed, the driver doesn't consider operations marked for erasure. Otherwise, patterns would need to know in which order the driver walks over the IR and we don't want patterns to know about the driver internals.
One note though: I have opted for keeping the dead casts management, for two reasons: the first (a bit meaningless) is its simplicity (just a users.empty() check); the second, more important, is that a CSE pass would lead to other IR changes that the user may not want to perform (for whatever reason). In you feel that we should completely avoid this please let me know.
This is fine.
Please address the two remaining comments and this should be good to go.
mlir/lib/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.cpp | ||
---|---|---|
49 | You can use SmallVector instead, it has convenient push_back/pop_back_val. std::stack is a wrapper around std::deque, which is less efficient than std::vector, which itself is less efficient than SmallVector. | |
68 | This is checking the cycle condition inside the lambda for all users, but the condition itself doesn't need the user so it may end up being checked repeatedly. Could you factor it out of the lambda? It may be possible to do a single sweep over users checking that (a) they are all UnrealizedConversionCastOp and (b) they cast back to the previous type; then the surrounding conditional will become if (b || current.getResultTypes() == op.getInputs().getTypes()) and the isSink below will be simplified to users.empty() || a. |
Patch updated.
To make things more clear, I’ve renamed the “sink” nodes to “exit” nodes.
Also the iteration on the users for the DAG traversal has been merged so that only one sweep is needed.
For the sake of completeness, I’ve also made the match to fail when there is a mismatch of input and output arguments among the casts (before, it was detected as a live cast but it could be accepted anyway if the types formed a cycle, a behavior that can lead to wrong results). To be honest I can’t imagine a situation where a cast takes multiple operands, but still the corner case is covered for future-proofness. In my opinion the unrealized cast doesn’t even have a reason to take more than one operand, but this is for another patch :)
Nit: llvm::append_range is slightly less verbose and slightly more efficient (it will reserve space).