Page MenuHomePhabricator

[mlir][vector] Fix CombineContractBroadcast for invalid cases
Needs ReviewPublic

Authored by antiagainst on Apr 12 2022, 5:41 AM.

Details

Summary

There are cases that after combining vector.contract and
vector.broadcast, the generated vector.contract's operands
do not have parallel or reduction pair in LHS and RHS at all.
Such cases may fail vector.contract verification. Explicitly
guard against such cases.

Diff Detail

Event Timeline

antiagainst created this revision.Apr 12 2022, 5:41 AM
Herald added a project: Restricted Project. · View Herald Transcript
antiagainst requested review of this revision.Apr 12 2022, 5:41 AM
antiagainst retitled this revision from [mlir][vector] Fix CombineContractBroadcast for all parallel cases to [mlir][vector] Fix CombineContractBroadcast for invalid cases.Apr 12 2022, 6:30 AM
antiagainst edited the summary of this revision. (Show Details)
mravishankar resigned from this revision.Apr 12 2022, 1:46 PM
ThomasRaoux added inline comments.Apr 14 2022, 7:56 PM
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1197

this is only correct for float and add kind of contract?

Address comments

antiagainst marked an inline comment as done.Apr 15 2022, 6:24 AM
antiagainst added inline comments.
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1197

Ah, great catch! Fixed.

antiagainst marked an inline comment as done.

Update to fix another invalid case

antiagainst edited the summary of this revision. (Show Details)Apr 25 2022, 5:11 AM
ThomasRaoux added inline comments.May 10 2022, 12:48 PM
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1205–1215

I don't think we can really remove the reduction dimension unless they are unit dimension?
For instance:

func.func @contract_broadcast_fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> {
  %bcast_a = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
  %bcast_b = vector.broadcast %b : vector<4xf32> to vector<1x2x4xf32>
  %contract = vector.contract {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>],
    iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>
  } %bcast_a, %bcast_b, %c : vector<1x2x4xf32>, vector<1x2x4xf32> into vector<4xf32>
  return %contract: vector<4xf32>
}

shouldn't become:

func.func @contract_broadcast_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
  %0 = vector.fma %arg0, %arg1, %arg2 : vector<4xf32>
  return %0 : vector<4xf32>
}
1208–1213

I find it a bit odd that we have lowering to fma in the pattern that tries to combine contract and broadcast, why can't this be a separate pattern that would be called by user if it wants this kind of lowering.

antiagainst edited the summary of this revision. (Show Details)

Remove vector.fma generation

antiagainst marked 2 inline comments as done.May 23 2022, 12:15 PM
antiagainst added inline comments.
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1205–1215

You are right.. This is problematic. Removed now.

1208–1213

Makes sense. Removed this part now.

antiagainst edited the summary of this revision. (Show Details)May 23 2022, 12:16 PM
antiagainst marked 2 inline comments as done.

Fix func.func

ThomasRaoux added inline comments.May 23 2022, 4:04 PM
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1190–1199

nit: can we reverse the condition to do early exit like the rest of the function?

1192

I don't understand why we need the second condition? As long as we have a reduction being used it should be okay to generate the new contraction op?

Address comments

antiagainst marked 2 inline comments as done.Jun 28 2022, 10:50 AM
antiagainst added inline comments.
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
1190–1199

Done.

1192

It's required by contraction op verification. For full reduction, valid contraction op expects scalar accumulator/result https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L674-L676. So we need to make sure we are checking that here. The second test case checks this.

would be nice to refactor to directly call into a new verifyContractionOpImpl helper in the same way that we have tensor::verifyInsertSliceOp