diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -70,6 +70,9 @@ /// Returns the number of elements held by this attribute. int64_t size() const { return getNumElements(); } + /// Returns if the number of elements held by this attribute is 0. + bool empty() const { return size() == 0; } + /// Generates a new ElementsAttr by mapping each int value to a new /// underlying APInt. The new values can represent either an integer or float. /// This ElementsAttr should contain integers. diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -534,8 +534,14 @@ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { auto isPotentiallyNonEmptyShape = [](Value shape) { - if (auto constShape = shape.getDefiningOp()) - return constShape.shape().size() != 0; + if (auto extentTensorTy = shape.getType().dyn_cast()) { + if (extentTensorTy.getDimSize(0) == 0) + return false; + } + if (auto constShape = shape.getDefiningOp()) { + if (constShape.shape().empty()) + return false; + } return true; }; auto newOperands = llvm::to_vector<8>( diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -641,13 +641,13 @@ // ----- // Empty shape arguments can be removed from broadcastable ops. // CHECK-LABEL: func @f -// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -func @f(%arg0 : tensor, %arg1 : tensor) { +// CHECK-SAME: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %{{.*}}: tensor<0xindex>) +func @f(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<0xindex>) { // CHECK-NOT: const_shape // CHECK: cstr_broadcastable %[[ARG0]], %[[ARG1]] : tensor, tensor %0 = shape.const_shape [] : !shape.shape - %1 = shape.cstr_broadcastable %arg0, %arg1, %0 - : tensor, tensor, !shape.shape + %1 = shape.cstr_broadcastable %arg0, %arg1, %0, %arg2 + : tensor, tensor, !shape.shape, tensor<0xindex> "consume.witness"(%1) : (!shape.witness) -> () return }