diff --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h deleted file mode 100644 --- a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h +++ /dev/null @@ -1,24 +0,0 @@ -//===- ArmSVEToLLVM.h - Conversion Patterns from ArmSVE to LLVM -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_ -#define MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_ - -namespace mlir { - -class LLVMTypeConverter; -class RewritePatternSet; -using OwningRewritePatternList = RewritePatternSet; - -/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM. -void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); - -} // namespace mlir - -#endif // MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_ diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -14,6 +14,8 @@ #define ARMSVE_OPS include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td" //===----------------------------------------------------------------------===// // ArmSVE dialect definition @@ -93,35 +95,6 @@ }]; } -//===----------------------------------------------------------------------===// -// ArmSVE type traits -//===----------------------------------------------------------------------===// - -def IsScalableVectorTypePred : - CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">; - -class ScalableVectorOf allowedTypes> : - ContainerType, IsScalableVectorTypePred, - "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()", - "scalable vector">; - -class IsScalableVectorOfLengthPred allowedLengths> : - And<[IsScalableVectorTypePred, - Or().getNumElements() == }] - # allowedlength>)>]>; - -class ScalableVectorOfLength allowedLengths> : Type< - IsScalableVectorOfLengthPred, - " of length " # !interleave(allowedLengths, "/")>; - -class ScalableVectorOfLengthAndType allowedLengths, - list allowedTypes> : Type< - And<[ScalableVectorOf.predicate, - ScalableVectorOfLength.predicate]>, - ScalableVectorOf.summary # - ScalableVectorOfLength.summary>; - //===----------------------------------------------------------------------===// // ArmSVE op definitions //===----------------------------------------------------------------------===// @@ -129,6 +102,26 @@ class ArmSVE_Op traits = []> : Op {} +class ArmSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +class ArmSVE_IntrBinaryOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + def SdotOp : ArmSVE_Op<"sdot", [NoSideEffect, AllTypesMatch<["src1", "src2"]>, @@ -273,4 +266,23 @@ "attr-dict `:` type($res)"; } +def UmmlaIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"ummla">, + Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + +def SmmlaIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"smmla">, + Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + +def SdotIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"sdot">, + Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + +def UdotIntrOp : + ArmSVE_IntrBinaryOverloadedOp<"udot">, + Arguments<(ins LLVM_AnyVector, LLVM_AnyVector, LLVM_AnyVector)>; + +def VectorScaleIntrOp: + ArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + #endif // ARMSVE_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEOpBase.td @@ -0,0 +1,45 @@ +//===-- ArmSVEOpBase.td - Base op definitions for ArmSVE ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the base operation definition file for ArmSVE scalable vector types. +// +//===----------------------------------------------------------------------===// + +#ifndef ARMSVE_OP_BASE +#define ARMSVE_OP_BASE + +//===----------------------------------------------------------------------===// +// ArmSVE scalable vector type constraints +//===----------------------------------------------------------------------===// + +def IsScalableVectorTypePred : + CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">; + +class ScalableVectorOf allowedTypes> : + ContainerType, IsScalableVectorTypePred, + "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()", + "scalable vector">; + +class IsScalableVectorOfLengthPred allowedLengths> : + And<[IsScalableVectorTypePred, + Or().getNumElements() == }] + # allowedlength>)>]>; + +class ScalableVectorOfLength allowedLengths> : Type< + IsScalableVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/")>; + +class ScalableVectorOfLengthAndType allowedLengths, + list allowedTypes> : Type< + And<[ScalableVectorOf.predicate, + ScalableVectorOfLength.predicate]>, + ScalableVectorOf.summary # + ScalableVectorOfLength.summary>; + +#endif // ARMSVE_OP_BASE \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt --- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -1,2 +1,6 @@ add_mlir_dialect(ArmSVE arm_sve ArmSVE) add_mlir_doc(ArmSVE -gen-dialect-doc ArmSVE Dialects/) + +set(LLVM_TARGET_DEFINITIONS ArmSVE.td) +mlir_tablegen(ArmSVEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRArmSVEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms.h @@ -0,0 +1,30 @@ +//===- Transforms.h - ArmSVE Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_H +#define MLIR_DIALECT_ARMSVE_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; +using OwningRewritePatternList = RewritePatternSet; + +/// Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM +/// intrinsics. +void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM +/// intrinsics. +void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -33,10 +33,4 @@ add_mlir_doc(ROCDLOps -gen-dialect-doc ROCDLDialect Dialects/) set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRROCDLConversionsIncGen) - -add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE) -add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/) -set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td) -mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen) +add_public_tablegen_target(MLIRROCDLConversionsIncGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td +++ /dev/null @@ -1,70 +0,0 @@ -//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the basic operations for the LLVMArmSVE dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVMIR_ARMSVE_OPS -#define LLVMIR_ARMSVE_OPS - -include "mlir/Dialect/LLVMIR/LLVMOpBase.td" - -//===----------------------------------------------------------------------===// -// LLVMArmSVE dialect definition -//===----------------------------------------------------------------------===// - -def LLVMArmSVE_Dialect : Dialect { - let name = "llvm_arm_sve"; - let cppNamespace = "::mlir::LLVM"; -} - -//----------------------------------------------------------------------------// -// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system -//----------------------------------------------------------------------------// - -class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : - LLVM_IntrOpBase overloadedResults=*/[0], - /*list overloadedOperands=*/[], // defined by result overload - /*list traits=*/traits, - /*int numResults=*/1>; - -class LLVMArmSVE_IntrBinaryOverloadedOp traits = []> : - LLVM_IntrOpBase overloadedResults=*/[0], - /*list overloadedOperands=*/[], // defined by result overload - /*list traits=*/traits, - /*int numResults=*/1>; - -def LLVM_aarch64_arm_sve_ummla : - LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_aarch64_arm_sve_smmla : - LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_aarch64_arm_sve_sdot : - LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_aarch64_arm_sve_udot : - LLVMArmSVE_IntrBinaryOverloadedOp<"udot">, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_vector_scale : - LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; - -#endif // ARMSVE_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h +++ /dev/null @@ -1,24 +0,0 @@ -//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the Target dialect for LLVMArmSVE in MLIR. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ -#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" - -#define GET_OP_CLASSES -#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc" - -#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc" - -#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -23,7 +23,6 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/GPU/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -60,7 +59,6 @@ DLTIDialect, gpu::GPUDialect, LLVM::LLVMDialect, - LLVM::LLVMArmSVEDialect, linalg::LinalgDialect, math::MathDialect, memref::MemRefDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -17,7 +17,7 @@ #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" @@ -32,7 +32,7 @@ registerArmNeonDialectTranslation(registry); registerAMXDialectTranslation(registry); registerAVX512DialectTranslation(registry); - registerLLVMArmSVEDialectTranslation(registry); + registerArmSVEDialectTranslation(registry); registerLLVMDialectTranslation(registry); registerNVVMDialectTranslation(registry); registerOpenMPDialectTranslation(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//=======- ArmSVEToLLVMIRTranslation.h - ArmSVE to LLVM IR --*- C++ -*-=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for ArmSVE dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSVE_ARMSVETOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_ARMSVE_ARMSVETOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the ArmSVE dialect and the translation from it to the LLVM IR in +/// the given registry; +void registerArmSVEDialectTranslation(DialectRegistry ®istry); + +/// Register the ArmSVE dialect and the translation from it in the registry +/// associated with the given context. +void registerArmSVEDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSVE_ARMSVETOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h deleted file mode 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h +++ /dev/null @@ -1,32 +0,0 @@ -//===- LLVMArmSVEToLLVMIRTranslation.h - LLVMArmSVE to LLVM IR --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This provides registration calls for LLVMArmSVE dialect to LLVM IR -// translation. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMARMSVE_LLVMARMSVETOLLVMIRTRANSLATION_H -#define MLIR_TARGET_LLVMIR_DIALECT_LLVMARMSVE_LLVMARMSVETOLLVMIRTRANSLATION_H - -namespace mlir { - -class DialectRegistry; -class MLIRContext; - -/// Register the LLVMArmSVE dialect and the translation from it to the LLVM IR -/// in the given registry; -void registerLLVMArmSVEDialectTranslation(DialectRegistry ®istry); - -/// Register the LLVMArmSVE dialect and the translation from it in the registry -/// associated with the given context. -void registerLLVMArmSVEDialectTranslation(MLIRContext &context); - -} // namespace mlir - -#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMARMSVE_LLVMARMSVETOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(MLIRArmSVEToLLVM - ArmSVEToLLVM.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRArmSVE - MLIRLLVMArmSVE - MLIRLLVMIR - MLIRStandardToLLVM - MLIRTransforms - ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -22,7 +22,6 @@ add_subdirectory(TosaToLinalg) add_subdirectory(TosaToSCF) add_subdirectory(TosaToStandard) -add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -29,7 +29,6 @@ } // end namespace gpu namespace LLVM { -class LLVMArmSVEDialect; class LLVMDialect; } // end namespace LLVM diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -14,13 +14,12 @@ LINK_LIBS PUBLIC MLIRArmNeon + MLIRArmSVE + MLIRArmSVETransforms MLIRAMX MLIRAMXTransforms MLIRAVX512 MLIRAVX512Transforms - MLIRArmSVE - MLIRArmSVEToLLVM - MLIRLLVMArmSVE MLIRLLVMIR MLIRMemRef MLIRStandardToLLVM diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -10,7 +10,6 @@ #include "../PassDetail.h" -#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/AMX/AMXDialect.h" @@ -19,7 +18,7 @@ #include "mlir/Dialect/AVX512/Transforms.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -47,7 +46,7 @@ if (enableArmNeon) registry.insert(); if (enableArmSVE) - registry.insert(); + registry.insert(); if (enableAMX) registry.insert(); if (enableAVX512) @@ -90,26 +89,8 @@ target.addLegalDialect(); } if (enableArmSVE) { - target.addLegalDialect(); - target.addIllegalDialect(); - auto hasScalableVectorType = [](TypeRange types) { - for (Type type : types) - if (type.isa()) - return true; - return false; - }; - // Remove any ArmSVE-specific types from function signatures and results. - populateFuncOpTypeConversionPattern(patterns, converter); - target.addDynamicallyLegalOp([hasScalableVectorType](FuncOp op) { - return !hasScalableVectorType(op.getType().getInputs()) && - !hasScalableVectorType(op.getType().getResults()); - }); - target.addDynamicallyLegalOp( - [hasScalableVectorType](Operation *op) { - return !hasScalableVectorType(op->getOperandTypes()) && - !hasScalableVectorType(op->getResultTypes()); - }); - populateArmSVEToLLVMConversionPatterns(converter, patterns); + configureArmSVELegalizeForExportTarget(target); + populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } if (enableAMX) { configureAMXLegalizeForExportTarget(target); diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt @@ -1,13 +1,2 @@ -add_mlir_dialect_library(MLIRArmSVE - IR/ArmSVEDialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE - - DEPENDS - MLIRArmSVEIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRSideEffectInterfaces - ) +add_subdirectory(IR) +add_subdirectory(Transforms) \ No newline at end of file diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt copy from mlir/lib/Dialect/ArmSVE/CMakeLists.txt copy to mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSVE/IR/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIRArmSVE - IR/ArmSVEDialect.cpp + ArmSVEDialect.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE @@ -9,5 +9,6 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMIR MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRArmSVETransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRArmSVEConversionsIncGen + + LINK_LIBS PUBLIC + MLIRArmSVE + MLIRIR + MLIRLLVMIR + MLIRStandardToLLVM + ) diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp rename from mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp rename to mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -1,4 +1,4 @@ -//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===// +//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,34 +6,16 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; using namespace mlir::arm_sve; -using namespace mlir::vector; - -using SdotOpLowering = - OneToOneConvertToLLVMPattern; - -using SmmlaOpLowering = - OneToOneConvertToLLVMPattern; - -using UdotOpLowering = - OneToOneConvertToLLVMPattern; - -using UmmlaOpLowering = - OneToOneConvertToLLVMPattern; - -using VectorScaleOpLowering = - OneToOneConvertToLLVMPattern; // Extract an LLVM IR type from the LLVM IR dialect type. static Type unwrap(Type type) { @@ -95,9 +77,19 @@ .getResult(0); } +using SdotOpLowering = OneToOneConvertToLLVMPattern; +using SmmlaOpLowering = OneToOneConvertToLLVMPattern; +using UdotOpLowering = OneToOneConvertToLLVMPattern; +using UmmlaOpLowering = OneToOneConvertToLLVMPattern; +using VectorScaleOpLowering = + OneToOneConvertToLLVMPattern; + /// Populate the given list with patterns that convert from ArmSVE to LLVM. -void mlir::populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { +void mlir::populateArmSVELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // Populate conversion patterns + // Remove any ArmSVE-specific types from function signatures and results. + populateFuncOpTypeConversionPattern(patterns, converter); converter.addConversion([&converter](ScalableVectorType svType) { return convertScalableVectorTypeToLLVM(svType, converter); }); @@ -105,13 +97,42 @@ // clang-format off patterns.add, - ForwardOperands, - ForwardOperands>(converter, - &converter.getContext()); + ForwardOperands, + ForwardOperands>(converter, + &converter.getContext()); patterns.add(converter); + SmmlaOpLowering, + UdotOpLowering, + UmmlaOpLowering, + VectorScaleOpLowering>(converter); // clang-format on } + +void mlir::configureArmSVELegalizeForExportTarget( + LLVMConversionTarget &target) { + target.addLegalOp(); + target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + auto hasScalableVectorType = [](TypeRange types) { + for (Type type : types) + if (type.isa()) + return true; + return false; + }; + target.addDynamicallyLegalOp([hasScalableVectorType](FuncOp op) { + return !hasScalableVectorType(op.getType().getInputs()) && + !hasScalableVectorType(op.getType().getResults()); + }); + target.addDynamicallyLegalOp( + [hasScalableVectorType](Operation *op) { + return !hasScalableVectorType(op->getOperandTypes()) && + !hasScalableVectorType(op->getResultTypes()); + }); +} diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -29,27 +29,6 @@ MLIRSupport ) -add_mlir_dialect_library(MLIRLLVMArmSVE - IR/LLVMArmSVEDialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR - - DEPENDS - MLIRLLVMArmSVEIncGen - MLIRLLVMArmSVEConversionsIncGen - intrinsics_gen - - LINK_COMPONENTS - AsmParser - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRLLVMIR - MLIRSideEffectInterfaces - ) - add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp +++ /dev/null @@ -1,31 +0,0 @@ -//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the LLVMArmSVE dialect and its operations. -// -//===----------------------------------------------------------------------===// - -#include "llvm/IR/IntrinsicsAArch64.h" - -#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" - -using namespace mlir; - -void LLVM::LLVMArmSVEDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" - >(); -} - -#define GET_OP_CLASSES -#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -37,9 +37,9 @@ LINK_LIBS PUBLIC MLIRArmNeonToLLVMIRTranslation + MLIRArmSVEToLLVMIRTranslation MLIRAMXToLLVMIRTranslation MLIRAVX512ToLLVMIRTranslation - MLIRLLVMArmSVEToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp rename from mlir/lib/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.cpp rename to mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.cpp @@ -1,4 +1,4 @@ -//===- LLVMArmSVEToLLVMIRTranslation.cpp - Translate LLVMArmSVE to LLVM IR-===// +//======- ArmSVEToLLVMIRTranslation.cpp - Translate ArmSVE to LLVM IR -=======// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,13 +6,12 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation between the MLIR LLVMArmSVE dialect and -// LLVM IR. +// This file implements a translation between the ArmSVE dialect and LLVM IR. // //===----------------------------------------------------------------------===// -#include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h" -#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/IR/Operation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -24,8 +23,8 @@ namespace { /// Implementation of the dialect interface that converts operations belonging -/// to the LLVMArmSVE dialect to LLVM IR. -class LLVMArmSVEDialectLLVMIRTranslationInterface +/// to the ArmSVE dialect to LLVM IR. +class ArmSVEDialectLLVMIRTranslationInterface : public LLVMTranslationDialectInterface { public: using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; @@ -36,21 +35,21 @@ convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const final { Operation &opInst = *op; -#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc" +#include "mlir/Dialect/ArmSVE/ArmSVEConversions.inc" return failure(); } }; } // end namespace -void mlir::registerLLVMArmSVEDialectTranslation(DialectRegistry ®istry) { - registry.insert(); - registry.addDialectInterface(); +void mlir::registerArmSVEDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); } -void mlir::registerLLVMArmSVEDialectTranslation(MLIRContext &context) { +void mlir::registerArmSVEDialectTranslation(MLIRContext &context) { DialectRegistry registry; - registerLLVMArmSVEDialectTranslation(registry); + registerArmSVEDialectTranslation(registry); context.appendDialectRegistry(registry); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSVE/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRArmSVEToLLVMIRTranslation + ArmSVEToLLVMIRTranslation.cpp + + DEPENDS + MLIRArmSVEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRArmSVE + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,7 +1,7 @@ add_subdirectory(ArmNeon) +add_subdirectory(ArmSVE) add_subdirectory(AMX) add_subdirectory(AVX512) -add_subdirectory(LLVMArmSVE) add_subdirectory(LLVMIR) add_subdirectory(NVVM) add_subdirectory(OpenMP) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMArmSVE/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMArmSVE/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMArmSVE/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_translation_library(MLIRLLVMArmSVEToLLVMIRTranslation - LLVMArmSVEToLLVMIRTranslation.cpp - - DEPENDS - MLIRLLVMArmSVEConversionsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRLLVMArmSVE - MLIRLLVMIR - MLIRSupport - MLIRTargetLLVMIRExport - ) diff --git a/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir rename from mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir rename to mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir --- a/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir @@ -4,7 +4,7 @@ %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: llvm_arm_sve.sdot + // CHECK: arm_sve.intr.sdot %0 = arm_sve.sdot %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -14,7 +14,7 @@ %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: llvm_arm_sve.smmla + // CHECK: arm_sve.intr.smmla %0 = arm_sve.smmla %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -24,7 +24,7 @@ %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: llvm_arm_sve.udot + // CHECK: arm_sve.intr.udot %0 = arm_sve.udot %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> @@ -34,14 +34,14 @@ %b: !arm_sve.vector<16xi8>, %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { - // CHECK: llvm_arm_sve.ummla + // CHECK: arm_sve.intr.ummla %0 = arm_sve.ummla %c, %a, %b : !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> return %0 : !arm_sve.vector<4xi32> } func @get_vector_scale() -> index { - // CHECK: llvm_arm_sve.vscale + // CHECK: arm_sve.vscale %0 = arm_sve.vector_scale : index return %0 : index } diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir --- a/mlir/test/Target/LLVMIR/arm-sve.mlir +++ b/mlir/test/Target/LLVMIR/arm-sve.mlir @@ -6,7 +6,7 @@ %arg2: !llvm.vec) -> !llvm.vec { // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) -> !llvm.vec llvm.return %0 : !llvm.vec @@ -18,7 +18,7 @@ %arg2: !llvm.vec) -> !llvm.vec { // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) -> !llvm.vec llvm.return %0 : !llvm.vec @@ -30,7 +30,7 @@ %arg2: !llvm.vec) -> !llvm.vec { // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) -> !llvm.vec llvm.return %0 : !llvm.vec @@ -42,7 +42,7 @@ %arg2: !llvm.vec) -> !llvm.vec { // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) -> !llvm.vec llvm.return %0 : !llvm.vec @@ -51,6 +51,6 @@ // CHECK-LABEL: define i64 @get_vector_scale() llvm.func @get_vector_scale() -> i64 { // CHECK: call i64 @llvm.vscale.i64() - %0 = "llvm_arm_sve.vscale"() : () -> i64 + %0 = "arm_sve.vscale"() : () -> i64 llvm.return %0 : i64 } diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -12,7 +12,6 @@ // CHECK-NEXT: gpu // CHECK-NEXT: linalg // CHECK-NEXT: llvm -// CHECK-NEXT: llvm_arm_sve // CHECK-NEXT: math // CHECK-NEXT: memref // CHECK-NEXT: nvvm