diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -37,6 +37,8 @@ there will be tools to lower from the ML frameworks into TOSA. }]; + let dependentDialects = ["tensor::TensorDialect"]; + let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; } diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_TOSA_IR_TOSAOPS_H #include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1227,6 +1227,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1250,6 +1252,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1273,6 +1277,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1296,6 +1302,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1319,6 +1327,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1342,6 +1352,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1371,6 +1383,8 @@ let results = (outs Tosa_RankedTensor:$output ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -1415,6 +1429,7 @@ }]; let hasCanonicalizer = 1; + let hasFolder = 1; let arguments = (ins Tosa_Tensor:$input1, @@ -1473,6 +1488,8 @@ let results = (outs Tosa_Tensor1Dto6D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1495,6 +1512,8 @@ let results = (outs Tosa_Tensor1Dto4D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1518,6 +1537,8 @@ let results = ( outs Tosa_Tensor1Dto6D:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1650,6 +1671,8 @@ let results = (outs Tosa_Tensor:$output ); + + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -106,19 +107,31 @@ // Operator Canonicalizers. //===----------------------------------------------------------------------===// -struct RemoveReshapeNoop : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConcatOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tosa::ReshapeOp op, + LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override { - if (op.input1().getType() != op.getType()) + if (op.input1().size() != 1) return failure(); + if (op.input1().front().getType() != op.getType()) { + rewriter + .replaceOpWithNewOp(op, op.getType(), + op.input1().front()) + .getResult(); + return success(); + } - rewriter.replaceOp(op, op.input1()); + rewriter.replaceOp(op, op.input1().front()); return success(); } }; +void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + struct ReshapeReshapeOptimization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -141,18 +154,88 @@ void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// +OpFoldResult CastOp::fold(ArrayRef operands) { + if (input().getType() == getType()) + return input(); + return {}; +} + OpFoldResult ConstOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); return valueAttr(); } +#define ReduceFolder(OP) \ + OpFoldResult OP::fold(ArrayRef operands) { \ + ShapedType inputTy = input().getType().cast(); \ + if (!inputTy.hasRank()) \ + return {}; \ + if (inputTy.getDimSize(axis()) == 1) \ + return input(); \ + return {}; \ + } + +ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp) + ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp) + ReduceFolder(ReduceSumOp) +#undef ReduceFolder + + OpFoldResult ReshapeOp::fold(ArrayRef operands) { + auto inputTy = input1().getType().dyn_cast(); + auto outputTy = getType().dyn_cast(); + + if (!inputTy || !outputTy || inputTy != outputTy) + return {}; + return input1(); +} + +OpFoldResult SliceOp::fold(ArrayRef operands) { + auto inputTy = input().getType().dyn_cast(); + auto outputTy = getType().dyn_cast(); + + if (!inputTy || !outputTy || inputTy != outputTy) + return {}; + if (inputTy.hasStaticShape()) + return input(); + + return {}; +} + +OpFoldResult TileOp::fold(ArrayRef operands) { + bool allOnes = true; + for (Attribute val : multiples().getValue()) { + allOnes = allOnes && val.cast().getValue().getSExtValue() == 1; + } + + if (allOnes && input1().getType() == getType()) + return input1(); + return {}; +} + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + if (!operands[1]) + return {}; + + DenseIntElementsAttr perms = operands[1].cast(); + + bool isRange = true; + for (auto it : llvm::enumerate(perms)) { + isRange = isRange && + it.value().getSExtValue() == static_cast(it.index()); + } + + if (isRange && input1().getType() == getType()) + return input1(); + return {}; +} + //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -237,13 +237,9 @@ // CHECK: fptrunc %23 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf16> - // CHECK: linalg.generic - // CHECK: yield - %24 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xf32> - // CHECK: linalg.generic // CHECK: divf - %25 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32> + %24 = "tosa.reciprocal"(%0) : (tensor<1xf32>) -> tensor<1xf32> return } @@ -383,29 +379,25 @@ // CHECK: trunci %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> - // CHECK: linalg.generic - // CHECK: yield - %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32> - // CHECK: linalg.generic // CHECK: sexti - %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> + %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpi - %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> + %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: sitofp - %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> + %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpi sgt // CHECK: subi // CHECK: select - %25 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + %24 = "tosa.abs"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -0,0 +1,240 @@ +// RUN: mlir-opt --canonicalize %s | FileCheck %s + +// CHECK-LABEL: @argmax_nofold +func @argmax_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.argmax" + %0 = "tosa.argmax"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @cast_fold +func @cast_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.cast"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @cast_nofold +func @cast_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.cast" + %0 = "tosa.cast"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @concat_fold +func @concat_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.concat"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @concat_fold_cast +func @concat_fold_cast(%arg0: tensor) -> tensor { + // CHECK: %[[VAR0:.*]] = tensor.cast %arg0 + // CHECK: return %[[VAR0]] + %0 = "tosa.concat"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_all_fold +func @reduce_all_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.reduce_all"(%arg0) {axis = 1 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_all_nofold +func @reduce_all_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.reduce_all" + %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_any_fold +func @reduce_any_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.reduce_any"(%arg0) {axis = 1 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_any_nofold +func @reduce_any_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.reduce_any" + %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_max_fold +func @reduce_max_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.reduce_max"(%arg0) {axis = 1 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_max_nofold +func @reduce_max_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.reduce_max" + %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_min_fold +func @reduce_min_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.reduce_min"(%arg0) {axis = 1 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_min_nofold +func @reduce_min_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.reduce_min" + %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_prod_fold +func @reduce_prod_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.reduce_prod"(%arg0) {axis = 1 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_prod_nofold +func @reduce_prod_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.reduce_prod" + %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_sum_fold +func @reduce_sum_fold(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reduce_sum_nofold +func @reduce_sum_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.reduce_sum" + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reshape_canonicalize +func @reshape_canonicalize(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "tosa.reshape"(%arg0) {new_shape = [-1, 10]}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @reshape_canonicalize_double +func @reshape_canonicalize_double(%arg0: tensor) -> tensor { + // CHECK: %[[VAR0:.+]] = "tosa.reshape"(%arg0) {new_shape = [-1, 5]} + // CHECK: return %[[VAR0]] + %0 = "tosa.reshape"(%arg0) {new_shape = [5, -1]}: (tensor) -> tensor<5x?xf32> + %1 = "tosa.reshape"(%0) {new_shape = [-1, 5]}: (tensor<5x?xf32>) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @slice_fold +func @slice_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK: return %arg0 + %0 = "tosa.slice"(%arg0) { size = [3, 4], start = [0, 0]}: (tensor<3x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + +// CHECK-LABEL: @slice_nofold +func @slice_nofold(%arg0: tensor) -> tensor { + // CHECK: "tosa.slice" + %0 = "tosa.slice"(%arg0) { size = [3, 4], start = [0, 0]}: (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @tile_fold +func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK: return %arg0 + %0 = "tosa.tile"(%arg0) { multiples = [1, 1] }: (tensor<3x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + +// CHECK-LABEL: @tile_nofold +func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { + // CHECK: "tosa.tile" + %0 = "tosa.tile"(%arg0) { multiples = [1, 2] }: (tensor<3x4xf32>) -> tensor<3x8xf32> + return %0 : tensor<3x8xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold +func @transpose_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { + // CHECK: return %arg0 + %0 = constant dense<[0, 1]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +} + + +// ----- + +// CHECK-LABEL: @transpose_nofold +func @transpose_nofold(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK: "tosa.transpose" + %0 = constant dense<[1, 0]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x3xf32>, tensor<2xi32>) -> tensor<3x3xf32> + return %1 : tensor<3x3xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_shape +func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor { + // CHECK: "tosa.transpose" + %0 = constant dense<[0, 1]> : tensor<2xi32> + %1 = "tosa.transpose"(%arg0, %0) { perms = [1, 0] }: (tensor<3x4xf32>, tensor<2xi32>) -> tensor + return %1 : tensor +}