diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -425,6 +425,8 @@ bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != r.size() || l.size() != 1) return false; + if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) + return false; return succeeded(verifyCompatibleShape(l[0], r[0])); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -46,6 +46,14 @@ // ----- +func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor { + // expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor + return %0 : tensor +} + +// ----- + func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { // expected-error@+1 {{'tosa.pad' op padding of pad is not constant}} %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32>