diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -450,34 +450,47 @@ //===----------------------------------------------------------------------===// OpFoldResult BroadcastOp::fold(ArrayRef operands) { - if (operands.size() == 1) + if (shapes().size() == 1) return shapes().front(); + // If all but one of the operands are known to be empty shapes, return the + // remaining shape. + Value onlyPotentiallyNonEmptyShape; + Value value; + Attribute constOperand; + for (auto it : llvm::zip(shapes(), operands)) { + std::tie(value, constOperand) = it; + if (!constOperand || constOperand.cast().size() > 0) { + if (!onlyPotentiallyNonEmptyShape) { + onlyPotentiallyNonEmptyShape = value; + } else { + onlyPotentiallyNonEmptyShape = nullptr; + break; + } + } + } + if (onlyPotentiallyNonEmptyShape && + onlyPotentiallyNonEmptyShape.getType().isa() == + getType().isa()) + return onlyPotentiallyNonEmptyShape; + // TODO: Support folding with more than 2 input shapes if (shapes().size() > 2) return nullptr; - if (!operands[1]) - return nullptr; - - auto rhsShape = llvm::to_vector<6>( - operands[1].cast().getValues()); - if (rhsShape.empty()) - return shapes()[0]; - - if (!operands[0]) + if (!operands[0] || !operands[1]) return nullptr; - auto lhsShape = llvm::to_vector<6>( operands[0].cast().getValues()); - if (lhsShape.empty()) - return shapes()[1]; - + auto rhsShape = llvm::to_vector<6>( + operands[1].cast().getValues()); SmallVector resultShape; + // If the shapes are not compatible, we can't fold it. // TODO: Fold to an "error". if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) return nullptr; + Builder builder(getContext()); return builder.getIndexTensorAttr(resultShape); } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -120,6 +120,20 @@ // ----- +// All but one operands are known empty shapes. +// CHECK-LABEL: @all_but_one_empty +// CHECK-SAME: (%[[ARG:.*]]: !shape.shape) +func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape { + // CHECK: return %[[ARG]] + %0 = shape.const_shape [] : !shape.shape + %1 = constant dense<[]> : tensor<0xindex> + %2 = shape.broadcast %0, %arg0, %1, %0 : !shape.shape, !shape.shape, + tensor<0xindex>, !shape.shape -> !shape.shape + return %2 : !shape.shape +} + +// ----- + // Incompatible shapes. No folding. // CHECK-LABEL: func @f func @f() -> !shape.shape {