diff --git a/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h b/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h @@ -0,0 +1,23 @@ +//===- ArmNeonToLLVM.h - Conversion Patterns from ArmNeon to LLVM ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARMNEONTOLLVM_NEONTOLLVM_H_ +#define MLIR_CONVERSION_ARMNEONTOLLVM_NEONTOLLVM_H_ + +namespace mlir { + +class LLVMTypeConverter; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from theArmNeon dialect to LLVM. +void populateArmNeonToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMNEONTOLLVM_NEONTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -396,12 +396,13 @@ operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AVX512, Neon, SVE, etc.) in combination with the architectural-neutral + (AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral vector dialect lowering. }]; let constructor = "mlir::createConvertVectorToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; + // Override explicitly in C++ to allow conditional dialect dependence. + // let dependentDialects; let options = [ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", @@ -413,6 +414,10 @@ Option<"enableAVX512", "enable-avx512", "bool", /*default=*/"false", "Enables the use of AVX512 dialect while lowering the vector " + "dialect.">, + Option<"enableArmNeon", "enable-arm-neon", + "bool", /*default=*/"false", + "Enables the use of ArmNeon dialect while lowering the vector " "dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,7 +23,7 @@ struct LowerVectorToLLVMOptions { LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), - enableAVX512(false) {} + enableArmNeon(false), enableAVX512(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -37,9 +37,14 @@ enableAVX512 = b; return *this; } + LowerVectorToLLVMOptions &setEnableArmNeon(bool b) { + enableArmNeon = b; + return *this; + } bool reassociateFPReductions; bool enableIndexOptimizations; + bool enableArmNeon; bool enableAVX512; }; diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td @@ -0,0 +1,60 @@ +//===-- ArmNeonOps.td - ArmNeon dialect op definitions -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the ArmNeon dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef ARMNEON_OPS +#define ARMNEON_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// ArmNeon dialect definition +//===----------------------------------------------------------------------===// + +def ArmNeon_Dialect : Dialect { + let name = "arm_neon"; + let cppNamespace = "::mlir::arm_neon"; +} + +//===----------------------------------------------------------------------===// +// ArmNeon op definitions +//===----------------------------------------------------------------------===// + +class ArmNeon_Op traits = []> : + Op {} + +def SMullOp : ArmNeon_Op<"smull", [NoSideEffect, + AllTypesMatch<["a", "b"]>, + TypesMatchWith< + "res has same vector shape and element bitwidth scaled by 2 as a", + "a", "res", "$_self.cast().scaleElementBitwidth(2)">]> { + let summary = "smull roundscale op"; + let description = [{ + Signed Multiply Long (vector). This instruction multiplies corresponding + signed integer values in the lower or upper half of the vectors of the two + source SIMD&FP registers, places the results in a vector, and writes the + vector to the destination SIMD&FP register. + + Source: + https://developer.arm.com/architectures/instruction-sets/simd-isas/neon/intrinsics + }]; + // Supports either: + // (vector<8xi8>, vector<8xi8>) -> (vector<8xi16>) + // (vector<4xi16>, vector<4xi16>) -> (vector<4xi32>) + // (vector<2xi32>, vector<2xi32>) -> (vector<2xi64>) + let arguments = (ins VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$a, + VectorOfLengthAndType<[8, 4, 2], [I8, I16, I32]>:$b); + let results = (outs VectorOfLengthAndType<[8, 4, 2], [I16, I32, I64]>:$res); + let assemblyFormat = + "$a `,` $b attr-dict `:` type($a) `to` type($res)"; +} + +#endif // ARMNEON_OPS diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h @@ -0,0 +1,25 @@ +//===- ArmNeonDialect.h - MLIR Dialect forArmNeon ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for ArmNeon in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ +#define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmNeon/ArmNeon.h.inc" + +#endif // MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmNeon/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(ArmNeon arm_neon) +add_mlir_doc(ArmNeon -gen-dialect-doc ArmNeon Dialects/) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Affine) add_subdirectory(Async) +add_subdirectory(ArmNeon) add_subdirectory(AVX512) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -8,25 +8,32 @@ mlir_tablegen(LLVMOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRLLVMOpsIncGen) -add_mlir_dialect(NVVMOps nvvm) -add_mlir_doc(NVVMOps -gen-dialect-doc NVVMDialect Dialects/) -add_mlir_dialect(ROCDLOps rocdl) -add_mlir_doc(ROCDLOps -gen-dialect-doc ROCDLDialect Dialects/) - set(LLVM_TARGET_DEFINITIONS LLVMOps.td) mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions) mlir_tablegen(LLVMConversionEnumsToLLVM.inc -gen-enum-to-llvmir-conversions) mlir_tablegen(LLVMConversionEnumsFromLLVM.inc -gen-enum-from-llvmir-conversions) add_public_tablegen_target(MLIRLLVMConversionsIncGen) + +add_mlir_dialect(NVVMOps nvvm) +add_mlir_doc(NVVMOps -gen-dialect-doc NVVMDialect Dialects/) set(LLVM_TARGET_DEFINITIONS NVVMOps.td) mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRNVVMConversionsIncGen) + +add_mlir_dialect(ROCDLOps rocdl) +add_mlir_doc(ROCDLOps -gen-dialect-doc ROCDLDialect Dialects/) set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRROCDLConversionsIncGen) add_mlir_dialect(LLVMAVX512 llvm_avx512 LLVMAVX512) - +add_mlir_doc(LLVMAVX512 -gen-dialect-doc LLVMAVX512 Dialects/) set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td) mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen) + +add_mlir_dialect(LLVMArmNeon llvm_arm_neon LLVMArmNeon) +add_mlir_doc(LLVMArmNeon -gen-dialect-doc LLVMArmNeon Dialects/) +set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td) +mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeon.td @@ -0,0 +1,43 @@ +//===-- LLVMArmNeon.td - LLVMArmNeon dialect op definitions ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMArmNeon dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_ARMNEON_OPS +#define LLVMIR_ARMNEON_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMArmNeon dialect definition +//===----------------------------------------------------------------------===// + +def LLVMArmNeon_Dialect : Dialect { + let name = "llvm_arm_neon"; + let cppNamespace = "::mlir::LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVMArmNeon intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +class LLVMArmNeon_IntrBinaryOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +def LLVM_aarch64_arm_neon_smull : + LLVMArmNeon_IntrBinaryOverloadedOp<"smull">, Arguments<(ins LLVM_Type, LLVM_Type)>; + +#endif // ARMNEON_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h @@ -0,0 +1,24 @@ +//===- LLVMArmNeonDialect.h - MLIR Dialect for LLVMArmNeon ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMArmNeon in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmNeon.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h.inc" + +#endif // MLIR_DIALECT_LLVMIR_LLVMARMNEONDIALECT_H_ diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -150,6 +150,11 @@ /// Return true if this is an unsigned integer type. bool isUnsigned() const { return getSignedness() == Unsigned; } + /// Get or create a new IntegerType with the same signedness as `this` and a + /// bitwidth scaled by `scale`. + /// Return null if the scaled element type cannot be represented. + IntegerType scaleElementBitwidth(unsigned scale); + /// Integer representation maximal bitwidth. static constexpr unsigned kMaxWidth = 4096; }; @@ -174,6 +179,10 @@ /// Return the bitwidth of this float type. unsigned getWidth(); + /// Get or create a new FloatType with bitwidth scaled by `scale`. + /// Return null if the scaled element type cannot be represented. + FloatType scaleElementBitwidth(unsigned scale); + /// Return the floating semantics of this float type. const llvm::fltSemantics &getFloatSemantics(); }; @@ -433,6 +442,11 @@ } ArrayRef getShape() const; + + /// Get or create a new VectorType with the same shape as `this` and an + /// element type of bitwidth scaled by `scale`. + /// Return null if the scaled element type cannot be represented. + VectorType scaleElementBitwidth(unsigned scale); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -16,9 +16,11 @@ #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -44,11 +46,13 @@ // clang-format off registry.insert -static Type getSrcVectorElementType(OpTy op) { - return op.src().getType().template cast().getElementType(); -} - -// TODO: Code is currently copy-pasted and adapted from the code -// 1-1 LLVM conversion. It would better if it were properly exposed in core and -// reusable. -/// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to -/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass -/// operands as is, preserve attributes. -template -static LogicalResult -matchAndRewriteOneToOne(LLVMTypeConverter &typeConverter, Operation *op, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - unsigned numResults = op->getNumResults(); - - Type packedType; - if (numResults != 0) { - packedType = typeConverter.packFunctionResults(op->getResultTypes()); - if (!packedType) - return failure(); - } - - auto newOp = rewriter.create(op->getLoc(), packedType, operands, - op->getAttrs()); - - // If the operation produced 0 or 1 result, return them immediately. - if (numResults == 0) - return rewriter.eraseOp(op), success(); - if (numResults == 1) - return rewriter.replaceOp(op, newOp->getResult(0)), success(); - - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - SmallVector results; - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - auto type = typeConverter.convertType(op->getResult(i).getType()); - results.push_back(rewriter.create( - op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); - } - rewriter.replaceOp(op, results); - return success(); +static Type getSrcVectorElementType(Operation *op) { + return cast(op) + .src() + .getType() + .template cast() + .getElementType(); } namespace { -// TODO: Patterns are too verbose due to the fact that we have 1 op (e.g. -// MaskRndScaleOp) and different possible target ops. It would be better to take -// a Functor so that all these conversions become 1-liners. -struct MaskRndScaleOpPS512Conversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(MaskRndScaleOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(op).isF32()) - return failure(); - return matchAndRewriteOneToOne( - *getTypeConverter(), op, operands, rewriter); - } -}; - -struct MaskRndScaleOpPD512Conversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +// TODO: turn these into simpler declarative templated patterns when we've had +// enough. +struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern { + explicit MaskRndScaleOp512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, + typeConverter) {} LogicalResult matchAndRewrite(MaskRndScaleOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(op).isF64()) - return failure(); - return matchAndRewriteOneToOne( - *getTypeConverter(), op, operands, rewriter); - } -}; - -struct ScaleFOpPS512Conversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(MaskScaleFOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(op).isF32()) - return failure(); - return matchAndRewriteOneToOne( - *getTypeConverter(), op, operands, rewriter); + Type elementType = getSrcVectorElementType(op); + if (elementType.isF32()) + return LLVM::detail::oneToOneRewrite( + op, LLVM::x86_avx512_mask_rndscale_ps_512::getOperationName(), + operands, *getTypeConverter(), rewriter); + if (elementType.isF64()) + return LLVM::detail::oneToOneRewrite( + op, LLVM::x86_avx512_mask_rndscale_pd_512::getOperationName(), + operands, *getTypeConverter(), rewriter); + return failure(); } }; -struct ScaleFOpPD512Conversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct ScaleFOp512Conversion : public ConvertToLLVMPattern { + explicit ScaleFOp512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, + typeConverter) {} LogicalResult matchAndRewrite(MaskScaleFOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!getSrcVectorElementType(op).isF64()) - return failure(); - return matchAndRewriteOneToOne( - *getTypeConverter(), op, operands, rewriter); + Type elementType = getSrcVectorElementType(op); + if (elementType.isF32()) + return LLVM::detail::oneToOneRewrite( + op, LLVM::x86_avx512_mask_scalef_ps_512::getOperationName(), operands, + *getTypeConverter(), rewriter); + if (elementType.isF64()) + return LLVM::detail::oneToOneRewrite( + op, LLVM::x86_avx512_mask_scalef_pd_512::getOperationName(), operands, + *getTypeConverter(), rewriter); + return failure(); } }; } // namespace @@ -135,9 +83,7 @@ void mlir::populateAVX512ToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off - patterns.insert(converter); + patterns.insert(ctx, converter); // clang-format on } diff --git a/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp b/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.cpp @@ -0,0 +1,31 @@ +//===- ArmNeonToLLVM.cpp - ArmNeon to the LLVM dialect --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::arm_neon; + +using SMullOpLowering = + OneToOneConvertToLLVMPattern; + +/// Populate the given list with patterns that convert from ArmNeon to LLVM. +void mlir::populateArmNeonToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(converter); +} diff --git a/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeonToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArmNeonToLLVM + ArmNeonToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeonToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArmNeon + MLIRLLVMArmNeon + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(AffineToStandard) +add_subdirectory(ArmNeonToLLVM) add_subdirectory(AsyncToLLVM) add_subdirectory(AVX512ToLLVM) add_subdirectory(GPUCommon) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -25,8 +25,9 @@ } // end namespace gpu namespace LLVM { -class LLVMDialect; +class LLVMArmNeonDialect; class LLVMAVX512Dialect; +class LLVMDialect; } // end namespace LLVM namespace NVVM { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -11,10 +11,13 @@ #include "../PassDetail.h" #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" +#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -28,8 +31,17 @@ LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { this->reassociateFPReductions = options.reassociateFPReductions; this->enableIndexOptimizations = options.enableIndexOptimizations; + this->enableArmNeon = options.enableArmNeon; this->enableAVX512 = options.enableAVX512; } + // Override explicitly to allow conditional dialect dependence. + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + if (enableArmNeon) + registry.insert(); + if (enableAVX512) + registry.insert(); + } void runOnOperation() override; }; } // namespace @@ -56,6 +68,11 @@ // Architecture specific augmentations. LLVMConversionTarget target(getContext()); + if (enableArmNeon) { + target.addLegalDialect(); + target.addIllegalDialect(); + populateArmNeonToLLVMConversionPatterns(converter, patterns); + } if (enableAVX512) { target.addLegalDialect(); target.addIllegalDialect(); diff --git a/mlir/lib/Dialect/ArmNeon/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmNeon/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRArmNeon + IR/ArmNeonDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmNeon + + DEPENDS + MLIRArmNeonIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp @@ -0,0 +1,29 @@ +//===- ArmNeonOps.cpp - MLIRArmNeon ops implementation --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ArmNeon dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void arm_neon::ArmNeonDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/ArmNeon/ArmNeon.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmNeon/ArmNeon.cpp.inc" diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Affine) +add_subdirectory(ArmNeon) add_subdirectory(Async) add_subdirectory(AVX512) add_subdirectory(GPU) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -49,6 +49,27 @@ MLIRSideEffectInterfaces ) +add_mlir_dialect_library(MLIRLLVMArmNeon + IR/LLVMArmNeonDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMArmNeonIncGen + MLIRLLVMArmNeonConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) + add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmNeonDialect.cpp @@ -0,0 +1,32 @@ +//===- LLVMArmNeonDialect.cpp - MLIR LLVMArmNeon ops implementation +//-------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMArmNeon dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsAArch64.h" + +#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void LLVM::LLVMArmNeonDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMArmNeon.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmNeon.cpp.inc" diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -66,6 +66,12 @@ return getImpl()->signedness; } +IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { + if (!scale) + return IntegerType(); + return IntegerType::get(scale * getWidth(), getSignedness(), getContext()); +} + //===----------------------------------------------------------------------===// // Float Type //===----------------------------------------------------------------------===// @@ -93,6 +99,22 @@ llvm_unreachable("non-floating point type used"); } +FloatType FloatType::scaleElementBitwidth(unsigned scale) { + if (!scale) + return FloatType(); + MLIRContext *ctx = getContext(); + if (isF16() || isBF16()) { + if (scale == 2) + return FloatType::getF32(ctx); + if (scale == 4) + return FloatType::getF64(ctx); + } + if (isF32()) + if (scale == 2) + return FloatType::getF64(ctx); + return FloatType(); +} + //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// @@ -306,6 +328,18 @@ ArrayRef VectorType::getShape() const { return getImpl()->getShape(); } +VectorType VectorType::scaleElementBitwidth(unsigned scale) { + if (!scale) + return VectorType(); + if (auto et = getElementType().dyn_cast()) + if (auto scaledEt = et.scaleElementBitwidth(scale)) + return VectorType::get(getShape(), scaledEt); + if (auto et = getElementType().dyn_cast()) + if (auto scaledEt = et.scaleElementBitwidth(scale)) + return VectorType::get(getShape(), scaledEt); + return VectorType(); +} + //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -55,6 +55,25 @@ MLIRTargetLLVMIRModuleTranslation ) +add_mlir_translation_library(MLIRTargetArmNeon + LLVMIR/LLVMArmNeonIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + + DEPENDS + MLIRLLVMArmNeonConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMArmNeon + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + ) + add_mlir_translation_library(MLIRTargetNVVMIR LLVMIR/ConvertToNVVMIR.cpp diff --git a/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/LLVMArmNeonIntr.cpp @@ -0,0 +1,63 @@ +//===- ArmNeonIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and ArmNeon dialects +// and LLVM IR with ArmNeon intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; + +namespace { +class LLVMArmNeonModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/LLVMIR/LLVMArmNeonConversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; + +std::unique_ptr +translateLLVMArmNeonModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + return LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); +} +} // end namespace + +namespace mlir { +void registerArmNeonToLLVMIRTranslation() { + TranslateFromMLIRRegistration reg( + "arm-neon-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = translateLLVMArmNeonModuleToLLVMIR( + module, llvmContext, "LLVMDialectModule"); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} +} // namespace mlir diff --git a/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ArmNeonToLLVM/convert-to-llvm.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-neon" | mlir-opt | FileCheck %s + +// CHECK-LABEL: arm_neon_smull +func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>) + -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) { + // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16> + %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16> + %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}: + vector<8xi16> to vector<4xi16> + + // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<4 x i16>, !llvm.vec<4 x i16>) -> !llvm.vec<4 x i32> + %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32> + %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}: + vector<4xi32> to vector<2xi32> + + // CHECK: arm_neon.smull{{.*}}: (!llvm.vec<2 x i32>, !llvm.vec<2 x i32>) -> !llvm.vec<2 x i64> + %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64> + + return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64> +} diff --git a/mlir/test/Dialect/ArmNeon/roundtrip.mlir b/mlir/test/Dialect/ArmNeon/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmNeon/roundtrip.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: arm_neon_smull +func @arm_neon_smull(%a: vector<8xi8>, %b: vector<8xi8>) + -> (vector<8xi16>, vector<4xi32>, vector<2xi64>) { + // CHECK: arm_neon.smull {{.*}}: vector<8xi8> to vector<8xi16> + %0 = arm_neon.smull %a, %b : vector<8xi8> to vector<8xi16> + %00 = vector.extract_strided_slice %0 {offsets = [3], sizes = [4], strides = [1]}: + vector<8xi16> to vector<4xi16> + + // CHECK: arm_neon.smull {{.*}}: vector<4xi16> to vector<4xi32> + %1 = arm_neon.smull %00, %00 : vector<4xi16> to vector<4xi32> + %11 = vector.extract_strided_slice %1 {offsets = [1], sizes = [2], strides = [1]}: + vector<4xi32> to vector<2xi32> + + // CHECK: arm_neon.smull {{.*}}: vector<2xi32> to vector<2xi64> + %2 = arm_neon.smull %11, %11 : vector<2xi32> to vector<2xi64> + + return %0, %1, %2 : vector<8xi16>, vector<4xi32>, vector<2xi64> +} diff --git a/mlir/test/Target/arm-neon.mlir b/mlir/test/Target/arm-neon.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/arm-neon.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate -arm-neon-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: arm_neon_smull +llvm.func @arm_neon_smull(%arg0: !llvm.vec<8 x i8>, %arg1: !llvm.vec<8 x i8>) -> !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> { + // CHECK: %[[V0:.*]] = call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %{{.*}}, <8 x i8> %{{.*}}) + // CHECK-NEXT: %[[V00:.*]] = shufflevector <8 x i16> %3, <8 x i16> %[[V0]], <4 x i32> + %0 = "llvm_arm_neon.smull"(%arg0, %arg1) : (!llvm.vec<8 x i8>, !llvm.vec<8 x i8>) -> !llvm.vec<8 x i16> + %1 = llvm.shufflevector %0, %0 [3, 4, 5, 6] : !llvm.vec<8 x i16>, !llvm.vec<8 x i16> + + // CHECK-NEXT: %[[V1:.*]] = call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %[[V00]], <4 x i16> %[[V00]]) + // CHECK-NEXT: %[[V11:.*]] = shufflevector <4 x i32> %[[V1]], <4 x i32> %[[V1]], <2 x i32> + %2 = "llvm_arm_neon.smull"(%1, %1) : (!llvm.vec<4 x i16>, !llvm.vec<4 x i16>) -> !llvm.vec<4 x i32> + %3 = llvm.shufflevector %2, %2 [1, 2] : !llvm.vec<4 x i32>, !llvm.vec<4 x i32> + + // CHECK-NEXT: %[[V1:.*]] = call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %[[V11]], <2 x i32> %[[V11]]) + %4 = "llvm_arm_neon.smull"(%3, %3) : (!llvm.vec<2 x i32>, !llvm.vec<2 x i32>) -> !llvm.vec<2 x i64> + + %5 = llvm.mlir.undef : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> + %6 = llvm.insertvalue %0, %5[0] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> + %7 = llvm.insertvalue %2, %6[1] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> + %8 = llvm.insertvalue %4, %7[2] : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> + + // CHECK: ret { <8 x i16>, <4 x i32>, <2 x i64> } + llvm.return %8 : !llvm.struct<(vec<8 x i16>, vec<4 x i32>, vec<2 x i64>)> +}