We should only fold tensor.casts that provide some new static information about
shapes, instead of looking for a symmetric pattern cast(for(cast)).
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
This is in line with how casts are folded in other parts of the system so LGTM.
Please add a minimal test before landing!
@pifon I apologize, this approval was a big oversight on my part, the change is incorrect: scf.for provides no guarantee that the iter_arg type remains the same type dynamically.
This is why the original pattern was looking for pairs of casts.
We need to revert this.
The use case that exposed the bug will be fixed independently indeed.
Here is a problematic example:
%1 = tensor.cast %0: !static_tensor to !dynamic_tensor %2 = scf.for ... (%iter = %1) -> { %3 = take_a_random_slice(%iter) : !dynamic_tensor -> !dynamic_tensor scf.yield %3 : !dynamic_tensor } -> !dynamic_tensor %3 = do_something(%2) : !dynamic_tensor
take_a_random_slice can return a slice of any size and casting it to !static_tensor is only valid when the last yielded tensor has the expected dynamic size at runtime.
This easily propagates miscompiles that result in out of bounds accesses way later in the program.
This PR makes an assumption on scf.for that the dynamic type remains constant, this is an incorrect assumption.
Does not that mean that in this example we yield values of different sizes on every iteration? I thought that type(iter_arg), type(result) and the type in scf.yield should all be compatible.
Correct
I thought that type(iter_arg), type(result) and the type in scf.yield should all be compatible.
The static types are indeed compatible, but there is no guarantee about the dynamic value of the ?.
I think @springerm ran into similar assumption errors in the past.
This folding is valid only if the loop is "dynamic-shape-preserving". In your example, this cast may fail at runtime depending on what @do is doing:
%4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
I submitted a dim(iter_arg) folding some time ago and then noticed that a blanket folding (without dynamic shape analysis) is incorrect. This is how we fixed it: https://reviews.llvm.org/D109430. The code contains a TypeSwitch for ops for which we know that they conserve the dynamic shape. We can do a bit better these days: We know that destination style ops preserve the dynamic shape, so we could query that interface. But it won't work with CallOps, because those are not destination style ops.