diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -232,6 +232,12 @@ return Value(); } +// Custom vectorization precondition function type. This is intented to be used +// with CustomVectorizationHook. Returns success if the correpsonding custom +// hook can vectorize the op. +using CustomVectorizationPrecondition = + std::function<LogicalResult(Operation *)>; + // Custom vectorization function type. Produce a vector form of Operation* // assuming all its vectorized operands are already in the BlockAndValueMapping. // Return nullptr if the Operation cannot be vectorized. @@ -300,6 +306,68 @@ return VectorizationResult{VectorizationStatus::NewOp, transposeOp}; } +/// Helper function to check if the tensor.extract can be vectorized by the +/// custom hook vectorizeTensorExtract. +static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) { + tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op); + if (!extractOp) + return failure(); + + // Currently only supports extraction with an 1-D index. + if (extractOp.getIndices().size() != 1) + return failure(); + + if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) + return failure(); + + if (llvm::any_of(extractOp->getResultTypes(), [](Type type) { + return !VectorType::isValidElementType(type); + })) { + return failure(); + } + + return success(); +} + +/// Helper function to vectorize the tensor.extract operations. Returns +/// VectorizationStatus::NewOp to signal the vectorization algorithm that it +/// should map the produced operations. This function is meant to be used as a +/// CustomVectorizationHook. +static VectorizationResult +vectorizeTensorExtract(OpBuilder &b, Operation *op, LinalgOp linalgOp, + const BlockAndValueMapping &bvm) { + tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op); + if (!extractOp) + return VectorizationResult{VectorizationStatus::Failure, nullptr}; + auto loc = extractOp.getLoc(); + + // Currently only supports extraction with an 1-D index. Checked in the + // tensorExtractVectorizationPrecondition. + assert(extractOp.getIndices().size() == 1); + + auto indexVec = bvm.lookup(extractOp.getIndices()[0]); + // Compute the static loop sizes of the extract op. + auto targetShape = linalgOp.computeStaticLoopSizes(); + + SmallVector<Value> gatherIndices; + gatherIndices.push_back(b.create<arith::ConstantIndexOp>(loc, 0)); + + auto maskConstantOp = b.create<arith::ConstantOp>( + loc, DenseIntElementsAttr::get( + VectorType::get(targetShape, b.getI1Type()), true)); + + auto resultType = + VectorType::get(targetShape, extractOp.getResult().getType()); + auto passThruConstantOp = + b.create<arith::ConstantOp>(loc, b.getZeroAttr(resultType)); + + auto gatherOp = b.create<vector::GatherOp>( + loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + maskConstantOp, passThruConstantOp); + + return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; +} + /// Emit reduction operations if the shapes of the value to reduce is different /// that the result shape. static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, @@ -515,6 +583,14 @@ }; hooks.push_back(vectorizeIndex); + // 4c. Register CustomVectorizationHook for extractOp. + CustomVectorizationHook vectorizeExtract = + [&](Operation *op, + const BlockAndValueMapping &bvm) -> VectorizationResult { + return vectorizeTensorExtract(b, op, linalgOp, bvm); + }; + hooks.push_back(vectorizeExtract); + // 5. Iteratively call `vectorizeOneOp` to each op in the slice. for (Operation &op : block->getOperations()) { VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks); @@ -552,9 +628,20 @@ return success(); } -static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) { +static LogicalResult vectorizeStaticLinalgOpPrecondition( + linalg::LinalgOp op, + ArrayRef<CustomVectorizationPrecondition> customPreconditions) { + // All types in the body should be a supported element type for VectorType. for (Operation &innerOp : op->getRegion(0).front()) { + // Check if any custom hook can vectorize the inner op. + if (llvm::any_of( + customPreconditions, + [&](const CustomVectorizationPrecondition &customPrecondition) { + return succeeded(customPrecondition(&innerOp)); + })) { + continue; + } if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) { return !VectorType::isValidElementType(type); })) { @@ -566,16 +653,8 @@ return failure(); } } - if (isElementwise(op)) { - // Some operations in the body cannot be vectorized. - for (Operation &payloadOp : *op.getBlock()) { - if (isa<tensor::ExtractOp>(payloadOp)) { - LDBG("precondition failed: `tensor.extract` not vectorizable"); - return failure(); - } - } + if (isElementwise(op)) return success(); - } // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to // reverse-engineer... @@ -601,7 +680,13 @@ LDBG("precondition failed: dynamic shape"); return failure(); } - return vectorizeStaticLinalgOpPrecondition(linalgOp); + + SmallVector<CustomVectorizationPrecondition> customPreconditions; + + // Register CustomVectorizationPrecondition for extractOp. + customPreconditions.push_back(tensorExtractVectorizationPrecondition); + + return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions); } LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -161,8 +161,8 @@ if (!llvm::hasSingleElement(r)) return false; for (Operation &op : r.front()) { - if (!(isa<arith::ConstantOp, func::ConstantOp, linalg::YieldOp, - linalg::IndexOp>(op) || + if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp, + linalg::YieldOp, linalg::IndexOp>(op) || OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1457,3 +1457,68 @@ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } } + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @vectorize_1d_tensor_extract(%arg0: tensor<3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x7x2xf32>, %arg3: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { + %2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%arg1, %arg2 : tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg3 : tensor<4x7x3x2xf32>) { + ^bb0(%arg4: i32, %arg5: f32, %arg6: f32): + %3 = arith.index_cast %arg4 : i32 to index + %7 = tensor.extract %arg0[%3] : tensor<3xf32> + linalg.yield %7 : f32 + } -> tensor<4x7x3x2xf32> + return %2 : tensor<4x7x3x2xf32> +} +// CHECK-LABEL: func.func @vectorize_1d_tensor_extract +// CHECK-SAME: %[[ARG0:.*]]: tensor<3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<4x3xi32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]] +// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] +// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST]] +// CHECK: %[[INDICES:.*]] = vector.transpose %[[BROADCAST]] +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]]] [%[[INDICES]]], %[[MASK]], %[[PASSTHRU]] +// CHECK: vector.transfer_write %[[GATHER]] + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 +} + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { + %2 = linalg.generic { + indexing_maps = [#map0, #map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%arg1, %arg2, %arg3 : tensor<4x3xi32>, tensor<4x3xi32>, tensor<4x7x2xf32>) outs(%arg4 : tensor<4x7x3x2xf32>) { + ^bb0(%arg5: i32, %arg6: i32, %arg7: f32, %arg8: f32): + %3 = arith.index_cast %arg5 : i32 to index + %4 = arith.index_cast %arg6 : i32 to index + %7 = tensor.extract %arg0[%3, %4] : tensor<3x3xf32> + linalg.yield %7 : f32 + } -> tensor<4x7x3x2xf32> + return %2 : tensor<4x7x3x2xf32> +} +// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract +// CHECK: tensor.extract + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 +}