diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1019,7 +1019,7 @@ loc, outputTy.getShape(), outputTy.getElementType()); Value zeroTensor = rewriter.create(loc, initTensor, zero).getResult(0); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, ValueRange{zeroTensor}); return success(); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -844,13 +844,13 @@ // CHECK-LABEL: @matmul -func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { +func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>, %arg2: tensor<1x6xf32>) -> (tensor<1x5x6xf32>) { // CHECK: [[C0:%.+]] = constant 0 - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6] - // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<5x6xf32>, f32 -> tensor<5x6xf32> - // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILLED]] : tensor<5x6xf32>) -> tensor<5x6xf32> - %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x6xf32>) -> (tensor<5x6xf32>) - return %0 : tensor<5x6xf32> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6] + // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<1x5x6xf32>, f32 -> tensor<1x5x6xf32> + // CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) + return %0 : tensor<1x5x6xf32> } // -----