diff --git a/mlir/include/mlir/Dialect/Traits.h b/mlir/include/mlir/Dialect/Traits.h --- a/mlir/include/mlir/Dialect/Traits.h +++ b/mlir/include/mlir/Dialect/Traits.h @@ -47,7 +47,7 @@ bool getBroadcastedShape(ArrayRef shape1, ArrayRef shape2, SmallVectorImpl &resultShape); -/// Returns true if a broadcast between the 2 shapes is guaranteed to be +/// Returns true if a broadcast between n shapes is guaranteed to be /// successful and not result in an error. False does not guarantee that the /// shapes are not broadcastable; it might guarantee that they are not /// broadcastable or it might mean that this function does not have enough @@ -59,6 +59,7 @@ /// dimension, while this function will return false because it's possible for /// both shapes to have a dimension greater than 1 and different which would /// fail to broadcast. +bool staticallyKnownBroadcastable(ArrayRef> shapes); bool staticallyKnownBroadcastable(ArrayRef shape1, ArrayRef shape2); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -490,38 +490,49 @@ patterns.insert(context); } -OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { - // TODO: Add folding for the nary case - if (operands.size() != 2) - return nullptr; +// Return the index of the only non-scalar element in attributes, if there is +// any. +static bool hasSingleNonScalar(ArrayRef attributes) { + bool nonScalarSeen = false; + for (size_t i = 0; i != attributes.size(); ++i) { + if (!attributes[i] || + attributes[i].cast().getNumElements() != 0) { + if (nonScalarSeen) + return false; + nonScalarSeen = true; + } + } + return nonScalarSeen; +} - // Both operands are not needed if one is a scalar. - if (operands[0] && - operands[0].cast().getNumElements() == 0) - return BoolAttr::get(getContext(), true); - if (operands[1] && - operands[1].cast().getNumElements() == 0) +OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { + // No broadcasting is needed if all operands but one are scalar. + if (hasSingleNonScalar(operands)) return BoolAttr::get(getContext(), true); - if (operands[0] && operands[1]) { - auto lhsShape = llvm::to_vector<6>( - operands[0].cast().getValues()); - auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); - SmallVector resultShape; - if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) - return BoolAttr::get(getContext(), true); - } + if ([&] { + SmallVector, 6> extents; + for (const auto &operand : operands) { + if (!operand) + return false; + extents.push_back(llvm::to_vector<6>( + operand.cast().getValues())); + } + return OpTrait::util::staticallyKnownBroadcastable(extents); + }()) + return BoolAttr::get(getContext(), true); // Lastly, see if folding can be completed based on what constraints are known // on the input shapes. - SmallVector lhsShape, rhsShape; - if (failed(getShapeVec(shapes()[0], lhsShape))) - return nullptr; - if (failed(getShapeVec(shapes()[1], rhsShape))) - return nullptr; - - if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) + if ([&] { + SmallVector, 6> extents; + for (const auto &shape : shapes()) { + extents.emplace_back(); + if (failed(getShapeVec(shape, extents.back()))) + return false; + } + return OpTrait::util::staticallyKnownBroadcastable(extents); + }()) return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -15,19 +15,44 @@ bool OpTrait::util::staticallyKnownBroadcastable(ArrayRef shape1, ArrayRef shape2) { - // Two dimensions are compatible when - // 1. they are defined and equal, or - // 2. one of them is 1 - return llvm::all_of(llvm::zip(llvm::reverse(shape1), llvm::reverse(shape2)), - [](auto dimensions) { - auto dim1 = std::get<0>(dimensions); - auto dim2 = std::get<1>(dimensions); - if (dim1 == 1 || dim2 == 1) - return true; - if (dim1 == dim2 && !ShapedType::isDynamic(dim1)) - return true; - return false; - }); + SmallVector, 2> extents; + extents.emplace_back(shape1.begin(), shape1.end()); + extents.emplace_back(shape2.begin(), shape2.end()); + return staticallyKnownBroadcastable(extents); +} + +bool OpTrait::util::staticallyKnownBroadcastable( + ArrayRef> shapes) { + size_t maxRank = shapes[0].size(); + for (size_t i = 1; i != shapes.size(); ++i) + maxRank = std::max(maxRank, shapes[i].size()); + + // We look backwards through every column of `shapes`. + for (size_t i = 0; i != maxRank; ++i) { + bool seenDynamic = false; + Optional other; + for (ArrayRef extent : shapes) { + int64_t dim = i >= extent.size() ? 1 : extent[extent.size() - i - 1]; + + if (dim == 1) + continue; + + // Dimensions are compatible when + //. 1. One is dynamic, the rest is 1. + if (ShapedType::isDynamic(dim)) { + if (seenDynamic || other) + return false; + seenDynamic = true; + } + + // 2. All are 1 or a specific constant. + if (other && dim != *other) + return false; + + other = dim; + } + } + return true; } bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -589,6 +589,92 @@ return } +// ----- +// Fold ternary broadcastable +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [8, 1] : !shape.shape + %cs1 = shape.const_shape [1, 8] : !shape.shape + %cs2 = shape.const_shape [1, 1] : !shape.shape + %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Fold ternary broadcastable with dynamic ranks +// CHECK-LABEL: func @f +func @f() { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [8, 1] : !shape.shape + %cs1 = shape.const_shape [1, -1] : !shape.shape + %0 = shape.cstr_broadcastable %cs0, %cs0, %cs1 : !shape.shape, !shape.shape, !shape.shape + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// One scalar and one non-scalar and one unknown cannot be broadcasted at compile time +// CHECK-LABEL: func @f +func @f() { + // CHECK: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [8, 1] : !shape.shape + %cs1 = shape.const_shape [1, 8] : !shape.shape + %cs2 = shape.const_shape [1, -1] : !shape.shape + %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// One scalar and two unknowns cannot be broadcasted at compile time +// CHECK-LABEL: func @f +func @f() { + // CHECK: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [8, 1] : !shape.shape + %cs1 = shape.const_shape [1, -1] : !shape.shape + %cs2 = shape.const_shape [1, -1] : !shape.shape + %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// Broadcastable with scalars and a non-scalar can be constant folded +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) { + // CHECK-NEXT: shape.const_witness true + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [] : !shape.shape + %0 = shape.cstr_broadcastable %cs0, %cs0, %arg0 : !shape.shape, !shape.shape, !shape.shape + "consume.witness"(%0) : (!shape.witness) -> () + return +} + +// ----- +// One scalar and one non-scalar and one unknown cannot be folded. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) { + // CHECK: shape.cstr_broadcastable + // CHECK-NEXT: consume.witness + // CHECK-NEXT: return + %cs0 = shape.const_shape [] : !shape.shape + %cs1 = shape.const_shape [2] : !shape.shape + %0 = shape.cstr_broadcastable %cs0, %cs1, %arg0 : !shape.shape, !shape.shape, !shape.shape + "consume.witness"(%0) : (!shape.witness) -> () + return +} + // ----- // Fold `rank` based on constant shape.