diff --git a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt --- a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -0,0 +1,27 @@ +//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for ArmSME in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_IR_ARMSME_H +#define MLIR_DIALECT_ARMSME_IR_ARMSME_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc" + +#endif // MLIR_DIALECT_ARMSME_IR_ARMSME_H diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -0,0 +1,122 @@ +//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ArmSME dialect and contains intrinsic ops to lower to +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#ifndef ARMSME_OPS +#define ARMSME_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// ArmSME dialect definition +//===----------------------------------------------------------------------===// + +def ArmSME_Dialect : Dialect { + let name = "arm_sme"; + let cppNamespace = "::mlir::arm_sme"; + let summary = "Basic dialect to target Arm SME architectures"; + let description = [{ + This dialect contains the definitions necessary to target Arm SME + scalable matrix operations. + + Sources: + https://developer.arm.com/documentation/ddi0616 + https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions + }]; +} + +//===----------------------------------------------------------------------===// +// ArmSME Intrinsic op definitions +//===----------------------------------------------------------------------===// + +def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>; +def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2], + [I8, I16, BF16, F16, F32, F64]>; +def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>; + +class ArmSME_IntrOp overloadedOperands = [], + list traits = []> + : LLVM_IntrOpBase< + /*Dialect dialect=*/ArmSME_Dialect, + /*string opName=*/"intr." # mnemonic, + /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic), + /*list overloadedResults=*/[], + /*list overloadedOperands=*/overloadedOperands, + /*list traits=*/traits, + /*int numResults=*/0>; + +// Zero +def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">, + Arguments<(ins Arg)>; + +// MOP's +class ArmSME_IntrMopOverloadedOp + : ArmSME_IntrOp, + Arguments<(ins Arg, + Arg, + Arg, + Arg, + Arg)>; + +def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">; +def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">; +def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">; +def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">; +def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">; +def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">; +def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">; +def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">; +def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">; +def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">; +def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">; +def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">; + +// Loads +class ArmSME_IntrLoadOp + : ArmSME_IntrOp, + Arguments<(ins Arg, + Arg, + Arg, + Arg)>; + +def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">; +def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">; +def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">; +def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">; +def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">; +def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">; +def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">; +def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">; +def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">; +def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">; + +// Stores +class ArmSME_IntrStoreOp + : ArmSME_IntrOp, + Arguments<(ins Arg, + Arg, + Arg, + Arg)>; + +def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">; +def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">; +def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">; +def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">; +def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">; +def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">; +def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">; +def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">; +def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">; +def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">; + +#endif // ARMSME_OPS diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(ArmSME arm_sme ArmSME) +add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme) + +set(LLVM_TARGET_DEFINITIONS ArmSME.td) +mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRArmSMEConversionsIncGen) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -117,6 +118,7 @@ pdl_interp::PDLInterpDialect, quant::QuantizationDialect, spirv::SPIRVDialect, + arm_sme::ArmSMEDialect, arm_sve::ArmSVEDialect, vector::VectorDialect, NVVM::NVVMDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -16,6 +16,7 @@ #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" @@ -35,6 +36,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerArmNeonDialectTranslation(registry); registerAMXDialectTranslation(registry); + registerArmSMEDialectTranslation(registry); registerArmSVEDialectTranslation(registry); registerBuiltinDialectTranslation(registry); registerGPUDialectTranslation(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//=======- ArmSMEToLLVMIRTranslation.h - ArmSME to LLVM IR --*- C++ -*-=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for ArmSME dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the ArmSME dialect and the translation from it to the LLVM IR in +/// the given registry; +void registerArmSMEDialectTranslation(DialectRegistry ®istry); + +/// Register the ArmSME dialect and the translation from it in the registry +/// associated with the given context. +void registerArmSMEDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(IR) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp @@ -0,0 +1,36 @@ +//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ArmSME dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +//===----------------------------------------------------------------------===// +// Tablegen Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc" + +void ArmSMEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRArmSMEDialect + ArmSME.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME + + DEPENDS + MLIRArmSMEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRSideEffectInterfaces +) diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -46,6 +46,7 @@ LINK_LIBS PUBLIC MLIRArmNeonToLLVMIRTranslation + MLIRArmSMEToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation MLIRAMXToLLVMIRTranslation MLIRBuiltinToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp @@ -0,0 +1,56 @@ +//======- ArmSMEToLLVMIRTranslation.cpp - Translate ArmSME to LLVM IR -=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the ArmSME dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the ArmSME dialect to LLVM IR. +class ArmSMEDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/ArmSME/IR/ArmSMEConversions.inc" + + return failure(); + } +}; +} // namespace + +void mlir::registerArmSMEDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, arm_sme::ArmSMEDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void mlir::registerArmSMEDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerArmSMEDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRArmSMEToLLVMIRTranslation + ArmSMEToLLVMIRTranslation.cpp + + DEPENDS + MLIRArmSMEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRArmSMEDialect + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(ArmNeon) +add_subdirectory(ArmSME) add_subdirectory(ArmSVE) add_subdirectory(AMX) add_subdirectory(Builtin) diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -0,0 +1,225 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @arm_sme_zero +llvm.func @arm_sme_zero() { + %c0 = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.zero(i32 0) + "arm_sme.intr.zero"(%c0) : (i32) -> () + llvm.return +} + +// ----- + +// CHECK-LABEL: @arm_sme_fmopa +llvm.func @arm_sme_fmopa(%nxv2f64 : vector<[2]xf64>, + %nxv4f32 : vector<[4]xf32>, + %nxv8f16 : vector<[8]xf16>, + %nxv8bf16: vector<[8]xbf16>, + %nxv2i1 : vector<[2]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv8i1 : vector<[8]xi1>) { + %c0 = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64 + "arm_sme.intr.mopa"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) : + (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> () + // CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32 + "arm_sme.intr.mopa"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) : + (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> () + // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8f16 + "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> () + // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8bf16 + "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> () + llvm.return +} + +// ----- + +// CHECK-LABEL: @arm_sme_imopa +llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>, + %nxv16i8 : vector<[16]xi8>, + %nxv8i1 : vector<[8]xi1>, + %nxv16i1 : vector<[16]xi1>) { + %c0 = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv8i16 + "arm_sme.intr.smopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv8i16 + "arm_sme.intr.umopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv8i16 + "arm_sme.intr.sumopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv8i16 + "arm_sme.intr.usmopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv16i8 + "arm_sme.intr.smopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv16i8 + "arm_sme.intr.umopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv16i8 + "arm_sme.intr.sumopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8 + "arm_sme.intr.usmopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + llvm.return +} + +// ----- + +// CHECK-LABEL: @arm_sme_fmops +llvm.func @arm_sme_fmops(%nxv2f64 : vector<[2]xf64>, + %nxv4f32 : vector<[4]xf32>, + %nxv8f16 : vector<[8]xf16>, + %nxv8bf16: vector<[8]xbf16>, + %nxv2i1 : vector<[2]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv8i1 : vector<[8]xi1>) { + %c0 = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.mops.nxv2f64 + "arm_sme.intr.mops"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) : + (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> () + // CHECK: call void @llvm.aarch64.sme.mops.nxv4f32 + "arm_sme.intr.mops"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) : + (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> () + // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8f16 + "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> () + // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8bf16 + "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> () + llvm.return +} + +// ----- + +// CHECK-LABEL: @arm_sme_imops +llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>, + %nxv16i8 : vector<[16]xi8>, + %nxv8i1 : vector<[8]xi1>, + %nxv16i1 : vector<[16]xi1>) { + %c0 = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv8i16 + "arm_sme.intr.smops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv8i16 + "arm_sme.intr.umops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv8i16 + "arm_sme.intr.sumops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv8i16 + "arm_sme.intr.usmops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) : + (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv16i8 + "arm_sme.intr.smops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv16i8 + "arm_sme.intr.umops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv16i8 + "arm_sme.intr.sumops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8 + "arm_sme.intr.usmops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) : + (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + llvm.return +} + +// ----- + +// CHECK-LABEL: @arm_sme_load +llvm.func @arm_sme_load(%nxv1i1 : vector<[1]xi1>, + %nxv2i1 : vector<[2]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv8i1 : vector<[8]xi1>, + %nxv16i1 : vector<[16]xi1>, + %p8 : !llvm.ptr, + %p16 : !llvm.ptr, + %p32 : !llvm.ptr, + %p64 : !llvm.ptr, + %p128 : !llvm.ptr) { + %c0 = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.ld1q.horiz + "arm_sme.intr.ld1q.horiz"(%nxv1i1, %p128, %c0, %c0) : + (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1d.horiz + "arm_sme.intr.ld1d.horiz"(%nxv2i1, %p64, %c0, %c0) : + (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1w.horiz + "arm_sme.intr.ld1w.horiz"(%nxv4i1, %p32, %c0, %c0) : + (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1h.horiz + "arm_sme.intr.ld1h.horiz"(%nxv8i1, %p16, %c0, %c0) : + (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1b.horiz + "arm_sme.intr.ld1b.horiz"(%nxv16i1, %p8, %c0, %c0) : + (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1q.vert + "arm_sme.intr.ld1q.vert"(%nxv1i1, %p128, %c0, %c0) : + (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1d.vert + "arm_sme.intr.ld1d.vert"(%nxv2i1, %p64, %c0, %c0) : + (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1w.vert + "arm_sme.intr.ld1w.vert"(%nxv4i1, %p32, %c0, %c0) : + (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1h.vert + "arm_sme.intr.ld1h.vert"(%nxv8i1, %p16, %c0, %c0) : + (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.ld1b.vert + "arm_sme.intr.ld1b.vert"(%nxv16i1, %p8, %c0, %c0) : + (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () + llvm.return +} + +// ----- + +// CHECK-LABEL: @arm_sme_store +llvm.func @arm_sme_store(%nxv1i1 : vector<[1]xi1>, + %nxv2i1 : vector<[2]xi1>, + %nxv4i1 : vector<[4]xi1>, + %nxv8i1 : vector<[8]xi1>, + %nxv16i1 : vector<[16]xi1>, + %p8 : !llvm.ptr, + %p16 : !llvm.ptr, + %p32 : !llvm.ptr, + %p64 : !llvm.ptr, + %p128 : !llvm.ptr) { + %c0 = llvm.mlir.constant(0 : index) : i32 + // CHECK: call void @llvm.aarch64.sme.st1q.horiz + "arm_sme.intr.st1q.horiz"(%nxv1i1, %p128, %c0, %c0) : + (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1d.horiz + "arm_sme.intr.st1d.horiz"(%nxv2i1, %p64, %c0, %c0) : + (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1w.horiz + "arm_sme.intr.st1w.horiz"(%nxv4i1, %p32, %c0, %c0) : + (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1h.horiz + "arm_sme.intr.st1h.horiz"(%nxv8i1, %p16, %c0, %c0) : + (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1b.horiz + "arm_sme.intr.st1b.horiz"(%nxv16i1, %p8, %c0, %c0) : + (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1q.vert + "arm_sme.intr.st1q.vert"(%nxv1i1, %p128, %c0, %c0) : + (vector<[1]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1d.vert + "arm_sme.intr.st1d.vert"(%nxv2i1, %p64, %c0, %c0) : + (vector<[2]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1w.vert + "arm_sme.intr.st1w.vert"(%nxv4i1, %p32, %c0, %c0) : + (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1h.vert + "arm_sme.intr.st1h.vert"(%nxv8i1, %p16, %c0, %c0) : + (vector<[8]xi1>, !llvm.ptr, i32, i32) -> () + // CHECK: call void @llvm.aarch64.sme.st1b.vert + "arm_sme.intr.st1b.vert"(%nxv16i1, %p8, %c0, %c0) : + (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () + llvm.return +} diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -6,6 +6,7 @@ // CHECK-SAME: amx // CHECK-SAME: arith // CHECK-SAME: arm_neon +// CHECK-SAME: arm_sme // CHECK-SAME: arm_sve // CHECK-SAME: async // CHECK-SAME: bufferization