diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1111,14 +1111,43 @@ return success(); } +/// Extract operands and shape from a tensor with dynamic extents. +static void operandsAndShape(TensorType resultType, + Operation::operand_range dynamicExtents, + SmallVectorImpl &newOperands, + SmallVectorImpl &newShape) { + auto operandsIt = dynamicExtents.begin(); + for (int64_t dim : resultType.getShape()) { + if (!ShapedType::isDynamic(dim)) { + newShape.push_back(dim); + continue; + } + APInt index; + if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { + newShape.push_back(ShapedType::kDynamic); + newOperands.push_back(*operandsIt++); + continue; + } + newShape.push_back(index.getSExtValue()); + operandsIt++; + } +} + LogicalResult GenerateOp::verify() { // Ensure that the tensor type has as many dynamic dimensions as are // specified by the operands. - RankedTensorType resultTy = llvm::cast(getType()); - if (getNumOperands() != resultTy.getNumDynamicDims()) + RankedTensorType resultType = llvm::cast(getType()); + if (getNumOperands() != resultType.getNumDynamicDims()) return emitError("must have as many index operands as dynamic extents " "in the result type"); - + // Ensure operands are non-negative. + SmallVector newOperands; + SmallVector newShape; + operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape); + for (int64_t newdim : newShape) { + if (newdim < 0 && !ShapedType::isDynamic(newdim)) + return emitError("tensor dimensions must be non-negative"); + } return success(); } @@ -1176,24 +1205,11 @@ if (resultType.hasStaticShape()) return failure(); - SmallVector newOperands; - SmallVector newShape; - auto operandsIt = tensorFromElements.getDynamicExtents().begin(); - - for (int64_t dim : resultType.getShape()) { - if (!ShapedType::isDynamic(dim)) { - newShape.push_back(dim); - continue; - } - APInt index; - if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { - newShape.push_back(ShapedType::kDynamic); - newOperands.push_back(*operandsIt++); - continue; - } - newShape.push_back(index.getSExtValue()); - operandsIt++; - } + Operation::operand_range dynamicExtents = + tensorFromElements.getDynamicExtents(); + SmallVector newOperands; + SmallVector newShape; + operandsAndShape(resultType, dynamicExtents, newOperands, newShape); if (newOperands.size() == tensorFromElements.getDynamicExtents().size()) return failure(); diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -112,6 +112,20 @@ } : tensor return %tnsr : tensor } + +// ----- + +func.func @generate_negative_size() -> tensor { + %cst = arith.constant 0 : i32 + %size = index.constant -128 + // expected-error@+1 {{tensor dimensions must be non-negative}} + %tensor = tensor.generate %size { + ^bb0(%arg0: index, %arg1: index): + tensor.yield %cst : i32 + } : tensor + return %tensor : tensor +} + // ----- func.func @tensor.reshape_element_type_mismatch(