Fold scf.for iter_arg/result pairs that go through incoming/ougoing
a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
%0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
%1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
-> (tensor<?x?xf32>) {
%2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
scf.yield %2 : tensor<?x?xf32>
}
%2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
use_of(%2)folds into:
%0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
-> (tensor<32x1024xf32>) {
%2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
%3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
scf.yield %4 : tensor<32x1024xf32>
}
use_of(%0)
Nit: / -> .