This is an archive of the discontinued LLVM Phabricator instance.

[mlir][scf] Add scf.for + tensor.cast canonicalization pattern
ClosedPublic

Authored by nicolasvasilache on Apr 16 2021, 9:21 AM.

Details

Summary

Fold scf.for iter_arg/result pairs that go through incoming/ougoing
a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:

%0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
%1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
   -> (tensor<?x?xf32>) {
  %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
  scf.yield %2 : tensor<?x?xf32>
}
%2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
use_of(%2)

folds into:

%0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
    -> (tensor<32x1024xf32>) {
  %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
  %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
  %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
  scf.yield %4 : tensor<32x1024xf32>
}
use_of(%0)

Diff Detail

Event Timeline

nicolasvasilache requested review of this revision.Apr 16 2021, 9:21 AM
Herald added a project: Restricted Project. · View Herald TranscriptApr 16 2021, 9:21 AM
ftynse accepted this revision.Apr 16 2021, 9:37 AM
ftynse added a subscriber: ftynse.

Not sure which form is more canonical though. I would suppose the one where less casts are performed, i.e., the original one. I see value in this when it propagates static tensor shape to the loop result, but not necessarily in the other direction.

mlir/lib/Dialect/SCF/SCF.cpp
607

Nit: / -> .

619–620

Can't this just be newForOp.getRegionIterArgs()[operand.getOperandNumber()] without the triple indiction?

636

Conversion patterns don't like setOperand. Recreate the op instead.

mlir/test/Dialect/SCF/canonicalize.mlir
613

Please add a newline

This revision is now accepted and ready to land.Apr 16 2021, 9:37 AM
ThomasRaoux accepted this revision.Apr 16 2021, 9:40 AM

Looks good. It would be nice to have a more generic way to propagate things in and out forOp regions as this is a problem I ran into few times but I don't have a simple solution.

mlir/lib/Dialect/SCF/SCF.cpp
607

typo

mlir/test/Dialect/SCF/canonicalize.mlir
612

nit: add a newline at end of file

nicolasvasilache marked 5 inline comments as done.Apr 16 2021, 9:42 AM
nicolasvasilache marked an inline comment as done.
mlir/lib/Dialect/SCF/SCF.cpp
619–620

No because I also need to take into account the control operands which I view as leaking an implementation detail.

Addre review comments.

This revision was landed with ongoing or failed builds.Apr 16 2021, 9:55 AM
This revision was automatically updated to reflect the committed changes.