diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -900,17 +900,14 @@ failed(verifyAtLeastNResults(op, 1))) return failure(); - auto type = op->getResult(0).getType(); - auto elementType = getElementTypeOrSelf(type); - for (auto resultType : llvm::drop_begin(op->getResultTypes())) { - if (getElementTypeOrSelf(resultType) != elementType || - failed(verifyCompatibleShape(resultType, type))) + Type type = op->getResult(0).getType(); + for (Type resultType : llvm::drop_begin(op->getResultTypes())) { + if (type != resultType) return op->emitOpError() << "requires the same type for all operands and results"; } - for (auto opType : op->getOperandTypes()) { - if (getElementTypeOrSelf(opType) != elementType || - failed(verifyCompatibleShape(opType, type))) + for (Type opType : op->getOperandTypes()) { + if (type != opType) return op->emitOpError() << "requires the same type for all operands and results"; } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -158,8 +158,6 @@ func.func @succeededSameOperandAndResultType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>, %t1d: tensor, %i32 : i32) { "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> "test.same_operand_and_result_type"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> - "test.same_operand_and_result_type"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32> - "test.same_operand_and_result_type"(%t1, %t1d) : (tensor<1xf32>, tensor) -> tensor<1xf32> "test.same_operand_and_result_type"(%i32, %i32) : (i32, i32) -> i32 return } @@ -174,6 +172,22 @@ // ----- +func.func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %tr: tensor<*xf32>) { + // expected-error@+1 {{requires the same type for all operands and results}} + "test.same_operand_and_result_type"(%t10, %tr) : (tensor<10xf32>, tensor<*xf32>) -> tensor<10xf32> + return +} + +// ----- + +func.func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t1d: tensor) { + // expected-error@+1 {{requires the same type for all operands and results}} + "test.same_operand_and_result_type"(%t10, %t1d) : (tensor<10xf32>, tensor) -> tensor<10xf32> + return +} + +// ----- + func.func @failedElementwiseMappable_different_rankedness(%arg0: tensor, %arg1: tensor<*xf32>) { // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<*xf32>) -> tensor<*xf32>