diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1419,8 +1419,7 @@ // Operator: concat //===----------------------------------------------------------------------===// def Tosa_ConcatOp : Tosa_Op<"concat", [ - DeclareOpInterfaceMethods, + InferTensorType, Pure]> { let summary = "Concatenates tensors along one dimension."; @@ -1439,6 +1438,12 @@ ); let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -422,6 +422,12 @@ return success(); } +bool tosa::ConcatOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != r.size() || l.size() != 1) + return false; + return succeeded(verifyCompatibleShape(l[0], r[0])); +} + LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions, @@ -447,14 +453,17 @@ if (outputShape[i] == ShapedType::kDynamic) outputShape[i] = operandShape.getDimSize(i); if (outputShape[i] != operandShape.getDimSize(i)) - return failure(); + return emitOptionalError(location, + "Cannot concat tensors with different sizes" + " on the non-axis dimension ", + i); } hasRankedInput = true; } - + Type inputType = operands.getType()[0].cast().getElementType(); if (!hasRankedInput) { - inferredReturnShapes.push_back(ShapedTypeComponents()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } @@ -475,7 +484,7 @@ outputShape[axis] = concatDimSize; - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -36,4 +36,10 @@ return %0 : tensor<1x27x27x16xi8> } +// ----- +func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor { + // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -491,16 +491,6 @@ // ----- -// CHECK-LABEL: @test_concat_failure -func.func @test_concat_failure(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> () { - // CHECK: "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor - %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor - - return -} - -// ----- - // CHECK-LABEL: @test_padding_no_const func.func @test_padding_no_const(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xi32>) -> () { // CHECK: "tosa.pad"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<2x2xi32>) -> tensor