diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -627,6 +627,148 @@ return offset; } +enum VectorMemoryAccessKind { + // TODO: ScalarBroadcast, + Contiguous, + Gather +}; + +/// Check whether /p val can be used for calculating an index for a contiguous +/// load operation, i.e. whether /p val: +/// * is invariant with respect to /p linalgOp, i.e. whether it remains +/// constant for all iterations, ar +/// * increments with the loop iterator (when /p strideZero is false) or is +/// not affected by the loop indices (/p strideZero is true). +static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, size_t dim, + bool strideZero) { + auto *block = linalgOp.getBlock(); + + // Bail out if this is a block argument for this linalg.generic Op. + // TODO: We could try analysing the corresponding affine map here. + if (val.dyn_cast()) + return llvm::all_of(block->getArguments(), + [&val](Value v) { return (v != val); }); + + Operation *defOp = val.getDefiningOp(); + assert(defOp && "This is neither a block argument nor an operation result"); + + // Given the assumption on the shape of the target tensor, index Op is + // either: + // * constant (for non-trailing dims), or + // * increments with stride one together with the trailing dimension + // Both cases are fine for contigious loads. + if (auto indexOp = dyn_cast(defOp)) + return strideZero ? (indexOp.getDim() != dim) : (indexOp.getDim() == dim); + + auto *ancestor = block->findAncestorOpInBlock(*defOp); + + // Values define outside `linalgOp`. + if (!ancestor) + return true; + + // Values defined inside `linalgOp`, which are constant. + if (dyn_cast(ancestor)) + return true; + + bool result = true; + for (auto op : ancestor->getOperands()) + result &= isContiguousLoadIdx(linalgOp, op, dim, strideZero); + + return result; +} + +/// Check whether the calculation of \p val is based on linalg.index Op with +/// the dim attribute matching \p dim. +static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) { + auto *block = linalgOp.getBlock(); + auto targetShape = linalgOp.getStaticLoopRanges(); + + if (val.isa()) + return false; + + Operation *defOp = val.getDefiningOp(); + assert(defOp && "This is neither a block argument nor an operation result"); + + if (auto indexOp = dyn_cast(defOp)) + return (indexOp.getDim() == dim); + + auto *ancestor = block->findAncestorOpInBlock(*defOp); + + if (!ancestor) + return false; + + bool result = false; + for (auto op : ancestor->getOperands()) + result |= isBasedOnIndexOp(linalgOp, op, dim); + + return result; +} + +/// Check whether \p extractOp would be a gather or a contiguous load Op after +/// vectorising \p linalgOp. Note that it is always safe to use gather load +/// operations for contiguous loads (albeit slow), but not vice-versa. When in +/// doubt, bail out and assume that \p extractOp is a gather load. +static VectorMemoryAccessKind +getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, + LinalgOp &linalgOp) { + + auto targetShape = linalgOp.getStaticLoopRanges(); + + // Assume that it's a gather load when reading _into_: + // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or + // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`. + // TODO: Relax these conditions. + if ((llvm::count_if(targetShape, + [](int64_t dimSize) { return dimSize > 1; }) != 1) || + targetShape.back() == 1) + return VectorMemoryAccessKind::Gather; + + auto inputShape = extractOp.getTensor().getType().cast(); + + // Assume that it's a gather load when reading _from_ a tensor for which the + // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. + // TODO: Relax this condition. + if (inputShape.getShape().back() == 1) + return VectorMemoryAccessKind::Gather; + + if (!llvm::all_of(linalgOp.getIndexingMapsArray(), + [](AffineMap m) { return m.isMinorIdentity(); })) { + return VectorMemoryAccessKind::Gather; + } + + bool isContiguous = true; + + // Iterate over all indices. Analyze whether the way each index is calculate + // is suitable for contiguous load operations (e.g. loop invariant). + auto indices = extractOp.getIndices(); + for (auto [i, indexVal] : llvm::enumerate(indices)) { + if (inputShape.getShape()[i] == 1) { + // This extractOp index must be a loop-invariant constant + continue; + } + + auto extractOpBottomIdx = indices.size() - 1; + auto strideOneDim = targetShape.size() - 1; + bool strideZero = (i != extractOpBottomIdx); + isContiguous &= + isContiguousLoadIdx(linalgOp, indexVal, strideOneDim, strideZero); + } + + // The calculation of the trailing index must include the loop index. Given + // the assumption on the output tensor (which is defined by the iteration + // space), only the trailing dim matters. + auto extractOpTrailingIdx = indices.back(); + isContiguous &= + isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, targetShape.size() - 1); + + if (isContiguous) { + LDBG("Found contigous load: " << extractOp); + return VectorMemoryAccessKind::Contiguous; + } + + return VectorMemoryAccessKind::Gather; +} + /// Helper function to vectorize the tensor.extract operations. Returns /// VectorizationStatus::NewOp to signal the vectorization algorithm that it /// should map the produced operations. This function is meant to be used as a @@ -658,14 +800,67 @@ extractOp.getIndices().size(), rewriter.create(loc, 0)); - Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + VectorMemoryAccessKind memAccessKind = + getTensorExtractMemoryAccessPattern(extractOp, linalgOp); + + // 1. Handle gather access + if (memAccessKind == VectorMemoryAccessKind::Gather) { + // TODO: We need a mechanism to turn gather loads on/off. + // if (not vectorizeAsGatherLoads) + // return VectorizationResult{VectorizationStatus::Failure, nullptr}; + + Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + + // Generate the gather load + auto gatherOp = rewriter.create( + loc, resultType, extractOp.getTensor(), baseIndices, offset, + maskConstantOp, passThruConstantOp); + + LDBG("Vectorised as gather load: " << extractOp); + return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + } + + // 2. Handle contiguous access. + SmallVector transferReadIdxs; + auto resTrailingDim = resultType.getShape().back(); + auto zero = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type())); + + // Collect indices for `vector.transfer_read`. At this point, the indices will + // either be scalars or would have been broadcast to vectors matching the + // result type. For indices that are vectors, there are two options: + // * for non-trailing indices, all elements are identical (contiguous + // loads are identified by looking for non-trailing indices that are + // invariant with respect to the corresponding linalg.generic), or + // * for trailing indices, the index vector will contain values with stride + // one, but for `vector.transfer_read` only the first (i.e. 0th) index is + // needed. + // This means that + // * for scalar indices - just re-use it, + // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom + // (0th) element and use that. + for (size_t i = 0; i < extractOp.getIndices().size(); i++) { + auto idx = bvm.lookup(extractOp.getIndices()[i]); + if (idx.getType().isIndex()) { + transferReadIdxs.push_back(idx); + continue; + } + + auto indexAs1dVector = rewriter.create( + loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()), + bvm.lookup(extractOp.getIndices()[i])); + transferReadIdxs.push_back( + rewriter.create(loc, indexAs1dVector, zero)); + } + + // `tensor.extract_element` is always in-bounds, hence the following holds. + SmallVector inBounds(resultType.getRank(), true); - // Generate the gather load - auto gatherOp = rewriter.create( - loc, resultType, extractOp.getTensor(), baseIndices, offset, - maskConstantOp, passThruConstantOp); + auto transferReadOp = rewriter.create( + loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds); - return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; + LDBG("Vectorised as contiguous load: " << extractOp); + return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; } /// Emit reduction operations if the shapes of the value to reduce is different diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1539,7 +1539,6 @@ iterator_types = ["parallel", "parallel", "parallel"] } outs(%arg2 : tensor<1x1x3xf32>) { ^bb0(%arg4: f32): - %3 = linalg.index 2 : index %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32> linalg.yield %7 : f32 } -> tensor<1x1x3xf32> @@ -1568,7 +1567,7 @@ // ----- #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { +func.func @vectorize_nd_tensor_extract_transfer_read_easy(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { %1 = linalg.generic { indexing_maps = [#map1], iterator_types = ["parallel", "parallel", "parallel"] @@ -1583,16 +1582,19 @@ return %1 : tensor<1x1x3xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_easy // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32> -// CHECK: %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex> -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> -// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex> +// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32> -// CHECK: vector.transfer_write %[[GATHER]] +// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex> +// CHECK: %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex> +// CHECK: %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex> +// CHECK: %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32> +// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): @@ -1601,6 +1603,35 @@ %2 = transform.structured.vectorize %1 { vectorize_nd_extract } } + // ----- + +func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> { + %c79 = arith.constant 79 : index + %25 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } outs(%extracted_slice : tensor<1x4xf32>) { + ^bb0(%out: f32): + %26 = linalg.index 0 : index + %27 = arith.addi %arg0, %26 : index + %28 = arith.addi %27, %arg2 : index + %29 = linalg.index 1 : index + %30 = arith.addi %arg1, %29 : index + %31 = arith.addi %30, %arg4 : index + %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32> + linalg.yield %extracted : f32 + } -> tensor<1x4xf32> + return %25 : tensor<1x4xf32> +} +// CHECK: vector.transfer_read + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + } + // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)> diff --git a/mlir/test/Dialect/Linalg/vectorize_tensor_extract.mlir b/mlir/test/Dialect/Linalg/vectorize_tensor_extract.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/vectorize_tensor_extract.mlir @@ -0,0 +1,131 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1, d2) -> (d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @vector_gather(%arg0: tensor<3x3x3xf32>, %arg1: tensor<3xi32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + %2 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg1 : tensor<3xi32>) outs(%arg2 : tensor<1x1x3xf32>) { + ^bb0(%arg3: i32, %arg4: f32): + %c0 = arith.constant 0 : index + %3 = arith.index_cast %arg3 : i32 to index + %7 = tensor.extract %arg0[%c0, %c0, %3] : tensor<3x3x3xf32> + linalg.yield %7 : f32 + } -> tensor<1x1x3xf32> + return %2 : tensor<1x1x3xf32> +} +// CHECK: %[[GATHER:.*]] = vector.gather + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1: (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } +} + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @gather(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg2 : tensor<1x1x3xf32>) { + ^bb0(%arg4: f32): + %c0 = arith.constant 1 : index + %c1 = arith.constant 2 : index + %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32> + linalg.yield %7 : f32 + } -> tensor<1x1x3xf32> + return %2 : tensor<1x1x3xf32> +} +// CHECK: %[[GATHER:.*]] = vector.gather + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1: (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } +} + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @transfer_read(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + %1 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg2 : tensor<1x1x3xf32>) { + ^bb0(%arg4: f32): + %2 = linalg.index 0 : index + %3 = linalg.index 1 : index + %4 = linalg.index 2 : index + %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32> + linalg.yield %5 : f32 + } -> tensor<1x1x3xf32> + return %1 : tensor<1x1x3xf32> +} +// CHECK: %[[V2:.+]] = vector.transfer_read + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1: (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } +} + +// ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @gather(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + %c0 = arith.constant 1 : index + %c1 = arith.constant 2 : index + %2 = linalg.generic { + indexing_maps = [#map1], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%arg2 : tensor<1x1x3xf32>) { + ^bb0(%arg4: f32): + %3 = linalg.index 2 : index + %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32> + linalg.yield %7 : f32 + } -> tensor<1x1x3xf32> + return %2 : tensor<1x1x3xf32> +} +// CHECK: %[[GATHER:.*]] = vector.gather + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1: (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + } + + // ----- + +func.func @transfer_read(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> { + %c79 = arith.constant 79 : index + %25 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } outs(%extracted_slice : tensor<1x4xf32>) { + ^bb0(%out: f32): + %26 = linalg.index 0 : index + %27 = arith.addi %arg0, %26 : index + %28 = arith.addi %27, %arg2 : index + %29 = linalg.index 1 : index + %30 = arith.addi %arg1, %29 : index + %31 = arith.addi %30, %arg4 : index + %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32> + linalg.yield %extracted : f32 + } -> tensor<1x4xf32> + return %25 : tensor<1x4xf32> +} +// CHECK: %[[V2:.+]] = vector.transfer_read + +transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1: (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_nd_extract } + }