diff --git a/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h b/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h @@ -0,0 +1,29 @@ +//===- ArmSMEToSCF.h - Convert ArmSME to SCF dialect ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_ +#define MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_ + +#include + +namespace mlir { +class Pass; +class RewritePatternSet; + +#define GEN_PASS_DECL_CONVERTARMSMETOSCF +#include "mlir/Conversion/Passes.h.inc" + +/// Collect a set of patterns to convert from the ArmSME dialect to SCF. +void populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns); + +/// Create a pass to convert a subset of ArmSME ops to SCF. +std::unique_ptr createConvertArmSMEToSCFPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMSMETOSCF_ARMSMETOSCF_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -14,6 +14,7 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" +#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1108,6 +1108,22 @@ let dependentDialects = ["arm_sme::ArmSMEDialect"]; } +//===----------------------------------------------------------------------===// +// ArmSMEToSCF +//===----------------------------------------------------------------------===// + +def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> { + let summary = "Lower the operations from the ArmSME dialect into the SCF " + "dialect"; + let constructor = "mlir::createConvertArmSMEToSCFPass()"; + let dependentDialects = [ + "scf::SCFDialect", + "arith::ArithDialect", + "vector::VectorDialect", + "arm_sme::ArmSMEDialect" + ]; +} + //===----------------------------------------------------------------------===// // VectorToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" //===----------------------------------------------------------------------===// // ArmSME dialect definition @@ -307,6 +308,102 @@ "`:` type($base) `,` type($valueToStore)"; } +def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ + AllTypesMatch<["tile", "result"]> +]> { + let summary = "Tile slice load and update operation"; + let description = [{ + Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile + slice is defined by the dimension of the 2D scalable vector type pointed by + the index. A tile slice index describes where in the input tile the tile + slice is loaded to. The updated tile is returned as the result. + + The slice of memory read is defined by a base and indices and must be + contiguous. The memref must be either rank 1 or rank 2, have dynamic + dimensions since the operation is scalable, and the element type must be a + scalar that matches the element type of the result. + + Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index. + ```mlir + %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + ``` + + Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index. + ```mlir + %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xf32> + ``` + + Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index. + ```mlir + %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref, vector<[1]x[1]xi128> + ``` + }]; + let arguments = (ins + Arg:$base, + SMETile:$tile, Variadic:$indices, Index:$tile_slice_index); + let results = (outs SMETile:$result); + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + VectorType getVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; + + let assemblyFormat = [{ + $base `[` $indices `]` `,` $tile `,` $tile_slice_index + attr-dict `:` type($base) `,` type($result) + }]; +} + +def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { + let summary = "Tile slice store operation"; + let description = [{ + Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile + slice is defined by the dimension of the 2D scalable vector type pointed by + the index. A tile slice index describes where in the input tile the tile + slice is stored from. + + The slice of memory written is defined by a base and indices and must be + contiguous. The memref must be either rank 1 or rank 2, have dynamic + dimensions since the operation is scalable, and the element type must be a + scalar that matches the element type of the input tile. + + Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory. + ```mlir + arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref + ``` + + Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory. + ```mlir + arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref + ``` + + Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory. + ```mlir + arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref + ``` + }]; + let arguments = (ins SMETile:$tile, Index:$tile_slice_index, + Arg:$base, + Variadic:$indices); + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + VectorType getVectorType() { + return ::llvm::cast(getTile().getType()); + } + }]; + + let assemblyFormat = [{ + $tile `,` $tile_slice_index `,` $base `[` $indices `]` + attr-dict `:` type($base) `,` type($tile) + }]; +} + //===----------------------------------------------------------------------===// // ArmSME Intrinsic op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -0,0 +1,187 @@ +//===- ArmSMEToSCF.cpp - Convert ArmSME to SCF dialect ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering of ArmSME operations to SCF. +// +//===----------------------------------------------------------------------===// +#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Utils/Utils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTARMSMETOSCF +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +/// Lower `arm_sme.tile_load` to a loop over the tile slices and load each slice +/// using `arm_sme.load_tile_slice`. +/// +/// BEFORE: +/// ```mlir +/// %tile = arm_sme.tile_load %src[%c0, %c0] : +/// memref, vector<[4]x[4]xi32> +/// ``` +/// +/// AFTER: +/// ```mlir +/// %tile_id = arm_sme.get_tile_id : i32 +/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> +/// %vscale = vector.vscale +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %min_svl_s = arith.constant 4 : index +/// %svl_s = arith.muli %min_svl_s, %vscale : index +/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { +/// %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx], +/// %tile, %tile_slice_idx : memref, vector<[4]x[4]xi32> +/// } +/// ``` +struct TileLoadOpConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, + PatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard g(rewriter); + auto loc = tileLoadOp.getLoc(); + auto tileType = tileLoadOp.getVectorType(); + auto tileElementType = tileType.getElementType(); + unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + + // Create 'arm_sme.get_tile' op. + auto tileId = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth)); + + // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to + // use as input tile to 'arm_sme.load_tile_slice' ops. + auto tile = + rewriter.create(loc, tileType, tileId); + + // Create a loop that loads each ZA tile slice from memory. + auto step = rewriter.create(loc, 1); + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + auto forOp = + rewriter.create(loc, lowerBound, numTileSlices, step); + + rewriter.setInsertionPointToStart(forOp.getBody()); + + auto tileSliceIndex = forOp.getInductionVar(); + // TODO: use indices + // Create 'arm_sme.load_tile_slice' to load tile slice from + // memory into tile. + rewriter.create( + loc, tileType, tileLoadOp.getBase(), tile, tileSliceIndex, + tileSliceIndex); + + rewriter.setInsertionPointAfter(forOp); + + // Replace 'arm_sme.tile_load' with the tile. + rewriter.replaceOp(tileLoadOp, tile); + + return success(); + } +}; + +/// Lower `arm_sme.tile_store` to a loop over the tile slices and store each +/// slice using `arm_sme.store_tile_slice`. +/// +/// BEFORE: +/// ```mlir +/// arm_sme.tile_store %tile, %dest[%c0, %c0] +/// : memref, vector<[4]x[4]xi32 +/// ``` +/// +/// AFTER: +/// ```mlir +/// %vscale = vector.vscale +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %min_svl_s = arith.constant 4 : index +/// %svl_s = arith.muli %min_svl_s, %vscale : index +/// scf.for %tile_slice_idx = %c0 to %svl_s step %c1 { +/// arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx] +/// : memref, vector<[4]x[4]xi32> +/// } +/// ``` +struct TileStoreOpConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp, + PatternRewriter &rewriter) const override { + OpBuilder::InsertionGuard g(rewriter); + auto loc = tileStoreOp.getLoc(); + auto tileType = tileStoreOp.getVectorType(); + auto tileElementType = tileType.getElementType(); + + // Create a loop that stores each ZA tile slice from memory. + auto step = rewriter.create(loc, 1); + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + auto forOp = + rewriter.create(loc, lowerBound, numTileSlices, step); + + rewriter.setInsertionPointToStart(forOp.getBody()); + + auto tileSliceIndex = forOp.getInductionVar(); + // TODO: use indices + rewriter.replaceOpWithNewOp( + tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, + tileStoreOp.getBase(), tileSliceIndex); + + return success(); + } +}; + +} // namespace + +void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +namespace { + +struct ConvertArmSMEToSCFPass + : public impl::ConvertArmSMEToSCFBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + populateArmSMEToSCFConversionPatterns(patterns); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createConvertArmSMEToSCFPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_conversion_library(MLIRArmSMEToSCF + ArmSMEToSCF.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToSCF + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArmSMEDialect + MLIRArmSMEUtils + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(ArithToLLVM) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) +add_subdirectory(ArmSMEToSCF) add_subdirectory(AsyncToLLVM) add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexToLibm) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -138,68 +138,40 @@ llvm_unreachable("memref has unexpected rank!"); } -/// Conversion pattern for `arm_sme.tile_load` to SME intrinsics. -/// -/// Lower `arm_sme.tile_load` to a loop over the rows of ZA and load each row -/// using `arm_sme.intr.ld1*.horiz`. -/// -/// BEFORE: -/// ```mlir -/// %tile = arm_sme.tile_load %base[%c0, %c0] : -/// memref, vector<[4]x[4]xi32> -/// ``` -/// -/// AFTER: -/// ```mlir -/// %tile_id = arm_sme.get_tile_id : i32 -/// %vscale = vector.vscale -/// %c0 = arith.constant 0 : index -/// %c1 = arith.constant 1 : index -/// %min_svl_s = arith.constant 4 : index -/// %svl_s = arith.muli %min_svl_s, %vscale : index -/// scf.for %tile_slice = %c0 to %svl_s step %c1 { -/// // (...) -/// "arm_sme.intr.ld1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) : -/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -/// } -/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> -/// ``` -struct TileLoadToArmSMELowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +/// Lower `arm_sme.load_tile_slice` to SME intrinsics. +struct LoadTileSliceToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, - arm_sme::TileLoadOp::Adaptor adaptor, + matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp, + arm_sme::LoadTileSliceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = tileLoadOp.getLoc(); - auto tileType = tileLoadOp.getVectorType(); + auto loc = loadTileSliceOp.getLoc(); + auto tileType = loadTileSliceOp.getVectorType(); auto tileElementType = tileType.getElementType(); unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); - // Create 'arm_sme.get_tile_id' op. - auto tile = rewriter.create( - loc, rewriter.getIntegerType(tileElementWidth)); + // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being + // loaded to. + auto tile = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth), + loadTileSliceOp.getTile()); - // Create a loop that loads each ZA tile slice from memory. - auto step = rewriter.create(loc, 1); auto minTileSlices = rewriter.create( loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); // This describes both the number of ZA tile slices and the number of // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, // ..., SVL_Q). auto numTileSlices = rewriter.create(loc, minTileSlices, vscale); - auto forOp = - rewriter.create(loc, lowerBound, numTileSlices, step); - rewriter.setInsertionPointToStart(forOp.getBody()); // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice. - auto memRefType = tileLoadOp.getMemRefType(); - auto tileSlice = forOp.getInductionVar(); + auto memRefType = loadTileSliceOp.getMemRefType(); + auto tileSlice = loadTileSliceOp.getTileSliceIndex(); // TODO: The 'indices' argument for the 'base' memref is currently ignored, // 'tileSliceIndex' should be added to 'indices[0]'. Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, @@ -241,52 +213,27 @@ break; } - rewriter.setInsertionPointAfter(forOp); - // The load intrinsics have no result, replace 'arm_sme.tile_load' with // 'arm_sme.cast_tile_to_vector' to preserve dataflow. - rewriter.replaceOpWithNewOp(tileLoadOp, tileType, - tile); + rewriter.replaceOpWithNewOp(loadTileSliceOp, + tileType, tile); return success(); } }; -/// Conversion pattern for `arm_sme.tile_store` to SME intrinsics. -/// -/// Lower `arm_sme.tile_store` to a loop over the rows of ZA and store each row -/// using `arm_sme.intr.st1*.horiz`. -/// -/// BEFORE: -/// ```mlir -/// arm_sme.tile_store %value, %base[%c0, %c0] : memref, -/// vector<[4]x[4]xi32 -/// ``` -/// -/// AFTER: -/// ```mlir -/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xi32> to i32 -/// %vscale = vector.vscale -/// %c0 = arith.constant 0 : index -/// %c1 = arith.constant 1 : index -/// %min_svl_s = arith.constant 4 : index -/// %svl_s = arith.muli %min_svl_s, %vscale : index -/// scf.for %tile_slice = %c0 to %svl_s step %c1 { -/// // (...) -/// "arm_sme.intr.st1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) : -/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () -/// } -/// ``` -struct TileStoreToArmSMELowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +/// Lower for `arm_sme.store_tile_slice` to SME intrinsics. +struct StoreTileSliceToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(arm_sme::TileStoreOp tileStoreOp, - arm_sme::TileStoreOp::Adaptor adaptor, + matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp, + arm_sme::StoreTileSliceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = tileStoreOp.getLoc(); - auto tileType = tileStoreOp.getVectorType(); + auto loc = storeTileSliceOp.getLoc(); + auto tileType = storeTileSliceOp.getVectorType(); auto tileElementType = tileType.getElementType(); unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); @@ -294,27 +241,21 @@ // being stored. auto tile = rewriter.create( loc, rewriter.getIntegerType(tileElementWidth), - tileStoreOp.getValueToStore()); + storeTileSliceOp.getTile()); - // Create a loop that stores each ZA tile slice to memory. - auto step = rewriter.create(loc, 1); auto minTileSlices = rewriter.create( loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); // This describes both the number of ZA tile slices and the number of // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, // ..., SVL_Q). auto numTileSlices = rewriter.create(loc, minTileSlices, vscale); - auto forOp = - rewriter.create(loc, lowerBound, numTileSlices, step); - rewriter.setInsertionPointToStart(forOp.getBody()); // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. - auto memRefType = tileStoreOp.getMemRefType(); - auto tileSlice = forOp.getInductionVar(); + auto memRefType = storeTileSliceOp.getMemRefType(); + auto tileSlice = storeTileSliceOp.getTileSliceIndex(); // TODO: The 'indices' argument for the 'base' memref is currently ignored, // 'tileSliceIndex' should be added to 'indices[0]'. Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, @@ -340,19 +281,19 @@ llvm_unreachable("unexpected element type!"); case 8: rewriter.replaceOpWithNewOp( - tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); break; case 16: rewriter.replaceOpWithNewOp( - tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); break; case 32: rewriter.replaceOpWithNewOp( - tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); break; case 64: rewriter.replaceOpWithNewOp( - tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32); break; } @@ -403,6 +344,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(converter); + patterns.add(converter); } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @arm_sme_tile_load( +// CHECK-SAME: %[[SRC:.*]]: memref) { +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK-NEXT: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK-NEXT: arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[TILE_SLICE_INDEX]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref, vector<[4]x[4]xi32> +func.func @arm_sme_tile_load(%src : memref) { + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +// CHECK-LABEL: func.func @arm_sme_tile_store( +// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index +// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] { +// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[TILE_SLICE_INDEX]]] : memref, vector<[4]x[4]xi32> +func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref) { + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xi32> + return +} diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s + +// This test verifies the temporary casts that are emitted when lowering to +// intrinsics to preserve data flow are correct. Canonicalization will remove +// these. + +// CHECK-LABEL: @arm_sme_zero +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK: arm_sme.intr.zero +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK: scf.for +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8 +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 +// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +func.func @arm_sme_zero(%dest : memref) { + %c0 = arith.constant 0 : index + %tile = arm_sme.zero : vector<[16]x[16]xi8> + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +// CHECK-LABEL: @arm_sme_tile_load +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK: scf.for +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8 +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 +// CHECK: "arm_sme.intr.ld1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK: } +// CHECK: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> +func.func @arm_sme_tile_load(%dest : memref) -> vector<[16]x[16]xi8> { + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %dest[%c0, %c0] : memref, vector<[16]x[16]xi8> + return %tile : vector<[16]x[16]xi8> +} + +// ----- + +// CHECK-LABEL: @arm_sme_tile_store( +// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, +// CHECK: scf.for +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 +// CHECK: "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref) { + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -351,3 +351,165 @@ arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[2]x[2]xf64> return } + +// ----- + +func.func @arm_sme_load_tile_slice_i8(%src : memref, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_i16(%src : memref, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[8]x[8]xi16> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xi16> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_i32(%src : memref, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[4]x[4]xi32> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_i64(%src : memref, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[2]x[2]xi64> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xi64> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_i128(%src : memref, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[1]x[1]xi128> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[1]x[1]xi128> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_f16(%src : memref, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[8]x[8]xf16> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xf16> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_bf16(%src : memref, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[8]x[8]xbf16> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[8]x[8]xbf16> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_f32(%src : memref, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[4]x[4]xf32> + return +} + +// ----- + +func.func @arm_sme_load_tile_slice_f64(%src : memref, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) { + // CHECK: arm_sme.load_tile_slice {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref, vector<[2]x[2]xf64> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[8]x[8]xi16> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xi16> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[4]x[4]xi32> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[2]x[2]xi64> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xi64> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[1]x[1]xi128> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[1]x[1]xi128> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[8]x[8]xf16> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xf16> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[8]x[8]xbf16> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[8]x[8]xbf16> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[4]x[4]xf32> + return +} + +// ----- + +func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref) -> () { + // CHECK: arm_sme.store_tile_slice {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref, vector<[2]x[2]xf64> + return +} diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s // CHECK-LABEL: @transfer_write_2d_zero_i8( // CHECK-SAME: %[[ARG0:.*]]: memref) @@ -8,13 +8,13 @@ // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32 // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> -// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () +// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 @@ -35,6 +35,7 @@ // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index @@ -43,7 +44,7 @@ // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 @@ -52,7 +53,6 @@ // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () // CHECK-NEXT: } -// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> // CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> func.func @vector_load_i8(%arg0 : memref) -> vector<[16]x[16]xi8> { %c0 = arith.constant 0 : index @@ -66,6 +66,7 @@ // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index @@ -74,8 +75,11 @@ // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { +// CHECK-NEXT: %[[VSCALE_1:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX_1:.*]] = builtin.unrealized_conversion_cast %[[VSCALE_1]] : i64 to index +// CHECK-NEXT: %[[SVL_B_1:.*]] = arith.muli %[[VSCALE_IDX_1]], %[[MIN_SVL_B]] : index // CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 -// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B]] : index to i64 +// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B_1]] : index to i64 // CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 @@ -83,7 +87,6 @@ // CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () // CHECK-NEXT: } -// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> // CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref) -> vector<[16]x[16]xi8> { %c0 = arith.constant 0 : index @@ -97,11 +100,11 @@ // CHECK-LABEL: @vector_load_i16( // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16> // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index // CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 // CHECK: arm_sme.intr.ld1h.horiz -// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16> func.func @vector_load_i16(%arg0 : memref) -> vector<[8]x[8]xi16> { %c0 = arith.constant 0 : index %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xi16> @@ -113,12 +116,12 @@ // CHECK-LABEL: @vector_load_i32( // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index // CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index // CHECK-NOT: arith.extui %[[TILE_ID]] // CHECK-NOT: arith.trunci %[[TILE_ID]] // CHECK: arm_sme.intr.ld1w.horiz -// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> func.func @vector_load_i32(%arg0 : memref) -> vector<[4]x[4]xi32> { %c0 = arith.constant 0 : index %tile = vector.load %arg0[%c0, %c0] : memref, vector<[4]x[4]xi32> @@ -130,11 +133,11 @@ // CHECK-LABEL: @vector_load_i64( // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64> // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index // CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index // CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32 // CHECK: arm_sme.intr.ld1d.horiz -// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64> func.func @vector_load_i64(%arg0 : memref) -> vector<[2]x[2]xi64> { %c0 = arith.constant 0 : index %tile = vector.load %arg0[%c0, %c0] : memref, vector<[2]x[2]xi64> @@ -146,11 +149,11 @@ // CHECK-LABEL: @vector_load_f16( // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16> // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index // CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 // CHECK: arm_sme.intr.ld1h.horiz -// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16> func.func @vector_load_f16(%arg0 : memref) -> vector<[8]x[8]xf16> { %c0 = arith.constant 0 : index %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xf16> @@ -162,11 +165,11 @@ // CHECK-LABEL: @vector_load_bf16( // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16> // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index // CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 // CHECK: arm_sme.intr.ld1h.horiz -// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16> func.func @vector_load_bf16(%arg0 : memref) -> vector<[8]x[8]xbf16> { %c0 = arith.constant 0 : index %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xbf16> @@ -178,12 +181,12 @@ // CHECK-LABEL: @vector_load_f32( // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32> // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index // CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index // CHECK-NOT: arith.extui %[[TILE_ID]] // CHECK-NOT: arith.trunci %[[TILE_ID]] // CHECK: arm_sme.intr.ld1w.horiz -// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32> func.func @vector_load_f32(%arg0 : memref) -> vector<[4]x[4]xf32> { %c0 = arith.constant 0 : index %tile = vector.load %arg0[%c0, %c0] : memref, vector<[4]x[4]xf32> @@ -195,11 +198,11 @@ // CHECK-LABEL: @vector_load_f64( // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64 +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64> // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index // CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index // CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32 // CHECK: arm_sme.intr.ld1d.horiz -// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64> func.func @vector_load_f64(%arg0 : memref) -> vector<[2]x[2]xf64> { %c0 = arith.constant 0 : index %tile = vector.load %arg0[%c0, %c0] : memref, vector<[2]x[2]xf64> @@ -212,7 +215,6 @@ // CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, // CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index @@ -221,7 +223,8 @@ // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { -// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 +// CHECK: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 @@ -242,9 +245,9 @@ // CHECK-LABEL: @vector_store_i16( // CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xi16>, // CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16 -// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index // CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16 // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 // CHECK: arm_sme.intr.st1h.horiz func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref) { @@ -258,9 +261,9 @@ // CHECK-LABEL: @vector_store_i32( // CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>, // CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32 -// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index +// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index // CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32 // CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]] // CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]] // CHECK: arm_sme.intr.st1w.horiz @@ -275,9 +278,9 @@ // CHECK-LABEL: @vector_store_i64( // CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xi64>, // CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64 -// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index +// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index // CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64 // CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32 // CHECK: arm_sme.intr.st1d.horiz func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref) { @@ -291,9 +294,9 @@ // CHECK-LABEL: @vector_store_f16( // CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xf16>, // CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16 -// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index // CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16 // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 // CHECK: arm_sme.intr.st1h.horiz func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref) { @@ -307,9 +310,9 @@ // CHECK-LABEL: @vector_store_bf16( // CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xbf16>, // CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16 -// CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index // CHECK: %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16 // CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 // CHECK: arm_sme.intr.st1h.horiz func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref) { @@ -322,9 +325,9 @@ // CHECK-LABEL: @vector_store_f32( // CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>, // CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32 -// CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index +// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index // CHECK: %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32 // CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]] // CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]] // CHECK: arm_sme.intr.st1w.horiz @@ -339,9 +342,9 @@ // CHECK-LABEL: @vector_store_f64( // CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xf64>, // CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64 -// CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index +// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index // CHECK: %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64 // CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32 // CHECK: arm_sme.intr.st1d.horiz func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref) { diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ -// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \ +// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ // RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ -// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \ +// RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ +// RUN: -convert-vector-to-llvm="enable-arm-sme" \ // RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \