diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -248,6 +248,11 @@ APInt index; if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) return failure(); + // Prevent out of bounds accesses. This can happen in invalid code that will + // never execute. + if (tensorFromElements->getNumOperands() >= index.getZExtValue() || + index.getSExtValue() < 0) + return failure(); rewriter.replaceOp(extract, tensorFromElements.getOperand(index.getZExtValue())); return success(); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -116,7 +116,36 @@ %c0 = constant 0 : index %tensor = tensor.from_elements %element : tensor<1xindex> %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex> - // CHECK: [[ARG]] : index + return %extracted_element : index +} + +// ----- + +// Ensure the optimization doesn't segfault from bad constants +// CHECK-LABEL: func @extract_negative_from_tensor.from_elements +func @extract_negative_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c-1 = constant -1 : index + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex> + // CHECK: tensor.from_elements + // CHECK: %[[RESULT:.*]] = tensor.extract + // CHECK: return %[[RESULT]] + return %extracted_element : index +} + +// ----- + +// Ensure the optimization doesn't segfault from bad constants +// CHECK-LABEL: func @extract_oob_from_tensor.from_elements +func @extract_oob_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c1 = constant 1 : index + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex> + // CHECK: tensor.from_elements + // CHECK: %[[RESULT:.*]] = tensor.extract + // CHECK: return %[[RESULT]] return %extracted_element : index }