diff --git a/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h b/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h @@ -0,0 +1,30 @@ +//===- ArmNeon2dToIntr.h - convert Arm Neon 2d ops to intrinsics ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ +#define MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +class FuncOp; +template +class OperationPass; + +/// Populates patterns for the lowering of Arm NEON 2D ops to intrinsics. +/// See createConvertArmNeon2dToIntrPass. +void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns); + +/// Creates a pass to lower Arm NEON 2D ops to intrinsics, i.e. +/// equivalent ops operating on flattened 1D vectors and mapping more +/// directly to the corresponding Arm NEON instruction. +std::unique_ptr> createConvertArmNeon2dToIntrPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -10,6 +10,7 @@ #define MLIR_CONVERSION_PASSES_H #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -592,4 +592,15 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// ArmNeon2dToIntr +//===----------------------------------------------------------------------===// + +def ConvertArmNeon2dToIntr : Pass<"arm-neon-2d-to-intr", "FuncOp"> { + let summary = "Convert Arm NEON structured ops to intrinsics"; + let constructor = "mlir::createConvertArmNeon2dToIntrPass()"; + let dependentDialects = ["arm_neon::ArmNeonDialect", "vector::VectorDialect"]; +} + + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td @@ -15,6 +15,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// // ArmNeon dialect definition @@ -117,4 +118,44 @@ "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; } +class ArmNeon_2dOp traits = []> + : Op; + +def Sdot2dOp : ArmNeon_2dOp<"sdot", [ + NoSideEffect, + AllTypesMatch<["b", "c"]>, + AllTypesMatch<["a", "res"]>, + TypesMatchWith<"res has the same number of elements as operand b", + "b", "res", + "VectorType::get({$_self.cast().getShape()[0]}," + "IntegerType::get($_self.getContext(), 32))">] + > { + let summary = "sdot op"; + let description = [{ + The two input vectors `b` and `c` have a 2D shape, consisting of either 2 + or 4 rows, each row having length 4. This operation computes the pair-wise + dot-products of the rows of `b` and `c` and accumulates them with the + corresponding entry of `a`: + + ``` + res[i] := a[i] + dot_product(b[i, ...], c[i, ...]) + ``` + + }]; + // Supports either: + // (vector<2xi32>, vector<2x4xi8>, vector<2x4xi8>) -> vector<2xi32> + // (vector<4xi32>, vector<4x4xi8>, vector<4x4xi8>) -> vector<4xi32> + // TODO: how do we express 2D shape requirements here? + let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a, + VectorOfLengthAndType<[16, 8], [I8]>:$b, + VectorOfLengthAndType<[16, 8], [I8]>:$c); + let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res); + let assemblyFormat = + "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; + let verifier = [{ return ::verify(*this); }]; + } + #endif // ARMNEON_OPS diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -0,0 +1,74 @@ +//===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" +#include "../PassDetail.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::arm_neon; + +namespace { + +class Sdot2dLoweringPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + /// Convert to 1-dimensional vector type to match the requirements of + /// arm.neon.intr.sdot + LogicalResult matchAndRewrite(Sdot2dOp op, + PatternRewriter &rewriter) const override { + auto elemType = op.b().getType().cast().getElementType(); + int length = op.b().getType().cast().getShape()[0] * 4; + auto flattenedVectorType = VectorType::get({length}, elemType); + auto b2d = op.b(); + auto c2d = op.c(); + auto loc = op.getLoc(); + auto b1d = + rewriter.create(loc, flattenedVectorType, b2d); + auto c1d = + rewriter.create(loc, flattenedVectorType, c2d); + Value newOp = rewriter.create(loc, op->getResult(0).getType(), + op.a(), b1d, c1d); + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + +class ConvertArmNeon2dToIntr + : public ConvertArmNeon2dToIntrBase { + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + populateConvertArmNeon2dToIntrPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { + +void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +std::unique_ptr> createConvertArmNeon2dToIntrPass() { + return std::make_unique(); +} + +} // namespace mlir diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt b/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRArmNeon2dToIntr + ArmNeon2dToIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeon2dToIntr + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArmNeon + MLIRPass + MLIRTransforms + MLIRIR + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(AffineToStandard) +add_subdirectory(ArmNeon2dToIntr) add_subdirectory(AsyncToLLVM) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToStandard) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -76,6 +76,10 @@ class VectorDialect; } // end namespace vector +namespace arm_neon { +class ArmNeonDialect; +} // end namespace arm_neon + #define GEN_PASS_CLASSES #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp --- a/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp +++ b/mlir/lib/Dialect/ArmNeon/IR/ArmNeonDialect.cpp @@ -25,5 +25,31 @@ >(); } +static LogicalResult verify(arm_neon::Sdot2dOp op) { + auto shapeA = op.a().getType().cast().getShape(); + auto shapeB = op.b().getType().cast().getShape(); + + if (shapeA.size() != 1) + return emitError(op.getLoc(), "Operand a should be a 1-dimensional vector"); + + if (shapeA[0] != 2 && shapeA[0] != 4) + return emitError(op.getLoc(), "Operand a should have length 2 or 4"); + + if (shapeB.size() != 2) + return emitError(op.getLoc(), "Operand b should be a 2-dimensional vector"); + + if (shapeB[1] != 4) + return emitError( + op.getLoc(), + "The inner size of the 2-dimensional operand b should be 4"); + + if (shapeB[0] != shapeA[0]) + return emitError(op.getLoc(), + "The outer size of the 2-dimensional operand b should " + "equal the size of 1-dimensional operand a"); + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/ArmNeon/ArmNeon.cpp.inc" diff --git a/mlir/test/Target/LLVMIR/arm-neon-2d.mlir b/mlir/test/Target/LLVMIR/arm-neon-2d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/arm-neon-2d.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -arm-neon-2d-to-intr %s | FileCheck %s + +// CHECK-LABEL: arm_neon_sdot2d_4x4_i8i8 +func @arm_neon_sdot2d_4x4_i8i8(%a: vector<4xi32>, %b: vector<4x4xi8>, %c: vector<4x4xi8>) -> vector<4xi32> { + // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<16xi8>, vector<16xi8> to vector<4xi32> + // CHECK-NEXT: return %{{.*}} : vector<4xi32> + %0 = arm_neon.2d.sdot %a, %b, %c : vector<4x4xi8>, vector<4x4xi8> to vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: arm_neon_sdot2d_2x4_i8i8 +func @arm_neon_sdot2d_2x4_i8i8(%a: vector<2xi32>, %b: vector<2x4xi8>, %c: vector<2x4xi8>) -> vector<2xi32> { + // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<8xi8>, vector<8xi8> to vector<2xi32> + // CHECK-NEXT: return %{{.*}} : vector<2xi32> + %0 = arm_neon.2d.sdot %a, %b, %c : vector<2x4xi8>, vector<2x4xi8> to vector<2xi32> + return %0 : vector<2xi32> +}