Page MenuHomePhabricator

[mlir][Linalg] Enhance Linalg fusion on generic op and tensor_reshape op.
ClosedPublic

Authored by hanchung on Aug 20 2020, 12:45 PM.

Details

Summary

The tensor_reshape op was only fusible only if it is a collapsing case. Now we
propagate the op to all the operands so there is a further chance to fuse it
with generic op. The pre-conditions are:

  1. The producer is not an indexed_generic op.
  2. All the shapes of the operands are the same.
  3. All the indexing maps are identity.
  4. All the loops are parallel loops.

It is possible to fuse the ops if the producer is an indexed_generic op. We
still can compute the original indices. E.g., if the reshape op collapses the d0
and d1, we can use DimOp to get the width of d1, and calculate the index
d0 * width + d1. Then replace all the uses with it. However, this pattern is
not implemented in the patch.

Diff Detail

Event Timeline

hanchung created this revision.Aug 20 2020, 12:45 PM
Herald added a project: Restricted Project. · View Herald TranscriptAug 20 2020, 12:45 PM
hanchung requested review of this revision.Aug 20 2020, 12:45 PM
mravishankar requested changes to this revision.Aug 20 2020, 11:32 PM
mravishankar added inline comments.
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
845

You could combine all the checks into a single if with (cond1) || (cond2) || ...

850

I think you can assert !types.empty() here.

855

This could be

llvm::any_of(producer.getIndexingMaps(), [](AffineMap map) { return map.isIdentity(); })
861

This could be

llvm::any_of(producer.iterator_types(), [](Attribute attr) { return attr.cast<StringAttr>().getValue() != getParallelIteratorTypeName(); }))
873

I dont think you need to do this. There is a separate pattern that will fold the constant -> tensor_reshape into a constant

mlir/test/Dialect/Linalg/fusion-tensor.mlir
248

I think it would be better to check the shape, indexing maps, etc here as well cause those are generated by the pattern being applied.

This revision now requires changes to proceed.Aug 20 2020, 11:32 PM
hanchung updated this revision to Diff 287312.Aug 24 2020, 2:10 AM
hanchung marked 5 inline comments as done.

Address comments

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
845

I actually prefer to make them separated because they are independent. I usually put the checks into a single if when they are logically related, eg, check if they are in the range [l, r], I would write if (l <= val && val <= r).

What do you think?

873

Needs to update ReshapeOp::fold and use createOrFold. I fixed it in this patch as well.

mravishankar requested changes to this revision.Aug 24 2020, 12:58 PM
mravishankar added inline comments.
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
840–841

The way this was setup, the isFusible was supposed to be a collection of all checks that tell you if the producer-consumer pair is fusible. If this returns true. Its OK to do this in a future change, but it would be good to retain this structure. This could be refactored to

static bool isFusibleCase1(...)

static LinalgOp fuseCase1(...)

static bool isFusibleCase2(...)

static LinalgOp fuseCase2(...)

static bool isFusible(..) {
   return isFusibleCase1(..) || isFusibleCase2(...)
}

static bool fuse(...) {
  if (isFusibleCase1(..)) {
    return fuseCase1(..);
  }
  if (isFusibleCase2(..)) {
    return fuseCase2(..);
  }
  return nullptr;
}
845

I think it is better to combine these. Makes the code less verbose. COmbining all the checks, and a comment describing what each check is must be clear enough.

864

Maybe we need to add a couple more checks to this.

  1. The producer linalg.generic op has a single user (the linalg.tensor_reshape op). WHen these operations are converted to buffers the reshape ideally just becomes a view modifier. So
%0 = linalg.generic ... : ... -> tensor<typeA>
%1 = linalg.tensor_reshape %0 -> tensor<typeA> to tensor<typeB>

in buffer world would be

linalg.generic %0 .... : ...., memref<typeA>
%1 = linalg.reshape %0 ... : memref<typeA> into memref<typeB>

With a single use in tensor world, there wont be an increases in "memory usage" when converted to buffers as the modified code would just lower to

linalg.generic %0 .... : ...., memref<typeB>

If the generic op had two uses

%0 = linalg.generic ... : ... -> tensor<typeA>
%1 = linalg.tensor_reshape %0 -> tensor<typeA> to tensor<typeB>
%2 = linalg.generic %0 ... : tensor<typeA> ...

Fusion would result in

%0 = linalg.generic ... : ... -> tensor<typeA>
%1 = linalg.generic ... : .... -> tensor<typeB>
%2 = linalg.generic %0 ... : tensor<typeA>

this when converted to buffers

linalg.generic .... %0 : ... memref<typeA>
linalg.generic .... %1 : ... memref<typeB>
linalg.generic %0 ... : memref<typeA>

This has an extra memref<typeA> .

  1. The operating theory has been that is better to convert to "higher"-dimensionality. The test case below is converting the producer op to use a higher dimensionality. Maybe have that requirement explicit, i.e. check that the tensor_reshape result is of higher rank than the tensor_reshape source.
This revision now requires changes to proceed.Aug 24 2020, 12:58 PM
hanchung updated this revision to Diff 288259.Thu, Aug 27, 3:09 AM
hanchung marked 3 inline comments as done.

Address comments

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
845

I tried this in isFusibleCase2, but it looks bad to me. :(

We need to count and map the condition to each other.

I've seen that we should make code descriptive instead of adding comment somewhere. But maybe this is not style to follow here. So I was putting the conditions separately in the beginning.

864

Added the second check.

Regarding the first check, I think in this case we would lose a chance to fuse these two generic op. The reshape op would be propagate and become

tensor_reshape
generic
generic

And then it would fuse into the first generic op. In the end two generic ops have a chance to fuse.

I tested with this case

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
                                            -> tensor<8x33x4xf32> {
  %cst = constant dense<2.000000e+00> : tensor<264x4xf32>
  %0 = linalg.generic
    {args_in = 2 : i64, args_out = 1 : i64,
     indexing_maps = [#map0, #map0, #map0],
     iterator_types = ["parallel", "parallel"]}
    %arg0, %cst {
    ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
      %2 = mulf %arg1, %arg2 : f32
      linalg.yield %2 : f32
    }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32>
  %1 = linalg.tensor_reshape %0 [#map1, #map2] :
    tensor<264x4xf32> into tensor<8x33x4xf32>
  %2 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
     indexing_maps = [#map3, #map3],
     iterator_types = ["parallel", "parallel", "parallel"]}
    %1 {
    ^bb0(%arg1: f32):  // no predecessors
      %2 = mulf %arg1, %arg1 : f32
      linalg.yield %2 : f32
    }: tensor<8x33x4xf32> -> tensor<8x33x4xf32>

  return %2 : tensor<8x33x4xf32>
}

And it would be fused to

#map0 = affine_map<(d0, d1, d2) -> (d0 * 33 + d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>


module {
  func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>) -> tensor<8x33x4xf32> {
    %cst = constant 2.000000e+00 : f32
    %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} %arg0 {
    ^bb0(%arg1: f32):  // no predecessors
      %1 = mulf %arg1, %cst : f32
      %2 = mulf %1, %1 : f32
      linalg.yield %2 : f32
    }: tensor<264x4xf32> -> tensor<8x33x4xf32>
    return %0 : tensor<8x33x4xf32>
  }
}

Which looks good to me. Even the all of the ops are not fused, won't it result in

linalg.reshape
linalg.generic
linalg.generic

in buffers world?

hanchung added inline comments.Thu, Aug 27, 7:05 AM
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
864

Oh, I was using wrong example.

It should be

from

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>)
                                            -> (tensor<8x33x4xf32>, tensor<264x4xf32>) {
  %cst = constant dense<2.000000e+00> : tensor<264x4xf32>
  %0 = linalg.generic
    {args_in = 2 : i64, args_out = 1 : i64,
     indexing_maps = [#map0, #map0, #map0],
     iterator_types = ["parallel", "parallel"]}
    %arg0, %cst {
    ^bb0(%arg1: f32, %arg2: f32):  // no predecessors
      %2 = mulf %arg1, %arg2 : f32
      linalg.yield %2 : f32
    }: tensor<264x4xf32>, tensor<264x4xf32> -> tensor<264x4xf32>
  %1 = linalg.tensor_reshape %0 [#map1, #map2] :
    tensor<264x4xf32> into tensor<8x33x4xf32>
  %2 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
     indexing_maps = [#map0, #map0],
     iterator_types = ["parallel", "parallel"]}
    %0 {
    ^bb0(%arg1: f32):  // no predecessors
      %2 = mulf %arg1, %arg1 : f32
      linalg.yield %2 : f32
    }: tensor<264x4xf32> -> tensor<264x4xf32>

  return %1, %2 : tensor<8x33x4xf32>, tensor<264x4xf32>
}

to

#map0 = affine_map<(d0, d1, d2) -> (d0 * 33 + d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1) -> (d0, d1)>


module {
  func @generic_op_reshape_consumer_expanding(%arg0: tensor<264x4xf32>) -> (tensor<8x33x4xf32>, tensor<264x4xf32>) {
    %cst = constant 2.000000e+00 : f32
    %0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} %arg0 {
    ^bb0(%arg1: f32):  // no predecessors
      %2 = mulf %arg1, %cst : f32
      linalg.yield %2 : f32
    }: tensor<264x4xf32> -> tensor<8x33x4xf32>
    %1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel"]} %arg0 {
    ^bb0(%arg1: f32):  // no predecessors
      %2 = mulf %arg1, %cst : f32
      %3 = mulf %2, %2 : f32
      linalg.yield %3 : f32
    }: tensor<264x4xf32> -> tensor<264x4xf32>
    return %0, %1 : tensor<8x33x4xf32>, tensor<264x4xf32>
  }
}

I think it would still be

linalg.generic
linalg.generic

in buffers world?

Why would

%0 = linalg.generic ... : ... -> tensor<typeA>
%1 = linalg.tensor_reshape %0 -> tensor<typeA> to tensor<typeB>
%2 = linalg.generic %0 ... : tensor<typeA> ...

become

%0 = linalg.generic ... : ... -> tensor<typeA>
%1 = linalg.generic ... : .... -> tensor<typeB>
%2 = linalg.generic %0 ... : tensor<typeA>

I think the tensor reshape op would be propagated up and eventually either be lowered to liangl.reshape or fuse with the next generic (ie %0)?

mravishankar accepted this revision.Thu, Aug 27, 10:51 AM

Looks good. Few minor comments. Please address before submitting.

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
832–834

Nit: I was hoping you would be able to pick a better name than isFusibleCase1 (I used it cause I couldnt think of one). Same for isFusibleCase2

Maybe make *Case1 -> *WhenCollapsing and *Case2 -> *WhenExpanding?

864

I dont disagree with what you are saying. But I think more experimentation/data is needed here and its better to go incremental instead of solving a general case that might have unintended consequences. If for the current uses cases checking that the tensor_reshape has a single use and only then applying the transformation is safer.

The example you gave is fine and it that case the tensor_reshape has a single use. But it easy to adapt that example to have a case with the tensor_reshape has multiple uses. You are right that this case wouldnt be handled right now. Lets revisit that if we need to?

Regarding your question,

%0 = linalg.generic ... : ... -> tensor<typeA>
%1 = linalg.generic ... : .... -> tensor<typeB>
%2 = linalg.generic %0 ... : tensor<typeA>

I am not refering to the linalg.reshape that exists above the snippet. We can discuss this offline. FWIW, if your use case currently has multiple uses of the reshape op, then its OK to not add that check.

888

Maybe a matter of preference, but this looks clean to me :)

901

Nit: This should be consumer.getSrcType().getRank() < consumer.getResultType().getRank(). == is illegal by op definition.

This revision is now accepted and ready to land.Thu, Aug 27, 10:51 AM
hanchung updated this revision to Diff 288534.Thu, Aug 27, 11:13 PM
hanchung marked 5 inline comments as done.

Address comments

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
832–834

Oh I miss this. Let's use isCollapsingAndFusible, fuseCollapsingCase, etc.

864

Thanks for the detail explanation. I agree with you, let's do it incremental.