Fix the insert point when expanding affine apply and handle cases with
symbols. Also add missing precondition to dynamic shape vectorization.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
mlir/test/Dialect/Linalg/vectorization.mlir | ||
---|---|---|
320 | Missing d0 + s0 expansion :). It would be good to add more tests if the bug is not specific to symbols. |
mlir/test/Dialect/Linalg/vectorization.mlir | ||
---|---|---|
320 |
I think that it's actually there. INPUT #map0 = affine_map<(d0) -> (d0)> func.func @vectorize_affine_apply(%arg0: tensor<32xf32>, %arg3: index) -> tensor<32xi32> { %0 = tensor.empty() : tensor<32xi32> %1 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} ins(%arg0 : tensor<32xf32>) outs(%0 : tensor<32xi32>) { ^bb0(%arg1: f32, %arg2: i32): %2 = linalg.index 0 : index %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3) %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3] %3 = arith.index_cast %13 : index to i32 linalg.yield %3 : i32 } -> tensor<32xi32> return %1 : tensor<32xi32> } OUTPUT (with comments where the expansions happens): module { func.func @vectorize_affine_apply_3(%arg0: tensor<32xf32>, %arg1: index) -> tensor<32xi32> { %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : vector<32xindex> %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<32xi32> %1 = vector.broadcast %arg1 : index to vector<32xindex> // EXPANSION 1: %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg3) %2 = arith.addi %1, %cst : vector<32xindex> %3 = vector.broadcast %arg1 : index to vector<32xindex> // EXPANSION 2: %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3] %4 = arith.addi %2, %3 : vector<32xindex> %5 = arith.index_cast %4 : vector<32xindex> to vector<32xi32> %6 = vector.transfer_write %5, %0[%c0] {in_bounds = [true]} : vector<32xi32>, tensor<32xi32> return %6 : tensor<32xi32> } transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!pdl.operation) -> !pdl.operation %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 {vectorize_nd_extract} } }
Added in https://reviews.llvm.org/D143429. Here's one other case that's not yet tested, though present in your example in https://reviews.llvm.org/D142371: (d0, d1, d2) -> (d1 + d2 + d3). That also seems to work 🤔 . |
Missing d0 + s0 expansion :).
It would be good to add more tests if the bug is not specific to symbols.
I would also add cases for something like (d0...) -> (0) or even (d0...) -> ()