diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1087,6 +1087,7 @@ let arguments = (ins PDL_Operation:$target, UnitAttr:$vectorize_padding, + UnitAttr:$vectorize_extract, UnitAttr:$disable_multi_reduction_to_contract_patterns, UnitAttr:$disable_transfer_permutation_map_lowering_patterns); let results = (outs PDL_Operation:$transformed); @@ -1094,7 +1095,9 @@ let assemblyFormat = "$target attr-dict"; let builders = [ - OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)> + OpBuilder<(ins "Value":$target, + CArg<"bool", "false">:$vectorizePadding, + CArg<"bool", "false">:$vectorizeExtract)>, ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1078,6 +1078,16 @@ GenericOp genericOp, ArrayRef foldedIterationDims, RewriterBase &rewriter); +//===----------------------------------------------------------------------===// +// Linalg Vectorizer CommandLine Options +//===----------------------------------------------------------------------===// + +/// Register a set of useful command-line options that can be used to configure +/// the Linalg vectorizer. +void registerLinalgVectorizerCLOptions(); +/// Enable vectorization for N-D tensor.extract operator +void enableNDExtractVectorization(); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1785,12 +1785,17 @@ //===----------------------------------------------------------------------===// void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, - Value target, bool vectorizePadding) { + Value target, bool vectorizePadding, + bool vectorizeExtract) { result.addOperands(target); if (vectorizePadding) { result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), builder.getUnitAttr()); } + if (vectorizeExtract) { + result.addAttribute(VectorizeOp::getVectorizeExtractAttrName(result.name), + builder.getUnitAttr()); + } result.addTypes(pdl::OperationType::get(builder.getContext())); } @@ -1841,6 +1846,9 @@ if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); + if (getVectorizeExtract()) + linalg::enableNDExtractVectorization(); + if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -45,6 +46,31 @@ #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) +namespace { +/// This struct contains command line options that can be used to configure +/// this vectorizer. This uses a struct wrapper to avoid the need for global +/// command line options. +struct MLIRLinalgVectorizerOptions { + llvm::cl::opt vectorizeNDExtract{ + "linalg-vectorize-n-d-extract", + llvm::cl::desc("Enable vectorization of tensor.extract for n-D tensors"), + llvm::cl::init(false)}; +}; +} // namespace + +static llvm::ManagedStatic clOptions; + +/// Register a set of useful command-line options that can be used to configure +/// the behaviour of the Linalg vectorizer. +void mlir::linalg::registerLinalgVectorizerCLOptions() { + // Make sure that the options struct has been initialized. + *clOptions; +} + +void mlir::linalg::enableNDExtractVectorization() { + clOptions->vectorizeNDExtract = true; +} + /// Try to vectorize `convOp` as a convolution. static FailureOr vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp); @@ -319,9 +345,12 @@ if (!extractOp) return failure(); - // Currently only supports extraction with an 1-D index. - if (extractOp.getIndices().size() != 1) - return failure(); + if (extractOp.getIndices().size() != 1) { + // Initialize values based on the command line flags if they were provided. + if (!clOptions.isConstructed() || !clOptions->vectorizeNDExtract) { + return failure(); + } + } if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) return failure(); @@ -335,6 +364,51 @@ return success(); } +/// Calculates the offsets (`$index_vec`) for `vector.gather` operations +/// generated from `tensor.extract`. The offset is calculated as follows +/// (example using scalar values): +/// +/// offset = extractOp.indices[0] +/// for (i = 1; i < numIndices; i++) +/// offset = extractOp.dimSize[i] * offset + extractOp.indices[i]; +/// +/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: +/// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 +static Value +calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp, + const BlockAndValueMapping &bvm, + const SmallVectorImpl &targetShape) { + // The vector of indices for GatherOp should be shaped as the output vector + auto indexVecType = VectorType::get(targetShape, b.getIndexType()); + auto loc = extractOp.getLoc(); + + Value offset = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[0])); + + const size_t numIndices = extractOp.getIndices().size(); + for (size_t i = 1; i < numIndices; i++) { + auto dimSizeBcast = b.create( + loc, indexVecType, + b.create( + loc, + extractOp.getTensor().getType().cast().getDimSize(i))); + offset = b.create(loc, offset, dimSizeBcast); + + auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]); + if (i == numIndices - 1) { + // We only need an additional broadcast for the trailing index. All other + // indices have already been broadcast by `vectorizeLinalgIndex` to match + // the output size. + originalIndexBcast = b.create( + loc, indexVecType, bvm.lookup(extractOp.getIndices()[i])); + } + + offset = b.create(loc, originalIndexBcast, offset); + } + + return offset; +} + /// Helper function to vectorize the tensor.extract operations. Returns /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a @@ -347,29 +421,29 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = extractOp.getLoc(); - // Currently only supports extraction with an 1-D index. Checked in the - // tensorExtractVectorizationPrecondition. - assert(extractOp.getIndices().size() == 1); - - auto indexVec = bvm.lookup(extractOp.getIndices()[0]); // Compute the static loop sizes of the extract op. auto targetShape = linalgOp.computeStaticLoopSizes(); - SmallVector gatherIndices; - gatherIndices.push_back(rewriter.create(loc, 0)); - + auto resultType = + VectorType::get(targetShape, extractOp.getResult().getType()); auto maskConstantOp = rewriter.create( loc, DenseIntElementsAttr::get( VectorType::get(targetShape, rewriter.getI1Type()), /*value=*/true)); - - auto resultType = - VectorType::get(targetShape, extractOp.getResult().getType()); auto passThruConstantOp = rewriter.create(loc, rewriter.getZeroAttr(resultType)); + // Base indices are currently set to 0. We will need to re-visit if more + // generic scenarios are to be supported. + SmallVector baseIndices( + extractOp.getIndices().size(), + rewriter.create(loc, 0)); + + Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + + // Generate the gather load auto gatherOp = rewriter.create( - loc, resultType, extractOp.getTensor(), gatherIndices, indexVec, + loc, resultType, extractOp.getTensor(), baseIndices, offset, maskConstantOp, passThruConstantOp); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1500,7 +1500,7 @@ #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { +func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> { %2 = linalg.generic { indexing_maps = [#map0, #map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"] @@ -1513,14 +1513,34 @@ } -> tensor<4x7x3x2xf32> return %2 : tensor<4x7x3x2xf32> } -// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract -// CHECK: tensor.extract +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> +// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32> +// CHECK-SAME: %[[ARG3:.*]]: tensor<4x7x2xf32> +// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 +// CHECK: %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex> +// CHECK: %[[CST_1:.*]] = arith.constant dense : vector<4x7x3x2xi1> +// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32> +// CHECK: %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex> +// CHECK: %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex> +// CHECK: %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex> +// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex> +// CHECK: %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32> +// CHECK: vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation - %2 = transform.structured.vectorize %1 + %2 = transform.structured.vectorize %1 { vectorize_extract } } // -----