diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -88,16 +88,22 @@ LogicalResult matchAndRewrite(tosa::ReshapeOp op, PatternRewriter &rewriter) const override { Value input = op.getInput1(); + ShapedType inputTy = input.getType().cast(); + ShapedType resultTy = op.getType().cast(); ArrayAttr newShape = op.getNewShape(); + if (inputTy.getElementType() != resultTy.getElementType()) + return rewriter.notifyMatchFailure(op, "element type does not match."); + // Check if input is constant DenseElementsAttr inputAttr; if (!matchPattern(input, m_Constant(&inputAttr))) - return failure(); + return rewriter.notifyMatchFailure(op, "Non-constant input."); // Check if has >1 consumer and is not splat if (!input.hasOneUse() && !inputAttr.isSplat()) - return failure(); + return rewriter.notifyMatchFailure(op, + "Used more than once or not-splat"); // Grab the new shape SmallVector newShapeValues = llvm::to_vector<6>( @@ -132,7 +138,7 @@ return success(); } -struct NoOpOptimization : public OpRewritePattern { +struct TransposeNoOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TransposeOp op, @@ -159,9 +165,60 @@ } }; +// Determines the case when tosa.transpose is a tosa.reshape operation. +struct TransposeIsReshape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + DenseIntElementsAttr permAttr; + if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) + return rewriter.notifyMatchFailure(op, "Non-constant permutation"); + + auto input = op.getInput1(); + auto inputTy = input.getType().cast(); + if (!inputTy.hasRank()) + return rewriter.notifyMatchFailure(op, "Unranked input."); + + int64_t numDynDims = 0; + for (int i = 0; i < inputTy.getRank(); ++i) + if (inputTy.isDynamicDim(i)) + numDynDims++; + + if (numDynDims > 1) + return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim."); + + SmallVector permValues = llvm::to_vector<6>( + llvm::map_range(permAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + SmallVector nonZeroPerms; + nonZeroPerms.reserve(permValues.size()); + for (auto idx : permValues) { + auto sz = inputTy.getDimSize(idx); + if (sz != 1) + nonZeroPerms.push_back(idx); + } + + for (int i = 1, s = nonZeroPerms.size(); i < s; ++i) + if (nonZeroPerms[i - 1] > nonZeroPerms[i]) + return rewriter.notifyMatchFailure(op, + "Transpose changes memeory layout."); + + SmallVector newShape; + newShape.reserve(inputTy.getRank()); + for (int i = 0, s = inputTy.getRank(); i < s; ++i) + newShape.push_back(inputTy.getDimSize(permValues[i])); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInput1(), rewriter.getI64ArrayAttr(newShape)); + return success(); + } +}; + void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } struct AddZeroOptimization : public OpRewritePattern { @@ -958,6 +1015,11 @@ if (!operands[1]) return {}; + auto inputTy = getInput1().getType().cast(); + auto resultTy = getType().cast(); + if (inputTy.getElementType() != resultTy.getElementType()) + return {}; + // Transposing splat values just means reshaping. if (auto input = operands[0].dyn_cast_or_null()) { if (input.isSplat()) diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -400,6 +400,14 @@ return %1 : tensor<3x4x5x6xf32> } +// CHECK-LABEL: @transpose_is_reshape +func.func @transpose_is_reshape(%arg0: tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> { + // CHECK: "tosa.reshape"(%arg0) {new_shape = [1, 4, 1, 5]} : (tensor<1x4x5x1xf32>) -> tensor<1x4x1x5xf32> + %perms = "tosa.const"() {value = dense<[3, 1, 0, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %0 = "tosa.transpose"(%arg0, %perms) : (tensor<1x4x5x1xf32>, tensor<4xi32>) -> tensor<1x4x1x5xf32> + return %0 : tensor<1x4x1x5xf32> +} + // CHECK-LABEL: @single_bit_reshape // https://github.com/llvm/llvm-project/issues/55440 func.func @single_bit_reshape() -> tensor<1xi1> { diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -90,12 +90,12 @@ } // CHECK-LABEL: @transpose_nofold_quantized_types -func.func @transpose_nofold_quantized_types() -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> { +func.func @transpose_nofold_quantized_types() -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> { %perms = "tosa.const"() {value = dense<[1, 2, 3, 0]> : tensor<4xi32>} : () -> tensor<4xi32> - %input = "tosa.const"() {value = dense<[[[[-127, 127, 127, -127, -127, -127, -127, -127, -127, 127, 127, 127, 127, 127, -127, 127]]]]> : tensor<1x1x1x16xi8>} : () -> tensor<1x1x1x16xi8> + %input = "tosa.const"() {value = dense<-127> : tensor<2x1x1x2xi8>} : () -> tensor<2x1x1x2xi8> // CHECK: tosa.transpose - %0 = "tosa.transpose"(%input, %perms) : (tensor<1x1x1x16xi8>, tensor<4xi32>) -> tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> - return %0: tensor<1x1x16x1x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,2.100000e+00,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> + %0 = "tosa.transpose"(%input, %perms) : (tensor<2x1x1x2xi8>, tensor<4xi32>) -> tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> + return %0: tensor<1x1x2x2x!quant.uniform:f32:3, {1.000000e-01,1.000000e-01}>> } // -----