diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file defines the operation set for the TOSA dialect as defined in -// the TOSA specfication (https://developer.mlplatform.org/w/tosa/). +// the TOSA specfication (https://developer.mlplatform.org/w/tosa/). // //===----------------------------------------------------------------------===// @@ -58,7 +58,7 @@ //===----------------------------------------------------------------------===// def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Performs max pooling on the input."; @@ -118,8 +118,6 @@ let builders = [Tosa_ConvOpQuantInfoBuilder]; let verifier = [{ return verifyConvOp(*this); }]; - - let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -187,8 +185,6 @@ let builders = [Tosa_ConvOpQuantInfoBuilder]; let verifier = [{ return verifyConvOp(*this); }]; - - let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -326,9 +322,9 @@ let description = [{ Clamp to an arbitrary minimum and maximum value. - Maximum and minimum values are specified as values in the range of the + Maximum and minimum values are specified as values in the range of the input type. - No zero point subtraction is done to the values, thus to clamp to the zero + No zero point subtraction is done to the values, thus to clamp to the zero point value, the zero point itself should be supplied as the minimum value. }]; @@ -488,7 +484,7 @@ let description = [{ Elementwise bitwise AND of input1 and input2. Axis of size 1 - will be broadcast as necessary. + will be broadcast as necessary. }]; let arguments = (ins @@ -1379,7 +1375,7 @@ let summary = "Concatenates tensors along one dimension."; let description = [{ - Concatenate a variadic amount of tensors along a given axis. No data + Concatenate a variadic amount of tensors along a given axis. No data conversion happens during a concat operation. }]; @@ -1405,7 +1401,7 @@ let summary = "Pads a tensor with value specified."; let description = [{ - Pads a tensor along borders of each dimension with pad_value. + Pads a tensor along borders of each dimension with pad_value. }]; let arguments = (ins @@ -1510,7 +1506,7 @@ //===----------------------------------------------------------------------===// def Tosa_TileOp: Tosa_Op<"tile", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Tile operator"; @@ -1534,7 +1530,7 @@ //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_Op<"transpose", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Transpose operator"; @@ -1565,7 +1561,7 @@ //===----------------------------------------------------------------------===// def Tosa_GatherOp : Tosa_Op<"gather", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Gather operation,"; @@ -1697,7 +1693,7 @@ //===----------------------------------------------------------------------===// // Operator: rescale //===----------------------------------------------------------------------===// -def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, +def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Tosa rescale operator"; diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -22,6 +22,7 @@ std::unique_ptr createTosaDecomposeTransposeConvPass(); std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); +std::unique_ptr createTosaOptimizationPass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); #define GEN_PASS_REGISTRATION diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -1,4 +1,4 @@ -//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===// +//===-- Passes.td - TOSA pass declarations ----*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file declares the optimization passes for the TOSA Dialect in MLIR. +// This file declares the passes for the TOSA Dialect in MLIR. // //===----------------------------------------------------------------------===// @@ -58,4 +58,13 @@ let constructor = "createTosaMakeBroadcastablePass()"; } +def TosaOptimization : FunctionPass<"tosa-optimization"> { + let summary = "TOSA operation optimizations"; + let description = [{ + "Pass to perform optimizations on TOSA operations" + }]; + + let constructor = "createTosaOptimizationPass()"; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -423,195 +423,6 @@ results.insert(context); } -struct Conv2DFullyConnectedOptimization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::Conv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value weight = op.weight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getType().cast(); - - if (!inputType.hasStaticShape() || !weightType.hasRank()) { - return failure(); - } - - // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { - return failure(); - } - } - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[1] != 1 || weightShape[2] != 1) { - return failure(); - } - - // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. - ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. - llvm::SmallVector revisedWeightShape{weightShape[0], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); - - // Perform a fully connected network over the reshaped input and weight. - llvm::SmallVector fullyConnectedShape{ - inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; - auto fullyConnectedShapeType = RankedTensorType::get( - fullyConnectedShape, - resultType.dyn_cast().getElementType()); - - Value fullyConnectedValue; - if (op.quantization_info()) { - fullyConnectedValue = - rewriter - .create( - op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.bias(), op.quantization_info().getValue()) - .getResult(); - } else { - fullyConnectedValue = rewriter - .create( - op.getLoc(), fullyConnectedShapeType, - reshapedInput, reshapedWeight, op.bias()) - .getResult(); - } - - // Reshape output to [N, IH, IW, OC]. - llvm::SmallVector outputShape{inputShape[0], inputShape[1], - inputShape[2], weightShape[0]}; - rewriter.replaceOpWithNewOp( - op, resultType, fullyConnectedValue, - rewriter.getI64ArrayAttr(outputShape)); - return success(); - } -}; - -void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -struct DepthwiseConv2DMulOptimization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value weight = op.weight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.output().getType().cast(); - Type inputEType = inputType.getElementType(); - - if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && - resultType.hasStaticShape())) { - return failure(); - } - - // Quantization information needs to still be performed. - if (op.quantization_info() || !inputEType.isa()) { - return failure(); - } - - // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { - return failure(); - } - } - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[0] != 1 || weightShape[1] != 1) { - return failure(); - } - - // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. - ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. - llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); - - // Perform an elementwise mul over the reshaped input and weight. - llvm::SmallVector mulShape{inputShape[0], inputShape[1], - inputShape[2], inputShape[3], - weightShape[3]}; - auto mulShapeType = RankedTensorType::get( - mulShape, - weight.getType().dyn_cast().getElementType()); - Value mulValue = - rewriter - .create(op.getLoc(), mulShapeType, reshapedInput, - reshapedWeight, /*shift=*/0) - .getResult(); - - // Reshape output to [N, H, W, C * M]. - auto outputShape = op.output().getType().cast().getShape(); - auto outputShapeType = RankedTensorType::get( - outputShape, - input.getType().dyn_cast().getElementType()); - auto outputValue = - rewriter.create(op.getLoc(), outputShapeType, mulValue, - rewriter.getI64ArrayAttr(outputShape)); - - // Add in the bias. - rewriter - .replaceOpWithNewOp(op, outputShapeType, outputValue, - op.bias()) - .getResult(); - return success(); - } -}; - -void DepthwiseConv2DOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// @@ -710,7 +521,8 @@ // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -template static LogicalResult verifyConvOp(T op) { +template +static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = op.input().getType().template dyn_cast(); auto weightType = op.weight().getType().template dyn_cast(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ TosaDecomposeTransposeConv.cpp TosaInferShapes.cpp TosaMakeBroadcastable.cpp + TosaOptimization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp @@ -0,0 +1,243 @@ +//===- TosaOptimization.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pass to perform optimizations on TOSA operations +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; +using namespace mlir::tosa; + +#define PASS_NAME "tosa-optimization" +#define DEBUG_TYPE PASS_NAME + +namespace { + +struct Conv2DIsFullyConnected : public OpRewritePattern { + explicit Conv2DIsFullyConnected(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(tosa::Conv2DOp op, + PatternRewriter &rewriter) const { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.getType().cast(); + + if (!inputType.hasStaticShape() || !weightType.hasRank()) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[1] != 1 || weightShape[2] != 1) { + return failure(); + } + + // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. + llvm::SmallVector revisedWeightShape{weightShape[0], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform a fully connected network over the reshaped input and weight. + llvm::SmallVector fullyConnectedShape{ + inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; + auto fullyConnectedShapeType = RankedTensorType::get( + fullyConnectedShape, + resultType.dyn_cast().getElementType()); + + Value fullyConnectedValue; + if (op.quantization_info()) { + fullyConnectedValue = + rewriter + .create( + op.getLoc(), fullyConnectedShapeType, reshapedInput, + reshapedWeight, op.bias(), op.quantization_info().getValue()) + .getResult(); + } else { + fullyConnectedValue = rewriter + .create( + op.getLoc(), fullyConnectedShapeType, + reshapedInput, reshapedWeight, op.bias()) + .getResult(); + } + + // Reshape output to [N, IH, IW, OC]. + llvm::SmallVector outputShape{inputShape[0], inputShape[1], + inputShape[2], weightShape[0]}; + rewriter.replaceOpWithNewOp( + op, resultType, fullyConnectedValue, + rewriter.getI64ArrayAttr(outputShape)); + return success(); + } +}; + +struct DepthwiseConv2DIsMul : public OpRewritePattern { + explicit DepthwiseConv2DIsMul(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, + PatternRewriter &rewriter) const { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.output().getType().cast(); + Type inputEType = inputType.getElementType(); + + if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && + resultType.hasStaticShape())) { + return failure(); + } + + // Quantization information needs to still be performed. + if (op.quantization_info() || !inputEType.isa()) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[0] != 1 || weightShape[1] != 1) { + return failure(); + } + + // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. + llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform an elementwise mul over the reshaped input and weight. + llvm::SmallVector mulShape{inputShape[0], inputShape[1], + inputShape[2], inputShape[3], + weightShape[3]}; + auto mulShapeType = RankedTensorType::get( + mulShape, + weight.getType().dyn_cast().getElementType()); + Value mulValue = + rewriter + .create(op.getLoc(), mulShapeType, reshapedInput, + reshapedWeight, /*shift=*/0) + .getResult(); + + // Reshape output to [N, H, W, C * M]. + auto outputShape = op.output().getType().cast().getShape(); + auto outputShapeType = RankedTensorType::get( + outputShape, + input.getType().dyn_cast().getElementType()); + auto outputValue = + rewriter.create(op.getLoc(), outputShapeType, mulValue, + rewriter.getI64ArrayAttr(outputShape)); + + // Add in the bias. + rewriter + .replaceOpWithNewOp(op, outputShapeType, outputValue, + op.bias()) + .getResult(); + return success(); + } +}; + +class TosaOptimization : public PassWrapper { +public: + explicit TosaOptimization() {} + void runOnFunction() override; + + StringRef getArgument() const final { return PASS_NAME; } + StringRef getDescription() const final { + return "Applies TOSA Operation Optimizations"; + } +}; + +void TosaOptimization::runOnFunction() { + OwningRewritePatternList patterns(&getContext()); + + patterns.insert(&getContext()); + patterns.insert(&getContext()); + + auto func = getFunction(); + if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr mlir::tosa::createTosaOptimizationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -68,43 +68,6 @@ // ----- -// CHECK-LABEL: @conv2d_as_fully_connected -func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { - // CHECK-NOT: "tosa.conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} - // CHECK-SAME: -> tensor<400x2xf32> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} - // CHECK-SAME: -> tensor<3x2xf32> - // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK-SAME: -> tensor<400x3xf32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} - // CHECK-SAME: -> tensor<4x10x10x3xf32> - // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> - return %0 : tensor<4x10x10x3xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected_quant -func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { - // CHECK-NOT: "tosa.conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} - // CHECK-SAME: -> tensor<400x2xi8> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} - // CHECK-SAME: -> tensor<3x2xi8> - // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32} - // CHECK-SAME: -> tensor<400x3xi32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} - // CHECK-SAME: -> tensor<4x10x10x3xi32> - // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> - return %0 : tensor<4x10x10x3xi32> -} - -// ----- - // CHECK-LABEL: @conv2d_stride_2 func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> { // CHECK: "tosa.conv2d" @@ -127,35 +90,6 @@ // ----- -// CHECK-LABEL: @depthwise_conv2d_as_mul -func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { - // CHECK-NOT: "tosa.depthwise_conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} - // CHECK-SAME: -> tensor<4x10x10x2x1xf32> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} - // CHECK-SAME: -> tensor<1x1x1x2x3xf32> - // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) - // CHECK-SAME: -> tensor<4x10x10x2x3xf32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} - // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) - // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: return %[[VAR4]] - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> - return %0 : tensor<4x10x10x6xf32> -} - -// ----- - -// CHECK-LABEL: @depthwise_conv2d_as_mul_q -func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { - // CHECK: "tosa.depthwise_conv2d" - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> - return %0 : tensor<4x10x10x6xi32> -} - -// ----- - // CHECK-LABEL: @depthwise_conv2d_stride_2 func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK: "tosa.depthwise_conv2d" @@ -172,7 +106,7 @@ return %0 : tensor<4x10x10x6xf32> } -// ---- +// ----- // CHECK-LABEL: @pad_noop func @pad_noop(%arg0: tensor) -> tensor { @@ -182,7 +116,7 @@ return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_i32 func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { @@ -193,7 +127,7 @@ return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_f32 func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { @@ -204,7 +138,7 @@ return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_quant func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { diff --git a/mlir/test/Dialect/Tosa/operation_optimization.mlir b/mlir/test/Dialect/Tosa/operation_optimization.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/operation_optimization.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected +func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { + // CHECK-NOT: "tosa.conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} + // CHECK-SAME: -> tensor<400x2xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-SAME: -> tensor<3x2xf32> + // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) + // CHECK-SAME: -> tensor<400x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} + // CHECK-SAME: -> tensor<4x10x10x3xf32> + // CHECK: return %[[VAR3]] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + return %0 : tensor<4x10x10x3xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected_quant +func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { + // CHECK-NOT: "tosa.conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} + // CHECK-SAME: -> tensor<400x2xi8> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-SAME: -> tensor<3x2xi8> + // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) + // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32} + // CHECK-SAME: -> tensor<400x3xi32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} + // CHECK-SAME: -> tensor<4x10x10x3xi32> + // CHECK: return %[[VAR3]] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + return %0 : tensor<4x10x10x3xi32> +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul +func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { + // CHECK-NOT: "tosa.depthwise_conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} + // CHECK-SAME: -> tensor<4x10x10x2x1xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} + // CHECK-SAME: -> tensor<1x1x1x2x3xf32> + // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) + // CHECK-SAME: -> tensor<4x10x10x2x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: return %[[VAR4]] + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + return %0 : tensor<4x10x10x6xf32> +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul_q +func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { + // CHECK: "tosa.depthwise_conv2d" + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + return %0 : tensor<4x10x10x6xi32> +} + +// -----