diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -611,11 +611,11 @@ const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { + Value dimIdx = rewriter.create(loc, i); + auto dimSize = broadcastIfNeeded( rewriter, - rewriter.create( - loc, - extractOp.getTensor().getType().cast().getDimSize(i)), + rewriter.create(loc, extractOp.getTensor(), dimIdx), indexVecType.getShape()); offset = rewriter.create(loc, offset, dimSize); @@ -630,6 +630,159 @@ return offset; } +enum VectorMemoryAccessKind { + // TODO: ScalarBroadcast, + Contiguous, + Gather +}; + +/// Check whether /p val can be used for calculating an index for a contiguous +/// load operation. This means that /p val should either: +/// * be invariant with respect to /p linalgOp, or +/// * increment by 1 with every loop iterator (when /p shouldBeConstant is +/// false). +/// Parameters /p trailingLoopDim and /p shouldBeConstant are used to analyze +/// `linalg.index` ops. +static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, + size_t trailingLoopDim, bool shouldBeConstant) { + auto *block = linalgOp.getBlock(); + + // Bail out if this is a block argument for this linalg.generic Op. + // TODO: We could try analysing the corresponding affine map here. + if (val.dyn_cast()) + return llvm::all_of(block->getArguments(), + [&val](Value v) { return (v != val); }); + + Operation *defOp = val.getDefiningOp(); + assert(defOp && "This is neither a block argument nor an operation result"); + + // We know that we are reading into a 1-D tensor like this: + // `tensor<1x1x4xi32`. Given this assumption, the following Op: + // * `%idx = `linalg.index dim : index`, + // will either: + // 1. produce a constant when `dim` _is not_ the trailing loop dim, or + // 2. increment with stride one when `dim` _is_ the trailing loop dim. + if (auto indexOp = dyn_cast(defOp)) + return shouldBeConstant ? (indexOp.getDim() != trailingLoopDim) + : (indexOp.getDim() == trailingLoopDim); + + auto *ancestor = block->findAncestorOpInBlock(*defOp); + + // Values define outside `linalgOp`. + if (!ancestor) + return true; + + // Values defined inside `linalgOp`, which are constant. + if (dyn_cast(ancestor)) + return true; + + // Conservatively reject Ops that could lead to non-contiguous accesses. + if (!isa(ancestor)) + return false; + + bool result = true; + for (auto op : ancestor->getOperands()) + result &= + isContiguousLoadIdx(linalgOp, op, trailingLoopDim, shouldBeConstant); + + return result; +} + +/// Check whether the calculation of \p val is based on linalg.index Op with +/// the dim attribute matching \p dim. +static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) { + auto *block = linalgOp.getBlock(); + auto targetShape = linalgOp.getStaticLoopRanges(); + + if (val.isa()) + return false; + + Operation *defOp = val.getDefiningOp(); + assert(defOp && "This is neither a block argument nor an operation result"); + + if (auto indexOp = dyn_cast(defOp)) + return (indexOp.getDim() == dim); + + auto *ancestor = block->findAncestorOpInBlock(*defOp); + + if (!ancestor) + return false; + + bool result = false; + for (auto op : ancestor->getOperands()) + result |= isBasedOnIndexOp(linalgOp, op, dim); + + return result; +} + +/// Check whether \p extractOp would be a gather or a contiguous load Op after +/// vectorising \p linalgOp. Note that it is always safe to use gather load +/// operations for contiguous loads (albeit slow), but not vice-versa. When in +/// doubt, bail out and assume that \p extractOp is a gather load. +static VectorMemoryAccessKind +getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, + LinalgOp &linalgOp) { + + auto targetShape = linalgOp.getStaticLoopRanges(); + + // Assume that it's a gather load when reading _into_: + // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or + // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`. + // TODO: Relax these conditions. + if ((llvm::count_if(targetShape, + [](int64_t dimSize) { return dimSize > 1; }) != 1) || + targetShape.back() == 1) + return VectorMemoryAccessKind::Gather; + + auto inputShape = extractOp.getTensor().getType().cast(); + + // Assume that it's a gather load when reading _from_ a tensor for which the + // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. + // TODO: Relax this condition. + if (inputShape.getShape().back() == 1) + return VectorMemoryAccessKind::Gather; + + // The trailing loop dim is needed when analyzing ops like: + // * %idx = `linalg.index : index`. + auto trailingLoopDim = targetShape.size() - 1; + + bool isContiguous = true; + + // Iterate over all indices. Analyze the way each index is calculated and + // decide whether it is suitable for a contiguous load (e.g. loop invariant). + auto indices = extractOp.getIndices(); + for (auto [i, indexVal] : llvm::enumerate(indices)) { + if (inputShape.getShape()[i] == 1) { + // This index will always be equal 0, so it is a loop-invariant constant. + continue; + } + + // Should this index be loop invariant? + // * _no_ if this is the trailing index, + // * _yes_ otherwise. + auto extractOpBottomIdx = indices.size() - 1; + bool loopInvariantIndex = (i != extractOpBottomIdx); + + isContiguous &= isContiguousLoadIdx(linalgOp, indexVal, trailingLoopDim, + loopInvariantIndex); + } + + // The trailing index in the extract Op must increment with every iteration, + // which means that it must be based on a loop index. Given the assumption + // on the output tensor, only the trailing loop index is not constant, so + // that's what we need to check against. + auto extractOpTrailingIdx = indices.back(); + isContiguous &= + isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, trailingLoopDim); + + if (isContiguous) { + LDBG("Found contigous load: " << extractOp); + return VectorMemoryAccessKind::Contiguous; + } + + return VectorMemoryAccessKind::Gather; +} + /// Helper function to vectorize the tensor.extract operations. Returns /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a @@ -660,15 +813,64 @@ extractOp.getIndices().size(), rewriter.create(loc, 0)); - Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + VectorMemoryAccessKind memAccessKind = + getTensorExtractMemoryAccessPattern(extractOp, linalgOp); + + // 1. Handle gather access + if (memAccessKind == VectorMemoryAccessKind::Gather) { + Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + + // Generate the gather load + Operation *gatherOp = rewriter.create( + loc, resultType, extractOp.getTensor(), baseIndices, offset, + maskConstantOp, passThruConstantOp); + gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); + + LDBG("Vectorised as gather load: " << extractOp); + return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + } + + // 2. Handle contiguous access. + SmallVector transferReadIdxs; + auto resTrailingDim = resultType.getShape().back(); + auto zero = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type())); + + // Collect indices for `vector.transfer_read`. At this point, the indices will + // either be scalars or would have been broadcast to vectors matching the + // result type. For indices that are vectors, there are two options: + // * for non-trailing indices, all elements are identical (contiguous + // loads are identified by looking for non-trailing indices that are + // invariant with respect to the corresponding linalg.generic), or + // * for trailing indices, the index vector will contain values with stride + // one, but for `vector.transfer_read` only the first (i.e. 0th) index is + // needed. + // This means that + // * for scalar indices - just re-use it, + // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom + // (0th) element and use that. + for (size_t i = 0; i < extractOp.getIndices().size(); i++) { + auto idx = bvm.lookup(extractOp.getIndices()[i]); + if (idx.getType().isIndex()) { + transferReadIdxs.push_back(idx); + continue; + } + + auto indexAs1dVector = rewriter.create( + loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()), + bvm.lookup(extractOp.getIndices()[i])); + transferReadIdxs.push_back( + rewriter.create(loc, indexAs1dVector, zero)); + } + + // `tensor.extract_element` is always in-bounds, hence the following holds. + SmallVector inBounds(resultType.getRank(), true); - // Generate the gather load - Operation *gatherOp = rewriter.create( - loc, resultType, extractOp.getTensor(), baseIndices, offset, - maskConstantOp, passThruConstantOp); - gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); + auto transferReadOp = rewriter.create( + loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds); - return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + LDBG("Vectorised as contiguous load: " << extractOp); + return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; } /// Emit reduction operations if the shapes of the value to reduce is different diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics - -// Masked vectorisation of `tensor.extract`: -// * requires the `{ vectorize_nd_extract }` attribute, -// * has not been implemented yet (hence the attribute is absent). -// TOOD: Implement masked vectorization for `tensor.extract` - -#map1 = affine_map<(d0, d1) -> (d0, d1)> -func.func @extract_masked_vectorize(%arg0: tensor, %arg1: tensor) -> tensor { - %c0 = arith.constant 1 : index - %c1 = arith.constant 2 : index - // expected-error@+1 {{failed to vectorize op}} - %2 = linalg.generic { - indexing_maps = [#map1], - iterator_types = ["parallel", "parallel"] - } outs(%arg1 : tensor) { - ^bb0(%arg3: f32): - %7 = tensor.extract %arg0[%c0, %c1] : tensor - linalg.yield %7 : f32 - } -> tensor - return %2 : tensor -} - - -transform.sequence failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation - transform.structured.masked_vectorize %0 vector_sizes [3, 3] - } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1584,7 +1584,6 @@ iterator_types = ["parallel", "parallel", "parallel"] } outs(%arg2 : tensor<1x1x3xf32>) { ^bb0(%arg4: f32): - %3 = linalg.index 2 : index %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32> linalg.yield %7 : f32 } -> tensor<1x1x3xf32> @@ -1613,7 +1612,7 @@ // ----- #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { +func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { %1 = linalg.generic { indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel"] @@ -1628,16 +1627,19 @@ return %1 : tensor<1x1x3xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32> -// CHECK: %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex> -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> -// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex> +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32> -// CHECK: vector.transfer_write %[[GATHER]] +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex> +// CHECK: %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex> +// CHECK: %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex> +// CHECK: %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32> +// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): @@ -1646,6 +1648,56 @@ %2 = transform.structured.vectorize %1 { vectorize_nd_extract } } + // ----- + +func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> { + %c79 = arith.constant 79 : index + %25 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } outs(%extracted_slice : tensor<1x4xf32>) { + ^bb0(%out: f32): + %26 = linalg.index 0 : index + %27 = arith.addi %arg0, %26 : index + %28 = arith.addi %27, %arg2 : index + %29 = linalg.index 1 : index + %30 = arith.addi %arg1, %29 : index + %31 = arith.addi %30, %arg4 : index + %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32> + linalg.yield %extracted : f32 + } -> tensor<1x4xf32> + return %25 : tensor<1x4xf32> +} + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex +// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>, +// CHECK-SAME: {{.*}}: index, +// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { +// CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 79 : index +// CHECK: %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex> +// CHECK: %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex> +// CHECK: %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex> +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex> +// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex> +// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> +// CHECK: %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> +// CHECK: %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> +// CHECK: %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + } + // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> @@ -1693,6 +1745,51 @@ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 { vectorize_nd_extract } } +// ----- + +#map = affine_map<(d0) -> (d0)> +func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32>, %arg1: tensor<5xi32>) -> tensor<5xf32> { + %c5 = arith.constant 5 : index + %c0 = arith.constant 0 : index + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<5xf32>) { + ^bb0(%out: f32): + %2 = linalg.index 0 : index + %extracted = tensor.extract %arg1[%2] : tensor<5xi32> + %3 = arith.index_cast %extracted : i32 to index + %4 = arith.maxsi %3, %c0 : index + %5 = arith.minsi %4, %c5 : index + %extracted_0 = tensor.extract %arg0[%5] : tensor<6xf32> + linalg.yield %extracted_0 : f32 + } -> tensor<5xf32> + return %1 : tensor<5xf32> +} + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_contiguous_and_gather( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf32> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<5xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<5xindex> +// CHECK: %[[VAL_5:.*]] = arith.constant dense<5> : vector<5xindex> +// CHECK: %[[VAL_6:.*]] = arith.constant dense : vector<5xi1> +// CHECK: %[[VAL_7:.*]] = arith.constant dense<0.000000e+00> : vector<5xf32> +// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<5xf32> +// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32> +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : vector<5xi32> to vector<5xindex> +// CHECK: %[[VAL_11:.*]] = arith.maxsi %[[VAL_10]], %[[VAL_4]] : vector<5xindex> +// CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_5]] : vector<5xindex> +// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_2]]] {{\[}}%[[VAL_12]]], %[[VAL_6]], %[[VAL_7]] : tensor<6xf32>, vector<5xindex>, vector<5xi1>, vector<5xf32> into vector<5xf32> +// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32> +// CHECK: return %[[VAL_14]] : tensor<5xf32> + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + } + // ----- @@ -2119,6 +2216,56 @@ // ----- +#map1 = affine_map<(d0, d1) -> (d0, d1)> +func.func @extract_masked_vectorize(%arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 1 : index + %c1 = arith.constant 2 : index + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel"] + } outs(%arg1 : tensor) { + ^bb0(%arg3: f32): + %7 = tensor.extract %arg0[%c0, %c1] : tensor + linalg.yield %7 : f32 + } -> tensor + return %2 : tensor +} + +// CHECK-LABEL: func.func @extract_masked_vectorize( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<3x3xi1> +// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor, vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32> +// CHECK: %[[VAL_12:.*]] = arith.constant dense : vector<3x3xi1> +// CHECK: %[[VAL_13:.*]] = arith.constant dense<0.000000e+00> : vector<3x3xf32> +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_15:.*]] = arith.constant dense<1> : vector<3x3xindex> +// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor +// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : index to vector<3x3xindex> +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x3xindex> +// CHECK: %[[VAL_20:.*]] = arith.constant dense<2> : vector<3x3xindex> +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : vector<3x3xindex> +// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {{\[}}%[[VAL_21]]], %[[VAL_12]], %[[VAL_13]] : tensor, vector<3x3xindex>, vector<3x3xi1>, vector<3x3xf32> into vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32> +// CHECK: %[[VAL_23:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_24:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_22]], %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_23]]] {in_bounds = [true, true]} : vector<3x3xf32>, tensor } : vector<3x3xi1> -> tensor + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } + } + +// ----- + func.func @do_not_generate_masks(%arg0: tensor<8x32xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<8x32xf32>) -> tensor<8x32xf32> {