diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -313,10 +313,6 @@ if (!extractOp) return failure(); - // Currently only supports extraction with an 1-D index. - if (extractOp.getIndices().size() != 1) - return failure(); - if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) return failure(); @@ -329,6 +325,51 @@ return success(); } +/// Calculates the offsets (`$index_vec`) for `vector.gather` operations +/// generated from `tensor.extract`. The offset is calculated as follows +/// (example using scalar values): +/// +/// offset = extractOp.indices[0] +/// for (i = 1; i < numIndices; i++) +/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i]; +/// +/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: +/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 +static Value +calculateOffsetForGLoad(OpBuilder &b, tensor::ExtractOp extractOp, + const BlockAndValueMapping &bvm, + const SmallVector &targetShape) { + // The vector of indices for GatherOp should be shaped as the output vector + auto indexVecType = VectorType::get(targetShape, b.getIndexType()); + auto loc = extractOp.getLoc(); + + Value offset = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[0])); + + const size_t numIndices = extractOp.getIndices().size(); + for (size_t i = 1; i < numIndices; i++) { + auto dimSizeBcast = b.create( + loc, indexVecType, + b.create( + loc, + extractOp->getOperandTypes()[0].cast().getDimSize(i))); + offset = b.create(loc, offset, dimSizeBcast); + + auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]); + if (i == numIndices - 1) { + // We only need an additional broadcast for the trailing index. All other + // indices have already been broadcast by `vectorizeLinalgIndex` to match + // the output size. + originalIndexBcast = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[i])); + } + + offset = b.create(loc, originalIndexBcast, offset); + } + + return offset; +} + /// Helper function to vectorize the tensor.extract operations. Returns /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a @@ -341,29 +382,28 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); - // Currently only supports extraction with an 1-D index. Checked in the - // tensorExtractVectorizationPrecondition. - assert(extractOp.getIndices().size() == 1); - - auto indexVec = bvm.lookup(extractOp.getIndices()[0]); // Compute the static loop sizes of the extract op. auto targetShape = linalgOp.computeStaticLoopSizes(); - SmallVector gatherIndices; - gatherIndices.push_back(b.create(loc, 0)); - + auto resultType = + VectorType::get(targetShape, extractOp.getResult().getType()); auto maskConstantOp = b.create( loc, DenseIntElementsAttr::get(VectorType::get(targetShape, b.getI1Type()), /*value=*/true)); - - auto resultType = - VectorType::get(targetShape, extractOp.getResult().getType()); auto passThruConstantOp = b.create(loc, b.getZeroAttr(resultType)); + // Base indices are currently set to 0. We will need to re-visit if more + // generic scenarios are to be supported. + SmallVector baseIndices(extractOp.getIndices().size(), + b.create(loc, 0)); + + Value offset = calculateOffsetForGLoad(b, extractOp, bvm, targetShape); + + // Generate the gather load auto gatherOp = b.create( - loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + loc, resultType, extractOp.getTensor(), baseIndices, offset, maskConstantOp, passThruConstantOp); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1500,7 +1500,7 @@ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { +func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { %2 = linalg.generic { indexing_maps = [#map0, #map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"] @@ -1513,8 +1513,28 @@ } -> tensor<4x7x3x2xf32> return %2 : tensor<4x7x3x2xf32> } -// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract -// CHECK: tensor.extract +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> +// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG3:.*]]: tensor<4x7x2xf32> +// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 +// CHECK: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK: %[[CST_1:.*]] = arith.constant dense : vector<4x7x3x2xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> +// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation):