diff --git a/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h b/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h @@ -0,0 +1,30 @@ +//===- ArmNeon2dToIntr.h - convert Arm Neon 2d ops to intrinsics ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ +#define MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +class FuncOp; +template +class OperationPass; + +/// Populates patterns for the lowering of Arm NEON 2D ops to intrinsics. +/// See createConvertArmNeon2dToIntrPass. +void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns); + +/// Creates a pass to lower Arm NEON 2D ops to intrinsics, i.e. +/// equivalent ops operating on flattened 1D vectors and mapping more +/// directly to the corresponding Arm NEON instruction. +std::unique_ptr> createConvertArmNeon2dToIntrPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMNEON2DTOINTR_ARMNEON2DTOINTR_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -10,6 +10,7 @@ #define MLIR_CONVERSION_PASSES_H #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -607,4 +607,15 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// ArmNeon2dToIntr +//===----------------------------------------------------------------------===// + +def ConvertArmNeon2dToIntr : Pass<"arm-neon-2d-to-intr", "FuncOp"> { + let summary = "Convert Arm NEON structured ops to intrinsics"; + let constructor = "mlir::createConvertArmNeon2dToIntrPass()"; + let dependentDialects = ["arm_neon::ArmNeonDialect", "vector::VectorDialect"]; +} + + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td @@ -15,6 +15,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// // ArmNeon dialect definition @@ -117,4 +118,56 @@ "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; } +class ArmNeon_2dOp traits = []> + : Op; + +def Sdot2dOp : ArmNeon_2dOp<"sdot", [ + NoSideEffect, + AllTypesMatch<["b", "c"]>, + AllTypesMatch<["a", "res"]>, + PredOpTrait< + "operand `a` should be 1-dimensional", + CPred<"a().getType().cast().getShape().size() == 1"> + >, + PredOpTrait< + "operand `b` should be 2-dimensional", + CPred<"b().getType().cast().getShape().size() == 2"> + >, + PredOpTrait< + "operand `b` should have 4 columns", + CPred<"b().getType().cast().getShape()[1] == 4"> + >, + PredOpTrait< + "operand `b` should have as many rows as the size of operand `a`", + CPred<"b().getType().cast().getShape()[0] == a().getType().cast().getShape()[0]"> + >, + ] + > { + let summary = "sdot op"; + let description = [{ + The two input vectors `b` and `c` have a 2D shape, consisting of either 2 + or 4 rows, each row having length 4. This operation computes the pair-wise + dot-products of the rows of `b` and `c` and accumulates them with the + corresponding entry of `a`: + + ``` + res[i] := a[i] + dot_product(b[i, ...], c[i, ...]) + ``` + + }]; + // Supports either: + // (vector<2xi32>, vector<2x4xi8>, vector<2x4xi8>) -> vector<2xi32> + // (vector<4xi32>, vector<4x4xi8>, vector<4x4xi8>) -> vector<4xi32> + // TODO: how do we express 2D shape requirements here? + let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a, + VectorOfLengthAndType<[16, 8], [I8]>:$b, + VectorOfLengthAndType<[16, 8], [I8]>:$c); + let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res); + let assemblyFormat = + "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; + } + #endif // ARMNEON_OPS diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -0,0 +1,74 @@ +//===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" +#include "../PassDetail.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::arm_neon; + +namespace { + +class Sdot2dLoweringPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + /// Convert to 1-dimensional vector type to match the requirements of + /// arm.neon.intr.sdot + LogicalResult matchAndRewrite(Sdot2dOp op, + PatternRewriter &rewriter) const override { + auto elemType = op.b().getType().cast().getElementType(); + int length = op.b().getType().cast().getShape()[0] * 4; + auto flattenedVectorType = VectorType::get({length}, elemType); + auto b2d = op.b(); + auto c2d = op.c(); + auto loc = op.getLoc(); + auto b1d = + rewriter.create(loc, flattenedVectorType, b2d); + auto c1d = + rewriter.create(loc, flattenedVectorType, c2d); + Value newOp = rewriter.create(loc, op->getResult(0).getType(), + op.a(), b1d, c1d); + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + +class ConvertArmNeon2dToIntr + : public ConvertArmNeon2dToIntrBase { + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + populateConvertArmNeon2dToIntrPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { + +void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +std::unique_ptr> createConvertArmNeon2dToIntrPass() { + return std::make_unique(); +} + +} // namespace mlir diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt b/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRArmNeon2dToIntr + ArmNeon2dToIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeon2dToIntr + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArmNeon + MLIRPass + MLIRTransforms + MLIRIR + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(AffineToStandard) +add_subdirectory(ArmNeon2dToIntr) add_subdirectory(AsyncToLLVM) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToStandard) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -80,6 +80,10 @@ class VectorDialect; } // end namespace vector +namespace arm_neon { +class ArmNeonDialect; +} // end namespace arm_neon + #define GEN_PASS_CLASSES #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/test/Dialect/ArmNeon/invalid.mlir b/mlir/test/Dialect/ArmNeon/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmNeon/invalid.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func @a_is_2d(%a : vector<2x2xi32>, %b : vector<4x4xi8>) -> vector<2x2xi32> { + // expected-error@+1 {{operand `a` should be 1-dimensional}} + %0 = arm_neon.2d.sdot %a, %b, %b : vector<4x4xi8>, vector<4x4xi8> to vector<2x2xi32> + return %0 : vector<2x2xi32> +} + +// ----- + +func @b_is_3d(%a : vector<4xi32>, %b : vector<1x4x4xi8>) -> vector<4xi32> { + // expected-error@+1 {{operand `b` should be 2-dimensional}} + %0 = arm_neon.2d.sdot %a, %b, %b : vector<1x4x4xi8>, vector<1x4x4xi8> to vector<4xi32> + return %0 : vector<4xi32> +} + +// ----- + +func @b_has_2_columns(%a : vector<4xi32>, %b : vector<4x2xi8>) -> vector<4xi32> { + // expected-error@+1 {{operand `b` should have 4 columns}} + %0 = arm_neon.2d.sdot %a, %b, %b : vector<4x2xi8>, vector<4x2xi8> to vector<4xi32> + return %0 : vector<4xi32> +} + +// ----- + +func @b_has_2_rows_but_a_has_length_4(%a : vector<4xi32>, %b : vector<2x4xi8>) -> vector<4xi32> { + // expected-error@+1 {{operand `b` should have as many rows as the size of operand `a`}} + %0 = arm_neon.2d.sdot %a, %b, %b : vector<2x4xi8>, vector<2x4xi8> to vector<4xi32> + return %0 : vector<4xi32> +} diff --git a/mlir/test/Target/LLVMIR/arm-neon-2d.mlir b/mlir/test/Target/LLVMIR/arm-neon-2d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/arm-neon-2d.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -arm-neon-2d-to-intr %s | FileCheck %s + +// CHECK-LABEL: arm_neon_sdot2d_4x4_i8i8 +func @arm_neon_sdot2d_4x4_i8i8(%a: vector<4xi32>, %b: vector<4x4xi8>, %c: vector<4x4xi8>) -> vector<4xi32> { + // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<16xi8>, vector<16xi8> to vector<4xi32> + // CHECK-NEXT: return %{{.*}} : vector<4xi32> + %0 = arm_neon.2d.sdot %a, %b, %c : vector<4x4xi8>, vector<4x4xi8> to vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: arm_neon_sdot2d_2x4_i8i8 +func @arm_neon_sdot2d_2x4_i8i8(%a: vector<2xi32>, %b: vector<2x4xi8>, %c: vector<2x4xi8>) -> vector<2xi32> { + // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<8xi8>, vector<8xi8> to vector<2xi32> + // CHECK-NEXT: return %{{.*}} : vector<2xi32> + %0 = arm_neon.2d.sdot %a, %b, %c : vector<2x4xi8>, vector<2x4xi8> to vector<2xi32> + return %0 : vector<2xi32> +}