This is an archive of the discontinued LLVM Phabricator instance.

[mlir][tensor] Canonicalize scalar `tensor.insert_slice(tensor.insert, _)` to `tensor.insert`
Needs ReviewPublic

Authored by christopherbate on Aug 23 2023, 10:38 AM.



Canonicalizes the pattern

%0 = tensor.insert %scalar into %t1[...] : (scalar tensor type)
%1 = tensor.insert_slice %0 into %t2[<indices>]


%1 = tensor.insert %scalar into %t2[<indices>]

This has a side effect on bufferization: prior to change canonicalization, the
IR below would result in two allocations (even with empty tensor elimination),
whereas afterwards it results in just the creation of two ops:

func.func @func(%arg0 : f32, %arg1: f32, %arg2: tensor<4xf32>) -> (tensor<4xf32>) {
  %c0 = arith.constant 0 : index
  %c2 = arith.constant 2 : index
  %c3 = arith.constant 3 : index
  %e1 = tensor.empty() : tensor<1xf32>
  %e2 = tensor.empty() : tensor<f32>
  %0 =  tensor.insert %arg0 into %e1[%c0] : tensor<1xf32>
  %1 =  tensor.insert %arg1 into %e2[] : tensor<f32>
  %2 = tensor.insert_slice %0 into %arg2[%c2][1][1] : tensor<1xf32> into tensor<4xf32>
  %3 = tensor.insert_slice %1 into %2[%c3][1][1] : tensor<f32> into tensor<4xf32>
  return %3 : tensor<4xf32>

Diff Detail

Event Timeline

Herald added a project: Restricted Project. · View Herald TranscriptAug 23 2023, 10:38 AM
christopherbate requested review of this revision.Aug 23 2023, 10:38 AM

Seems to me you either need to check that the tensor.insert indices are all 0 or that you need to combine them together no ?

Additionally, have you looked at mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp ?
It seems you could add a new pattern there ?

Note that those patterns are not canonicalizations because they are not always desirable, so we made them opt-in.

christopherbate added inline comments.Sep 8 2023, 3:12 PM

The producer (tensor.insert) is producing a scalar. That's what line 2500 is checking. The consumer (tensor.insert_slice) is then inserting that scalar into some position. So I don't think we need to check the indices of tensor.insert? Or did I get something wrong here?