diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -358,7 +358,7 @@ return nullptr; // TODO: Support folding with more than 2 input shapes - if (operands.size() > 2 && !operands[2].isa()) + if (shapes().size() > 2) return nullptr; auto rhsShape = llvm::to_vector<6>(