diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -313,10 +313,6 @@ if (!extractOp) return failure(); - // Currently only supports extraction with an 1-D index. - if (extractOp.getIndices().size() != 1) - return failure(); - if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) return failure(); @@ -341,29 +337,57 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); - // Currently only supports extraction with an 1-D index. Checked in the - // tensorExtractVectorizationPrecondition. - assert(extractOp.getIndices().size() == 1); - - auto indexVec = bvm.lookup(extractOp.getIndices()[0]); // Compute the static loop sizes of the extract op. auto targetShape = linalgOp.computeStaticLoopSizes(); - SmallVector gatherIndices; - gatherIndices.push_back(b.create(loc, 0)); - + auto resultType = + VectorType::get(targetShape, extractOp.getResult().getType()); auto maskConstantOp = b.create( loc, DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()), /*value=*/true)); - - auto resultType = - VectorType::get(targetShape, extractOp.getResult().getType()); auto passThruConstantOp = b.create(loc, b.getZeroAttr(resultType)); + const size_t numIndices = extractOp.getIndices().size(); + SmallVector gatherIndices(numIndices, + b.create(loc, 0)); + + // The 1-D index case is straightforward and processed separately. + if (numIndices == 1) { + auto indexVec = bvm.lookup(extractOp.getIndices()[0]); + + auto gatherOp = b.create( + loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + maskConstantOp, passThruConstantOp); + return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + } + + // `indexVecTy` is basically the output type. + auto indexVecType = VectorType::get(targetShape, b.getIndexType()); + + // Calculate the offsets for the gather load: + // offset = indices[0] + // for (i = 1; i < numIndices; i++) + // offset = dimSize[i] * offset + indices[i]; + Value offset = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[0])); + + for (size_t i = 1; i < numIndices; i++) { + auto dimSizeBcast = b.create( + loc, indexVecType, + b.create( + loc, op->getOperandTypes()[0].cast().getDimSize(i))); + offset = b.create(loc, offset, dimSizeBcast); + + auto originalIndexBcast = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[i])); + offset = b.create(loc, originalIndexBcast, offset); + } + + // Generate the gather load auto gatherOp = b.create( - loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + loc, resultType, extractOp.getTensor(), gatherIndices, offset, maskConstantOp, passThruConstantOp); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1500,7 +1500,7 @@ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { +func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { %2 = linalg.generic { indexing_maps = [#map0, #map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"] @@ -1513,8 +1513,28 @@ } -> tensor<4x7x3x2xf32> return %2 : tensor<4x7x3x2xf32> } -// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract -// CHECK: tensor.extract +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> +// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG3:.*]]: tensor<4x7x2xf32> +// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 +// CHECK: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK: %[[CST_1:.*]] = arith.constant dense : vector<4x7x3x2xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> +// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation):