diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -414,6 +414,51 @@ }]; } +def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [ + AllTypesMatch<["tile", "result"]>, + TypesMatchWith< + "type of 'vector' matches type of 'tile' slice", + "tile", "vector", + "VectorType::get(" + "::llvm::cast($_self).getShape().drop_front()," + "::llvm::cast($_self).getElementType()," + "/*scalableDims=*/{true})">, +]> { + let summary = "Move 1-D scalable vector to slice of 2-D tile"; + let description = [{ + The vector to tile slice operation moves a 1-D scalable vector to a slice + of a 2-D scalable vector tile at the given index. The type of the 1-D + scalable vector to be moved must match the type of the tile slice. A tile + slice is a 1-D vector of horizontally or vertically contiguous elements + within a ZA tile. Horizontal tile slices are currently assumed when + lowering to intrinsics. The updated tile is returned as the result. + + Example 1: Move a vector<[16]xi8> into tile at given index. + ```mlir + %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8> + ``` + + Example 2: Move a vector<[2]xf64> into tile at given index. + ```mlir + %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64> + ``` + }]; + let arguments = (ins + SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index); + let results = (outs SMETile:$result); + + let extraClassDeclaration = [{ + VectorType getTileType() { + return ::llvm::cast(getTile().getType()); + } + }]; + + let assemblyFormat = [{ + $vector `,` $tile `,` $tile_slice_index + attr-dict `:` type($vector) `into` type($result) + }]; +} + //===----------------------------------------------------------------------===// // ArmSME Intrinsic op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -93,15 +93,67 @@ LogicalResult matchAndRewrite(arith::ConstantOp constantOp, PatternRewriter &rewriter) const final { - auto vType = dyn_cast(constantOp.getType()); - if (!vType || !arm_sme::isValidSMETileVectorType(vType)) + auto tileType = dyn_cast(constantOp.getType()); + if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); auto denseAttr = dyn_cast(constantOp.getValueAttr()); - if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) + if (!denseAttr || !denseAttr.isSplat()) return failure(); - rewriter.replaceOpWithNewOp(constantOp, vType); + auto tileElementType = tileType.getElementType(); + + // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. + if (isSplatZero(tileElementType, denseAttr)) { + rewriter.replaceOpWithNewOp(constantOp, tileType); + return success(); + } + + // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice' + // ops that broadcast the constant to each tile slice. + OpBuilder::InsertionGuard g(rewriter); + auto loc = constantOp.getLoc(); + + // Unpack 1-d vector type from 2-d vector type. + auto tileSliceType = + VectorType::get(tileType.getShape().drop_front(), tileElementType, + /*scalableDims=*/{true}); + auto denseAttr1D = DenseElementsAttr::get( + tileSliceType, denseAttr.getSplatValue()); + auto constantOp1D = rewriter.create(loc, denseAttr1D); + + unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + + // Create 'arm_sme.get_tile' op. + auto tileId = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth)); + + // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to + // use as input tile to 'arm_sme.move_vector_to_tile_slice' ops. + auto tile = + rewriter.create(loc, tileType, tileId); + + auto step = rewriter.create(loc, 1); + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + // Create a loop that broadcasts the constant to each ZA tile slice. + auto forOp = + rewriter.create(loc, lowerBound, numTileSlices, step); + rewriter.setInsertionPointToStart(forOp.getBody()); + auto tileSliceIndex = forOp.getInductionVar(); + + // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile slice. + rewriter.create( + loc, tileType, constantOp1D, tile, tileSliceIndex); + + rewriter.setInsertionPointAfter(forOp); + + rewriter.replaceOp(constantOp, tile); return success(); } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -308,6 +308,58 @@ } }; +/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal +/// tile slices are currently supported. +struct MoveVectorToTileSliceToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arm_sme::MoveVectorToTileSliceOp moveVectorToTileSliceOp, + arm_sme::MoveVectorToTileSliceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = moveVectorToTileSliceOp.getLoc(); + auto tileType = moveVectorToTileSliceOp.getTileType(); + auto tileElementType = tileType.getElementType(); + unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + + // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being + // loaded to. + auto tile = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth), + moveVectorToTileSliceOp.getTile()); + + auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex(); + + // Cast tile slice from index to i32 for intrinsic. + auto tileSliceI32 = rewriter.create( + loc, rewriter.getI32Type(), tileSlice); + + // Create all active predicate mask. + auto one = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), + /*scalableDims=*/{true}); + auto allActiveMask = rewriter.create(loc, predTy, one); + + auto tileI32 = castTileIDToI32(tile, loc, rewriter); + + // Create 'arm_sme.intr.write.horiz' to write vector to tile slice. + rewriter.create( + loc, tileI32, tileSliceI32, allActiveMask, + moveVectorToTileSliceOp.getVector()); + + // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with + // 'arm_sme.cast_tile_to_vector' to preserve dataflow. + rewriter.replaceOpWithNewOp( + moveVectorToTileSliceOp, tileType, tile); + + return success(); + } +}; + } // namespace void mlir::configureArmSMELegalizeForExportTarget( @@ -320,8 +372,8 @@ arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, - arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_za_enable, - arm_sme::aarch64_sme_za_disable>(); + arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz, + arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); // Mark 'func.func' ops as legal if either: @@ -353,5 +405,6 @@ LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(patterns.getContext()); patterns.add(converter); + LoadTileSliceToArmSMELowering, + MoveVectorToTileSliceToArmSMELowering>(converter); } diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir --- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir @@ -83,3 +83,47 @@ "prevent.dce"(%zero) : (vector<[2]x[2]xf64>) -> () return } + +// ============================================================================= +// Non-zero arith.constant dense to SME +// ============================================================================= + +// ----- + +// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() { +// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK: %[[VSCALE:.*]] = vector.vscale +// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index +// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[16]xi8> into vector<[16]x[16]xi8> +// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[16]x[16]xi8>) -> () +func.func @arith_constant_dense_2d_nonzero_i8() { + %two = arith.constant dense<2> : vector<[16]x[16]xi8> + "prevent.dce"(%two) : (vector<[16]x[16]xi8>) -> () + return +} + +// ----- + +// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() { +// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i64 +// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i64 to vector<[2]x[2]xf64> +// CHECK: %[[VSCALE:.*]] = vector.vscale +// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index +// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK: arm_sme.move_vector_to_tile_slice %[[C2_SPLAT]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[2]xf64> into vector<[2]x[2]xf64> +// CHECK: "prevent.dce"(%[[TILE]]) : (vector<[2]x[2]xf64>) -> () +func.func @arith_constant_dense_2d_nonzero_f64() { + %two = arith.constant dense<2.0> : vector<[2]x[2]xf64> + "prevent.dce"(%two) : (vector<[2]x[2]xf64>) -> () + return +} diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -71,3 +71,21 @@ %0 = arm_sme.get_tile_id : i1 return %0 : i1 } + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}} + %0 = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xi8> into vector<[16]x[16]xi8> + return %0 : vector<[16]x[16]xi8> +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vector<[8]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]x[4]xf32> { + %c0 = arith.constant 0 : index + // expected-error@+1 {{op failed to verify that type of 'vector' matches type of 'tile' slice}} + %0 = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xf32> into vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -577,3 +577,84 @@ arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xf64> return } + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_i16(%vector : vector<[8]xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[8]xi16> into vector<[8]x[8]xi16> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xi16> into vector<[8]x[8]xi16> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_i64(%vector : vector<[2]xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[2]xi64> into vector<[2]x[2]xi64> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xi64> into vector<[2]x[2]xi64> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_i128(%vector : vector<[1]xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[1]xi128> into vector<[1]x[1]xi128> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[1]xi128> into vector<[1]x[1]xi128> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_f16(%vector : vector<[8]xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xf16> into vector<[8]x[8]xf16> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[8]xbf16> into vector<[8]x[8]xbf16> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[8]xbf16> into vector<[8]x[8]xbf16> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_f32(%vector : vector<[4]xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[4]xf32> into vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xf32> into vector<[4]x[4]xf32> + return +} + +// ----- + +func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> () { + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[2]xf64> into vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64> + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ +// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ +// RUN: %mcr_aarch64_cmd \ +// RUN: -march=aarch64 -mattr=+sve,+sme \ +// RUN: -e entry -entry-point-result=i32 \ +// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +// Integration test demonstrating filling a 32-bit element ZA tile with a +// non-zero constant via vector to tile (MOVA) ops. + +llvm.func @printCString(!llvm.ptr) + +func.func @printTileBegin() { + %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @printTileEnd() { + %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @entry() -> i32 { + %c0 = arith.constant 0 : index + %c1_index = arith.constant 1 : index + + %min_elts_s = arith.constant 4 : index + %vscale = vector.vscale + + // "svl" refers to the Streaming Vector Length and "svl_s" the number of + // 32-bit elements in a vector of SVL bits. + %svl_s = arith.muli %min_elts_s, %vscale : index + + // Allocate memory. + %tilesize = arith.muli %svl_s, %svl_s : index + %mem = memref.alloca(%tilesize) : memref + + // Fill a tile with '123'. This will get lowered to a 1-d vector splat of + // '123' and a loop that writes this vector to each tile slice in the ZA + // tile. + %tile = arith.constant dense<123> : vector<[4]x[4]xi32> + + // Store tile to memory so it can be dumped. + vector.store %tile, %mem[%c0] : memref, vector<[4]x[4]xi32> + + // Dump "mem". The smallest SVL is 128-bits so the tile will be at least + // 4x4xi32. + // + // CHECK: TILE BEGIN + // CHECK-NEXT: ( 123, 123, 123, 123 + // CHECK-NEXT: ( 123, 123, 123, 123 + // CHECK-NEXT: ( 123, 123, 123, 123 + // CHECK-NEXT: ( 123, 123, 123, 123 + // CHECK: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %tilesize step %svl_s { + %tileslice = vector.load %mem[%i] : memref, vector<[4]xi32> + vector.print %tileslice : vector<[4]xi32> + } + func.call @printTileEnd() : () -> () + + %c0_i32 = arith.constant 0 : i32 + return %c0_i32 : i32 +} + +llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A") +llvm.mlir.global internal constant @str_tile_end("TILE END\0A")