diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -34,6 +34,17 @@ } // namespace tosa } // namespace mlir +//===----------------------------------------------------------------------===// +// Utility Functions +//===----------------------------------------------------------------------===// +namespace mlir { +namespace tosa { +/// Appends the canonicalization patterns for all the TOSA ops to the `patterns` +void populateTosaOpsCanonicalizationPatterns(MLIRContext *ctx, + RewritePatternSet &patterns); +} // namespace tosa +} // namespace mlir + #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -26,7 +26,10 @@ RewritePatternSet &patterns); void populateTosaDecomposeDepthwise(MLIRContext *ctx, RewritePatternSet &patterns); +void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx, + RewritePatternSet &patterns); +std::unique_ptr createTosaLayerwiseConstantFoldPass(); std::unique_ptr createTosaInferShapesPass(); std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -15,6 +15,15 @@ include "mlir/Pass/PassBase.td" +def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> { + let summary = "Fold layerwise operations on constant tensors"; + let description = [{ + Pass that enables folding of full-layer operations on constant tensors. + }]; + + let constructor = "createTosaLayerwiseConstantFoldPass()"; +} + def TosaInferShapes : Pass<"tosa-infer-shapes", "func::FuncOp"> { let summary = "Propagate shapes across TOSA operations"; let description = [{ diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -76,6 +76,8 @@ pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addNestedPass(tosa::createTosaToLinalgNamed()); pm.addNestedPass(createCanonicalizerPass()); + // TODO: Remove pass that operates on const tensor and enable optionality + pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -94,6 +94,20 @@ // Operator Canonicalizers. //===----------------------------------------------------------------------===// +template +void addOpsCanonicalizations(MLIRContext *ctx, RewritePatternSet &patterns) { + (void)std::initializer_list{ + 0, (Args::getCanonicalizationPatterns(patterns, ctx), 0)...}; +} + +void mlir::tosa::populateTosaOpsCanonicalizationPatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + addOpsCanonicalizations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(ctx, patterns); +} + struct ConcatOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -189,70 +203,6 @@ return success(); } -struct ConstantTransposeOptimization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::TransposeOp op, - PatternRewriter &rewriter) const override { - auto outputType = op.getType().cast(); - ArrayRef outputShape = outputType.getShape(); - // TOSA supports quantized types. - if (!outputType.getElementType().isIntOrIndexOrFloat()) - return failure(); - - DenseElementsAttr inputValues; - if (!matchPattern(op.input1(), m_Constant(&inputValues))) - return failure(); - // Make sure the input is a constant that has a single user. - if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers())) - return failure(); - - DenseIntElementsAttr permAttr; - if (!matchPattern(op.perms(), m_Constant(&permAttr))) - return failure(); - auto permValues = llvm::to_vector<6>(llvm::map_range( - // TOSA allows both 32- and 64-bit integer tensors here. - permAttr.getValues(), - [](const APInt &val) { return val.getZExtValue(); })); - - auto inputType = op.input1().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - int64_t numElements = inputType.getNumElements(); - - SmallVector outputValues; - outputValues.resize(numElements); - - // Transpose the input constant. Because we don't know its rank in advance, - // we need to loop over the range [0, element count) and delinearize the - // index. - auto attrValues = inputValues.getValues(); - for (int srcLinearIndex = 0; srcLinearIndex < numElements; - ++srcLinearIndex) { - SmallVector srcIndices(inputType.getRank(), 0); - int totalCount = srcLinearIndex; - for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { - srcIndices[dim] = totalCount % inputShape[dim]; - totalCount /= inputShape[dim]; - } - - SmallVector dstIndices(outputType.getRank(), 0); - for (int dim = outputType.getRank() - 1; dim >= 0; --dim) - dstIndices[dim] = srcIndices[permValues[dim]]; - - uint64_t dstLinearIndex = dstIndices.front(); - for (int dim = 1; dim < outputType.getRank(); ++dim) - dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; - - outputValues[dstLinearIndex] = attrValues[srcIndices]; - } - - rewriter.replaceOpWithNewOp( - op, outputType, DenseElementsAttr::get(outputType, outputValues)); - return success(); - } -}; - struct NoOpOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -282,7 +232,6 @@ void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); results.add(context); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -2,7 +2,9 @@ TosaDecomposeTransposeConv.cpp TosaDecomposeConv2D.cpp TosaDecomposeDepthwise.cpp + TosaFoldConstantTranspose.cpp TosaInferShapes.cpp + TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp TosaOptionalDecompositions.cpp diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -1,4 +1,4 @@ -//===- TosaDecomposeConv2D.cpp ------------------------------------------===// +//===- TosaDecomposeConv2D.cpp --------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -1,5 +1,4 @@ -//===- TosaDecomposeDepthwise.cpp -//------------------------------------------===// +//===- TosaDecomposeDepthwise.cpp -----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -1,5 +1,4 @@ -//===- TosaDecomposeTransposeConv.cpp -//------------------------------------------===// +//===- TosaDecomposeTransposeConv.cpp -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp @@ -0,0 +1,91 @@ +//===- TosaFoldConstantTranspose.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Fold TOSA Transpose operation on constant data +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaFoldConstantTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + auto outputType = op.getType().cast(); + // TOSA supports quantized types. + if (!outputType.getElementType().isIntOrIndexOrFloat()) + return failure(); + + DenseElementsAttr inputValues; + if (!matchPattern(op.input1(), m_Constant(&inputValues))) + return failure(); + // Make sure the input is a constant that has a single user. + if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers())) + return failure(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(op.perms(), m_Constant(&permAttr))) + return failure(); + auto permValues = llvm::to_vector<6>(llvm::map_range( + // TOSA allows both 32- and 64-bit integer tensors here. + permAttr.getValues(), + [](const APInt &val) { return val.getZExtValue(); })); + + auto inputType = op.input1().getType().cast(); + ArrayRef inputShape = inputType.getShape(); + int64_t numElements = inputType.getNumElements(); + + SmallVector outputValues; + outputValues.resize(numElements); + + // Transpose the input constant. Because we don't know its rank in advance, + // we need to loop over the range [0, element count) and delinearize the + // index. + auto attrValues = inputValues.getValues(); + ArrayRef outputShape = outputType.getShape(); + for (int srcLinearIndex = 0; srcLinearIndex < numElements; + ++srcLinearIndex) { + SmallVector srcIndices(inputType.getRank(), 0); + int totalCount = srcLinearIndex; + for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { + srcIndices[dim] = totalCount % inputShape[dim]; + totalCount /= inputShape[dim]; + } + + SmallVector dstIndices(outputType.getRank(), 0); + for (int dim = outputType.getRank() - 1; dim >= 0; --dim) + dstIndices[dim] = srcIndices[permValues[dim]]; + + uint64_t dstLinearIndex = dstIndices.front(); + for (int dim = 1; dim < outputType.getRank(); ++dim) + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; + + outputValues[dstLinearIndex] = attrValues[srcIndices]; + } + + rewriter.replaceOpWithNewOp( + op, outputType, DenseElementsAttr::get(outputType, outputValues)); + return success(); + } +}; + +} // namespace + +void mlir::tosa::populateTosaFoldConstantTransposePatterns( + MLIRContext *ctx, RewritePatternSet &patterns) { + patterns.add(ctx); +} diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -1,4 +1,4 @@ -//===- TosaInferShapes.cpp ------------------------------------------===// +//===- TosaInferShapes.cpp ------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp copy from mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp copy to mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -1,5 +1,4 @@ -//===- TosaOptionalDecompositions.cpp -//------------------------------------------===// +//===- TosaLayerwiseConstantFoldPass.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,9 +6,7 @@ // //===----------------------------------------------------------------------===// // -// Pass to apply the Tosa operations decompositions -// exposed as populate functions in -// include/mlir/Dialect/Tosa/Transforms/Passes.h +// This file implements constant folding transformations on TOSA operations // //===----------------------------------------------------------------------===// @@ -20,19 +17,19 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +using namespace mlir::tosa; namespace { -struct TosaOptionalDecompositions - : public TosaOptionalDecompositionsBase { +struct TosaLayerwiseConstantFoldPass + : public TosaLayerwiseConstantFoldPassBase { void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); - mlir::tosa::populateTosaDecomposeConv2D(ctx, patterns); - mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns); - mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns); + mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns); + mlir::tosa::populateTosaOpsCanonicalizationPatterns(ctx, patterns); if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) signalPassFailure(); @@ -41,6 +38,6 @@ } // namespace -std::unique_ptr mlir::tosa::createTosaOptionalDecompositions() { - return std::make_unique(); +std::unique_ptr mlir::tosa::createTosaLayerwiseConstantFoldPass() { + return std::make_unique(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp @@ -1,5 +1,4 @@ -//===- TosaOptionalDecompositions.cpp -//------------------------------------------===// +//===- TosaOptionalDecompositions.cpp -------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -391,104 +391,6 @@ return %0 : tensor<3x8xf32> } -// CHECK-LABEL: @transpose_fold -func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { - // CHECK: return %arg0 - %0 = arith.constant dense<[0, 1]> : tensor<2xi32> - %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32> - return %1 : tensor<3x4xf32> -} - -// CHECK-LABEL: @transpose_nofold -func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { - // CHECK: "tosa.transpose" - %0 = arith.constant dense<[1, 0]> : tensor<2xi32> - %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> - return %1 : tensor<3x3xf32> -} - -// CHECK-LABEL: @transpose_nofold_shape -func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { - // CHECK: "tosa.transpose" - %0 = arith.constant dense<[1, 0]> : tensor<2xi32> - %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor - return %1 : tensor -} - -// CHECK-LABEL: @transpose_fold_splat -func.func @transpose_fold_splat() -> tensor<3x2xf32> { - %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: %[[CST:.+]] = "tosa.const"() - // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32> - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - // CHECK: return %[[CST]] - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_fold_2d_float -func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { - %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: %[[CST:.+]] = "tosa.const"() - // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - // CHECK: return %[[CST]] - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_fold_4d_int -func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { - %input = "tosa.const"() {value = dense<[[ - [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], - [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] - ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32> - %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> - // CHECK: %[[CST:.+]] = "tosa.const"() - // CHECK-SAME{LITERAL}: value = dense<[ - // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], - // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], - // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] - // CHECK-SAME{LITERAL}: ]> - %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32> - // CHECK: return %[[CST]] - return %1 : tensor<3x1x4x2xi32> -} - -// CHECK-LABEL: @transpose_nofold_non_cst_input -func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> { - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: tosa.transpose - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_nofold_non_cst_perms -func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> { - %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - // CHECK: tosa.transpose - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// CHECK-LABEL: @transpose_nofold_multi_users -func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { - %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: tosa.transpose - %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> -} - -// CHECK-LABEL: @transpose_nofold_quantized_types -func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> { - %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> - %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8> - // CHECK: tosa.transpose - %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> - return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> -} - // CHECK-LABEL: @transpose_no_op func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> { // CHECK: return %arg0 diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +// CHECK-LABEL: @transpose_fold +func.func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK: return %arg0 + %0 = arith.constant dense<[0, 1]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + +// CHECK-LABEL: @transpose_nofold +func.func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK: "tosa.transpose" + %0 = arith.constant dense<[1, 0]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %1 : tensor<3x3xf32> +} + +// CHECK-LABEL: @transpose_nofold_shape +func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { + // CHECK: "tosa.transpose" + %0 = arith.constant dense<[1, 0]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @transpose_fold_splat +func.func @transpose_fold_splat() -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_fold_2d_float +func.func @transpose_fold_2d_float() -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_fold_4d_int +func.func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { + %input = "tosa.const"() {value = dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32> + %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi32> +} + +// CHECK-LABEL: @transpose_nofold_non_cst_input +func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> { + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_nofold_non_cst_perms +func.func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// CHECK-LABEL: @transpose_nofold_multi_users +func.func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> +} + +// CHECK-LABEL: @transpose_nofold_quantized_types +func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> { + %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> + %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8> + // CHECK: tosa.transpose + %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> + return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> +}