diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1341,6 +1341,7 @@ OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef operands) { auto insertValueOp = getContainer().getDefiningOp(); + OpFoldResult result = {}; while (insertValueOp) { if (getPosition() == insertValueOp.getPosition()) return insertValueOp.getValue(); @@ -1358,10 +1359,16 @@ // ``` if (getPosition().getValue().take_front(min) == insertValueOp.getPosition().getValue().take_front(min)) - return {}; + return result; + + // If neither a prefix, nor the exact position, we can extract out of the + // value being inserted into. Moreover, we can try again if that operand + // is itself an insertvalue expression. + getContainerMutable().assign(insertValueOp.getContainer()); + result = getResult(); insertValueOp = insertValueOp.getContainer().getDefiningOp(); } - return {}; + return result; } LogicalResult ExtractValueOp::verify() { diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -37,6 +37,18 @@ %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> llvm.return %3 : f32 + +} +// ----- + +// CHECK-LABEL: fold_unrelated_extractvalue +llvm.func @fold_unrelated_extractvalue(%arr: !llvm.array<4xf32>) -> f32 { + %f0 = arith.constant 0.0 : f32 + // CHECK-NOT: insertvalue + // CHECK: extractvalue + %2 = llvm.insertvalue %f0, %arr[0] : !llvm.array<4xf32> + %3 = llvm.extractvalue %2[1] : !llvm.array<4xf32> + llvm.return %3 : f32 } // -----