diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1057,15 +1057,6 @@ return success(); } -/// Checks if two ShapedTypes are the same, ignoring the element type. -static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) { - if (a.getTypeID() != b.getTypeID()) - return false; - if (!a.hasRank()) - return !b.hasRank(); - return a.getShape() == b.getShape(); -} - LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { auto isMappableType = [](Type type) { return type.isa(); @@ -1097,8 +1088,9 @@ auto mustMatchType = operandMappableTypes[0].cast(); for (auto type : llvm::concat(resultMappableTypes, operandMappableTypes)) { - if (!areSameShapedTypeIgnoringElementType(type.cast(), - mustMatchType)) { + auto shapedTy = type.cast(); + if (shapedTy.getTypeID() != mustMatchType.getTypeID() || + failed(verifyCompatibleShape(shapedTy, mustMatchType))) { return op->emitOpError() << "all non-scalar operands/results must have " "the same shape and base type: found " << type << " and " << mustMatchType; diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -162,6 +162,7 @@ func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<20xf32>) { // expected-error@+1 {{requires the same type for all operands and results}} "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xf32> + return } // ----- @@ -169,6 +170,7 @@ func @failedElementwiseMappable_different_rankedness(%arg0: tensor, %arg1: tensor<*xf32>) { // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor'}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<*xf32>) -> tensor<*xf32> + return } // ----- @@ -176,13 +178,16 @@ func @failedElementwiseMappable_different_rank(%arg0: tensor, %arg1: tensor) { // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor' and 'tensor'}} %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor) -> tensor + return } // ----- -func @failedElementwiseMappable_different_shape(%arg0: tensor, %arg1: tensor<5xf32>) { - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor'}} - %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<5xf32>) -> tensor +func @elementwiseMappable_dynamic_shapes(%arg0: tensor, + %arg1: tensor<5xf32>) { + %0 = "test.elementwise_mappable"(%arg0, %arg1) : + (tensor, tensor<5xf32>) -> tensor + return } // -----