diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -729,11 +729,7 @@ return offset; } -enum VectorMemoryAccessKind { - // TODO: ScalarBroadcast, - Contiguous, - Gather -}; +enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; /// Checks whether /p val can be used for calculating a loop invariant index. static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) { @@ -872,9 +868,9 @@ if (inputShape.getShape().back() == 1) return VectorMemoryAccessKind::Gather; - bool isContiguous = true; + bool leadingIdxsLoopInvariant = true; - // 3a. Analyze the leading indices of `extractOp`. + // 3. Analyze the leading indices of `extractOp`. // Look at the way each index is calculated and decide whether it is suitable // for a contiguous load, i.e. whether it's loop invariant. auto indices = extractOp.getIndices(); @@ -884,20 +880,34 @@ if (inputShape.getShape()[i] == 1) continue; - isContiguous &= isLoopInvariantIdx(linalgOp, indexVal); + leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal); } - // 3b. Analyze the trailing index for `extractOp`. + if (!leadingIdxsLoopInvariant) + return VectorMemoryAccessKind::Gather; + + // 4. Analyze the trailing index for `extractOp`. auto extractOpTrailingIdx = indices.back(); - // For contiguous loads, the trailing `extractOp` index should increment with - // every loop iteration. This effectively means that it must be based on the - // trailing loop index. This is what the following bool captures. + + // 4a. Scalar broadcast load + // If the trailing index is loop invariant then this is a scalar load. + if (isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) { + LDBG("Found scalar broadcast load: " << extractOp); + + return leadingIdxsLoopInvariant ? VectorMemoryAccessKind::ScalarBroadcast + : VectorMemoryAccessKind::Gather; + } + + // 4b. Contiguous loads + // The trailing `extractOp` index should increment with every loop iteration. + // This effectively means that it must be based on the trailing loop index. + // This is what the following bool captures. bool foundIndexOp = false; - isContiguous &= + bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp); - isContiguous &= foundIndexOp; + isContiguousLoad &= foundIndexOp; - if (isContiguous) { + if (isContiguousLoad) { LDBG("Found contigous load: " << extractOp); return VectorMemoryAccessKind::Contiguous; } @@ -948,10 +958,59 @@ maskConstantOp, passThruConstantOp); gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); - LDBG("Vectorised as gather load: " << extractOp); + LDBG("Vectorised as gather load: " << extractOp << "\n"); return VectorizationResult{VectorizationStatus::NewOp, gatherOp}; } + // 2. Handle scalar broadcast access. Similarly to the "gather load" case, + // generate a vector.gather. However, load only one element and then broadcast + // it. + if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { + SmallVector baseIndices( + extractOp.getIndices().size(), + rewriter.create(loc, 0)); + + Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape); + + // <1x1xNxi32> --> + auto resTrailingDim = resultType.getShape().back(); + auto offsetAs1dVector = rewriter.create( + loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()), + offset); + + // There's only 1 unique offset value in the `offset` vector. Extract it: + // --> i32 + auto zero = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getZeroAttr(rewriter.getI32Type())); + auto offsetUniqueVal = + rewriter.create(loc, offsetAs1dVector, zero); + + // Cast the scalar index to a 1-element vector: + // i32 --> <1xi32> + auto resultTypeAs1dVec = + VectorType::get({1}, extractOp.getResult().getType()); + auto offsetFor1Val = broadcastIfNeeded( + rewriter, offsetUniqueVal.getResult(), resultTypeAs1dVec.getShape()); + + auto maskConstantOp = rewriter.create( + loc, + DenseIntElementsAttr::get(VectorType::get({1}, rewriter.getI1Type()), + /*value=*/true)); + + auto passThruConstantOp = rewriter.create( + loc, rewriter.getZeroAttr(resultTypeAs1dVec)); + Operation *gatherOp = rewriter.create( + loc, resultTypeAs1dVec, extractOp.getTensor(), baseIndices, + offsetFor1Val, maskConstantOp, passThruConstantOp); + + auto readValue = rewriter.create( + loc, resultType, gatherOp->getResult(0)); + + LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n"); + return VectorizationResult{VectorizationStatus::NewOp, readValue}; + } + // 2. Handle contiguous access. LDBG("Vectorised as contiguous load: " << extractOp); SmallVector transferReadIdxs; diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -51,17 +51,23 @@ return %2 : tensor<1x1x3xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx -// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32> -// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32> -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1x3xi1> -// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32> -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// Magic "5" below comes from (1 * 3 + 2) (1: index into dim 1, 2: index into dim 2) -// CHECK: %[[IDX:.*]] = arith.constant dense<5> : vector<1x1x3xindex> -// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[IDX]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32> -// CHECK: vector.transfer_write %[[GATHER]] -// CHECK: } +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x3xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant dense<5> : vector<1x1x3xindex> +// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant dense : vector<1xi1> +// CHECK: %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> +// CHECK: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_3]] : vector<1x1x3xindex> to vector<3xindex> +// CHECK: %[[VAL_8:.*]] = vector.extractelement %[[VAL_7]]{{\[}}%[[VAL_4]] : i32] : vector<3xindex> +// CHECK: %[[VAL_9:.*]] = vector.broadcast %[[VAL_8]] : index to vector<1xindex> +// Load a scalar and broadcast it +// CHECK: %[[VAL_10:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]]] {{\[}}%[[VAL_9]]], %[[VAL_5]], %[[VAL_6]] : tensor<3x3xf32>, vector<1xindex>, vector<1xi1>, vector<1xf32> into vector<1xf32> +// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xf32> to vector<1x1x3xf32> +// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_1]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> +// CHECK: return %[[VAL_12]] : tensor<1x1x3xf32> +// CHECK: } transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): @@ -322,38 +328,43 @@ // TODO: Don't use vector.gather for the first tensor.extract. // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>, -// CHECK-SAME: -> tensor<1x1x4xf32> { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex> -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense : vector<1x1x4xi1> -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<0> : vector<1x1x4xi32> -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant dense<256> : vector<1x1x4xindex> -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> { +// CHECK: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex> +// CHECK: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_10:.*]] = arith.constant dense : vector<1xi1> +// CHECK: %[[VAL_11:.*]] = arith.constant dense<0> : vector<1xi32> +// CHECK: %[[VAL_12:.*]] = arith.constant dense<256> : vector<1x1x4xindex> +// CHECK: %[[VAL_13:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_14:.*]] = tensor.empty() : tensor<1x1x4xf32> -// CHECK: %[[VAL_15:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> -// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> +// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex> // CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<1x1x4xindex> -// CHECK: %[[VAL_18:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> +// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex> // CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex> // CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex> -// CHECK: %[[VAL_21:.*]] = vector.broadcast %{{.*}} : index to vector<1x1x4xindex> +// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex> // CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : vector<1x1x4xindex> -// CHECK: %[[VAL_23:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_10]], %[[VAL_10]]] {{\[}}%[[VAL_17]]], %[[VAL_8]], %[[VAL_9]] : tensor<1x20xi32>, vector<1x1x4xindex>, vector<1x1x4xi1>, vector<1x1x4xi32> into vector<1x1x4xi32> -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : vector<1x1x4xi32> to vector<1x1x4xindex> -// CHECK: %[[VAL_25:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_6]] : vector<1x1x4xindex> -// CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_25]], %[[VAL_11]] : vector<1x1x4xindex> -// CHECK: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_26]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_28:.*]] = vector.extractelement %[[VAL_27]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex> -// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_30:.*]] = vector.extractelement %[[VAL_29]]{{\[}}%[[VAL_12]] : i32] : vector<4xindex> -// CHECK: %[[VAL_31:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_28]], %[[VAL_30]]], %[[VAL_13]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> -// CHECK: %[[VAL_32:.*]] = vector.broadcast %[[VAL_31]] : vector<1x4xf32> to vector<1x1x4xf32> -// CHECK: %[[VAL_33:.*]] = vector.transfer_write %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_10]], %[[VAL_10]], %[[VAL_10]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> -// CHECK: return %[[VAL_33]] : tensor<1x1x4xf32> +// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_17]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_24:.*]] = vector.extractelement %[[VAL_23]]{{\[}}%[[VAL_9]] : i32] : vector<4xindex> +// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1xindex> +// Load a scalar and broadcast it +// CHECK: %[[VAL_26:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_8]], %[[VAL_8]]] {{\[}}%[[VAL_25]]], %[[VAL_10]], %[[VAL_11]] : tensor<1x20xi32>, vector<1xindex>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK: %[[VAL_27:.*]] = arith.index_cast %[[VAL_26]] : vector<1xi32> to vector<1xindex> +// CHECK: %[[VAL_28:.*]] = vector.broadcast %[[VAL_27]] : vector<1xindex> to vector<1x1x4xindex> +// CHECK: %[[VAL_29:.*]] = arith.maxsi %[[VAL_28]], %[[VAL_6]] : vector<1x1x4xindex> +// CHECK: %[[VAL_30:.*]] = arith.minsi %[[VAL_29]], %[[VAL_12]] : vector<1x1x4xindex> +// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_30]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_32:.*]] = vector.extractelement %[[VAL_31]]{{\[}}%[[VAL_9]] : i32] : vector<4xindex> +// CHECK: %[[VAL_33:.*]] = vector.shape_cast %[[VAL_22]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_34:.*]] = vector.extractelement %[[VAL_33]]{{\[}}%[[VAL_9]] : i32] : vector<4xindex> +// CHECK: %[[VAL_35:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_32]], %[[VAL_34]]], %[[VAL_13]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> +// CHECK: %[[VAL_36:.*]] = vector.broadcast %[[VAL_35]] : vector<1x4xf32> to vector<1x1x4xf32> +// CHECK: %[[VAL_37:.*]] = vector.transfer_write %[[VAL_36]], %[[VAL_14]]{{\[}}%[[VAL_8]], %[[VAL_8]], %[[VAL_8]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> +// CHECK: return %[[VAL_37]] : tensor<1x1x4xf32> // CHECK: } transform.sequence failures(propagate) {