diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -26,6 +26,28 @@ using namespace mlir; namespace { +/// Adjusts `indices` as follows for a given tile slice and returns them in +/// `outIndices`: +/// rank 1: (indices[0] + (tileSliceIndex * tileSliceNumElts)) +/// rank 2: (indices[0] + tileSliceIndex, indices[1]) +void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex, + Value tileSliceNumElts, + SmallVectorImpl &outIndices, Location loc, + PatternRewriter &rewriter) { + assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); + + auto tileSliceOffset = tileSliceIndex; + if (rank == 1) + tileSliceOffset = + rewriter.create(loc, tileSliceOffset, tileSliceNumElts); + + auto baseIndexPlusTileSliceOffset = + rewriter.create(loc, indices[0], tileSliceOffset); + outIndices.push_back(baseIndexPlusTileSliceOffset); + + if (rank == 2) + outIndices.push_back(indices[1]); +} /// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice /// using `arm_sme.load_tile_slice`. @@ -77,6 +99,9 @@ auto vscale = rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, + // ..., SVL_Q). auto numTileSlices = rewriter.create(loc, minTileSlices, vscale); auto forOp = @@ -84,13 +109,16 @@ rewriter.setInsertionPointToStart(forOp.getBody()); + // Create 'arm_sme.load_tile_slice' to load tile slice from memory into + // tile. + SmallVector memrefIndices; auto tileSliceIndex = forOp.getInductionVar(); - // TODO: use indices - // Create 'arm_sme.load_tile_slice' to load tile slice from - // memory into tile. - rewriter.create( - loc, tileType, tileLoadOp.getBase(), tile, tileSliceIndex, - tileSliceIndex); + getMemrefIndices(tileLoadOp.getIndices(), + tileLoadOp.getMemRefType().getRank(), tileSliceIndex, + numTileSlices, memrefIndices, loc, rewriter); + rewriter.create(loc, tileType, + tileLoadOp.getBase(), tile, + memrefIndices, tileSliceIndex); rewriter.setInsertionPointAfter(forOp); @@ -139,6 +167,9 @@ auto vscale = rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, + // ..., SVL_Q). auto numTileSlices = rewriter.create(loc, minTileSlices, vscale); auto forOp = @@ -146,11 +177,14 @@ rewriter.setInsertionPointToStart(forOp.getBody()); + SmallVector memrefIndices; auto tileSliceIndex = forOp.getInductionVar(); - // TODO: use indices + getMemrefIndices(tileStoreOp.getIndices(), + tileStoreOp.getMemRefType().getRank(), tileSliceIndex, + numTileSlices, memrefIndices, loc, rewriter); rewriter.replaceOpWithNewOp( tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, - tileStoreOp.getBase(), tileSliceIndex); + tileStoreOp.getBase(), memrefIndices); return success(); } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -111,33 +111,6 @@ return tile; } -/// Returns the following -/// * for rank 2 memrefs `tileSliceIndex`, since `getStridedElementPtr` does -/// the arithmetic. -/// * for rank 1 memrefs `tileSliceIndex * tileSliceNumElts`, adjusting the -/// index by the number of elements in a vector of SVL bits. -/// * otherwise throws an unreachable error. -Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex, - Value tileSliceNumElts, Location loc, - ConversionPatternRewriter &rewriter) { - assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); - - auto tileSliceIndexI64 = rewriter.create( - loc, rewriter.getI64Type(), tileSliceIndex); - - if (rank == 1) { - auto tileSliceNumEltsI64 = rewriter.create( - loc, rewriter.getI64Type(), tileSliceNumElts); - return rewriter.create(loc, tileSliceIndexI64, - tileSliceNumEltsI64); - } - - if (rank == 2) - return tileSliceIndexI64; - - llvm_unreachable("memref has unexpected rank!"); -} - /// Lower `arm_sme.load_tile_slice` to SME intrinsics. struct LoadTileSliceToArmSMELowering : public ConvertOpToLLVMPattern { @@ -159,25 +132,11 @@ loc, rewriter.getIntegerType(tileElementWidth), loadTileSliceOp.getTile()); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - // This describes both the number of ZA tile slices and the number of - // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, - // ..., SVL_Q). - auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); + Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(), + adaptor.getBase(), + adaptor.getIndices(), rewriter); - // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice. - auto memRefType = loadTileSliceOp.getMemRefType(); auto tileSlice = loadTileSliceOp.getTileSliceIndex(); - // TODO: The 'indices' argument for the 'base' memref is currently ignored, - // 'tileSliceIndex' should be added to 'indices[0]'. - Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, - numTileSlices, loc, rewriter); - Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), - {tileSliceIndex}, rewriter); // Cast tile slice to i32 for intrinsic. auto tileSliceI32 = rewriter.create( @@ -192,6 +151,7 @@ auto allActiveMask = rewriter.create(loc, predTy, one); auto tileI32 = castTileIDToI32(tile, loc, rewriter); + // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice. switch (tileElementWidth) { default: llvm_unreachable("unexpected element type!"); @@ -243,25 +203,12 @@ loc, rewriter.getIntegerType(tileElementWidth), storeTileSliceOp.getTile()); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - // This describes both the number of ZA tile slices and the number of - // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, - // ..., SVL_Q). - auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. - auto memRefType = storeTileSliceOp.getMemRefType(); + Value ptr = this->getStridedElementPtr( + loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(), + adaptor.getIndices(), rewriter); + auto tileSlice = storeTileSliceOp.getTileSliceIndex(); - // TODO: The 'indices' argument for the 'base' memref is currently ignored, - // 'tileSliceIndex' should be added to 'indices[0]'. - Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, - numTileSlices, loc, rewriter); - Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), - {tileSliceIndex}, rewriter); // Cast tile slice to i32 for intrinsic. auto tileSliceI32 = rewriter.create( diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -2,15 +2,16 @@ // CHECK-LABEL: func.func @arm_sme_tile_load( // CHECK-SAME: %[[SRC:.*]]: memref) { -// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 -// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> +// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_INDEX]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]x[4]xi32> +// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]x[4]xi32> func.func @arm_sme_tile_load(%src : memref) { %c0 = arith.constant 0 : index %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[4]x[4]xi32> @@ -28,7 +29,8 @@ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { -// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_INDEX]]] : memref, vector<[4]x[4]xi32> +// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index +// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref, vector<[4]x[4]xi32> func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref) { %c0 = arith.constant 0 : index arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xi32> diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -8,17 +8,19 @@ // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32 // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 // CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () @@ -31,32 +33,41 @@ // ----- -// CHECK-LABEL: @vector_load_i8( -// CHECK-SAME: %[[ARG0:.*]]: memref) +// Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first +// memref index. This verifies the offset is preserved when materializing the +// loop of tile slice loads. + +// CHECK-LABEL: @vector_load_i8_with_offset( +// CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 // CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C123:.*]] = arith.constant 123 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0:.*]] = arith.addi %[[TILE_SLICE]], %[[C123]] : index +// CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_PLUS_OFF0]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_PLUS_OFF0_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () // CHECK-NEXT: } // CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> -func.func @vector_load_i8(%arg0 : memref) -> vector<[16]x[16]xi8> { +func.func @vector_load_i8_with_offset(%arg0 : memref) -> vector<[16]x[16]xi8> { %c0 = arith.constant 0 : index - %tile = vector.load %arg0[%c0, %c0] : memref, vector<[16]x[16]xi8> + %c123 = arith.constant 123 : index + %tile = vector.load %arg0[%c123, %c0] : memref, vector<[16]x[16]xi8> return %tile : vector<[16]x[16]xi8> } @@ -75,14 +86,10 @@ // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK-NEXT: %[[VSCALE_1:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX_1:.*]] = builtin.unrealized_conversion_cast %[[VSCALE_1]] : i64 to index -// CHECK-NEXT: %[[SVL_B_1:.*]] = arith.muli %[[VSCALE_IDX_1]], %[[MIN_SVL_B]] : index -// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 -// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B_1]] : index to i64 -// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64 +// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE]], %[[SVL_B]] : index +// CHECK-NEXT: %[[TILE_SLICE_IDX_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_IDX]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () @@ -218,17 +225,19 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { +// CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 -// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 // CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -1,21 +1,31 @@ -// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ -// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ -// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ -// RUN: mlir-translate -mlir-to-llvmir | \ -// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ -// RUN: --entry-function=za0_d_f64 \ -// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s --check-prefix=CHECK-ZA0_D - -// Integration test demonstrating load/store to/from SME ZA tile. +// DEFINE: %{entry_point} = za0_d_f64 +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -enable-arm-streaming="mode=locally enable-za" \ +// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm +// DEFINE: %{run} = %mcr_aarch64_cmd \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme \ +// DEFINE: -e %{entry_point} -entry-point-result=i32 \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils + +// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-ZA0_D + +// REDEFINE: %{entry_point} = load_store_two_za_s_tiles +// RUN: %{compile} | %{run} | FileCheck %s + +// Integration tests demonstrating load/store to/from SME ZA tile. llvm.func @printF64(f64) +llvm.func @printI64(i64) llvm.func @printOpen() llvm.func @printClose() llvm.func @printComma() llvm.func @printNewline() +llvm.func @printCString(!llvm.ptr) +// This test verifies a 64-bit element ZA with FP64 data is correctly +// loaded/stored to/from memory. func.func @za0_d_f64() -> i32 { %c0 = arith.constant 0 : index %c0_f64 = arith.constant 0.0 : f64 @@ -191,3 +201,174 @@ %c0_i32 = arith.constant 0 : i32 return %c0_i32 : i32 } + +func.func @printTileBegin() { + %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +func.func @printTileEnd() { + %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr> + %1 = llvm.mlir.constant(0 : index) : i64 + %2 = llvm.getelementptr %0[%1, %1] + : (!llvm.ptr>, i64, i64) -> !llvm.ptr + llvm.call @printCString(%2) : (!llvm.ptr) -> () + return +} + +// This test loads two 32-bit element ZA tiles from memory and stores them back +// to memory in reverse order. This verifies the memref indices for the vector +// load and store are correctly preserved since the second tile is offset from +// the first tile. +func.func @load_store_two_za_s_tiles() -> i32 { + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c1_index = arith.constant 1 : index + %c2_index = arith.constant 2 : index + + %min_elts_s = arith.constant 4 : index + %vscale = vector.vscale + + // "svl" refers to the Streaming Vector Length and "svl_s" can mean either: + // * the number of 32-bit elements in a vector of SVL bits. + // * the number of tile slices (1d vectors) in a 32-bit element tile. + %svl_s = arith.muli %min_elts_s, %vscale : index + + // Allocate memory for two 32-bit element tiles. + %size_of_tile = arith.muli %svl_s, %svl_s : index + %size_of_two_tiles = arith.muli %size_of_tile, %c2_index : index + %mem1 = memref.alloca(%size_of_two_tiles) : memref + + // Fill memory that tile 1 will be loaded from with '1' and '2' for tile 2. + // + // For example, assuming an SVL of 128-bits and two 4x4xi32 tiles: + // + // tile 1 + // + // 1, 1, 1, 1 + // 1, 1, 1, 1 + // 1, 1, 1, 1 + // 1, 1, 1, 1 + // + // tile 2 + // + // 2, 2, 2, 2 + // 2, 2, 2, 2 + // 2, 2, 2, 2 + // 2, 2, 2, 2 + // + scf.for %i = %c0 to %size_of_two_tiles step %svl_s { + %isFirstTile = arith.cmpi ult, %i, %size_of_tile : index + %val = scf.if %isFirstTile -> i32 { + scf.yield %c1_i32 : i32 + } else { + scf.yield %c2_i32 : i32 + } + %splat_val = vector.broadcast %val : i32 to vector<[4]xi32> + vector.store %splat_val, %mem1[%i] : memref, vector<[4]xi32> + } + + // Dump "mem1". The smallest SVL is 128-bits so each tile will be at least + // 4x4xi32. + // + // CHECK: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + scf.for %i = %c0 to %size_of_two_tiles step %svl_s { + %tileslice = vector.load %mem1[%i] : memref, vector<[4]xi32> + + llvm.call @printOpen() : () -> () + scf.for %i2 = %c0 to %svl_s step %c1_index { + %elem = vector.extractelement %tileslice[%i2 : index] : vector<[4]xi32> + %elem_i64 = llvm.zext %elem : i32 to i64 + llvm.call @printI64(%elem_i64) : (i64) -> () + %last_i = arith.subi %svl_s, %c1_index : index + %isNotLastIter = arith.cmpi ult, %i2, %last_i : index + scf.if %isNotLastIter { + llvm.call @printComma() : () -> () + } + } + llvm.call @printClose() : () -> () + llvm.call @printNewline() : () -> () + } + + // Load tile 1 from memory + %za0_s = vector.load %mem1[%c0] : memref, vector<[4]x[4]xi32> + + // Load tile 2 from memory + %za1_s = vector.load %mem1[%size_of_tile] : memref, vector<[4]x[4]xi32> + + // Allocate new memory to store tiles to + %mem2 = memref.alloca(%size_of_two_tiles) : memref + + // Zero new memory + scf.for %i = %c0 to %size_of_two_tiles step %c1_index { + memref.store %c0_i32, %mem2[%i] : memref + } + + // Stores tiles back to (new) memory in reverse order + + // Store tile 2 to memory + vector.store %za1_s, %mem2[%c0] : memref, vector<[4]x[4]xi32> + + // Store tile 1 to memory + vector.store %za0_s, %mem2[%size_of_tile] : memref, vector<[4]x[4]xi32> + + // Dump "mem2" and check the tiles were stored in reverse order. The smallest + // SVL is 128-bits so the tiles will be at least 4x4xi32. + // + // CHECK: TILE BEGIN + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK-NEXT: ( 2, 2, 2, 2 + // CHECK: TILE END + // CHECK-NEXT: TILE BEGIN + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK-NEXT: ( 1, 1, 1, 1 + // CHECK: TILE END + func.call @printTileBegin() : () -> () + scf.for %i = %c0 to %size_of_two_tiles step %svl_s { + %av = vector.load %mem2[%i] : memref, vector<[4]xi32> + + llvm.call @printOpen() : () -> () + scf.for %i2 = %c0 to %svl_s step %c1_index { + %elem = vector.extractelement %av[%i2 : index] : vector<[4]xi32> + %elem_i64 = llvm.zext %elem : i32 to i64 + llvm.call @printI64(%elem_i64) : (i64) -> () + %last_i = arith.subi %svl_s, %c1_index : index + %isNotLastIter = arith.cmpi ult, %i2, %last_i : index + scf.if %isNotLastIter { + llvm.call @printComma() : () -> () + } + } + llvm.call @printClose() : () -> () + llvm.call @printNewline() : () -> () + + %tileSizeMinusStep = arith.subi %size_of_tile, %svl_s : index + %isNextTile = arith.cmpi eq, %i, %tileSizeMinusStep : index + scf.if %isNextTile { + func.call @printTileEnd() : () -> () + func.call @printTileBegin() : () -> () + } + } + func.call @printTileEnd() : () -> () + + return %c0_i32 : i32 +} + +llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A") +llvm.mlir.global internal constant @str_tile_end("TILE END\0A")