diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1556,6 +1556,7 @@ Tosa_Tensor1Dto6D:$output ); + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -519,6 +519,65 @@ results.add(context); } +struct ConcatSliceOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, + PatternRewriter &rewriter) const override { + Value sliceInput = sliceOp.getInput(); + auto concatOp = sliceInput.getDefiningOp(); + if (!concatOp) + return rewriter.notifyMatchFailure( + sliceOp, "slice input must be concat operation"); + + OperandRange inputs = concatOp.getInput1(); + auto concatType = dyn_cast(concatOp.getType()); + if (!concatType || !concatType.hasStaticShape()) + return rewriter.notifyMatchFailure( + sliceOp, "slice input must be a static ranked tensor"); + int32_t axis = concatOp.getAxis(); + + llvm::SmallVector sliceStart(sliceOp.getStart()); + llvm::ArrayRef sliceSize = sliceOp.getSize(); + + // Validate slice on the concatenated axis. Slicing along this + // axis should span only one of the inputs to the concatenate + // operation. + std::optional replaceWithSlice; + for (auto input : inputs) { + auto inputType = dyn_cast(input.getType()); + if (!inputType || !inputType.hasStaticShape()) + return rewriter.notifyMatchFailure( + sliceOp, "concat input must be a static ranked tensor"); + + if (sliceStart[axis] >= 0 && + (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { + replaceWithSlice = + rewriter + .create( + sliceOp.getLoc(), sliceOp.getType(), input, + rewriter.getDenseI64ArrayAttr(sliceOp.getStart()), + rewriter.getDenseI64ArrayAttr(sliceSize)) + .getResult(); + break; + } + sliceStart[axis] -= inputType.getDimSize(axis); + } + + if (!replaceWithSlice) + return rewriter.notifyMatchFailure( + sliceOp, "corresponding concat input not found for slice"); + + rewriter.replaceOp(sliceOp, replaceWithSlice.value()); + return success(); + } +}; + +void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -434,3 +434,56 @@ %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array, offset = array, border = array} : (tensor<1x15x13x1xi8>) -> tensor<1x15x13x1xi8> return %resize : tensor<1x15x13x1xi8> } + +// ----- + +// CHECK-LABEL: @canonicalize_concat_slice_final_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32> +// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32> +func.func @canonicalize_concat_slice_final_axis(%arg0 : tensor<1x12x12x1xf32>, %arg1 : tensor<1x12x12x1xf32>) -> (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>) -> tensor<1x12x12x2xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x12x2xf32>) -> tensor<1x12x12x1xf32> + return %1, %2 : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_concat_slice_middle_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> +// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12xf32>, tensor<1x12x12xf32> +func.func @canonicalize_concat_slice_middle_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x12xf32>, tensor<1x12x12xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 1 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x24x12xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x24x12xf32>) -> tensor<1x12x12xf32> + return %1, %2 : tensor<1x12x12xf32>, tensor<1x12x12xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_cross_concat_inputs +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.concat"(%[[VAL_0]], %[[VAL_1]]) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.slice"(%[[VAL_2]]) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32> +// CHECK: return %[[VAL_3]], %[[VAL_4]] : tensor<1x12x15xf32>, tensor<1x12x20xf32> +func.func @canonicalize_cross_concat_inputs(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x12x15xf32>, tensor<1x12x20xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x15xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x12x20xf32> + return %1, %2 : tensor<1x12x15xf32>, tensor<1x12x20xf32> +} + +// ----- + +// CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) {size = array, start = array} : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) {size = array, start = array} : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32> +// CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32> +func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32> + %1 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x6x12xf32> + %2 = "tosa.slice"(%0) {size = array, start = array} : (tensor<1x12x24xf32>) -> tensor<1x3x12xf32> + return %1, %2 : tensor<1x6x12xf32>, tensor<1x3x12xf32> +}