Index: mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp =================================================================== --- mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -26,6 +26,25 @@ using namespace mlir; namespace { +/// Adjusts `indices` as follows for a given tile slice and returns them in +/// `outIndices`: +/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts)) +/// rank 2: (indices[0] + tileSliceIndex, indices[1]) +void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, + Value tileSliceNumElts, SmallVector &outIndices, + Location loc, PatternRewriter &rewriter) { + assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); + + outIndices = indices; + + auto tileSliceOffset = tileSliceIndex; + if (rank == 1) + tileSliceOffset = + rewriter.create(loc, tileSliceOffset, tileSliceNumElts); + auto baseIndexPlusTileSliceOffset = + rewriter.create(loc, indices[0], tileSliceOffset); + outIndices[0] = baseIndexPlusTileSliceOffset; +} /// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice /// using `arm_sme.load_tile_slice`. @@ -78,6 +97,9 @@ auto vscale = rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, + // ..., SVL_Q). auto numTileSlices = rewriter.create(loc, minTileSlices, vscale); auto forOp = @@ -85,12 +107,16 @@ rewriter.setInsertionPointToStart(forOp.getBody()); + // Create 'arm_sme.load_tile_slice' to load tile slice from memory into + // tile. + SmallVector indices; auto tileSliceIndex = forOp.getInductionVar(); - // TODO: use indices - // Create 'arm_sme.load_tile_slice' to load tile slice from - // memory into tile. + getMemrefIndices(tileLoadOp.getIndices(), + tileLoadOp.getMemRefType().getRank(), tileSliceIndex, + numTileSlices, indices, loc, rewriter); + rewriter.create( - loc, tileType, tileLoadOp.getBase(), tileSliceIndex, tileInit, + loc, tileType, tileLoadOp.getBase(), indices, tileInit, tileSliceIndex); rewriter.setInsertionPointAfter(forOp); @@ -143,6 +169,9 @@ auto vscale = rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, + // ..., SVL_Q). auto numTileSlices = rewriter.create(loc, minTileSlices, vscale); auto forOp = @@ -150,11 +179,14 @@ rewriter.setInsertionPointToStart(forOp.getBody()); + SmallVector indices; auto tileSliceIndex = forOp.getInductionVar(); - // TODO: use indices + getMemrefIndices(tileStoreOp.getIndices(), + tileStoreOp.getMemRefType().getRank(), tileSliceIndex, + numTileSlices, indices, loc, rewriter); rewriter.replaceOpWithNewOp( tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, - tileStoreOp.getBase(), tileSliceIndex); + tileStoreOp.getBase(), indices); return success(); } Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -111,33 +111,6 @@ return tile; } -/// Returns the following -/// * for rank 2 memrefs `tileSliceIndex`, since `getStridedElementPtr` does -/// the arithmetic. -/// * for rank 1 memrefs `tileSliceIndex * tileSliceNumElts`, adjusting the -/// index by the number of elements in a vector of SVL bits. -/// * otherwise throws an unreachable error. -Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex, - Value tileSliceNumElts, Location loc, - ConversionPatternRewriter &rewriter) { - assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); - - auto tileSliceIndexI64 = rewriter.create( - loc, rewriter.getI64Type(), tileSliceIndex); - - if (rank == 1) { - auto tileSliceNumEltsI64 = rewriter.create( - loc, rewriter.getI64Type(), tileSliceNumElts); - return rewriter.create(loc, tileSliceIndexI64, - tileSliceNumEltsI64); - } - - if (rank == 2) - return tileSliceIndexI64; - - llvm_unreachable("memref has unexpected rank!"); -} - /// Lower `arm_sme.load_tile_slice` to SME intrinsics. struct LoadTileSliceToArmSMELowering : public ConvertOpToLLVMPattern { @@ -159,25 +132,12 @@ loc, rewriter.getIntegerType(tileElementWidth), loadTileSliceOp.getTile()); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - // This describes both the number of ZA tile slices and the number of - // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, - // ..., SVL_Q). - auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice. - auto memRefType = loadTileSliceOp.getMemRefType(); + Value ptr = this->getStridedElementPtr( + loc, loadTileSliceOp.getMemRefType(), adaptor.getBase(), + adaptor.getIndices(), rewriter); + auto tileSlice = loadTileSliceOp.getTileSliceIndex(); - // TODO: The 'indices' argument for the 'base' memref is currently ignored, - // 'tileSliceIndex' should be added to 'indices[0]'. - Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, - numTileSlices, loc, rewriter); - Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), - {tileSliceIndex}, rewriter); // Cast tile slice to i32 for intrinsic. auto tileSliceI32 = rewriter.create( @@ -243,25 +203,12 @@ loc, rewriter.getIntegerType(tileElementWidth), storeTileSliceOp.getTile()); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - // This describes both the number of ZA tile slices and the number of - // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, - // ..., SVL_Q). - auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. - auto memRefType = storeTileSliceOp.getMemRefType(); + Value ptr = this->getStridedElementPtr( + loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(), + adaptor.getIndices(), rewriter); + auto tileSlice = storeTileSliceOp.getTileSliceIndex(); - // TODO: The 'indices' argument for the 'base' memref is currently ignored, - // 'tileSliceIndex' should be added to 'indices[0]'. - Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, - numTileSlices, loc, rewriter); - Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), - {tileSliceIndex}, rewriter); // Cast tile slice to i32 for intrinsic. auto tileSliceI32 = rewriter.create( Index: mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir =================================================================== --- mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -2,15 +2,16 @@ // CHECK-LABEL: func.func @arm_sme_tile_load( // CHECK-SAME: %[[SRC:.*]]: memref) { -// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 -// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> +// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_INDEX]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]x[4]xi32> +// CHECK-NEXT: %[[TILE_SLICE_OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]x[4]xi32> func.func @arm_sme_tile_load(%src : memref) { %c0 = arith.constant 0 : index %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[4]x[4]xi32> @@ -28,7 +29,8 @@ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_INDEX]]] : memref, vector<[4]x[4]xi32> +// CHECK: %[[TILE_SLICE_OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_OFFSET]], %[[C0]]] : memref, vector<[4]x[4]xi32> func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref) { %c0 = arith.constant 0 : index arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xi32> Index: mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir =================================================================== --- mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -8,17 +8,19 @@ // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32 // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 // CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () @@ -40,15 +42,17 @@ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () @@ -75,14 +79,10 @@ // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK-NEXT: %[[VSCALE_1:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX_1:.*]] = builtin.unrealized_conversion_cast %[[VSCALE_1]] : i64 to index -// CHECK-NEXT: %[[SVL_B_1:.*]] = arith.muli %[[VSCALE_IDX_1]], %[[MIN_SVL_B]] : index -// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 -// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B_1]] : index to i64 -// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64 +// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE]], %[[SVL_B]] : index +// CHECK-NEXT: %[[TILE_SLICE_IDX_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_IDX]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () @@ -218,17 +218,19 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { +// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 -// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () Index: mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir =================================================================== --- mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -5,16 +5,30 @@ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ // RUN: --entry-function=za0_d_f64 \ -// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_runner_utils%shlibext | \ // RUN: FileCheck %s --check-prefix=CHECK-ZA0_D -// Integration test demonstrating load/store to/from SME ZA tile. +// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ +// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ +// RUN: --entry-function=load_store_two_za_s_tiles \ +// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Integration tests demonstrating load/store to/from SME ZA tile. llvm.func @printF64(f64) +llvm.func @printI64(i64) llvm.func @printOpen() llvm.func @printClose() llvm.func @printComma() llvm.func @printNewline() +llvm.func @printCString(!llvm.ptr) func.func @za0_d_f64() -> i32 { %c0 = arith.constant 0 : index @@ -191,3 +205,172 @@ %c0_i32 = arith.constant 0 : i32 return %c0_i32 : i32 } + +func.func @printTileBegin() { + %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @printTileEnd() { + %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +// This test loads two 32-bit element ZA tiles from memory and stores them back +// to memory in reverse order to verify the memref indices. +func.func @load_store_two_za_s_tiles() -> i32 { + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c1_index = arith.constant 1 : index + %c2_index = arith.constant 2 : index + + %min_elts_s = arith.constant 4 : index + %vscale = vector.vscale + + // "svl" refers to the Streaming Vector Length and "svl_s" can mean either: + // * the number of 32-bit elements in a vector of SVL bits. + // * the number of tile slices (1d vectors) in a 32-bit element tile. + %svl_s = arith.muli %min_elts_s, %vscale : index + + // Allocate memory for two 32-bit element tiles. + %tilesize = arith.muli %svl_s, %svl_s : index + %size = arith.muli %tilesize, %c2_index : index + %mem1 = memref.alloca(%size) : memref + + // Fill memory that tile 1 will be loaded from with '1' and '2' for tile 2. + // + // For example, assuming an SVL of 128-bits and two 4x4xi32 tiles: + // + // tile 1 + // + // 1, 1, 1, 1 + // 1, 1, 1, 1 + // 1, 1, 1, 1 + // 1, 1, 1, 1 + // + // tile 2 + // + // 2, 2, 2, 2 + // 2, 2, 2, 2 + // 2, 2, 2, 2 + // 2, 2, 2, 2 + // + scf.for %i = %c0 to %size step %svl_s { + %isFirstTile = arith.cmpi ult, %i, %tilesize : index + %val = scf.if %isFirstTile -> i32 { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c2_i32 : i32 + } + %splat_val = vector.broadcast %val : i32 to vector<[4]xi32> + vector.store %splat_val, %mem1[%i] : memref, vector<[4]xi32> + } + + // Dump "mem1". The smallest SVL is 128-bits so each tile will be at least + // 4x4xi32. + // + // CHECK: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + scf.for %i = %c0 to %size step %svl_s { + %tileslice = vector.load %mem1[%i] : memref, vector<[4]xi32> + + llvm.call @printOpen() : () -> () + scf.for %i2 = %c0 to %svl_s step %c1_index { + %elem = vector.extractelement %tileslice[%i2 : index] : vector<[4]xi32> + %elem_i64 = llvm.zext %elem : i32 to i64 + llvm.call @printI64(%elem_i64) : (i64) -> () + %last_i = arith.subi %svl_s, %c1_index : index + %isNotLastIter = arith.cmpi ult, %i2, %last_i : index + scf.if %isNotLastIter { + llvm.call @printComma() : () -> () + } + } + llvm.call @printClose() : () -> () + llvm.call @printNewline() : () -> () + } + + // Load tile 1 from memory + %za0_s = vector.load %mem1[%c0] : memref, vector<[4]x[4]xi32> + + // Load tile 2 from memory + %za1_s = vector.load %mem1[%tilesize] : memref, vector<[4]x[4]xi32> + + // Allocate new memory to store tiles to + %mem2 = memref.alloca(%size) : memref + + // Zero new memory + scf.for %i = %c0 to %size step %c1_index { + memref.store %c0_i32, %mem2[%i] : memref + } + + // Stores tiles back to (new) memory in reverse order + + // Store tile 2 to memory + vector.store %za1_s, %mem2[%c0] : memref, vector<[4]x[4]xi32> + + // Store tile 1 to memory + vector.store %za0_s, %mem2[%tilesize] : memref, vector<[4]x[4]xi32> + + // Dump "mem2" and check the tiles were stored in reverse order. The smallest + // SVL is 128-bits so the tiles will be at least 4x4xi32. + // + // CHECK: TILE BEGIN + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK: TILE END + // CHECK-NEXT: TILE BEGIN + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %size step %svl_s { + %av = vector.load %mem2[%i] : memref, vector<[4]xi32> + + llvm.call @printOpen() : () -> () + scf.for %i2 = %c0 to %svl_s step %c1_index { + %elem = vector.extractelement %av[%i2 : index] : vector<[4]xi32> + %elem_i64 = llvm.zext %elem : i32 to i64 + llvm.call @printI64(%elem_i64) : (i64) -> () + %last_i = arith.subi %svl_s, %c1_index : index + %isNotLastIter = arith.cmpi ult, %i2, %last_i : index + scf.if %isNotLastIter { + llvm.call @printComma() : () -> () + } + } + llvm.call @printClose() : () -> () + llvm.call @printNewline() : () -> () + + %tileSizeMinusStep = arith.subi %tilesize, %svl_s : index + %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index + scf.if %isNextTile { + func.call @printTileEnd() : () -> () + func.call @printTileBegin() : () -> () + } + } + func.call @printTileEnd() : () -> () + + return %c0_i32 : i32 +} + +llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A") +llvm.mlir.global internal constant @str_tile_end("TILE END\0A")