diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -313,10 +313,6 @@ if (!extractOp) return failure(); - // Currently only supports extraction with an 1-D index. - if (extractOp.getIndices().size() != 1) - return failure(); - if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) return failure(); @@ -341,29 +337,99 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); - // Currently only supports extraction with an 1-D index. Checked in the - // tensorExtractVectorizationPrecondition. - assert(extractOp.getIndices().size() == 1); - - auto indexVec = bvm.lookup(extractOp.getIndices()[0]); // Compute the static loop sizes of the extract op. auto targetShape = linalgOp.computeStaticLoopSizes(); - SmallVector gatherIndices; - gatherIndices.push_back(b.create(loc, 0)); - + auto resultType = + VectorType::get(targetShape, extractOp.getResult().getType()); auto maskConstantOp = b.create( loc, DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()), /*value=*/true)); - - auto resultType = - VectorType::get(targetShape, extractOp.getResult().getType()); auto passThruConstantOp = b.create(loc, b.getZeroAttr(resultType)); + SmallVector gatherIndices; + int numIndices = extractOp.getIndices().size(); + + // The 1-D index case is straightforward and processed separately. + if (numIndices == 1) { + auto indexVec = bvm.lookup(extractOp.getIndices()[0]); + + gatherIndices.push_back(b.create(loc, 0)); + auto gatherOp = b.create( + loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + maskConstantOp, passThruConstantOp); + return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + } + + // Sizes of slices of the input tensor. For a rank-3 tensor, this vector will + // contain the sizes of: + // [ rank-2 slice, rank-1 slice, element size] + // where each slice is taken starting from the fastest moving (i.e. inner + // most) dimension. For example, for tensor<45 x 80 x 15 x f32>, it will look + // like this: + // [ 80 x 15 x 1, 15 x 1, 1] + SmallVector sliceSizes; + sliceSizes.resize(numIndices); + sliceSizes[numIndices - 1] = b.create(loc, 1); + + // Every entry in this vector corresponds to an index passed to + // `tensor.extract`. It is additionally: + // * multiplied by the corresponding slice size from `sliceSizes`, + // * vectorised, i.e. broadcast to match the output size. + // We will use it calculate the offsets for `vector.gather`. + SmallVector indicesVec; + indicesVec.resize(numIndices); + + // `indexVecTy` is basically the output type. + auto indexVecType = VectorType::get(targetShape, b.getIndexType()); + + // 1. Compute the indices corresponding to the trailing dim. + // All indices apart from the trailing one have already been broadcast by + // `vectorizeLinalgIndex` to match the output size. Time to broadcast the + // trailing one. + indicesVec[numIndices - 1] = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[numIndices - 1])); + + // 2. Compute indices corresponding to indices from 0 to ... (last - 1) + for (int dim = numIndices - 2; dim >= 0; dim--) { + // Calculate the size of each slice. This will be used to multiply indices + // to get the right offsets. For example, for tensor<45 x 80 x 15 x f32>, + // we should have this + // 1 for dim = 0 + // 15 * 1 for dim = 1 + // 80 * 15 * 1 for dim = 2 + auto dimSize = b.getIndexAttr( + op->getOperandTypes()[0].cast().getDimSize(dim + 1)); + sliceSizes[dim] = b.create( + loc, sliceSizes[dim + 1], b.create(loc, dimSize)); + + // Take the original index value corresponding to dimension `dim` and scale + // it using `sliceSizes`. + auto dimSizeBcast = + b.create(loc, indexVecType, sliceSizes[dim]); + auto originalIndexBcast = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[dim])); + indicesVec[dim] = + b.create(loc, dimSizeBcast, originalIndexBcast); + } + + int lastIdx = extractOp.getIndices().size() - 1; + + // 3. Calculate the overall offset to be passed to `vector.gather`. + auto offset = b.create(loc, indicesVec[lastIdx], + indicesVec[lastIdx - 1]); + for (int dim = lastIdx - 2; dim >= 0; dim--) { + offset = b.create(loc, offset, indicesVec[dim]); + } + + for (unsigned i = 0; i < extractOp.getIndices().size(); i++) { + gatherIndices.push_back(b.create(loc, 0)); + } + auto gatherOp = b.create( - loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + loc, resultType, extractOp.getTensor(), gatherIndices, offset, maskConstantOp, passThruConstantOp); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1500,7 +1500,7 @@ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { +func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { %2 = linalg.generic { indexing_maps = [#map0, #map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"] @@ -1513,8 +1513,28 @@ } -> tensor<4x7x3x2xf32> return %2 : tensor<4x7x3x2xf32> } -// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract -// CHECK: tensor.extract +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> +// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG3:.*]]: tensor<4x7x2xf32> +// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 +// CHECK: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK: %[[CST_1:.*]] = arith.constant dense : vector<4x7x3x2xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> +// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation):