There are cases that after combining vector.contract and
vector.broadcast, the generated vector.contract's operands
do not have parallel or reduction pair in LHS and RHS at all.
Such cases may fail vector.contract verification. Explicitly
guard against such cases.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | ||
---|---|---|
1058 | this is only correct for float and add kind of contract? |
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | ||
---|---|---|
1058 | Ah, great catch! Fixed. |
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | ||
---|---|---|
1066–1076 | I don't think we can really remove the reduction dimension unless they are unit dimension? func.func @contract_broadcast_fma(%a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>) -> vector<4xf32> { %bcast_a = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> %bcast_b = vector.broadcast %b : vector<4xf32> to vector<1x2x4xf32> %contract = vector.contract { indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add> } %bcast_a, %bcast_b, %c : vector<1x2x4xf32>, vector<1x2x4xf32> into vector<4xf32> return %contract: vector<4xf32> } shouldn't become: func.func @contract_broadcast_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { %0 = vector.fma %arg0, %arg1, %arg2 : vector<4xf32> return %0 : vector<4xf32> } | |
1069–1074 | I find it a bit odd that we have lowering to fma in the pattern that tries to combine contract and broadcast, why can't this be a separate pattern that would be called by user if it wants this kind of lowering. |
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | ||
---|---|---|
1051–1060 | Done. | |
1053 | It's required by contraction op verification. For full reduction, valid contraction op expects scalar accumulator/result https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Vector/IR/VectorOps.cpp#L674-L676. So we need to make sure we are checking that here. The second test case checks this. |
would be nice to refactor to directly call into a new verifyContractionOpImpl helper in the same way that we have tensor::verifyInsertSliceOp
I don't understand why we need the second condition? As long as we have a reduction being used it should be okay to generate the new contraction op?