diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -223,6 +223,10 @@ rewriter.create(loc, allActiveMask, ptr, tileI32, tileSliceI32); break; + case 128: + rewriter.create(loc, allActiveMask, ptr, + tileI32, tileSliceI32); + break; } // The load intrinsics have no result, replace 'arm_sme.tile_load' with @@ -294,6 +298,10 @@ rewriter.replaceOpWithNewOp( storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); break; + case 128: + rewriter.replaceOpWithNewOp( + storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); + break; } return success(); @@ -309,9 +317,10 @@ arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz, - arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_st1b_horiz, - arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, - arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_za_enable, + arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz, + arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, + arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, + arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp --- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp @@ -25,17 +25,15 @@ } bool mlir::arm_sme::isValidSMETileElementType(Type type) { - // TODO: add support for i128. return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) || - type.isInteger(64) || type.isF16() || type.isBF16() || type.isF32() || - type.isF64(); + type.isInteger(64) || type.isInteger(128) || type.isF16() || + type.isBF16() || type.isF32() || type.isF64() || type.isF128(); } bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) { if ((vType.getRank() != 2) && vType.allDimsScalable()) return false; - // TODO: add support for i128. auto elemType = vType.getElementType(); if (!isValidSMETileElementType(elemType)) return false; diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -220,6 +220,20 @@ // ----- +// CHECK-LABEL: @vector_load_i128( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i128 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i128 to vector<[1]x[1]xi128> +// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i128 to i32 +// CHECK: arm_sme.intr.ld1q.horiz +func.func @vector_load_i128(%arg0 : memref) -> vector<[1]x[1]xi128> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[1]x[1]xi128> + return %tile : vector<[1]x[1]xi128> +} + +// ----- + // CHECK-LABEL: @vector_store_i8( // CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, // CHECK-SAME: %[[ARG0:.*]]: memref) @@ -363,3 +377,17 @@ vector.store %tile, %arg0[%c0, %c0] : memref, vector<[2]x[2]xf64> return } + +// ----- + +// CHECK-LABEL: @vector_store_i128( +// CHECK-SAME: %[[TILE:.*]]: vector<[1]x[1]xi128>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[1]x[1]xi128> to i128 +// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i128 to i32 +// CHECK: arm_sme.intr.st1q.horiz +func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[1]x[1]xi128> + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir @@ -0,0 +1,113 @@ +// DEFINE: %{entry_point} = test_load_store_zaq0 +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme \ +// DEFINE: -e %{entry_point} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s + +/// Note: The SME ST1Q/LD1Q instructions are currently broken in QEMU +/// see: https://gitlab.com/qemu-project/qemu/-/issues/1833 +/// This test is expected to fail until a fixed version of QEMU can be used. + +/// FIXME: Remove the 'XFAIL' below once a fixed QEMU version is available +/// (and installed on CI buildbot). +// XFAIL: {{.*}} + +func.func @print_i8s(%bytes: memref, %len: index) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + scf.for %i = %c0 to %len step %c16 { + %v = vector.load %bytes[%i] : memref, vector<16xi8> + vector.print %v : vector<16xi8> + } + return +} + +llvm.func @printCString(!llvm.ptr) + +func.func @print_str(%str: !llvm.ptr>) { + %c0 = llvm.mlir.constant(0 : index) : i64 + %str_bytes = llvm.getelementptr %str[%c0, %c0] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%str_bytes) : (!llvm.ptr) -> () + return +} + +func.func @vector_copy_i128(%src: memref, %dst: memref) { + %c0 = arith.constant 0 : index + %tile = vector.load %src[%c0, %c0] : memref, vector<[1]x[1]xi128> + vector.store %tile, %dst[%c0, %c0] : memref, vector<[1]x[1]xi128> + return +} + +func.func @test_load_store_zaq0() { + %init_a_str = llvm.mlir.addressof @init_tile_a : !llvm.ptr> + %init_b_str = llvm.mlir.addressof @init_tile_b : !llvm.ptr> + %final_a_str = llvm.mlir.addressof @final_tile_a : !llvm.ptr> + %final_b_str = llvm.mlir.addressof @final_tile_b : !llvm.ptr> + + %c0 = arith.constant 0 : index + %min_elts_q = arith.constant 1 : index + %bytes_per_128_bit = arith.constant 16 : index + + /// Calculate the size of an 128-bit tile, e.g. ZA{n}.q, in bytes: + %vscale = vector.vscale + %svl_q = arith.muli %min_elts_q, %vscale : index + %zaq_size = arith.muli %svl_q, %svl_q : index + %zaq_size_bytes = arith.muli %zaq_size, %bytes_per_128_bit : index + + /// Allocate memory for two 128-bit tiles (A and B) and fill them a constant. + /// The tiles are allocated as bytes so we can fill and print them, as there's + /// very little that can be done with 128-bit types directly. + %tile_a_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref + %tile_b_bytes = memref.alloca(%zaq_size_bytes) {alignment = 16} : memref + %fill_a_i8 = arith.constant 7 : i8 + %fill_b_i8 = arith.constant 64 : i8 + linalg.fill ins(%fill_a_i8 : i8) outs(%tile_a_bytes : memref) + linalg.fill ins(%fill_b_i8 : i8) outs(%tile_b_bytes : memref) + + /// Get an 128-bit view of the memory for tiles A and B: + %tile_a = memref.view %tile_a_bytes[%c0][%svl_q, %svl_q] : + memref to memref + %tile_b = memref.view %tile_b_bytes[%c0][%svl_q, %svl_q] : + memref to memref + + // CHECK-LABEL: INITIAL TILE A: + // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 ) + func.call @print_str(%init_a_str) : (!llvm.ptr>) -> () + func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref, index) -> () + vector.print punctuation + + // CHECK-LABEL: INITIAL TILE B: + // CHECK: ( 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64 ) + func.call @print_str(%init_b_str) : (!llvm.ptr>) -> () + func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref, index) -> () + vector.print punctuation + + /// Load tile A and store it to tile B: + func.call @vector_copy_i128(%tile_a, %tile_b) : (memref, memref) -> () + + // CHECK-LABEL: FINAL TILE A: + // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 ) + func.call @print_str(%final_a_str) : (!llvm.ptr>) -> () + func.call @print_i8s(%tile_a_bytes, %zaq_size_bytes) : (memref, index) -> () + vector.print punctuation + + // CHECK-LABEL: FINAL TILE B: + // CHECK: ( 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 ) + func.call @print_str(%final_b_str) : (!llvm.ptr>) -> () + func.call @print_i8s(%tile_b_bytes, %zaq_size_bytes) : (memref, index) -> () + + return +} + +llvm.mlir.global internal constant @init_tile_a ("INITIAL TILE A:\0A\00") +llvm.mlir.global internal constant @init_tile_b ("INITIAL TILE B:\0A\00") +llvm.mlir.global internal constant @final_tile_a(" FINAL TILE A:\0A\00") +llvm.mlir.global internal constant @final_tile_b(" FINAL TILE B:\0A\00")