diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1091,6 +1091,7 @@ let arguments = (ins PDL_Operation:$target, UnitAttr:$vectorize_padding, + UnitAttr:$vectorize_nd_extract, UnitAttr:$disable_multi_reduction_to_contract_patterns, UnitAttr:$disable_transfer_permutation_map_lowering_patterns); let results = (outs PDL_Operation:$transformed); @@ -1098,7 +1099,9 @@ let assemblyFormat = "$target attr-dict"; let builders = [ - OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)> + OpBuilder<(ins "Value":$target, + CArg<"bool", "false">:$vectorizePadding, + CArg<"bool", "false">:$vectorizeNDExtract)>, ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -345,7 +345,8 @@ const LinalgPromotionOptions &options); /// Emit a suitable vector form for a Linalg op with fully static shape. -LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp); +LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp, + bool vectorizeNDExtract = false); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); @@ -371,7 +372,8 @@ LinalgPromotionOptions options); /// Return success if the operation can be vectorized. -LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp); +LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, + bool vectorizeNDExtract = false); //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. @@ -865,6 +867,9 @@ void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); +void populateExtractOpVectorizationPatterns(RewritePatternSet &patterns, + PatternBenefit baseBenefit = 1); + /// Match and rewrite for the pattern: /// ``` /// %alloc = ... diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1781,12 +1781,17 @@ //===----------------------------------------------------------------------===// void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, - Value target, bool vectorizePadding) { + Value target, bool vectorizePadding, + bool vectorizeExtract) { result.addOperands(target); if (vectorizePadding) { result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), builder.getUnitAttr()); } + if (vectorizeExtract) { + result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name), + builder.getUnitAttr()); + } result.addTypes(pdl::OperationType::get(builder.getContext())); } @@ -1794,15 +1799,22 @@ /// This is an helper only to call vectorize via a pattern inside of /// VectorizeOp::applyToOne. struct VectorizationPattern : public RewritePattern { - explicit VectorizationPattern(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + explicit VectorizationPattern(MLIRContext *context, + bool vectorizeExtract = false) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + vectorizeNDExtract(vectorizeExtract) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return rewriter.notifyMatchFailure(op, "expected Linalg Op"); - return vectorize(rewriter, linalgOp); + return vectorize(rewriter, linalgOp, vectorizeNDExtract); } + +private: + /// Controls whether to vectorize `tensor.extract` when the input tensor is + /// rank >= 2. + bool vectorizeNDExtract = false; }; } // namespace @@ -1818,7 +1830,7 @@ MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx, getVectorizeNdExtract()); if (!getDisableTransferPermutationMapLoweringPatterns()) vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -242,7 +242,7 @@ // with CustomVectorizationHook. Returns success if the corresponding custom // hook can vectorize the op. using CustomVectorizationPrecondition = - std::function; + std::function; // Custom vectorization function type. Produce a vector form of Operation* // assuming all its vectorized operands are already in the BlockAndValueMapping. @@ -314,13 +314,13 @@ /// Helper function to check if the tensor.extract can be vectorized by the /// custom hook vectorizeTensorExtract. -static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) { +static LogicalResult +tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) return failure(); - // Currently only supports extraction with an 1-D index. - if (extractOp.getIndices().size() != 1) + if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract) return failure(); if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) @@ -335,6 +335,51 @@ return success(); } +/// Calculates the offsets (`$index_vec`) for `vector.gather` operations +/// generated from `tensor.extract`. The offset is calculated as follows +/// (example using scalar values): +/// +/// offset = extractOp.indices[0] +/// for (i = 1; i < numIndices; i++) +/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i]; +/// +/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: +/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 +static Value +calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp, + const BlockAndValueMapping &bvm, + const SmallVectorImpl &targetShape) { + // The vector of indices for GatherOp should be shaped as the output vector + auto indexVecType = VectorType::get(targetShape, b.getIndexType()); + auto loc = extractOp.getLoc(); + + Value offset = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[0])); + + const size_t numIndices = extractOp.getIndices().size(); + for (size_t i = 1; i < numIndices; i++) { + auto dimSizeBcast = b.create( + loc, indexVecType, + b.create( + loc, + extractOp.getTensor().getType().cast().getDimSize(i))); + offset = b.create(loc, offset, dimSizeBcast); + + auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]); + if (i == numIndices - 1) { + // We only need an additional broadcast for the trailing index. All other + // indices have already been broadcast by `vectorizeLinalgIndex` to match + // the output size. + originalIndexBcast = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[i])); + } + + offset = b.create(loc, originalIndexBcast, offset); + } + + return offset; +} + /// Helper function to vectorize the tensor.extract operations. Returns /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a @@ -347,29 +392,29 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); - // Currently only supports extraction with an 1-D index. Checked in the - // tensorExtractVectorizationPrecondition. - assert(extractOp.getIndices().size() == 1); - - auto indexVec = bvm.lookup(extractOp.getIndices()[0]); // Compute the static loop sizes of the extract op. auto targetShape = linalgOp.computeStaticLoopSizes(); - SmallVector gatherIndices; - gatherIndices.push_back(rewriter.create(loc, 0)); - + auto resultType = + VectorType::get(targetShape, extractOp.getResult().getType()); auto maskConstantOp = rewriter.create( loc, DenseIntElementsAttr::get( VectorType::get(targetShape, rewriter.getI1Type()), /*value=*/true)); - - auto resultType = - VectorType::get(targetShape, extractOp.getResult().getType()); auto passThruConstantOp = rewriter.create(loc, rewriter.getZeroAttr(resultType)); + // Base indices are currently set to 0. We will need to re-visit if more + // generic scenarios are to be supported. + SmallVector baseIndices( + extractOp.getIndices().size(), + rewriter.create(loc, 0)); + + Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + + // Generate the gather load auto gatherOp = rewriter.create( - loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + loc, resultType, extractOp.getTensor(), baseIndices, offset, maskConstantOp, passThruConstantOp); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; @@ -586,7 +631,7 @@ }; hooks.push_back(vectorizeYield); - // 4rewriter. Register CustomVectorizationHook for indexOp. + // 4b. Register CustomVectorizationHook for indexOp. CustomVectorizationHook vectorizeIndex = [&](Operation *op, const BlockAndValueMapping &bvm) -> VectorizationResult { @@ -642,7 +687,8 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition( linalg::LinalgOp op, - ArrayRef customPreconditions) { + ArrayRef customPreconditions, + bool vectorizeNDExtract) { // All types in the body should be a supported element type for VectorType. for (Operation &innerOp : op->getRegion(0).front()) { @@ -650,7 +696,8 @@ if (llvm::any_of( customPreconditions, [&](const CustomVectorizationPrecondition &customPrecondition) { - return succeeded(customPrecondition(&innerOp)); + return succeeded( + customPrecondition(&innerOp, vectorizeNDExtract)); })) { continue; } @@ -686,7 +733,9 @@ return success(); } -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { +LogicalResult +mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, + bool vectorizeNDExtract) { // All types must be static shape to go to vector. if (linalgOp.hasDynamicShape()) { LDBG("precondition failed: dynamic shape"); @@ -698,12 +747,13 @@ // Register CustomVectorizationPrecondition for extractOp. customPreconditions.push_back(tensorExtractVectorizationPrecondition); - return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions); + return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions, + vectorizeNDExtract); } -LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, - LinalgOp linalgOp) { - if (failed(vectorizeLinalgOpPrecondition(linalgOp))) +LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, + bool vectorizeNDExtract) { + if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract))) return failure(); SmallVector results; @@ -713,7 +763,7 @@ if (succeeded(convOr)) { llvm::append_range(results, (*convOr)->getResults()); } else { - if (failed(vectorizeLinalgOpPrecondition(linalgOp))) + if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract))) return failure(); LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp); if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results))) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1500,7 +1500,7 @@ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { +func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { %2 = linalg.generic { indexing_maps = [#map0, #map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"] @@ -1513,14 +1513,34 @@ } -> tensor<4x7x3x2xf32> return %2 : tensor<4x7x3x2xf32> } -// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract -// CHECK: tensor.extract +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> +// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG3:.*]]: tensor<4x7x2xf32> +// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 +// CHECK: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK: %[[CST_1:.*]] = arith.constant dense : vector<4x7x3x2xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> +// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation - %2 = transform.structured.vectorize %1 + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } } // -----