diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1638,7 +1638,6 @@ let arguments = (ins PDL_Operation:$target, UnitAttr:$vectorize_padding, - UnitAttr:$vectorize_nd_extract, UnitAttr:$disable_multi_reduction_to_contract_patterns, UnitAttr:$disable_transfer_permutation_map_lowering_patterns); let results = (outs PDL_Operation:$transformed); @@ -1647,8 +1646,7 @@ let builders = [ OpBuilder<(ins "Value":$target, - CArg<"bool", "false">:$vectorizePadding, - CArg<"bool", "false">:$vectorizeNDExtract)>, + CArg<"bool", "false">:$vectorizePadding)>, ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( @@ -1685,7 +1683,6 @@ let arguments = (ins PDL_Operation:$target, Variadic:$vector_sizes, - UnitAttr:$vectorize_nd_extract, DefaultValuedOptionalAttr: $static_vector_sizes); let results = (outs); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -400,8 +400,7 @@ /// `inputVectorShapes` also allows the vectorization of operations with dynamic /// shapes. LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp, - ArrayRef inputVectorSizes = {}, - bool vectorizeNDExtract = false); + ArrayRef inputVectorSizes = {}); /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); @@ -429,8 +428,7 @@ /// Return success if the operation can be vectorized. LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp, - ArrayRef inputVectorSizes = {}, - bool vectorizeNDExtract = false); + ArrayRef inputVectorSizes = {}); //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2904,17 +2904,12 @@ //===----------------------------------------------------------------------===// void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, - Value target, bool vectorizePadding, - bool vectorizeExtract) { + Value target, bool vectorizePadding) { result.addOperands(target); if (vectorizePadding) { result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), builder.getUnitAttr()); } - if (vectorizeExtract) { - result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name), - builder.getUnitAttr()); - } result.addTypes(pdl::OperationType::get(builder.getContext())); } @@ -2924,21 +2919,14 @@ struct VectorizationPattern : public RewritePattern { explicit VectorizationPattern(MLIRContext *context, bool vectorizeExtract = false) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), - vectorizeNDExtract(vectorizeExtract) {} + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return rewriter.notifyMatchFailure(op, "expected Linalg Op"); - return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, - vectorizeNDExtract); + return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}); } - -private: - /// Controls whether to vectorize `tensor.extract` when the input tensor is - /// rank >= 2. - bool vectorizeNDExtract = false; }; } // namespace @@ -2954,7 +2942,7 @@ MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx, getVectorizeNdExtract()); + patterns.add(ctx); if (!getDisableTransferPermutationMapLoweringPatterns()) vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); @@ -3037,8 +3025,7 @@ << "cannot vectorize non-Linalg op"; } - if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes, - getVectorizeNdExtract()))) { + if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes))) { return mlir::emitSilenceableFailure(target->getLoc()) << "failed to vectorize op"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -491,7 +491,7 @@ // with CustomVectorizationHook. Returns success if the corresponding custom // hook can vectorize the op. using CustomVectorizationPrecondition = - std::function; + std::function; // Custom vectorization function type. Produce a vector form of Operation* // assuming all its vectorized operands are already in the IRMapping. @@ -568,14 +568,11 @@ /// Helper function to check if the tensor.extract can be vectorized by the /// custom hook vectorizeTensorExtract. static LogicalResult -tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) { +tensorExtractVectorizationPrecondition(Operation *op) { tensor::ExtractOp extractOp = dyn_cast(op); if (!extractOp) return failure(); - if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract) - return failure(); - if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType())) return failure(); @@ -611,11 +608,11 @@ const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { + Value dimIdx = rewriter.create(loc, i); + auto dimSize = broadcastIfNeeded( rewriter, - rewriter.create( - loc, - extractOp.getTensor().getType().cast().getDimSize(i)), + rewriter.create(loc, extractOp.getTensor(), dimIdx), indexVecType.getShape()); offset = rewriter.create(loc, offset, dimSize); @@ -630,6 +627,143 @@ return offset; } +enum VectorMemoryAccessKind { + // TODO: ScalarBroadcast, + Contiguous, + Gather +}; + +/// Check whether /p val can be used for calculating an index for a contiguous +/// load operation, i.e. whether /p val: +/// * is invariant with respect to /p linalgOp, i.e. whether it remains +/// constant for all iterations, and +/// * increments with the loop iterator (when /p strideZero is false) or is +/// not affected by the loop indices (/p strideZero is true). +static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, size_t dim, + bool strideZero) { + auto *block = linalgOp.getBlock(); + + // Bail out if this is a block argument for this linalg.generic Op. + // TODO: We could try analysing the corresponding affine map here. + if (val.dyn_cast()) + return llvm::all_of(block->getArguments(), + [&val](Value v) { return (v != val); }); + + Operation *defOp = val.getDefiningOp(); + assert(defOp && "This is neither a block argument nor an operation result"); + + // Given the assumption on the shape of the target tensor, index Op is + // either: + // * constant (for non-trailing dims), or + // * increments with stride one together with the trailing dimension + // Both cases are fine for contigious loads. + if (auto indexOp = dyn_cast(defOp)) + return strideZero ? (indexOp.getDim() != dim) : (indexOp.getDim() == dim); + + auto *ancestor = block->findAncestorOpInBlock(*defOp); + + // Values define outside `linalgOp`. + if (!ancestor) + return true; + + // Values defined inside `linalgOp`, which are constant. + if (dyn_cast(ancestor)) + return true; + + bool result = true; + for (auto op : ancestor->getOperands()) + result &= isContiguousLoadIdx(linalgOp, op, dim, strideZero); + + return result; +} + +/// Check whether the calculation of \p val is based on linalg.index Op with +/// the dim attribute matching \p dim. +static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) { + auto *block = linalgOp.getBlock(); + auto targetShape = linalgOp.getStaticLoopRanges(); + + if (val.isa()) + return false; + + Operation *defOp = val.getDefiningOp(); + assert(defOp && "This is neither a block argument nor an operation result"); + + if (auto indexOp = dyn_cast(defOp)) + return (indexOp.getDim() == dim); + + auto *ancestor = block->findAncestorOpInBlock(*defOp); + + if (!ancestor) + return false; + + bool result = false; + for (auto op : ancestor->getOperands()) + result |= isBasedOnIndexOp(linalgOp, op, dim); + + return result; +} + +/// Check whether \p extractOp would be a gather or a contiguous load Op after +/// vectorising \p linalgOp. Note that it is always safe to use gather load +/// operations for contiguous loads (albeit slow), but not vice-versa. When in +/// doubt, bail out and assume that \p extractOp is a gather load. +static VectorMemoryAccessKind +getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, + LinalgOp &linalgOp) { + + auto targetShape = linalgOp.getStaticLoopRanges(); + + // Assume that it's a gather load when reading _into_: + // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or + // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`. + // TODO: Relax these conditions. + if ((llvm::count_if(targetShape, + [](int64_t dimSize) { return dimSize > 1; }) != 1) || + targetShape.back() == 1) + return VectorMemoryAccessKind::Gather; + + auto inputShape = extractOp.getTensor().getType().cast(); + + // Assume that it's a gather load when reading _from_ a tensor for which the + // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. + // TODO: Relax this condition. + if (inputShape.getShape().back() == 1) + return VectorMemoryAccessKind::Gather; + + bool isContiguous = true; + + // Iterate over all indices. Analyze whether the way each index is calculate + // is suitable for contiguous load operations (e.g. loop invariant). + auto indices = extractOp.getIndices(); + for (auto [i, indexVal] : llvm::enumerate(indices)) { + if (inputShape.getShape()[i] == 1) { + // This extractOp index must be a loop-invariant constant + continue; + } + + auto extractOpBottomIdx = indices.size() - 1; + auto strideOneDim = targetShape.size() - 1; + bool strideZero = (i != extractOpBottomIdx); + isContiguous &= + isContiguousLoadIdx(linalgOp, indexVal, strideOneDim, strideZero); + } + + // The calculation of the trailing index must include the loop index. Given + // the assumption on the output tensor (which is defined by the iteration + // space), only the trailing dim matters. + auto extractOpTrailingIdx = indices.back(); + isContiguous &= + isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, targetShape.size() - 1); + + if (isContiguous) { + LDBG("Found contigous load: " << extractOp); + return VectorMemoryAccessKind::Contiguous; + } + + return VectorMemoryAccessKind::Gather; +} + /// Helper function to vectorize the tensor.extract operations. Returns /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a @@ -660,15 +794,68 @@ extractOp.getIndices().size(), rewriter.create(loc, 0)); - Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + VectorMemoryAccessKind memAccessKind = + getTensorExtractMemoryAccessPattern(extractOp, linalgOp); + + // 1. Handle gather access + if (memAccessKind == VectorMemoryAccessKind::Gather) { + // TODO: We need a mechanism to turn gather loads on/off. + // if (not vectorizeAsGatherLoads) + // return VectorizationResult{VectorizationStatus::Failure, nullptr}; - // Generate the gather load - Operation *gatherOp = rewriter.create( - loc, resultType, extractOp.getTensor(), baseIndices, offset, - maskConstantOp, passThruConstantOp); - gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); + Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); - return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + // Generate the gather load + Operation *gatherOp = rewriter.create( + loc, resultType, extractOp.getTensor(), baseIndices, offset, + maskConstantOp, passThruConstantOp); + gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); + + LDBG("Vectorised as gather load: " << extractOp); + return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + } + + // 2. Handle contiguous access. + SmallVector transferReadIdxs; + auto resTrailingDim = resultType.getShape().back(); + auto zero = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type())); + + // Collect indices for `vector.transfer_read`. At this point, the indices will + // either be scalars or would have been broadcast to vectors matching the + // result type. For indices that are vectors, there are two options: + // * for non-trailing indices, all elements are identical (contiguous + // loads are identified by looking for non-trailing indices that are + // invariant with respect to the corresponding linalg.generic), or + // * for trailing indices, the index vector will contain values with stride + // one, but for `vector.transfer_read` only the first (i.e. 0th) index is + // needed. + // This means that + // * for scalar indices - just re-use it, + // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom + // (0th) element and use that. + for (size_t i = 0; i < extractOp.getIndices().size(); i++) { + auto idx = bvm.lookup(extractOp.getIndices()[i]); + if (idx.getType().isIndex()) { + transferReadIdxs.push_back(idx); + continue; + } + + auto indexAs1dVector = rewriter.create( + loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()), + bvm.lookup(extractOp.getIndices()[i])); + transferReadIdxs.push_back( + rewriter.create(loc, indexAs1dVector, zero)); + } + + // `tensor.extract_element` is always in-bounds, hence the following holds. + SmallVector inBounds(resultType.getRank(), true); + + auto transferReadOp = rewriter.create( + loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds); + + LDBG("Vectorised as contiguous load: " << extractOp); + return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; } /// Emit reduction operations if the shapes of the value to reduce is different @@ -976,10 +1163,8 @@ return success(); } -LogicalResult -mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, - ArrayRef inputVectorSizes, - bool vectorizeNDExtract) { +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition( + LinalgOp linalgOp, ArrayRef inputVectorSizes) { // tensor with dimension of 0 cannot be vectorized. if (llvm::any_of(linalgOp.getStaticShape(), [](int64_t dim) { return dim == 0; })) @@ -1021,7 +1206,7 @@ customPreconditions, [&](const CustomVectorizationPrecondition &customPrecondition) { return succeeded( - customPrecondition(&innerOp, vectorizeNDExtract)); + customPrecondition(&innerOp)); })) { continue; } @@ -1079,15 +1264,13 @@ /// `inputVectorShapes` also allows the vectorization of operations with dynamic /// shapes. LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp, - ArrayRef inputVectorSizes, - bool vectorizeNDExtract) { + ArrayRef inputVectorSizes) { LDBG("Attempting to vectorize:\n" << linalgOp << "\n"); LDBG("Input vector sizes: "); LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << "\n"); - if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, - vectorizeNDExtract))) { + if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes))) { LDBG("Vectorization pre-conditions failed\n"); return failure(); } @@ -1106,8 +1289,7 @@ if (succeeded(convOr)) { llvm::append_range(results, (*convOr)->getResults()); } else { - if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, - vectorizeNDExtract))) + if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes))) return failure(); LDBG("Vectorize generic by broadcasting to the canonical vector shape\n"); diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics - -// Masked vectorisation of `tensor.extract`: -// * requires the `{ vectorize_nd_extract }` attribute, -// * has not been implemented yet (hence the attribute is absent). -// TOOD: Implement masked vectorization for `tensor.extract` - -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @extract_masked_vectorize(%arg0: tensor, %arg1: tensor) -> tensor { - %c0 = arith.constant 1 : index - %c1 = arith.constant 2 : index - // expected-error@+1 {{failed to vectorize op}} - %2 = linalg.generic { - indexing_maps = [#map1], - iterator_types = ["parallel", "parallel"] - } outs(%arg1 : tensor) { - ^bb0(%arg3: f32): - %7 = tensor.extract %arg0[%c0, %c1] : tensor - linalg.yield %7 : f32 - } -> tensor - return %2 : tensor -} - - -transform.sequence failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.masked_vectorize %0 vector_sizes [3, 3] - } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -330,7 +330,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation - %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + %2 = transform.structured.vectorize %1 } // ----- @@ -1584,7 +1584,6 @@ iterator_types = ["parallel", "parallel", "parallel"] } outs(%arg2 : tensor<1x1x3xf32>) { ^bb0(%arg4: f32): - %3 = linalg.index 2 : index %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32> linalg.yield %7 : f32 } -> tensor<1x1x3xf32> @@ -1607,13 +1606,13 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation - %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + %2 = transform.structured.vectorize %1 } // ----- #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { +func.func @vectorize_nd_tensor_extract_transfer_read_easy(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { %1 = linalg.generic { indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel"] @@ -1628,23 +1627,55 @@ return %1 : tensor<1x1x3xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_easy // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32> -// CHECK: %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex> -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> -// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex> +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32> -// CHECK: vector.transfer_write %[[GATHER]] +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex> +// CHECK: %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex> +// CHECK: %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex> +// CHECK: %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32> +// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation - %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + %2 = transform.structured.vectorize %1 +} + + // ----- + +func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> { + %c79 = arith.constant 79 : index + %25 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } outs(%extracted_slice : tensor<1x4xf32>) { + ^bb0(%out: f32): + %26 = linalg.index 0 : index + %27 = arith.addi %arg0, %26 : index + %28 = arith.addi %27, %arg2 : index + %29 = linalg.index 1 : index + %30 = arith.addi %arg1, %29 : index + %31 = arith.addi %30, %arg4 : index + %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32> + linalg.yield %extracted : f32 + } -> tensor<1x4xf32> + return %25 : tensor<1x4xf32> } +// CHECK: vector.transfer_read + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 + } // ----- @@ -1691,7 +1722,7 @@ ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation - %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + %2 = transform.structured.vectorize %1 } // ----- @@ -2119,6 +2150,58 @@ // ----- +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @extract_masked_vectorize(%arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 1 : index + %c1 = arith.constant 2 : index + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel"] + } outs(%arg1 : tensor) { + ^bb0(%arg3: f32): + %7 = tensor.extract %arg0[%c0, %c1] : tensor + linalg.yield %7 : f32 + } -> tensor + return %2 : tensor +} + +// CHECK-LABEL: func.func @extract_masked_vectorize( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<3x3xi1> +// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor, vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32> +// CHECK: %[[VAL_12:.*]] = arith.constant dense : vector<3x3xi1> +// CHECK: %[[VAL_13:.*]] = arith.constant dense<0.000000e+00> : vector<3x3xf32> +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_15:.*]] = arith.constant dense<1> : vector<3x3xindex> +// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor +// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : index to vector<3x3xindex> +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x3xindex> +// CHECK: %[[VAL_20:.*]] = arith.constant dense<2> : vector<3x3xindex> +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : vector<3x3xindex> +// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {{\[}}%[[VAL_21]]], %[[VAL_12]], %[[VAL_13]] : tensor, vector<3x3xindex>, vector<3x3xi1>, vector<3x3xf32> into vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32> +// CHECK: %[[VAL_23:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_24:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_22]], %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_23]]] {in_bounds = [true, true]} : vector<3x3xf32>, tensor } : vector<3x3xi1> -> tensor +// CHECK: return %[[VAL_24]] : tensor +// CHECK: } + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.structured.masked_vectorize %0 vector_sizes [3, 3] + } + +// ----- + func.func @do_not_generate_masks(%arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<8x32xf32>) -> tensor<8x32xf32> {