diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -57,6 +57,7 @@ #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Conversion/TosaToSCF/TosaToSCF.h" #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" +#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1076,6 +1076,20 @@ ]; } +//===----------------------------------------------------------------------===// +// VectorToArmSME +//===----------------------------------------------------------------------===// + +def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> { + let summary = "Lower the operations from the vector dialect into the ArmSME " + "dialect"; + let description = [{ + Pass that converts vector dialect operations into equivalent ArmSME dialect + operations. + }]; + let dependentDialects = ["arm_sme::ArmSMEDialect"]; +} + //===----------------------------------------------------------------------===// // VectorToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToArmSME/VectorToArmSME.h b/mlir/include/mlir/Conversion/VectorToArmSME/VectorToArmSME.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToArmSME/VectorToArmSME.h @@ -0,0 +1,26 @@ +//===- VectorToArmSME.h - Convert vector to ArmSME dialect ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOARMSME_VECTORTOARMSME_H_ +#define MLIR_CONVERSION_VECTORTOARMSME_VECTORTOARMSME_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTVECTORTOARMSME +#include "mlir/Conversion/Passes.h.inc" + +/// Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM +/// intrinsics. +void populateVectorToArmSMEPatterns(RewritePatternSet &patterns, + MLIRContext &ctx); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOARMSME_VECTORTOARMSME_H_ diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -15,6 +15,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -33,7 +33,7 @@ https://developer.arm.com/documentation/ddi0616 https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions }]; - let dependentDialects = ["scf::SCFDialect"]; + let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"]; } //===----------------------------------------------------------------------===// @@ -196,6 +196,64 @@ let assemblyFormat = "attr-dict `:` type($tile_id)"; } +// +// Tile reset. +// + +def ZeroOp : ArmSME_Op<"zero", [Pure]> { + let summary = "Initialize the two-dimensional ZA array with 0s"; + let results = (outs nxnxv16i8:$res); + let description = [{ + Initialise ZA with 0. This operation is convenient wrapper for the SME + `zero` intrinsic and instruction. + + NOTE: At the moment it is assumed that the element type is `i8` and that + there's only one "virtual tile". + + Example: + + ```mlir + %0 = arm_sme.zero : vector<[16]x[16]xi8> + ``` + }]; + let extraClassDeclaration = [{ + VectorType getVectorType() { + return ::llvm::cast(getRes().getType()); + } + }]; + let assemblyFormat = "attr-dict `:` type($res)"; +} + +def TileStoreOp : ArmSME_Op<"tile_store"> { + let summary = "Tile store operation"; + let description = [{ + Store a 2D SME "virtual tile" to memory. + + NOTE: At the moment it is assumed that the element type is `i8` and that + there's only one "virtual tile". + + Example: + + ```mlir + arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref + ``` + }]; + let arguments = (ins nxnxv16i8:$valueToStore, + Arg:$base, + Variadic:$indices); + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + VectorType getVectorType() { + return ::llvm::cast(getValueToStore().getType()); + } + }]; + + let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict " + "`:` type($base) `,` type($valueToStore)"; +} + //===----------------------------------------------------------------------===// // ArmSME Intrinsic op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -47,6 +47,7 @@ add_subdirectory(TosaToLinalg) add_subdirectory(TosaToSCF) add_subdirectory(TosaToTensor) +add_subdirectory(VectorToArmSME) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToGPU) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_conversion_library(MLIRVectorToArmSME + VectorToArmSME.cpp + VectorToArmSMEPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToArmSME + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArmSMEDialect + MLIRLLVMCommonConversion + ) diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -0,0 +1,84 @@ +//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/IR/BuiltinTypes.h" +#include "llvm/Support/Casting.h" + +using namespace mlir; + +static constexpr unsigned kMinNumElts = 16; + +/// Returns true if 'val' is a splat of zero, false otherwise. +static bool isSplatZero(Type elemType, DenseElementsAttr val) { + if (llvm::isa(elemType)) + return val && val.isSplat() && val.getSplatValue().isZero(); + if (llvm::isa(elemType)) + return val && val.isSplat() && val.getSplatValue().isZero(); + return false; +} + +namespace { + +/// Look at `vector.transfer_write` operations and convert suitable candidates +/// to ArmSME operations, e.g.: +/// +/// %cst = arith.constant dense<0> : vector<[16]x[16]xi8> +/// vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref +/// +/// is converted to: +/// +/// %0 = arm_sme.zero : vector<[16]x[16]xi8> +/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref, +/// vector<[16]x[16]xi8> +/// +struct TransferWriteToArmSMELowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const final { + auto vType = writeOp.getVectorType(); + if (vType.getRank() != 2) + return failure(); + if (vType.getShape() != ArrayRef({kMinNumElts, kMinNumElts})) + return failure(); + if (vType.getElementType() != rewriter.getI8Type()) + return failure(); + if (vType.getScalableDims().size() != 2) + return failure(); + + auto loc = writeOp.getLoc(); + + if (!llvm::isa(writeOp.getSource().getType())) + return failure(); + + auto constant = writeOp.getVector().getDefiningOp(); + if (!constant) + return failure(); + + auto denseAttr = dyn_cast(constant.getValueAttr()); + if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) + return failure(); + + auto zero = rewriter.create(loc, vType); + + rewriter.replaceOpWithNewOp( + writeOp, zero, writeOp.getSource(), writeOp.getIndices()); + return success(); + } +}; + +} // namespace + +void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, + MLIRContext &ctx) { + patterns.add(&ctx); +} diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp @@ -0,0 +1,36 @@ +//===- VectorToArmSMEPass.cpp - Conversion from Vector to the ArmSME dialect =// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTVECTORTOARMSME +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::vector; + +namespace { +struct ConvertVectorToArmSMEPass + : public impl::ConvertVectorToArmSMEBase { + + void runOnOperation() override; +}; +} // namespace + +void ConvertVectorToArmSMEPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateVectorToArmSMEPatterns(patterns, getContext()); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -109,7 +109,6 @@ if (armSME) { configureArmSMELegalizeForExportTarget(target); populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); - arm_sme::populateVectorTransferLoweringPatterns(converter, patterns); } if (amx) { configureAMXLegalizeForExportTarget(target); diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt @@ -12,4 +12,5 @@ MLIRLLVMDialect MLIRSCFDialect MLIRSideEffectInterfaces + MLIRVectorDialect ) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,7 +1,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp LegalizeForLLVMExport.cpp - LowerVectorOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -8,15 +8,20 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; using namespace mlir::arm_sme; +static constexpr unsigned kMinNumElts = 16; +static constexpr unsigned kZeroZAMask = 255; + namespace { /// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func' /// ops to enable the ZA storage array. @@ -58,10 +63,104 @@ }; } // namespace -void mlir::populateArmSMELegalizeForLLVMExportPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} +/// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return +/// value. The latter is a nop, which should be folded away (e.g. during +/// canonicalisation). +/// +/// BEFORE: +/// ```mlir +/// %0 = arm_sme.zero : vector<[16]x[16]xi8> +/// ``` +/// +/// AFTER: +/// ```mlir +/// %1 = arm_sme.get_tile_id : i8 +/// %2 = arm_sme.cast_tile_to_vector %1 : i8 to vector<[16]x[16]xi8> +/// "arm_sme.intr.zero"(%c255_i32) : (i32) -> () +/// ``` +struct ZeroOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ZeroOp zero, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = zero.getLoc(); + + // Get Tile ID for the `zero` intrinsic. + // TODO: Map this to a valid `mask` for the `zero` intrinsic. + auto tileId = rewriter.create( + loc, zero.getVectorType().getElementType()); + + // Create 'arm_sme.intr.zero' intrinsic to zero ZA. + // FIXME: Replace the hard-coded mask with a valid value based + // on `tileId`. + auto mask = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask)); + rewriter.create(loc, mask); + + // Create `CastTileToVectorOp` to use it as the output + rewriter.replaceOpWithNewOp(zero, zero.getType(), + tileId); + + return success(); + } +}; + +/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row +/// using 'arm_sme.intr.str'. +/// +/// BEFORE: +/// ```mlir +/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref, +/// vector<[16]x[16]xi8 +/// ``` +/// +/// AFTER: +/// ```mlir +/// %vscale = "llvm.intr.vscale"() : () -> index +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %c16 = arith.constant 16 : index +/// %vec_size = arith.muli %c16, %vscale : index +/// scf.for %row_idx = %c0 to %vec_size step %c1 { +/// // (...) +/// "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> () +/// ``` +struct TileStoreOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TileStoreOp store, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = store.getLoc(); + + // Create loop that iterates from 0 to SVLB-1 inclusive (the number of + // vectors in ZA) and stores each ZA vector to memory. + auto step = rewriter.create(loc, 1); + auto minElems = rewriter.create(loc, kMinNumElts); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create(loc, minElems, vscale); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + rewriter.setInsertionPointToStart(forOp.getBody()); + + // Create 'arm_sme.intr.str' intrinsic to store ZA vector. + auto vnumI64 = rewriter.create( + loc, rewriter.getI64Type(), forOp.getInductionVar()); + auto offset = + rewriter.create(loc, rewriter.getI64Type(), 0); + Value ptr = + getStridedElementPtr(loc, store.getMemRefType(), adaptor.getBase(), + ValueRange{vnumI64, offset}, rewriter); + auto vnumI32 = rewriter.create( + loc, rewriter.getI32Type(), forOp.getInductionVar()); + rewriter.create(loc, vnumI32, ptr); + + rewriter.eraseOp(store); + return success(); + } +}; void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { @@ -95,3 +194,9 @@ return !funcOp->hasAttr("arm_za") || hasDisableZA; }); } + +void mlir::populateArmSMELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(converter); +} diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp +++ /dev/null @@ -1,111 +0,0 @@ -//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements rewrite patterns to lower vector dialect ops to ArmSME. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ArmSME/IR/ArmSME.h" -#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/PatternMatch.h" - -using namespace mlir; -using namespace mlir::arm_sme; - -static constexpr unsigned kMinNumElts = 16; -static constexpr unsigned kZeroZAMask = 255; - -/// Returns true if 'val' is a splat of zero, false otherwise. -static bool isSplatZero(Type elemType, DenseElementsAttr val) { - if (llvm::isa(elemType)) - return val && val.isSplat() && val.getSplatValue().isZero(); - if (llvm::isa(elemType)) - return val && val.isSplat() && val.getSplatValue().isZero(); - return false; -} - -namespace { -/// Lower 'vector.transfer_write' op to 'arm_sme.intr.zero' op. Currently only -/// supports 2d scalable vector type 'vector<[16x16]xi8>' that maps to the ZA0.B -/// SME virtual tile. This will be extended to support more element types. -struct TransferWriteToArmSMEZeroLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::TransferWriteOp write, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vType = write.getVectorType(); - if (vType.getRank() != 2) - return failure(); - if (vType.getShape() != ArrayRef({kMinNumElts, kMinNumElts})) - return failure(); - if (vType.getElementType() != rewriter.getI8Type()) - return failure(); - if (vType.getScalableDims().size() != 2) - return failure(); - - auto memRefType = llvm::dyn_cast(write.getSource().getType()); - if (!memRefType) - return failure(); - - auto constant = write.getVector().getDefiningOp(); - if (!constant) - return failure(); - - auto denseAttr = dyn_cast(constant.getValueAttr()); - if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) - return failure(); - - auto loc = write.getLoc(); - - // Create 'arm_sme.intr.zero' intrinsic to zero ZA. - auto tile = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask)); - rewriter.create(loc, tile); - - // Create loop that iterates from 0 to SVLB-1 inclusive (the number of - // vectors in ZA) and stores each ZA vector to memory. - auto step = rewriter.create(loc, 1); - auto minElems = rewriter.create(loc, kMinNumElts); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, minElems, vscale); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step); - rewriter.setInsertionPointToStart(forOp.getBody()); - - // Create 'arm_sme.intr.str' intrinsic to store ZA vector. - auto vnumI64 = rewriter.create( - loc, rewriter.getI64Type(), forOp.getInductionVar()); - auto offset = - rewriter.create(loc, rewriter.getI64Type(), 0); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getSource(), - ValueRange{vnumI64, offset}, rewriter); - auto vnumI32 = rewriter.create( - loc, rewriter.getI32Type(), forOp.getInductionVar()); - rewriter.create(loc, vnumI32, ptr); - - rewriter.eraseOp(write); - - return success(); - } -}; -} // namespace - -void mlir::arm_sme::populateVectorTransferLoweringPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter); -} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -183,3 +183,20 @@ %0 = arm_sme.get_tile_id : i128 return %0 : i128 } + +// ----- + +func.func @arm_sme_zero() -> () { + // CHECK: arm_sme.zero : vector<[16]x[16]xi8> + %0 = arm_sme.zero : vector<[16]x[16]xi8> + return +} + +// ----- + +func.func @arm_sme_store_tile(%tile : vector<[16]x[16]xi8>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s + +// CHECK-LABEL: @transfer_write_2d_zero_i8 +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32 +// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () +// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-DAG: %[[CAST_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64 +// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32 +// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> () +func.func @transfer_write_2d_zero_i8(%arg0 : memref) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<[16]x[16]xi8> + vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref + return +} + diff --git a/mlir/test/Dialect/ArmSME/vector-ops.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir rename from mlir/test/Dialect/ArmSME/vector-ops.mlir rename to mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -1,27 +1,13 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file | mlir-opt | FileCheck %s -// CHECK-LABEL: @transfer_write_2d_zero_i8 -// CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[C255:.*]] = arith.constant 255 : i32 -// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index -// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index -// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { -// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64 -// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 -// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32 -// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> () -func.func @transfer_write_2d_zero_i8(%arg0 : memref) { + +// CHECK-LABEL: func.func @transfer_write_2d_zero( +// CHECK-SAME: %[[ARG_0:.*]]: memref) { +func.func @transfer_write_2d_zero(%arg0 : memref) { +// CHECK: %[[C_0:.*]] = arith.constant 0 : index +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8> +// CHECK: arm_sme.tile_store %[[ZERO]], %[[ARG_0]][%[[C_0]], %[[C_0]]] : memref, vector<[16]x[16]xi8> +// CHECK: return %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]xi8> vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref @@ -48,7 +34,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__bad_shape // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[8]x[8]xi8> @@ -60,7 +46,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__bad_rank // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8> @@ -72,7 +58,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__non_memref_type // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor) -> tensor { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]xi8> @@ -84,7 +70,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__non_zero_value // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<1> : vector<[16]x[16]xi8> @@ -96,7 +82,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref, %arg1 : vector<[16]x[16]xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ +// RUN: mlir-opt %s -convert-vector-to-arm-sme -enable-arm-streaming="mode=locally enable-za" \ // RUN: -convert-vector-to-llvm="enable-arm-sme" -test-lower-to-llvm | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \