diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -815,8 +815,21 @@ if (!didEncounterError) return failure(); - rewriter.replaceOpWithNewOp(op, resultTy, - linalgOp.getResults()); + SmallVector reassociationMap; + uint64_t expandInputRank = + linalgOp.getResults()[0].getType().cast().getRank(); + reassociationMap.resize(expandInputRank); + + for (uint64_t i = 0; i < expandInputRank; i++) { + int32_t dimToPush = i > axis ? i + 1 : i; + reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); + } + int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; + reassociationMap[expandedDim].push_back( + rewriter.getAffineDimExpr(expandedDim + 1)); + + rewriter.replaceOpWithNewOp( + op, resultTy, linalgOp.getResults()[0], reassociationMap); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -746,8 +746,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.addf %arg1, %arg2 : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor into tensor - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor into tensor %0 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor) -> tensor return } @@ -768,8 +767,7 @@ // CHECK: ^bb0(%arg1: f32, %arg2: f32) // CHECK: %[[RES:.+]] = arith.mulf %arg1, %arg2 : f32 // CHECK: linalg.yield %[[RES]] : f32 - // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1]] : tensor<5x?xf32> into tensor - // CHECK: tensor.expand_shape %[[COLLAPSED]] {{\[}}[0, 1, 2]] : tensor into tensor<5x?x1xf32> + // CHECK: tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]] : tensor<5x?xf32> into tensor<5x?x1xf32> %0 = "tosa.reduce_prod"(%arg0) {axis = 2 : i64} : (tensor<5x?x4xf32>) -> tensor<5x?x1xf32> return }