Index: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td =================================================================== --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -224,6 +224,35 @@ let assemblyFormat = "attr-dict `:` type($res)"; } +def TileLoadOp : ArmSME_Op<"tile_load"> { + let summary = "Tile load operation"; + let description = [{ + Load a 2D SME "virtual tile" from memory. + + Example: + + ```mlir + %tile = arm_sme.tile_load %base[%c0, %c0] : memref, vector<[16]x[16]xi8> + ``` + }]; + let arguments = (ins + Arg:$base, + Variadic:$indices); + let results = (outs SMETile:$result); + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + VectorType getVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; + + let assemblyFormat = + "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; +} + def TileStoreOp : ArmSME_Op<"tile_store"> { let summary = "Tile store operation"; let description = [{ @@ -238,7 +267,7 @@ arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref ``` }]; - let arguments = (ins nxnxv16i8:$valueToStore, + let arguments = (ins SMETile:$valueToStore, Arg:$base, Variadic:$indices); let extraClassDeclaration = [{ @@ -301,10 +330,14 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">; // Loads +// +// Like all of the other SME intrinsics, loads have no result, use MemWrite +// side-effect to indicate writing to "ZA" resource to prevent the intrinsics +// from being removed by dead-code elimination. class ArmSME_IntrLoadOp : ArmSME_IntrOp, Arguments<(ins Arg, - Arg, + Arg, Arg, Arg)>; Index: mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -0,0 +1,38 @@ +//===- Utils.h - General ArmSME transformation utilities --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for various transformation utilities for +// the ArmSME dialect. These are not passes by themselves but are used +// either by passes, optimization sequences, or in turn by other transformation +// utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ +#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" + +namespace mlir { +namespace arm_sme { + +/// Return minimum number of elements for the given element `type` in +/// a vector of SVL bits. +unsigned getSMETileSliceMinNumElts(Type type); + +/// Returns true if `type` is a valid element type for an SME tile or false +/// otherwise. +bool isValidTileElementType(Type type); + +/// Returns true if vector type `vType` is SME tile-like or false otherwise. +bool isSMETileLikeVectorType(VectorType vType); + +} // namespace arm_sme +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ Index: mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt +++ mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt @@ -10,5 +10,6 @@ LINK_LIBS PUBLIC MLIRArmSMEDialect + MLIRArmSMEUtils MLIRLLVMCommonConversion ) Index: mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp =================================================================== --- mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Casting.h" @@ -76,9 +77,55 @@ } }; +/// Overloaded utility that replaces a vector.load or vector.store with their +/// respective SME counterparts. +static void replaceLoadOrStoreOp(vector::LoadOp load, + PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + load, load.getVectorType(), load.getBase(), load.getIndices()); +} + +static void replaceLoadOrStoreOp(vector::StoreOp store, + PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + store, store.getValueToStore(), store.getBase(), store.getIndices()); +} + +/// Conversion pattern for vector.load and vector.store. +template +struct VectorLoadStoreToArmSMELowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoadOrStoreOp loadOrStoreOp, + PatternRewriter &rewriter) const override { + if (!arm_sme::isSMETileLikeVectorType(loadOrStoreOp.getVectorType())) + return failure(); + + replaceLoadOrStoreOp(loadOrStoreOp, rewriter); + + return success(); + } +}; + +/// Conversion pattern for vector.load. +struct VectorLoadToArmSMELowering + : public VectorLoadStoreToArmSMELowering { + using VectorLoadStoreToArmSMELowering< + vector::LoadOp>::VectorLoadStoreToArmSMELowering; +}; + +/// Conversion pattern for vector.store. +struct VectorStoreToArmSMELowering + : public VectorLoadStoreToArmSMELowering { + using VectorLoadStoreToArmSMELowering< + vector::StoreOp>::VectorLoadStoreToArmSMELowering; +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add(&ctx); + patterns.add(&ctx); } Index: mlir/lib/Dialect/ArmSME/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(Utils) Index: mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRArmSMEDialect + MLIRArmSMEUtils MLIRFuncDialect MLIRLLVMCommonConversion MLIRVectorDialect Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -19,7 +20,6 @@ using namespace mlir; using namespace mlir::arm_sme; -static constexpr unsigned kMinNumElts = 16; static constexpr unsigned kZeroZAMask = 255; namespace { @@ -50,7 +50,6 @@ return success(); } }; -} // namespace /// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return /// value. The latter is a nop, which should be folded away (e.g. during @@ -95,38 +94,95 @@ } }; -/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row -/// using 'arm_sme.intr.str'. +/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or +/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar +/// integer, to an i32 that can be passed as the `tile` parameter to the SME +/// intrinsics. Or returns `tile` if already i32. +Value castTileIDToI32(Value tile, Location loc, + ConversionPatternRewriter &rewriter) { + assert((isa( + tile.getDefiningOp())) && + "expected ArmSME GetTileID or CastVectorToTile op!"); + unsigned width = tile.getType().getIntOrFloatBitWidth(); + if (width < 32) + return rewriter.create(loc, rewriter.getI32Type(), tile); + if (width > 32) + return rewriter.create(loc, rewriter.getI32Type(), tile); + return tile; +} + +/// Returns `offset` if memref is rank 2, otherwise adjusts `offset` by the +/// number of elements in a vector of SVL bits. +Value getOffset(MemRefType memRefType, Value offset, Value vscale, + Value minElems, Location loc, + ConversionPatternRewriter &rewriter) { + unsigned rank = memRefType.getRank(); + assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); + + auto offsetI64 = + rewriter.create(loc, rewriter.getI64Type(), offset); + if (rank == 1) { + auto vscaleI64 = rewriter.create( + loc, rewriter.getI64Type(), vscale); + auto minElemsI64 = rewriter.create( + loc, rewriter.getI64Type(), minElems); + auto numElems = rewriter.create(loc, minElemsI64, vscaleI64); + return rewriter.create(loc, offsetI64, numElems); + } + + if (rank == 2) + return offsetI64; + + llvm_unreachable("memref has unexpected rank!"); +} + +/// Conversion pattern for `arm_sme.tile_load` to SME intrinsics. +/// +/// Lower `arm_sme.tile_load` to a loop over the rows of ZA and load each row +/// using `arm_sme.intr.ld1*.horiz`. /// /// BEFORE: /// ```mlir -/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref, -/// vector<[16]x[16]xi8 +/// %tile = arm_sme.tile_load %base[%c0, %c0] : +/// memref, vector<[4]x[4]xi32> /// ``` /// /// AFTER: /// ```mlir -/// %vscale = "llvm.intr.vscale"() : () -> index -/// %c0 = arith.constant 0 : index -/// %c1 = arith.constant 1 : index -/// %c16 = arith.constant 16 : index -/// %vec_size = arith.muli %c16, %vscale : index -/// scf.for %row_idx = %c0 to %vec_size step %c1 { -/// // (...) -/// "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> () +/// %tile_id = arm_sme.get_tile_id : i32 +/// %vscale = vector.vscale +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %min_svl_s = arith.constant 4 : index +/// %num_vectors = arith.muli %min_svl_s, %vscale : index +/// scf.for %tile_slice = %c0 to %num_vectors step %c1 { +/// // (...) +/// "arm_sme.intr.ld1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) : +/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +/// } +/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> /// ``` -struct TileStoreOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct TileLoadToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileStoreOp store, OpAdaptor adaptor, + matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, + arm_sme::TileLoadOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = store.getLoc(); + auto vType = tileLoadOp.getVectorType(); + auto loc = tileLoadOp.getLoc(); + + // Create 'arm_sme.get_tile_id' op. + unsigned width = vType.getElementType().getIntOrFloatBitWidth(); + auto tile = rewriter.create( + loc, rewriter.getIntegerType(width)); - // Create loop that iterates from 0 to SVLB-1 inclusive (the number of - // vectors in ZA) and stores each ZA vector to memory. + // Create loop that iterates over the number of ZA vectors (0 to SVL-1 + // inclusive) and stores each ZA vector to memory. auto step = rewriter.create(loc, 1); - auto minElems = rewriter.create(loc, kMinNumElts); + auto minElems = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(vType.getElementType())); auto vscale = rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); @@ -134,29 +190,165 @@ auto forOp = rewriter.create(loc, lowerBound, upperBound, step); rewriter.setInsertionPointToStart(forOp.getBody()); - // Create 'arm_sme.intr.str' intrinsic to store ZA vector. - auto vnumI64 = rewriter.create( - loc, rewriter.getI64Type(), forOp.getInductionVar()); - auto offset = - rewriter.create(loc, rewriter.getI64Type(), 0); - Value ptr = - getStridedElementPtr(loc, store.getMemRefType(), adaptor.getBase(), - ValueRange{vnumI64, offset}, rewriter); - auto vnumI32 = rewriter.create( - loc, rewriter.getI32Type(), forOp.getInductionVar()); - rewriter.create(loc, vnumI32, ptr); - - rewriter.eraseOp(store); + // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA vector. + auto memRefType = tileLoadOp.getMemRefType(); + auto tileSlice = forOp.getInductionVar(); + Value offset = + getOffset(memRefType, tileSlice, vscale, minElems, loc, rewriter); + Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), + {offset}, rewriter); + auto tileSliceI32 = rewriter.create( + loc, rewriter.getI32Type(), tileSlice); + auto one = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto predTy = VectorType::get(vType.getShape()[0], rewriter.getI1Type(), + /*scalableDims=*/{true}); + auto mask = rewriter.create(loc, predTy, one); + + auto tileI32 = castTileIDToI32(tile, loc, rewriter); + switch (width) { + default: + llvm_unreachable("unexpected element type!"); + case 8: + rewriter.create(loc, mask, ptr, tileI32, + tileSliceI32); + break; + case 16: + rewriter.create(loc, mask, ptr, tileI32, + tileSliceI32); + break; + case 32: + rewriter.create(loc, mask, ptr, tileI32, + tileSliceI32); + break; + case 64: + rewriter.create(loc, mask, ptr, tileI32, + tileSliceI32); + break; + } + + rewriter.setInsertionPointAfter(forOp); + + // The load intrinsics have no result, replace 'arm_sme.tile_load' with + // 'arm_sme.cast_tile_to_vector' to preserve dataflow. + rewriter.replaceOpWithNewOp(tileLoadOp, vType, + tile); + return success(); } }; +/// Conversion pattern for `arm_sme.tile_store` to SME intrinsics. +/// +/// Lower `arm_sme.tile_store` to a loop over the rows of ZA and store each row +/// using `arm_sme.intr.st1*.horiz`. +/// +/// BEFORE: +/// ```mlir +/// arm_sme.tile_store %value, %base[%c0, %c0] : memref, +/// vector<[4]x[4]xi32 +/// ``` +/// +/// AFTER: +/// ```mlir +/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xi32> to i32 +/// %vscale = vector.vscale +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %min_svl_s = arith.constant 4 : index +/// %num_vectors = arith.muli %min_svl_s, %vscale : index +/// scf.for %tile_slice = %c0 to %num_vectors step %c1 { +/// // (...) +/// "arm_sme.intr.st1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) : +/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +/// } +/// ``` +struct TileStoreToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arm_sme::TileStoreOp tileStoreOp, + arm_sme::TileStoreOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto vType = tileStoreOp.getVectorType(); + auto elemType = vType.getElementType(); + auto loc = tileStoreOp.getLoc(); + + unsigned width = elemType.getIntOrFloatBitWidth(); + // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector + // being stored. + auto tile = rewriter.create( + loc, rewriter.getIntegerType(width), tileStoreOp.getValueToStore()); + + // Create loop that iterates over the number of ZA vectors (0 to SVL-1 + // inclusive) and stores each ZA vector to memory. + auto step = rewriter.create(loc, 1); + auto minElems = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(elemType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create(loc, minElems, vscale); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + rewriter.setInsertionPointToStart(forOp.getBody()); + + // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA vector. + auto memRefType = tileStoreOp.getMemRefType(); + auto tileSlice = forOp.getInductionVar(); + Value offset = + getOffset(memRefType, tileSlice, vscale, minElems, loc, rewriter); + Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), + {offset}, rewriter); + auto tileSliceI32 = rewriter.create( + loc, rewriter.getI32Type(), tileSlice); + auto one = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto predTy = VectorType::get(vType.getShape()[0], rewriter.getI1Type(), + /*scalableDims=*/{true}); + auto mask = rewriter.create(loc, predTy, one); + + auto tileI32 = castTileIDToI32(tile, loc, rewriter); + switch (width) { + default: + llvm_unreachable("unexpected element type!"); + case 8: + rewriter.replaceOpWithNewOp( + tileStoreOp, mask, ptr, tileI32, tileSliceI32); + break; + case 16: + rewriter.replaceOpWithNewOp( + tileStoreOp, mask, ptr, tileI32, tileSliceI32); + break; + case 32: + rewriter.replaceOpWithNewOp( + tileStoreOp, mask, ptr, tileI32, tileSliceI32); + break; + case 64: + rewriter.replaceOpWithNewOp( + tileStoreOp, mask, ptr, tileI32, tileSliceI32); + break; + } + + return success(); + } +}; + +} // namespace + void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); + target.addLegalOp< + scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector, + arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero, + arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz, + arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz, + arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_st1b_horiz, + arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, + arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_za_enable, + arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); // Mark 'func.func' ops as legal if either: @@ -187,5 +379,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(converter); + patterns.add(converter); } Index: mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRArmSMEUtils + Utils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils + + LINK_LIBS PUBLIC + MLIRArmSMEDialect + MLIRDialect + MLIRIR + ) Index: mlir/lib/Dialect/ArmSME/Utils/Utils.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Utils/Utils.cpp @@ -0,0 +1,48 @@ +//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for the ArmSME dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/Utils/Utils.h" + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +static constexpr unsigned MinStreamingVectorLengthInBits = 128; + +unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) { + assert(isValidTileElementType(type) && "invalid tile type!"); + return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth(); +} + +bool mlir::arm_sme::isValidTileElementType(Type type) { + // TODO: add support for i128. + return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) || + type.isInteger(64) || type.isF16() || type.isBF16() || type.isF32() || + type.isF64(); +} + +bool mlir::arm_sme::isSMETileLikeVectorType(VectorType vType) { + if ((vType.getRank() != 2) && vType.allDimsScalable()) + return false; + + // TODO: add support for i128. + auto elemType = vType.getElementType(); + if (!isValidTileElementType(elemType)) + return false; + + unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType); + if (vType.getShape() != ArrayRef({minNumElts, minNumElts})) + return false; + + return true; +} Index: mlir/test/Dialect/ArmSME/roundtrip.mlir =================================================================== --- mlir/test/Dialect/ArmSME/roundtrip.mlir +++ mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -194,6 +194,87 @@ // ----- +func.func @arm_sme_tile_load_i8(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +func.func @arm_sme_tile_load_i16(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[8]x[8]xi16> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[8]x[8]xi16> + return +} + +// ----- + +func.func @arm_sme_tile_load_i32(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[4]x[4]xi32> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +func.func @arm_sme_tile_load_i64(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[2]x[2]xi64> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[2]x[2]xi64> + return +} + +// ----- + +func.func @arm_sme_tile_load_i128(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[1]x[1]xi128> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[1]x[1]xi128> + return +} + +// ----- + +func.func @arm_sme_tile_load_f16(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[8]x[8]xf16> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[8]x[8]xf16> + return +} + +// ----- + +func.func @arm_sme_tile_load_bf16(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[8]x[8]xbf16> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[8]x[8]xbf16> + return +} + +// ----- + +func.func @arm_sme_tile_load_f32(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[4]x[4]xf32> + return +} + +// ----- + +func.func @arm_sme_tile_load_f64(%memref : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[2]x[2]xf64> + return +} + +// ----- + func.func @arm_sme_store_tile(%tile : vector<[16]x[16]xi8>, %dest : memref) -> () { // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index Index: mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir =================================================================== --- mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -1,28 +1,34 @@ // RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s +// ----- + // CHECK-LABEL: @transfer_write_2d_zero_i8 // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32 // CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 -// CHECK-DAG: %[[CAST_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index // CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index // CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index -// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { -// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64 -// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 -// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32 -// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> () +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return func.func @transfer_write_2d_zero_i8(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]xi8> @@ -30,3 +36,561 @@ return } +// ----- + +// CHECK-LABEL: @vector_load_i8( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> +func.func @vector_load_i8(%arg0 : memref) -> vector<[16]x[16]xi8> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[16]x[16]xi8> + return %tile : vector<[16]x[16]xi8> +} + +// ----- + +// CHECK-LABEL: @vector_load_i8_rank_1_memref( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[VSCALE_IDX_I64:.*]] = arith.index_castui %[[VSCALE_IDX]] : index to i64 +// CHECK-NEXT: %[[MIN_ZA_VECTORS_I64:.*]] = arith.index_castui %[[MIN_ZA_VECTORS]] : index to i64 +// CHECK-NEXT: %[[NUM_ELTS:.*]] = arith.muli %[[MIN_ZA_VECTORS_I64]], %[[VSCALE_IDX_I64]] : i64 +// CHECK-NEXT: %[[STRIDE_SVLB_BYTES:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[NUM_ELTS]] : i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[STRIDE_SVLB_BYTES]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> +func.func @vector_load_i8_rank_1_memref(%arg0 : memref) -> vector<[16]x[16]xi8> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0] : memref, vector<[16]x[16]xi8> + return %tile : vector<[16]x[16]xi8> +} + + +// ----- + +// CHECK-LABEL: @vector_load_i16( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i16 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1h.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[8]x[8]xi16> +func.func @vector_load_i16(%arg0 : memref) -> vector<[8]x[8]xi16> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xi16> + return %tile : vector<[8]x[8]xi16> +} + +// ----- + +// CHECK-LABEL: @vector_load_i32( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 4 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK-NEXT: "arm_sme.intr.ld1w.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID]], %[[TILE_SLICE_I32]]) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[4]x[4]xi32> +func.func @vector_load_i32(%arg0 : memref) -> vector<[4]x[4]xi32> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[4]x[4]xi32> + return %tile : vector<[4]x[4]xi32> +} + +// ----- + +// CHECK-LABEL: @vector_load_i64( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[2]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1d.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[2]x[2]xi64> +func.func @vector_load_i64(%arg0 : memref) -> vector<[2]x[2]xi64> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[2]x[2]xi64> + return %tile : vector<[2]x[2]xi64> +} + +// ----- + +// CHECK-LABEL: @vector_load_f16( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1h.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[8]x[8]xf16> +func.func @vector_load_f16(%arg0 : memref) -> vector<[8]x[8]xf16> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xf16> + return %tile : vector<[8]x[8]xf16> +} + +// ----- + +// CHECK-LABEL: @vector_load_bf16( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, bf16 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1h.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[8]x[8]xbf16> +func.func @vector_load_bf16(%arg0 : memref) -> vector<[8]x[8]xbf16> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xbf16> + return %tile : vector<[8]x[8]xbf16> +} + +// ----- + +// CHECK-LABEL: @vector_load_f32( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 4 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK-NEXT: "arm_sme.intr.ld1w.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID]], %[[TILE_SLICE_I32]]) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[4]x[4]xf32> +func.func @vector_load_f32(%arg0 : memref) -> vector<[4]x[4]xf32> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[4]x[4]xf32> + return %tile : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: @vector_load_f64( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[2]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1d.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[2]x[2]xf64> +func.func @vector_load_f64(%arg0 : memref) -> vector<[2]x[2]xf64> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[2]x[2]xf64> + return %tile : vector<[2]x[2]xf64> +} + +// ----- + +// CHECK-LABEL: @vector_store_i8( +// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_i16( +// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xi16>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i16 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 +// CHECK-NEXT: "arm_sme.intr.st1h.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[8]x[8]xi16> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_i32( +// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 4 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK-NEXT: "arm_sme.intr.st1w.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[CAST_VECTOR_TO_TILE]], %[[TILE_SLICE_I32]]) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_i64( +// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xi64>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[2]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32 +// CHECK-NEXT: "arm_sme.intr.st1d.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[2]x[2]xi64> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_f16( +// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xf16>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, f16 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 +// CHECK-NEXT: "arm_sme.intr.st1h.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[8]x[8]xf16> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_bf16( +// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xbf16>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 8 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, bf16 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[8]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 +// CHECK-NEXT: "arm_sme.intr.st1h.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[8]x[8]xbf16> + return +} +// ----- + +// CHECK-LABEL: @vector_store_f32( +// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 4 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[4]xi1> +// CHECK-NEXT: "arm_sme.intr.st1w.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[CAST_VECTOR_TO_TILE]], %[[TILE_SLICE_I32]]) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[4]x[4]xf32> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_f64( +// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xf64>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[2]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32 +// CHECK-NEXT: "arm_sme.intr.st1d.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[2]x[2]xf64> + return +} Index: mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir =================================================================== --- /dev/null +++ mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -0,0 +1,100 @@ +// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ +// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \ +// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ +// RUN: --entry-function=za0_d_f64 \ +// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --check-prefix=CHECK-ZA0_D + +// Integration test demonstrating load/store to/from SME ZA tile. + +func.func @za0_d_f64() -> i32 { + %c0 = arith.constant 0 : index + %c0_f64 = arith.constant 0.0 : f64 + %c1_f64 = arith.constant 1.0 : f64 + %c1_index = arith.constant 1 : index + + %min_elts_d = arith.constant 2 : index + %vscale = vector.vscale + + // "svl" refers to the Streaming Vector Length and "svl_d" the number of + // 64-bit elements in a vector of SVL bits. + %svl_d = arith.muli %min_elts_d, %vscale : index + + // Allocate "mem1" and fill each "row" with row number. + // + // For example, assuming an SVL of 256-bits: + // + // 0.1, 0.1, 0.1, 0.1 + // 1.1, 1.1, 1.1, 1.1 + // 2.1, 2.1, 2.1, 2.1 + // 3.1, 3.1, 3.1, 3.1 + // + %tilesize = arith.muli %svl_d, %svl_d : index + %mem1 = memref.alloca(%tilesize) : memref + %init_0 = arith.constant 0.1 : f64 + scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) { + %splat_val = vector.broadcast %val : f64 to vector<[2]xf64> + vector.store %splat_val, %mem1[%i] : memref, vector<[2]xf64> + %val_next = arith.addf %val, %c1_f64 : f64 + scf.yield %val_next : f64 + } + + // Load ZA0.D from "mem1" + %za0_d = vector.load %mem1[%c0] : memref, vector<[2]x[2]xf64> + + // Allocate "mem2" to store ZA0.D to + %mem2 = memref.alloca(%tilesize) : memref + + // Zero "mem2" + scf.for %i = %c0 to %tilesize step %c1_index { + memref.store %c0_f64, %mem2[%i] : memref + } + + // Verify "mem2" is zeroed by doing an add reduction with initial value of + // zero + %init_0_f64 = arith.constant 0.0 : f64 + %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) { + %row = vector.load %mem2[%vnum] : memref, vector<[2]xf64> + + %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) { + %t = vector.extractelement %row[%offset : index] : vector<[2]xf64> + %inner_add_reduce_next = arith.addf %inner_iter, %t : f64 + scf.yield %inner_add_reduce_next : f64 + } + + %add_reduce_next = arith.addf %iter, %inner_add_reduce : f64 + scf.yield %add_reduce_next : f64 + } + + // CHECK-ZA0_D: 0 + vector.print %add_reduce : f64 + + // Store ZA0.D to "mem2" + vector.store %za0_d, %mem2[%c0] : memref, vector<[2]x[2]xf64> + + // Verify "mem1" == "mem2" + %init_1 = arith.constant 1 : i64 + %mul_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) { + %row_1 = vector.load %mem1[%vnum] : memref, vector<[2]xf64> + %row_2 = vector.load %mem2[%vnum] : memref, vector<[2]xf64> + %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64> + + %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { + %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t_i64 = arith.extui %t : i1 to i64 + %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 + scf.yield %inner_mul_reduce_next : i64 + } + + %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 + scf.yield %mul_reduce_next : i64 + } + + // CHECK-ZA0_D-NEXT: 1 + vector.print %mul_reduce : i64 + + %c0_i32 = arith.constant 0 : i32 + return %c0_i32 : i32 +}