diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -402,6 +402,10 @@ return failure(); auto tensorType = tensorFromElements.getType().cast(); auto rank = tensorType.getRank(); + if (rank == 0) { + rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); + return success(); + } SmallVector indices(rank); int64_t flatIndex = 0; int64_t stride = 1; diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -135,6 +135,18 @@ // ----- +// CHECK-LABEL: func @extract_from_tensor.from_elements_0d +func @extract_from_tensor.from_elements_0d(%element : index) -> index { + // CHECK-SAME: ([[ARG:%.*]]: index) + %c0 = arith.constant 0 : index + %tensor = tensor.from_elements %element : tensor + %extracted_element = tensor.extract %tensor[] : tensor + // CHECK: [[ARG]] : index + return %extracted_element : index +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.from_elements_3d func @extract_from_tensor.from_elements_3d() -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {