diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -192,14 +192,18 @@ llvm::any_of(types, [](Type t) { return t.isa(); })); } -static bool areCompatibleShapes(ArrayRef shape1, - ArrayRef shape2) { +static bool isCompatibleInferredReturnShape(ArrayRef inferred, + ArrayRef existing) { auto isCompatible = [](int64_t dim1, int64_t dim2) { - return dim1 == dim2 || dim1 == -1 || dim2 == -1; + // If the inferred and existing dim is the same, or one of them is unknown + // then it is compatible, else if the inferred dim is 1 then it is also + // compatible. But if the existing dim is 1 and the inferred is greater than + // 1 then flag. + return dim1 == dim2 || dim1 == -1 || dim2 == -1 || dim1 == 1; }; - if (shape1.size() != shape2.size()) + if (inferred.size() != existing.size()) return false; - for (auto p : llvm::zip(shape1, shape2)) + for (auto p : llvm::zip(inferred, existing)) if (!isCompatible(std::get<0>(p), std::get<1>(p))) return false; return true; @@ -208,8 +212,20 @@ static std::string getShapeString(ArrayRef shape) { // TODO: should replace with printing shape more uniformly across here and // when in type. - return std::string( - formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end()))); + std::string ret; + llvm::raw_string_ostream ss(ret); + ss << '\''; + llvm::interleave( + shape, ss, + [&](int64_t dim) { + if (ShapedType::isDynamic(dim)) + ss << '?'; + else + ss << dim; + }, + "x"); + ss << '\''; + return ss.str(); } LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { @@ -252,7 +268,7 @@ for (auto type : rankedResults) { ArrayRef actualSuffix = getShape(type).take_back(resultShape.size()); - if (!areCompatibleShapes(actualSuffix, resultShape)) + if (!isCompatibleInferredReturnShape(resultShape, actualSuffix)) return op->emitOpError() << "result type " << getShapeString(getShape(type)) << " not broadcast compatible with broadcasted operands's shapes " diff --git a/mlir/test/Dialect/traits.mlir b/mlir/test/Dialect/traits.mlir --- a/mlir/test/Dialect/traits.mlir +++ b/mlir/test/Dialect/traits.mlir @@ -111,6 +111,13 @@ // ----- +func @broadcast_tensor_tensor_tensor(%arg0: tensor, %arg1: tensor<*xi32>) -> tensor { + %0 = "test.broadcastable"(%arg0, %arg1) : (tensor, tensor<*xi32>) -> tensor + return %0 : tensor +} + +// ----- + // Unranked operands but ranked result func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>):