diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -729,11 +729,7 @@ return offset; } -enum VectorMemoryAccessKind { - // TODO: ScalarBroadcast, - Contiguous, - Gather -}; +enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; /// Checks whether /p val can be used for calculating a loop invariant index. static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) { @@ -872,36 +868,57 @@ if (inputShape.getShape().back() == 1) return VectorMemoryAccessKind::Gather; - bool isContiguous = true; + bool leadingIdxsLoopInvariant = true; - // 3a. Analyze the leading indices of `extractOp`. + // 3. Analyze the leading indices of `extractOp`. // Look at the way each index is calculated and decide whether it is suitable // for a contiguous load, i.e. whether it's loop invariant. auto indices = extractOp.getIndices(); - auto leadIndices = ValueRange(indices.drop_back(1)); + auto leadIndices = indices.drop_back(1); for (auto [i, indexVal] : llvm::enumerate(leadIndices)) { if (inputShape.getShape()[i] == 1) continue; - isContiguous &= isLoopInvariantIdx(linalgOp, indexVal); + leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal); + } + + if (!leadingIdxsLoopInvariant) { + LDBG("Found gather load: " << extractOp); + return VectorMemoryAccessKind::Gather; } - // 3b. Analyze the trailing index for `extractOp`. + // 4. Analyze the trailing index for `extractOp`. + // At this point we know that the leading indices are loop invariant. This + // means that is potentially a scalar or a contiguous load. We can decide + // based on the trailing idx. auto extractOpTrailingIdx = indices.back(); - // For contiguous loads, the trailing `extractOp` index should increment with - // every loop iteration. This effectively means that it must be based on the - // trailing loop index. This is what the following bool captures. + + // 4a. Scalar broadcast load + // If the trailing index is loop invariant then this is a scalar load. + if (leadingIdxsLoopInvariant && + isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) { + LDBG("Found scalar broadcast load: " << extractOp); + + return VectorMemoryAccessKind::ScalarBroadcast; + } + + // 4b. Contiguous loads + // The trailing `extractOp` index should increment with every loop iteration. + // This effectively means that it must be based on the trailing loop index. + // This is what the following bool captures. bool foundIndexOp = false; - isContiguous &= + bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp); - isContiguous &= foundIndexOp; + isContiguousLoad &= foundIndexOp; - if (isContiguous) { + if (isContiguousLoad) { LDBG("Found contigous load: " << extractOp); return VectorMemoryAccessKind::Contiguous; } + // 5. Fallback case - gather load. + LDBG("Found gather load: " << extractOp); return VectorMemoryAccessKind::Gather; } @@ -948,16 +965,14 @@ maskConstantOp, passThruConstantOp); gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); - LDBG("Vectorised as gather load: " << extractOp); + LDBG("Vectorised as gather load: " << extractOp << "\n"); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; } - // 2. Handle contiguous access. - LDBG("Vectorised as contiguous load: " << extractOp); - SmallVector transferReadIdxs; - auto resTrailingDim = resultType.getShape().back(); - auto zero = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type())); + // 2. Handle: + // a. scalar loads + broadcast, + // b. contiguous loads. + // Both cases use vector.transfer_read. // Collect indices for `vector.transfer_read`. At this point, the indices will // either be scalars or would have been broadcast to vectors matching the @@ -972,6 +987,10 @@ // * for scalar indices - just re-use it, // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom // (0th) element and use that. + SmallVector transferReadIdxs; + auto resTrailingDim = resultType.getShape().back(); + auto zero = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type())); for (size_t i = 0; i < extractOp.getIndices().size(); i++) { auto idx = bvm.lookup(extractOp.getIndices()[i]); if (idx.getType().isIndex()) { @@ -988,10 +1007,24 @@ // `tensor.extract_element` is always in-bounds, hence the following holds. auto dstRank = resultType.getRank(); + auto srcRank = extractOp.getTensor().getType().getRank(); SmallVector inBounds(dstRank, true); - // Create a permutation map for transfer_read Op. - auto srcRank = extractOp.getTensor().getType().getRank(); + // 2a. Handle scalar broadcast access. + if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { + MLIRContext *ctx = rewriter.getContext(); + SmallVector exprs(dstRank, getAffineConstantExpr(0, ctx)); + auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); + + auto transferReadOp = rewriter.create( + loc, resultType, extractOp.getTensor(), transferReadIdxs, + permutationMap, inBounds); + + LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n"); + return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; + } + + // 2b. Handle contiguous access. auto permutationMap = AffineMap::getMinorIdentityMap( srcRank, std::min(dstRank, srcRank), rewriter.getContext()); @@ -1012,6 +1045,8 @@ auto transferReadOp = rewriter.create( loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap, inBounds); + + LDBG("Vectorised as contiguous load: " << extractOp); return VectorizationResult{VectorizationStatus::NewOp, transferReadOp}; } diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -51,17 +51,17 @@ return %2 : tensor<1x1x3xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32> -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> -// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// Magic "5" below comes from (1 * 3 + 2) (1: index into dim 1, 2: index into dim 2) -// CHECK: %[[IDX:.*]] = arith.constant dense<5> : vector<1x1x3xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[IDX]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32> -// CHECK: vector.transfer_write %[[GATHER]] -// CHECK: } +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]]{{\[}}%[[C1]], %[[C2]]] : tensor<3x3xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32> +// CHECK: %[[VAL_7:.*]] = vector.transfer_write %[[BCAST]], %[[ARG_1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> +// CHECK: return %[[VAL_7]] : tensor<1x1x3xf32> +// CHECK: } transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): @@ -316,44 +316,43 @@ return %1 : tensor<1x1x4xf32> } -// First `tensor.extract` is a loop invariant scalar load. This way, the -// following `tensor.extract` Op becomes a contiguous load (all other Ops used -// for address calculation also satisfy the required conditions). -// TODO: Don't use vector.gather for the first tensor.extract. - // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>, -// CHECK-SAME: -> tensor<1x1x4xf32> { +// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> { // CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex> // CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense : vector<1x1x4xi1> -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<0> : vector<1x1x4xi32> -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant dense<256> : vector<1x1x4xindex> -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1x1x4xf32> -// CHECK: %[[VAL_15:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> -// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<1x1x4xindex> -// CHECK: %[[VAL_18:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> -// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex> +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex> +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32> +// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex> +// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex> +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex> +// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex> // CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex> -// CHECK: %[[VAL_21:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> -// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : vector<1x1x4xindex> -// CHECK: %[[VAL_23:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_10]]] {{\[}}%[[VAL_17]]], %[[VAL_8]], %[[VAL_9]] : tensor<1x20xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32> -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : vector<1x1x4xi32> to vector<1x1x4xindex> -// CHECK: %[[VAL_25:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_6]] : vector<1x1x4xindex> -// CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_25]], %[[VAL_11]] : vector<1x1x4xindex> -// CHECK: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_26]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_28:.*]] = vector.extractelement %[[VAL_27]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex> -// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_30:.*]] = vector.extractelement %[[VAL_29]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex> -// CHECK: %[[VAL_31:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_28]], %[[VAL_30]]], %[[VAL_13]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> -// CHECK: %[[VAL_32:.*]] = vector.broadcast %[[VAL_31]] : vector<1x4xf32> to vector<1x1x4xf32> -// CHECK: %[[VAL_33:.*]] = vector.transfer_write %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> -// CHECK: return %[[VAL_33]] : tensor<1x1x4xf32> +// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex> +// First `tensor.extract` from the generic Op - loop invariant scalar load. +// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32> +// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index +// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex> +// CHECK: %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex> +// CHECK: %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex> +// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex> +// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used +// for address calculation also satisfy the required conditions). +// CHECK: %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> +// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32> +// CHECK: %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> +// CHECK: return %[[VAL_34]] : tensor<1x1x4xf32> // CHECK: } transform.sequence failures(propagate) {