diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -24,6 +24,38 @@ return false; } +/// Generates a for loop over ZA tile slices where the induction variable is +/// the tile slice index. +static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc, + Type eltType) { + auto step = rewriter.create(loc, 1); + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(eltType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + auto forOp = + rewriter.create(loc, lowerBound, numTileSlices, step); + rewriter.setInsertionPointToStart(forOp.getBody()); + return forOp; +} + +/// Returns a tile of the given vector type. +static arm_sme::CastTileToVector +getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc, + VectorType type) { + unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth(); + + // Create 'arm_sme.get_tile' op. + auto tileId = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth)); + + // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type. + return rewriter.create(loc, type, tileId); +} + namespace { /// Conversion pattern for vector.transfer_write. @@ -122,29 +154,10 @@ tileSliceType, denseAttr.getSplatValue()); auto constantOp1D = rewriter.create(loc, denseAttr1D); - unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); - - // Create 'arm_sme.get_tile' op. - auto tileId = rewriter.create( - loc, rewriter.getIntegerType(tileElementWidth)); - - // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to - // use as input tile to 'arm_sme.move_vector_to_tile_slice' ops. - auto tile = - rewriter.create(loc, tileType, tileId); - - auto step = rewriter.create(loc, 1); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); - auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - // Create a loop that broadcasts the constant to each ZA tile slice. - auto forOp = - rewriter.create(loc, lowerBound, numTileSlices, step); - rewriter.setInsertionPointToStart(forOp.getBody()); + arm_sme::CastTileToVector tile = + getSMETileAndCastToVector(rewriter, loc, tileType); + + auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); auto tileSliceIndex = forOp.getInductionVar(); // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice. @@ -159,10 +172,78 @@ } }; +/// Conversion pattern for vector.broadcast. +/// +/// Example: +/// +/// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> +/// +/// is converted to: +/// +/// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> +/// scf.for %tile_slice_index = %c0 to %num_tile_slices step %c1 { +/// arm_sme.move_vector_to_tile_slice %broadcast_to_1d, %tile, +/// %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> +/// } +/// +/// Supports scalar, 0-d vector, and 1-d vector broadcasts. +struct BroadcastOpToArmSMELowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, + PatternRewriter &rewriter) const final { + auto tileType = broadcastOp.getResultVectorType(); + if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) + return failure(); + + OpBuilder::InsertionGuard g(rewriter); + auto loc = broadcastOp.getLoc(); + + auto srcType = broadcastOp.getSourceType(); + auto srcVectorType = dyn_cast(srcType); + auto tileElementType = tileType.getElementType(); + + Value broadcastOp1D; + if (srcType.isIntOrFloat() || + (srcVectorType && (srcVectorType.getRank() == 0))) { + // Broadcast scalar or 0-d vector to 1-d vector. + auto tileSliceType = + VectorType::get(tileType.getShape().drop_front(), tileElementType, + /*scalableDims=*/{true}); + broadcastOp1D = rewriter.create( + loc, tileSliceType, broadcastOp.getSource()); + } else if (srcVectorType && (srcVectorType.getRank() == 1)) + // Value to broadcast is already a 1-d vector, nothing to do. + broadcastOp1D = broadcastOp.getSource(); + else + return failure(); + + arm_sme::CastTileToVector tile = + getSMETileAndCastToVector(rewriter, loc, tileType); + + // Create a loop over ZA tile slices. + auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType); + auto tileSliceIndex = forOp.getInductionVar(); + + // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value to each + // tile slice. + rewriter.create( + loc, tileType, broadcastOp1D, tile, tileSliceIndex); + + rewriter.setInsertionPointAfter(forOp); + + rewriter.replaceOp(broadcastOp, tile); + + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { patterns.add(&ctx); + VectorStoreToArmSMELowering, ConstantOpToArmSMELowering, + BroadcastOpToArmSMELowering>(&ctx); } diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -153,3 +153,54 @@ %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor return %0 : tensor } + +// ============================================================================= +// vector.broadcast +// ============================================================================= + +// ----- + +// CHECK-LABEL: func.func @broadcast_vec2d_from_i32( +// CHECK-SAME: %[[SRC:.*]]: i32) { +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32> +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> +// CHECK: %[[VSCALE:.*]] = vector.vscale +// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index +// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK: %[[C10:.*]] = arm_sme.move_vector_to_tile_slice %[[SRC_1D]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32> +// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[4]x[4]xi32>) -> () +func.func @broadcast_vec2d_from_i32(%arg0: i32) { + %0 = vector.broadcast %arg0 : i32 to vector<[4]x[4]xi32> + "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vec2d_from_vec0d( +// CHECK-SAME: %[[SRC:.*]]: vector) { +// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : vector to vector<[4]xf32> +// CHECK: scf.for +// CHECK: arm_sme.move_vector_to_tile_slice %[[SRC_1D]], {{.*}} +func.func @broadcast_vec2d_from_vec0d(%arg0: vector) { + %0 = vector.broadcast %arg0 : vector to vector<[4]x[4]xf32> + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vec2d_from_vec1d( +// CHECK-SAME: %[[SRC:.*]]: vector<[8]xi16>) { +// CHECK-NOT: vector.broadcast +// CHECK: scf.for +// CHECK: arm_sme.move_vector_to_tile_slice %[[SRC]], {{.*}} +func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) { + %0 = vector.broadcast %arg0 : vector<[8]xi16> to vector<[8]x[8]xi16> + "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> () + return +}