diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -248,6 +248,11 @@ APInt index; if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) return failure(); + // Prevent out of bounds accesses. This can happen in invalid code that will + // never execute. + if (tensorFromElements->getNumOperands() <= index.getZExtValue() || + index.getSExtValue() < 0) + return failure(); rewriter.replaceOp(extract, tensorFromElements.getOperand(index.getZExtValue())); return success(); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -122,6 +122,51 @@ // ----- +// Ensure the optimization doesn't segfault from bad constants +// CHECK-LABEL: func @extract_negative_from_tensor.from_elements +func @extract_negative_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c-1 = constant -1 : index + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex> + // CHECK: tensor.from_elements + // CHECK: %[[RESULT:.*]] = tensor.extract + // CHECK: return %[[RESULT]] + return %extracted_element : index +} + +// ----- + +// Ensure the optimization doesn't segfault from bad constants +// CHECK-LABEL: func @extract_oob_from_tensor.from_elements +func @extract_oob_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c1 = constant 1 : index + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex> + // CHECK: tensor.from_elements + // CHECK: %[[RESULT:.*]] = tensor.extract + // CHECK: return %[[RESULT]] + return %extracted_element : index +} + +// ----- + +// Ensure the optimization doesn't segfault from bad constants +// CHECK-LABEL: func @extract_oob_from_tensor.from_elements +func @extract_oob_from_tensor.from_elements(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c2 = constant 2 : index + %tensor = tensor.from_elements %element : tensor<1xindex> + %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex> + // CHECK: tensor.from_elements + // CHECK: %[[RESULT:.*]] = tensor.extract + // CHECK: return %[[RESULT]] + return %extracted_element : index +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.generate // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {