diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -645,6 +645,20 @@ let constructor = "tosa::createTosaToLinalg()"; } +//===----------------------------------------------------------------------===// +// TosaDecomposeRewrites +//===----------------------------------------------------------------------===// + +def TosaDecomposeRewrites : FunctionPass<"tosa-decompose-rewrites"> { + let summary = "Applies Tosa operations decomposition rewrites"; + let description = [{ + Pass to apply the Tosa operations decomposition rewrites + exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h + }]; + + let constructor = "tosa::createTosaDecomposeRewrites()"; +} + //===----------------------------------------------------------------------===// // TosaToLinalgNamed //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -21,6 +21,7 @@ std::unique_ptr createTosaToLinalg(); std::unique_ptr createTosaToLinalgNamed(); +std::unique_ptr createTosaDecomposeRewrites(); /// Populates passes to convert from TOSA to Linalg on buffers. At the end of /// the pass, the function will only contain linalg ops or standard ops if the diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -19,10 +19,16 @@ namespace mlir { namespace tosa { -std::unique_ptr createTosaDecomposeTransposeConvPass(); +// Expose Rewrite Functions that decompose TOSA Ops into further TOSA Ops. +// The rewrites can be selectively added to a conversion pass. +void populateTosaDecomposeConv2D(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaDecomposeTransposeConv(MLIRContext *ctx, + RewritePatternSet &patterns); +void populateTosaDecomposeDepthwise(MLIRContext *ctx, + RewritePatternSet &patterns); + std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); -std::unique_ptr createTosaOptimizationPass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); #define GEN_PASS_REGISTRATION diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -15,21 +15,6 @@ include "mlir/Pass/PassBase.td" -def TosaDecomposeTransposeConv : FunctionPass<"tosa-decompose-transpose-conv"> { - let summary = "Deompose transpose convolutiions into standard convolutions."; - let description = [{ - Pass that uses shape manipulation and convolution operations to transform - a transpose convolution into a regular convolution. - }]; - - let constructor = "createTosaDecomposeTransposeConvPass()"; - let dependentDialects = [ - "StandardOpsDialect", - "tensor::TensorDialect", - "tosa::TosaDialect", - ]; -} - def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> { let summary = "Propagate shapes across TOSA operations"; let description = [{ @@ -58,13 +43,4 @@ let constructor = "createTosaMakeBroadcastablePass()"; } -def TosaOptimization : FunctionPass<"tosa-optimization"> { - let summary = "TOSA operation optimizations"; - let description = [{ - "Pass to perform optimizations on TOSA operations" - }]; - - let constructor = "createTosaOptimizationPass()"; -} - #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -61,13 +61,35 @@ signalPassFailure(); } }; + +struct TosaDecomposeRewrites + : public TosaDecomposeRewritesBase { + void runOnFunction() { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + auto func = getFunction(); + + mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns); + mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns); + mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns); + + if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) + signalPassFailure(); + } +}; + } // namespace std::unique_ptr mlir::tosa::createTosaToLinalg() { return std::make_unique(); } +std::unique_ptr mlir::tosa::createTosaDecomposeRewrites() { + return std::make_unique(); +} + void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaDecomposeRewrites()); pm.addNestedPass(createTosaMakeBroadcastablePass()); pm.addNestedPass(createTosaToLinalgNamed()); pm.addNestedPass(mlir::createCanonicalizerPass()); diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,8 +1,9 @@ add_mlir_dialect_library(MLIRTosaTransforms TosaDecomposeTransposeConv.cpp + TosaDecomposeConv2D.cpp + TosaDecomposeDepthwise.cpp TosaInferShapes.cpp TosaMakeBroadcastable.cpp - TosaOptimization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -0,0 +1,115 @@ +//===- TosaDecomposeConv2D.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Decompose TOSA Conv2D operation to a series of TOSA Ops specifically +// (1) Convert a 1x1 Convolution to a Reshape->FC->Reshape +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct Conv2DIsFullyConnected : public OpRewritePattern { + explicit Conv2DIsFullyConnected(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(tosa::Conv2DOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.getType().cast(); + + if (!inputType.hasStaticShape() || !weightType.hasRank()) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[1] != 1 || weightShape[2] != 1) { + return failure(); + } + + // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. + llvm::SmallVector revisedWeightShape{weightShape[0], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform a fully connected network over the reshaped input and weight. + llvm::SmallVector fullyConnectedShape{ + inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; + auto fullyConnectedShapeType = RankedTensorType::get( + fullyConnectedShape, + resultType.dyn_cast().getElementType()); + + Value fullyConnectedValue; + if (op.quantization_info()) { + fullyConnectedValue = + rewriter + .create( + op.getLoc(), fullyConnectedShapeType, reshapedInput, + reshapedWeight, op.bias(), op.quantization_info().getValue()) + .getResult(); + } else { + fullyConnectedValue = rewriter + .create( + op.getLoc(), fullyConnectedShapeType, + reshapedInput, reshapedWeight, op.bias()) + .getResult(); + } + + // Reshape output to [N, IH, IW, OC]. + llvm::SmallVector outputShape{inputShape[0], inputShape[1], + inputShape[2], weightShape[0]}; + rewriter.replaceOpWithNewOp( + op, resultType, fullyConnectedValue, + rewriter.getI64ArrayAttr(outputShape)); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaDecomposeConv2D(MLIRContext *ctx, + RewritePatternSet &patterns) { + patterns.insert(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -0,0 +1,121 @@ +//===- TosaDecomposeDepthwise.cpp +//------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Decompose TOSA Depthwise operation to a series of TOSA Ops specifically +// (1) Convert a 1x1 Depthwise to Reshape -> Mul -> Reshape -> Add +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct DepthwiseConv2DIsMul : public OpRewritePattern { + explicit DepthwiseConv2DIsMul(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.output().getType().cast(); + Type inputEType = inputType.getElementType(); + + if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && + resultType.hasStaticShape())) { + return failure(); + } + + // Quantization information needs to still be performed. + if (op.quantization_info() || !inputEType.isa()) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[0] != 1 || weightShape[1] != 1) { + return failure(); + } + + // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. + llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform an elementwise mul over the reshaped input and weight. + llvm::SmallVector mulShape{inputShape[0], inputShape[1], + inputShape[2], inputShape[3], + weightShape[3]}; + auto mulShapeType = RankedTensorType::get( + mulShape, + weight.getType().dyn_cast().getElementType()); + Value mulValue = + rewriter + .create(op.getLoc(), mulShapeType, reshapedInput, + reshapedWeight, /*shift=*/0) + .getResult(); + + // Reshape output to [N, H, W, C * M]. + auto outputShape = op.output().getType().cast().getShape(); + auto outputShapeType = RankedTensorType::get( + outputShape, + input.getType().dyn_cast().getElementType()); + auto outputValue = + rewriter.create(op.getLoc(), outputShapeType, mulValue, + rewriter.getI64ArrayAttr(outputShape)); + + // Add in the bias. + rewriter + .replaceOpWithNewOp(op, outputShapeType, outputValue, + op.bias()) + .getResult(); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaDecomposeDepthwise(MLIRContext *ctx, + RewritePatternSet &patterns) { + patterns.insert(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -7,17 +7,19 @@ // //===----------------------------------------------------------------------===// // -// Insert reshape to binary op's input if needed to match rank +// Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically +// (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping +// etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D +// including transposing/reversing/reshaping etc.. +// of the weights and input/output tenors and reversing/reshaping etc .. of +// the weights // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tosa/IR//TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::tosa; @@ -369,22 +371,10 @@ } }; -/// Pass that enables broadcast by making all input arrays have the same -/// number of dimensions. Insert RESHAPE operations to lower rank operand -struct TosaDecomposeTransposeConv - : public TosaDecomposeTransposeConvBase { -public: - void runOnFunction() override { - auto func = getFunction(); - RewritePatternSet patterns(func.getContext()); - patterns - .insert( - func.getContext()); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - } -}; } // namespace -std::unique_ptr mlir::tosa::createTosaDecomposeTransposeConvPass() { - return std::make_unique(); +void mlir::tosa::populateTosaDecomposeTransposeConv( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.insert(ctx); + patterns.insert(ctx); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptimization.cpp +++ /dev/null @@ -1,243 +0,0 @@ -//===- TosaOptimization.cpp ------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Pass to perform optimizations on TOSA operations -// -//===----------------------------------------------------------------------===// - -#include "mlir/Analysis/DataFlowAnalysis.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/FormatVariadic.h" - -using namespace mlir; -using namespace mlir::tosa; - -#define PASS_NAME "tosa-optimization" -#define DEBUG_TYPE PASS_NAME - -namespace { - -struct Conv2DIsFullyConnected : public OpRewritePattern { - explicit Conv2DIsFullyConnected(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(tosa::Conv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value weight = op.weight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.getType().cast(); - - if (!inputType.hasStaticShape() || !weightType.hasRank()) { - return failure(); - } - - // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { - return failure(); - } - } - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[1] != 1 || weightShape[2] != 1) { - return failure(); - } - - // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. - ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. - llvm::SmallVector revisedWeightShape{weightShape[0], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); - - // Perform a fully connected network over the reshaped input and weight. - llvm::SmallVector fullyConnectedShape{ - inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; - auto fullyConnectedShapeType = RankedTensorType::get( - fullyConnectedShape, - resultType.dyn_cast().getElementType()); - - Value fullyConnectedValue; - if (op.quantization_info()) { - fullyConnectedValue = - rewriter - .create( - op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.bias(), op.quantization_info().getValue()) - .getResult(); - } else { - fullyConnectedValue = rewriter - .create( - op.getLoc(), fullyConnectedShapeType, - reshapedInput, reshapedWeight, op.bias()) - .getResult(); - } - - // Reshape output to [N, IH, IW, OC]. - llvm::SmallVector outputShape{inputShape[0], inputShape[1], - inputShape[2], weightShape[0]}; - rewriter.replaceOpWithNewOp( - op, resultType, fullyConnectedValue, - rewriter.getI64ArrayAttr(outputShape)); - return success(); - } -}; - -struct DepthwiseConv2DIsMul : public OpRewritePattern { - explicit DepthwiseConv2DIsMul(MLIRContext *context) - : OpRewritePattern(context) {} - - LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op, - PatternRewriter &rewriter) const override { - Value input = op.input(); - Value weight = op.weight(); - ShapedType inputType = input.getType().cast(); - ShapedType weightType = weight.getType().cast(); - ShapedType resultType = op.output().getType().cast(); - Type inputEType = inputType.getElementType(); - - if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && - resultType.hasStaticShape())) { - return failure(); - } - - // Quantization information needs to still be performed. - if (op.quantization_info() || !inputEType.isa()) { - return failure(); - } - - // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { - return failure(); - } - } - - // Only works for a 1x1 kernel. - ArrayRef weightShape = weightType.getShape(); - if (weightShape[0] != 1 || weightShape[1] != 1) { - return failure(); - } - - // Reshape input to [N, H, W, C] -> [N, H, W, C, 1]. - ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0], inputShape[1], inputShape[2], inputShape[3], 1}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); - auto reshapedInput = rewriter - .create( - op.getLoc(), revisedInputShapeType, input, - rewriter.getI64ArrayAttr(revisedInputShape)) - .getResult(); - - // Reshape kernel to [KH, KW, C, M] -> [1, 1, 1, C, M]. - llvm::SmallVector revisedWeightShape{1, 1, 1, weightShape[2], - weightShape[3]}; - auto revisedWeightShapeType = RankedTensorType::get( - revisedWeightShape, - weight.getType().dyn_cast().getElementType()); - auto reshapedWeight = rewriter - .create( - op.getLoc(), revisedWeightShapeType, weight, - rewriter.getI64ArrayAttr(revisedWeightShape)) - .getResult(); - - // Perform an elementwise mul over the reshaped input and weight. - llvm::SmallVector mulShape{inputShape[0], inputShape[1], - inputShape[2], inputShape[3], - weightShape[3]}; - auto mulShapeType = RankedTensorType::get( - mulShape, - weight.getType().dyn_cast().getElementType()); - Value mulValue = - rewriter - .create(op.getLoc(), mulShapeType, reshapedInput, - reshapedWeight, /*shift=*/0) - .getResult(); - - // Reshape output to [N, H, W, C * M]. - auto outputShape = op.output().getType().cast().getShape(); - auto outputShapeType = RankedTensorType::get( - outputShape, - input.getType().dyn_cast().getElementType()); - auto outputValue = - rewriter.create(op.getLoc(), outputShapeType, mulValue, - rewriter.getI64ArrayAttr(outputShape)); - - // Add in the bias. - rewriter - .replaceOpWithNewOp(op, outputShapeType, outputValue, - op.bias()) - .getResult(); - return success(); - } -}; - -class TosaOptimization : public PassWrapper { -public: - explicit TosaOptimization() = default; - void runOnFunction() override; - - StringRef getArgument() const final { return PASS_NAME; } - StringRef getDescription() const final { - return "Applies TOSA Operation Optimizations"; - } -}; - -void TosaOptimization::runOnFunction() { - OwningRewritePatternList patterns(&getContext()); - - patterns.insert(&getContext()); - patterns.insert(&getContext()); - - auto func = getFunction(); - if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) { - signalPassFailure(); - } -} - -} // namespace - -std::unique_ptr mlir::tosa::createTosaOptimizationPass() { - return std::make_unique(); -} diff --git a/mlir/test/Dialect/Tosa/operation_optimization.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir rename from mlir/test/Dialect/Tosa/operation_optimization.mlir rename to mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir --- a/mlir/test/Dialect/Tosa/operation_optimization.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -1,69 +1,40 @@ -// RUN: mlir-opt --split-input-file --tosa-optimization %s | FileCheck %s - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected -func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { - // CHECK-NOT: "tosa.conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} - // CHECK-SAME: -> tensor<400x2xf32> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} - // CHECK-SAME: -> tensor<3x2xf32> - // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK-SAME: -> tensor<400x3xf32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} - // CHECK-SAME: -> tensor<4x10x10x3xf32> - // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> - return %0 : tensor<4x10x10x3xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_as_fully_connected_quant -func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { - // CHECK-NOT: "tosa.conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} - // CHECK-SAME: -> tensor<400x2xi8> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} - // CHECK-SAME: -> tensor<3x2xi8> - // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) - // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32} - // CHECK-SAME: -> tensor<400x3xi32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} - // CHECK-SAME: -> tensor<4x10x10x3xi32> - // CHECK: return %[[VAR3]] - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> - return %0 : tensor<4x10x10x3xi32> -} - -// ----- - -// CHECK-LABEL: @depthwise_conv2d_as_mul -func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { - // CHECK-NOT: "tosa.depthwise_conv2d" - // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} - // CHECK-SAME: -> tensor<4x10x10x2x1xf32> - // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} - // CHECK-SAME: -> tensor<1x1x1x2x3xf32> - // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) - // CHECK-SAME: -> tensor<4x10x10x2x3xf32> - // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} - // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) - // CHECK-SAME: -> tensor<4x10x10x6xf32> - // CHECK: return %[[VAR4]] - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> - return %0 : tensor<4x10x10x6xf32> -} - -// ----- - -// CHECK-LABEL: @depthwise_conv2d_as_mul_q -func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { - // CHECK: "tosa.depthwise_conv2d" - %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> - return %0 : tensor<4x10x10x6xi32> -} - -// ----- +// RUN: mlir-opt --split-input-file --tosa-decompose-test-pass %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected +func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { + // CHECK-NOT: "tosa.conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} + // CHECK-SAME: -> tensor<400x2xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-SAME: -> tensor<3x2xf32> + // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) + // CHECK-SAME: -> tensor<400x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} + // CHECK-SAME: -> tensor<4x10x10x3xf32> + // CHECK: return %[[VAR3]] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + return %0 : tensor<4x10x10x3xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected_quant +func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<3x1x1x2xi8>, %arg2: tensor<3xi32>) -> tensor<4x10x10x3xi32> { + // CHECK-NOT: "tosa.conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} + // CHECK-SAME: -> tensor<400x2xi8> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-SAME: -> tensor<3x2xi8> + // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) + // CHECK-SAME: quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32} + // CHECK-SAME: -> tensor<400x3xi32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} + // CHECK-SAME: -> tensor<4x10x10x3xi32> + // CHECK: return %[[VAR3]] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 42 : i32, weight_zp = 24 : i32}} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + return %0 : tensor<4x10x10x3xi32> +} + +// ----- diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt --split-input-file --tosa-decompose-test-pass %s | FileCheck %s + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul +func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { + // CHECK-NOT: "tosa.depthwise_conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [4, 10, 10, 2, 1]} + // CHECK-SAME: -> tensor<4x10x10x2x1xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [1, 1, 1, 2, 3]} + // CHECK-SAME: -> tensor<1x1x1x2x3xf32> + // CHECK: %[[VAR2:.*]] = "tosa.mul"(%[[VAR0]], %[[VAR1]]) + // CHECK-SAME: -> tensor<4x10x10x2x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 6]} + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: %[[VAR4:.*]] = "tosa.add"(%[[VAR3]], %arg2) + // CHECK-SAME: -> tensor<4x10x10x6xf32> + // CHECK: return %[[VAR4]] + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<6xf32>) -> tensor<4x10x10x6xf32> + return %0 : tensor<4x10x10x6xf32> +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_as_mul_q +func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { + // CHECK: "tosa.depthwise_conv2d" + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + return %0 : tensor<4x10x10x6xi32> +} + +// ----- diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --split-input-file --tosa-decompose-transpose-conv %s | FileCheck %s +// RUN: mlir-opt --split-input-file --tosa-decompose-test-pass %s | FileCheck %s // CHECK-LABEL: @transpose_conv2d func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<2x?x?x5xf32> { diff --git a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tosa/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRTosaTestPasses TosaTestPasses.cpp + TosaTestDecomposePass.cpp EXCLUDE_FROM_LIBMLIR @@ -11,5 +12,4 @@ LINK_LIBS PUBLIC MLIRPass - MLIRTosa ) diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestDecomposePass.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestDecomposePass.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Tosa/TosaTestDecomposePass.cpp @@ -0,0 +1,53 @@ +//===- TosaTestPasses.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pass to apply the Tosa operations decomposition rewrites +// exposed as populate functions in +// include/mlir/Dialect/Tosa/Transforms/Passes.h +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define PASS_NAME "tosa-decompose-test-pass" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaTestDecompose : public PassWrapper { + StringRef getArgument() const final { return PASS_NAME; } + StringRef getDescription() const final { + return "TOSA Test Pass to verify the decomposition patterns"; + } + void runOnFunction() override; +}; + +void TosaTestDecompose::runOnFunction() { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + auto func = getFunction(); + + populateTosaDecomposeConv2D(ctx, patterns); + populateTosaDecomposeTransposeConv(ctx, patterns); + populateTosaDecomposeDepthwise(ctx, patterns); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // namespace + +namespace mlir { +void registerTosaTestDecomposePass() { PassRegistration(); } +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -52,6 +52,7 @@ void registerTestSpirvModuleCombinerPass(); void registerTestTraitsPass(); void registerTosaTestQuantUtilAPIPass(); +void registerTosaTestDecomposePass(); void registerVectorizerTestPass(); namespace test { @@ -142,6 +143,7 @@ registerTestTraitsPass(); registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + registerTosaTestDecomposePass(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner();