diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1198,6 +1198,11 @@ if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) return failure(); + for (int64_t newdim : newShape) { + if (newdim < 0 && !ShapedType::isDynamic(newdim)) + return failure(); + } + auto loc = tensorFromElements.getLoc(); auto newOp = rewriter.create( loc, RankedTensorType::get(newShape, resultType.getElementType()), diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -377,6 +377,20 @@ // ----- +// CHECK-LABEL: @generate_negative_size +func.func @generate_negative_size() -> tensor { + %cst = arith.constant 0 : i32 + %size = index.constant -128 + // CHECK: tensor.generate + %tensor = tensor.generate %size { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : i32 + } : tensor + return %tensor : tensor +} + +// ----- + // CHECK-LABEL: @from_elements.constant func.func @from_elements.constant() -> tensor<3xindex> { // CHECK: %[[CST:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xindex>