diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -300,6 +300,45 @@ return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; } +/// Helper function to vectorize the tensor.extract operations. Returns +/// VectorizationStatus::NewOp to signal the vectorization algorithm that it +/// should map the produced operations. This function is meant to be used as a +/// CustomVectorizationHook. +static VectorizationResult +vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp, + const BlockAndValueMapping &bvm) { + tensor::ExtractOp extractOp = dyn_cast(op); + if (!extractOp) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + auto loc = extractOp.getLoc(); + + // Currently only supports extraction with an 1-D index. + if (extractOp.getIndices().size() > 1) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + + auto indexVec = bvm.lookup(extractOp.getIndices()[0]); + // Compute the static loop sizes of the extract op. + auto targetShape = linalgOp.computeStaticLoopSizes(); + + SmallVector gatherIndices; + gatherIndices.push_back(b.create(loc, 0)); + + auto maskConstantOp = b.create( + loc, DenseIntElementsAttr::get( + VectorType::get(targetShape, b.getI1Type()), true)); + + auto resultType = + VectorType::get(targetShape, extractOp.getResult().getType()); + auto passThruConstantOp = + b.create(loc, b.getZeroAttr(resultType)); + + auto gatherOp = b.create( + loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + maskConstantOp, passThruConstantOp); + + return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; +} + /// Emit reduction operations if the shapes of the value to reduce is different /// that the result shape. static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, @@ -515,6 +554,14 @@ }; hooks.push_back(vectorizeIndex); + // 4c. Register CustomVectorizationHook for extractOp. + CustomVectorizationHook vectorizeExtract = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + return vectorizeTensorExtract(b, op, linalgOp, bvm); + }; + hooks.push_back(vectorizeExtract); + // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block->getOperations()) { VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks); @@ -553,10 +600,15 @@ } static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) { + assert(op->getNumRegions() == 1); + auto &opRegion = op->getRegion(0); // All types in the body should be a supported element type for VectorType. - for (Operation &innerOp : op->getRegion(0).front()) { - if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) { - return !VectorType::isValidElementType(type); + for (Operation &innerOp : opRegion.front()) { + if (llvm::any_of(innerOp.getOperands(), [&](Value operand) { + // Ignore the operands that come from outside, since they don't need + // to be vectorized. + return &opRegion == operand.getParentRegion() && + !VectorType::isValidElementType(operand.getType()); })) { return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -161,8 +161,8 @@ if (!llvm::hasSingleElement(r)) return false; for (Operation &op : r.front()) { - if (!(isa(op) || + if (!(isa(op) || OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1134,3 +1134,58 @@ // CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction , %[[MUL]], %[[V2]] // CHECK-DAG: vector.transfer_write %[[MUL]], %[[ARG2]] // CHECK-DAG: vector.transfer_write %[[ADD]], %[[ARG3]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @vectorize_1d_tensor_extract(%arg0: tensor<3xf32>, %arg1: tensor<4x3xi32>) -> tensor<4x7x3x2xf32> { + %0 = linalg.init_tensor [4, 7, 2] : tensor<4x7x2xf32> + %1 = linalg.init_tensor [4, 7, 3, 2] : tensor<4x7x3x2xf32> + %2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%arg1, %0 : tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%1 : tensor<4x7x3x2xf32>) { + ^bb0(%arg2: i32, %arg3: f32, %arg4: f32): + %3 = arith.index_cast %arg2 : i32 to index + %7 = tensor.extract %arg0[%3] : tensor<3xf32> + linalg.yield %7 : f32 + } -> tensor<4x7x3x2xf32> + return %2 : tensor<4x7x3x2xf32> +} +// CHECK-LABEL: func.func @vectorize_1d_tensor_extract +// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<4x3xi32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<4x7x3x2xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]] +// CHECK: %[[INDICES:.*]] = arith.index_cast %[[V0]] : vector<4x7x3x2xi32> to vector<4x7x3x2xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[PASSTHRU]] +// CHECK: vector.transfer_write %[[GATHER]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>) -> tensor<4x7x3x2xf32> { + %0 = linalg.init_tensor [4, 7, 2] : tensor<4x7x2xf32> + %1 = linalg.init_tensor [4, 7, 3, 2] : tensor<4x7x3x2xf32> + %2 = linalg.generic { + indexing_maps = [#map0, #map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%arg1, %arg2, %0 : tensor<4x3xi32>, tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%1 : tensor<4x7x3x2xf32>) { + ^bb0(%arg3: i32, %arg4: i32, %arg5: f32, %arg6: f32): + %3 = arith.index_cast %arg3 : i32 to index + %4 = arith.index_cast %arg4 : i32 to index + %7 = tensor.extract %arg0[%3, %4] : tensor<3x3xf32> + linalg.yield %7 : f32 + } -> tensor<4x7x3x2xf32> + return %2 : tensor<4x7x3x2xf32> +} +// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract +// CHECK: tensor.extract + +// -----