diff --git a/mlir/include/mlir/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.h b/mlir/include/mlir/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.h @@ -0,0 +1,18 @@ +#ifndef MLIR_CONVERSION_ARMNEONSTRUCTUREDTOINTR_ARMNEONSTRUCTUREDTOINTR_H_ +#define MLIR_CONVERSION_ARMNEONSTRUCTUREDTOINTR_ARMNEONSTRUCTUREDTOINTR_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +class FuncOp; +template +class OperationPass; + +void populateConvertArmNeonStructuredToIntrPatterns( + RewritePatternSet &patterns); +std::unique_ptr> +createConvertArmNeonStructuredToIntrPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMNEONSTRUCTUREDTOINTR_ARMNEONSTRUCTUREDTOINTR_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -10,6 +10,7 @@ #define MLIR_CONVERSION_PASSES_H #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -592,4 +592,15 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// ArmNeonStructuredToIntr +//===----------------------------------------------------------------------===// + +def ConvertArmNeonStructuredToIntr : Pass<"arm-neon-structured-to-intr", "FuncOp"> { + let summary = "Convert Arm NEON structured ops to intrinsics"; + let constructor = "mlir::createConvertArmNeonStructuredToIntrPass()"; + let dependentDialects = ["arm_neon::ArmNeonDialect", "vector::VectorDialect"]; +} + + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeon.td @@ -15,6 +15,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" //===----------------------------------------------------------------------===// // ArmNeon dialect definition @@ -117,4 +118,45 @@ "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; } +class ArmNeon_StructuredOp traits = []> + : Op; + +def StructuredSdotOp : ArmNeon_StructuredOp<"sdot", [ + NoSideEffect, + AllTypesMatch<["b", "c"]>, + AllTypesMatch<["a", "res"]>, + TypesMatchWith<"res has the same number of elements as operand b", + "b", "res", + "VectorType::get({$_self.cast().getShape()[0]}," + "IntegerType::get($_self.getContext(), 32))">] + // TODO: how do we express the 2d input shape requirement here, and the + // requirement that the inner dimension of b and c is 4? + > { + let summary = "sdot op"; + let description = [{ + The two input vectors `b` and `c` have a 2D shape, consisting of either 2 + or 4 rows, each row having length 4. This operation computes the pair-wise + dot-products of the rows of `b` and `c` and accumulates them with the + corresponding entry of `a`: + + ``` + res[i] := a[i] + dot_product(b[i, ...], c[i, ...]) + ``` + + }]; + // Supports either: + // (vector<2xi32>, vector<2x4xi8>, vector<2x4xi8>) -> vector<2xi32> + // (vector<4xi32>, vector<4x4xi8>, vector<4x4xi8>) -> vector<4xi32> + // TODO: how do we express 2D shape requirements here? + let arguments = (ins VectorOfLengthAndType<[4, 2], [I32]>:$a, + VectorOfLengthAndType<[16, 8], [I8]>:$b, + VectorOfLengthAndType<[16, 8], [I8]>:$c); + let results = (outs VectorOfLengthAndType<[4, 2], [I32]>:$res); + let assemblyFormat = + "$a `,` $b `,` $c attr-dict `:` type($b) `,` type($c) `to` type($res)"; + } + #endif // ARMNEON_OPS diff --git a/mlir/lib/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.cpp b/mlir/lib/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.cpp @@ -0,0 +1,79 @@ +//===- ArmNeonStructuredToIntr.cpp - conversion from Arm Neon structured ops to +//intrinsics --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmNeonStructuredToIntr/ArmNeonStructuredToIntr.h" +#include "../PassDetail.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::arm_neon; + +namespace { + +class StructuredSdotLoweringPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + /// Converts the type of the result to an LLVM type, pass operands as is, + /// preserve attributes. + LogicalResult matchAndRewrite(StructuredSdotOp op, + PatternRewriter &rewriter) const override { + auto elemType = op.b().getType().cast().getElementType(); + int length = op.b().getType().cast().getShape()[0] * 4; + auto flattenedVectorType = VectorType::get({length}, elemType); + auto structured_b = op.b(); + auto structured_c = op.c(); + auto loc = op.getLoc(); + auto flat_b = rewriter.create(loc, flattenedVectorType, + structured_b); + auto flat_c = rewriter.create(loc, flattenedVectorType, + structured_c); + Value newOp = rewriter.create(loc, op->getResult(0).getType(), + op.a(), flat_b, flat_c); + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + +class ConvertArmNeonStructuredToIntr + : public ConvertArmNeonStructuredToIntrBase< + ConvertArmNeonStructuredToIntr> { + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + populateConvertArmNeonStructuredToIntrPatterns(patterns); + + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { + +void populateConvertArmNeonStructuredToIntrPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +std::unique_ptr> +createConvertArmNeonStructuredToIntrPass() { + return std::make_unique(); +} + +} // namespace mlir diff --git a/mlir/lib/Conversion/ArmNeonStructuredToIntr/CMakeLists.txt b/mlir/lib/Conversion/ArmNeonStructuredToIntr/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmNeonStructuredToIntr/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(MLIRArmNeonStructuredToIntr + ArmNeonStructuredToIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmNeonStructuredToIntr + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArmNeon + MLIRPass + MLIRTransforms + MLIRIR + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(AffineToStandard) +add_subdirectory(ArmNeonStructuredToIntr) add_subdirectory(AsyncToLLVM) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToStandard) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -76,6 +76,10 @@ class VectorDialect; } // end namespace vector +namespace arm_neon { +class ArmNeonDialect; +} // end namespace arm_neon + #define GEN_PASS_CLASSES #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/test/Target/LLVMIR/arm-neon-structured.mlir b/mlir/test/Target/LLVMIR/arm-neon-structured.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/arm-neon-structured.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -arm-neon-structured-to-intr %s | FileCheck %s + +// CHECK-LABEL: arm_neon_structured_sdot_4x4_i8i8 +func @arm_neon_structured_sdot_4x4_i8i8(%a: vector<4xi32>, %b: vector<4x4xi8>, %c: vector<4x4xi8>) -> vector<4xi32> { + // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<16xi8>, vector<16xi8> to vector<4xi32> + // CHECK-NEXT: return %{{.*}} : vector<4xi32> + %0 = arm_neon.structured.sdot %a, %b, %c : vector<4x4xi8>, vector<4x4xi8> to vector<4xi32> + return %0 : vector<4xi32> +} + +// CHECK-LABEL: arm_neon_structured_sdot_2x4_i8i8 +func @arm_neon_structured_sdot_2x4_i8i8(%a: vector<2xi32>, %b: vector<2x4xi8>, %c: vector<2x4xi8>) -> vector<2xi32> { + // CHECK: arm_neon.intr.sdot %{{.*}}, %{{.*}}, %{{.*}} : vector<8xi8>, vector<8xi8> to vector<2xi32> + // CHECK-NEXT: return %{{.*}} : vector<2xi32> + %0 = arm_neon.structured.sdot %a, %b, %c : vector<2x4xi8>, vector<2x4xi8> to vector<2xi32> + return %0 : vector<2xi32> +}