diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -626,23 +626,19 @@ const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { - auto dimSizeBcast = b.create( - loc, indexVecType, + auto dimSize = broadcastIfNeeded( + b, b.create( loc, - extractOp.getTensor().getType().cast().getDimSize(i))); - offset = b.create(loc, offset, dimSizeBcast); - - auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]); - if (i == numIndices - 1) { - // We only need an additional broadcast for the trailing index. All other - // indices have already been broadcast by `vectorizeLinalgIndex` to match - // the output size. - originalIndexBcast = b.create( - loc, indexVecType, bvm.lookup(extractOp.getIndices()[i])); - } + extractOp.getTensor().getType().cast().getDimSize(i)), + indexVecType.getShape()); + + offset = b.create(loc, offset, dimSize); + + auto extractOpIndex = broadcastIfNeeded( + b, bvm.lookup(extractOp.getIndices()[i]), indexVecType.getShape()); - offset = b.create(loc, originalIndexBcast, offset); + offset = b.create(loc, extractOpIndex, offset); } return offset; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1494,10 +1494,83 @@ // ----- +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + %c0 = arith.constant 1 : index + %c1 = arith.constant 2 : index + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg2 : tensor<1x1x3xf32>) { + ^bb0(%arg4: f32): + %3 = linalg.index 2 : index + %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32> + linalg.yield %7 : f32 + } -> tensor<1x1x3xf32> + return %2 : tensor<1x1x3xf32> +} + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32> +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// Magic "5" below comes from (1 * 3 + 2) (1: index into dim 1, 2: index into dim 2) +// CHECK: %[[IDX:.*]] = arith.constant dense<5> : vector<1x1x3xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[IDX]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32> +// CHECK: vector.transfer_write %[[GATHER]] +// CHECK: } + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + } + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + %1 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg2 : tensor<1x1x3xf32>) { + ^bb0(%arg4: f32): + %2 = linalg.index 0 : index + %3 = linalg.index 1 : index + %4 = linalg.index 2 : index + %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32> + linalg.yield %5 : f32 + } -> tensor<1x1x3xf32> + return %1 : tensor<1x1x3xf32> +} + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32> +// CHECK: %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32> +// CHECK: vector.transfer_write %[[GATHER]] + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } +} + +// ----- + #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { +func.func @vectorize_nd_tensor_extract_index_from_tensor(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { %2 = linalg.generic { indexing_maps = [#map0, #map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"] @@ -1510,7 +1583,7 @@ } -> tensor<4x7x3x2xf32> return %2 : tensor<4x7x3x2xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_index_from_tensor // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> // CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32> // CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32>