diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -406,13 +406,12 @@ if (sliceStart[axis] >= 0 && (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) { - replaceWithSlice = - rewriter - .create( - sliceOp.getLoc(), sliceOp.getType(), input, - rewriter.getDenseI64ArrayAttr(sliceOp.getStart()), - rewriter.getDenseI64ArrayAttr(sliceSize)) - .getResult(); + replaceWithSlice = rewriter + .create( + sliceOp.getLoc(), sliceOp.getType(), input, + rewriter.getDenseI64ArrayAttr(sliceStart), + rewriter.getDenseI64ArrayAttr(sliceSize)) + .getResult(); break; } sliceStart[axis] -= inputType.getDimSize(axis); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -542,7 +542,7 @@ // CHECK-LABEL: @canonicalize_concat_slice_on_non_concat_axis // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12xf32>, %[[VAL_1:.*]]: tensor<1x12x12xf32> // CHECK: %[[VAL_2:.*]] = "tosa.slice"(%[[VAL_0]]) <{size = array, start = array}> : (tensor<1x12x12xf32>) -> tensor<1x6x12xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array, start = array}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.slice"(%[[VAL_1]]) <{size = array, start = array}> : (tensor<1x12x12xf32>) -> tensor<1x3x12xf32> // CHECK: return %[[VAL_2]], %[[VAL_3]] : tensor<1x6x12xf32>, tensor<1x3x12xf32> func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf32>, %arg1 : tensor<1x12x12xf32>) -> (tensor<1x6x12xf32>, tensor<1x3x12xf32>) { %0 = "tosa.concat"(%arg0, %arg1) {axis = 2 : i64} : (tensor<1x12x12xf32>, tensor<1x12x12xf32>) -> tensor<1x12x24xf32>