diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1092,6 +1092,10 @@ "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, + Option<"armSME", "enable-arm-sme", + "bool", /*default=*/"false", + "Enables the use of ArmSME dialect while lowering the vector " + "dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt --- a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -0,0 +1,50 @@ +//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the ArmSME dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef ArmSME +#define ArmSME + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// ArmSME dialect definition. +//===----------------------------------------------------------------------===// + +def ArmSME_Dialect : Dialect { + let name = "arm_sme"; + let cppNamespace = "::mlir::arm_sme"; + let summary = "Dialect to target the Armv9 Scalable Matrix Extension (SME)"; + let description = [{ + This dialect contains the definitions necessary to target specific Arm SME + operations. + + For more details on the architecture, see the Arm documentation: + https://developer.arm.com/documentation/ddi0616 + }]; + let usePropertiesForAttributes = 1; +} + +//===----------------------------------------------------------------------===// +// LLVMIR Intrinsics +//===----------------------------------------------------------------------===// + +class ArmSME_IntrOp traits = []> : + LLVM_IntrOpBase; + +/// Create a call to aarch64_sme_zero intrinsic. +def LLVM_aarch64_sme_zero + : ArmSME_IntrOp<"zero", 0>, Arguments<(ins I32:$imm)>; + +#endif // ArmSME diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEDialect.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEDialect.h @@ -0,0 +1,26 @@ +//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for ArmSME in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H +#define MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc" + +#endif // MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(ArmSME arm_sme ArmSME) +add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme) + +set(LLVM_TARGET_DEFINITIONS ArmSME.td) +mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRArmSMEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -0,0 +1,27 @@ +//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H +#define MLIR_DIALECT_ARMSME_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class RewritePatternSet; + +namespace arm_sme { +void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns); +} // namespace arm_sme + +/// Configure the target to support lowering ArmSME ops to ops that map to LLVM +/// intrinsics. +void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -118,6 +119,7 @@ quant::QuantizationDialect, spirv::SPIRVDialect, arm_sve::ArmSVEDialect, + arm_sme::ArmSMEDialect, vector::VectorDialect, NVVM::NVVMDialect, ROCDL::ROCDLDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -16,6 +16,7 @@ #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -36,6 +37,7 @@ registerArmNeonDialectTranslation(registry); registerAMXDialectTranslation(registry); registerArmSVEDialectTranslation(registry); + registerArmSMEDialectTranslation(registry); registerBuiltinDialectTranslation(registry); registerGPUDialectTranslation(registry); registerLLVMDialectTranslation(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//=======- ArmSMEToLLVMIRTranslation.h - ArmSME to LLVM IR --*- C++ -*-=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for ArmSME dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the ArmSME dialect and the translation from it to the LLVM IR in +/// the given registry; +void registerArmSMEDialectTranslation(DialectRegistry ®istry); + +/// Register the ArmSME dialect and the translation from it in the registry +/// associated with the given context. +void registerArmSMEDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -17,6 +17,8 @@ MLIRArmNeonDialect MLIRArmSVEDialect MLIRArmSVETransforms + MLIRArmSMEDialect + MLIRArmSMETransforms MLIRAMXDialect MLIRAMXTransforms MLIRLLVMCommonConversion diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,8 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -49,6 +51,8 @@ registry.insert(); if (armSVE) registry.insert(); + if (armSME) + registry.insert(); if (amx) registry.insert(); if (x86Vector) @@ -102,6 +106,10 @@ configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } + if (armSME) { + configureArmSMELegalizeForExportTarget(target); + arm_sme::populateVectorTransferLoweringPatterns(patterns); + } if (amx) { configureAMXLegalizeForExportTarget(target); populateAMXLegalizeForLLVMExportPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp @@ -0,0 +1,33 @@ +//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ArmSME dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +//===----------------------------------------------------------------------===// +// Tablegen Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" + +void ArmSMEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRArmSMEDialect + ArmSMEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME + + DEPENDS + MLIRArmSMEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + ) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp + LegalizeForLLVMExport.cpp + LowerVectorOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms @@ -8,6 +10,11 @@ MLIRArmSMETransformsIncGen LINK_LIBS PUBLIC + MLIRArmSMEDialect MLIRFuncDialect + MLIRLLVMDialect + MLIRVectorDialect + MLIRLLVMCommonConversion + MLIRIR MLIRPass ) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,19 @@ +//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +void mlir::configureArmSMELegalizeForExportTarget( + LLVMConversionTarget &target) { + target.addLegalOp(); +} diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp @@ -0,0 +1,55 @@ +//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrite patterns to lower vector dialect ops to ArmSME. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +static constexpr unsigned kZeroZAMask = 255; + +namespace { +/// Lower `vector.transfer_write` op to `arm_sme.intr.zero` op. Currently only +/// supports 2d scalable vector type `vector<[16x16]xi8>` that maps to the ZA0.B +/// SME tile. This will be extended to support more element types. +struct TransferWriteToArmSMEZeroLowering + : public OpRewritePattern { + TransferWriteToArmSMEZeroLowering(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + auto vType = write.getVectorType(); + if (vType.getRank() != 2) + return failure(); + if (vType.getShape() != ArrayRef({16, 16})) + return failure(); + if (vType.getElementType() != rewriter.getI8Type()) + return failure(); + if (vType.getNumScalableDims() != 2) + return failure(); + auto tile = rewriter.create( + write.getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(kZeroZAMask)); + rewriter.replaceOpWithNewOp(write, tile); + return success(); + } +}; +} // namespace + +void mlir::arm_sme::populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -44,4 +44,5 @@ MLIRVectorDialect MLIRVectorInterfaces MLIRVectorUtils + MLIRArmSMEDialect ) diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -47,6 +47,7 @@ LINK_LIBS PUBLIC MLIRArmNeonToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation + MLIRArmSMEToLLVMIRTranslation MLIRAMXToLLVMIRTranslation MLIRBuiltinToLLVMIRTranslation MLIRGPUToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp @@ -0,0 +1,56 @@ +//======- ArmSMEToLLVMIRTranslation.cpp - Translate ArmSME to LLVM IR -=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the ArmSME dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the ArmSME dialect to LLVM IR. +class ArmSMEDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/ArmSME/IR/ArmSMEConversions.inc" + + return failure(); + } +}; +} // namespace + +void mlir::registerArmSMEDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, arm_sme::ArmSMEDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void mlir::registerArmSMEDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerArmSMEDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRArmSMEToLLVMIRTranslation + ArmSMEToLLVMIRTranslation.cpp + + DEPENDS + MLIRArmSMEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRArmSMEDialect + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) +add_subdirectory(ArmSME) add_subdirectory(AMX) add_subdirectory(Builtin) add_subdirectory(GPU) diff --git a/mlir/test/Dialect/ArmSME/vector_ops.mlir b/mlir/test/Dialect/ArmSME/vector_ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/vector_ops.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s + +// CHECK-LABEL: @transfer_write_2d_zero_i8 +// CHECK: %[[C255:.*]] = arith.constant 255 : i32 +// CHECK: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () +func.func @transfer_write_2d_zero_i8() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' +// lowering only occurs for vector types of correct rank, shape, element size +// and number of scalable dims. + +// CHECK-LABEL: @transfer_write_2d_zero__bad_type +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_type() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16]xi4> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi4>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_shape +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_shape() { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %vscale = vector.vscale + %dim = arith.muli %c8, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[8x8]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[8x8]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_rank +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_rank() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16x16]xi8> + vector.transfer_write %cst, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16x16x16]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_num_scalable_dims +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_num_scalable_dims() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim) : memref<16x?xi8> + %cst = arith.constant dense<0> : vector<16x[16]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<16x[16]xi8>, memref<16x?xi8> + memref.dealloc %0 : memref<16x?xi8> + return +} diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-translate --mlir-to-llvmir -split-input-file %s | FileCheck %s + +// CHECK-LABEL: define void @arm_sme_zero +// CHECK: call void @llvm.aarch64.sme.zero(i32 255) +llvm.func @arm_sme_zero() { + %mask = llvm.mlir.constant(255 : i32) : i32 + "arm_sme.intr.zero"(%mask) : (i32) -> () + llvm.return +} + +// -----