diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1534,6 +1534,7 @@ outs Tosa_Tensor1Dto6D:$output ); + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -159,6 +159,71 @@ results.insert(context); } +struct ConstantTransposeOptimization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TransposeOp op, + PatternRewriter &rewriter) const override { + DenseElementsAttr inputValues; + if (!matchPattern(op.input1(), m_Constant(&inputValues))) + return failure(); + // Make sure the input is a constant that has a single user. + if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers())) + return failure(); + + DenseIntElementsAttr permAttr; + if (!matchPattern(op.perms(), m_Constant(&permAttr))) + return failure(); + auto permValues = llvm::to_vector<6>(llvm::map_range( + // TOSA allows both 32- and 64-bit integer tensors here. + permAttr.getValues(), + [](const APInt &val) { return val.getZExtValue(); })); + + auto inputType = op.input1().getType().cast(); + ArrayRef inputShape = inputType.getShape(); + int64_t numElements = inputType.getNumElements(); + + auto outputType = op.getType().cast(); + ArrayRef outputShape = outputType.getShape(); + + SmallVector outputValues; + outputValues.resize(numElements); + + // Transpose the input constant. Because we don't know its rank in advance, + // we need to loop over the range [0, element count) and delinearize the + // index. + for (int srcLinearIndex = 0; srcLinearIndex < numElements; + ++srcLinearIndex) { + SmallVector srcIndices(inputType.getRank(), 0); + int totalCount = srcLinearIndex; + for (int dim = inputType.getRank() - 1; dim >= 0; --dim) { + srcIndices[dim] = totalCount % inputShape[dim]; + totalCount /= inputShape[dim]; + } + + SmallVector dstIndices(outputType.getRank(), 0); + for (int dim = outputType.getRank() - 1; dim >= 0; --dim) + dstIndices[dim] = srcIndices[permValues[dim]]; + + uint64_t dstLinearIndex = dstIndices.front(); + for (int dim = 1; dim < outputType.getRank(); ++dim) + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; + + outputValues[dstLinearIndex] = inputValues.getValue(srcIndices); + } + + rewriter.replaceOpWithNewOp( + op, outputType, DenseElementsAttr::get(outputType, outputValues)); + return success(); + } +}; + +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// @@ -225,15 +290,18 @@ if (!operands[1]) return {}; - DenseIntElementsAttr perms = operands[1].cast(); - - bool isRange = true; - for (auto it : llvm::enumerate(perms)) { - isRange = isRange && - it.value().getSExtValue() == static_cast(it.index()); + // Transposing splat values just means reshaping. + if (auto input = operands[0].dyn_cast_or_null()) { + if (input.isSplat()) + return input.reshape(getType().cast()); } - if (isRange && input1().getType() == getType()) + auto perms = llvm::to_vector<6>(llvm::map_range( + operands[1].cast().getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + + if (llvm::equal(llvm::seq(0, perms.size()), perms) && + input1().getType() == getType()) return input1(); return {}; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --canonicalize %s | FileCheck %s +// RUN: mlir-opt --split-input-file --canonicalize %s | FileCheck %s // CHECK-LABEL: @argmax_nofold func @argmax_nofold(%arg0: tensor) -> tensor { @@ -237,3 +237,80 @@ %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: @transpose_fold_splat +func @transpose_fold_splat() -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<4.000000e+00> : tensor<3x2xf32> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_2d_float +func @transpose_fold_2d_float() -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_int +func @transpose_fold_4d_int() -> tensor<3x1x4x2xi32> { + %input = "tosa.const"() {value = dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi32>} : () -> tensor<1x2x3x4xi32> + %perms = "tosa.const"() {value = dense<[2, 0, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + // CHECK: %[[CST:.+]] = "tosa.const"() + // CHECK-SAME{LITERAL}: value = dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = "tosa.transpose"(%input, %perms) : (tensor<1x2x3x4xi32>, tensor<4xi64>) -> tensor<3x1x4x2xi32> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_non_cst_input +func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>) -> tensor<3x2xf32> { + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_non_cst_perms +func @transpose_nofold_non_cst_perms(%perms: tensor<2xi32>) -> tensor<3x2xf32> { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_multi_users +func @transpose_nofold_multi_users() -> (tensor<3x2xf32>, tensor<2x3xf32>) { + %input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: tosa.transpose + %1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + return %1, %input : tensor<3x2xf32>, tensor<2x3xf32> +}