diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -202,10 +202,16 @@ // inferredDim existingDim Behavior // ----------- ----------- -------- // dynamic dynamic OK - // dynamic static Error + // dynamic static OK^1 // static dynamic OK // static static OK if equal - return ShapedType::isDynamic(existingDim) || inferredDim == existingDim; + // + // [1] This allows for implicit dynamic-to-static cast with undefined + // behavior if the resulting runtime dimension size does not match the given + // static size. This is meant to only flag known invalid cases while the + // expectation is that these should be equal dynamically. + return ShapedType::isDynamic(existingDim) || + ShapedType::isDynamic(inferredDim) || inferredDim == existingDim; }; if (inferred.size() != existing.size()) return false; diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir --- a/mlir/test/Dialect/traits.mlir +++ b/mlir/test/Dialect/traits.mlir @@ -111,9 +111,10 @@ // ----- -// Error for inferred dynamic dimension but existing static dimensions +// Error for inferred dynamic dimension but existing static dimensions, this +// has moved to runtime UB if not matching. func.func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor) -> tensor<2xi32> { - // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '?'}} + // not-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '?'}} %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor) -> tensor<2xi32> return %0 : tensor<2xi32> }