diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1155,6 +1155,21 @@ while (insertValueOp) { if (getPosition() == insertValueOp.getPosition()) return insertValueOp.getValue(); + unsigned min = + std::min(getPosition().size(), insertValueOp.getPosition().size()); + // If one is fully prefix of the other, stop propagating back as it will + // miss dependencies. For instance, %3 should not fold to %f0 in the + // following example: + // ``` + // %1 = llvm.insertvalue %f0, %0[0, 0] : + // !llvm.array<4 x !llvm.array<4xf32>> + // %2 = llvm.insertvalue %arr, %1[0] : + // !llvm.array<4 x !llvm.array<4xf32>> + // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> + // ``` + if (getPosition().getValue().take_front(min) == + insertValueOp.getPosition().getValue().take_front(min)) + return {}; insertValueOp = insertValueOp.getContainer().getDefiningOp(); } return {}; diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -21,3 +21,20 @@ %5 = llvm.add %3, %4 : i32 llvm.return %5 : i32 } + +// ----- + +// CHECK-LABEL: no_fold_extractvalue +llvm.func @fold_extractvalue(%arr: !llvm.array<4xf32>) -> f32 { + %f0 = arith.constant 0.0 : f32 + %0 = llvm.mlir.undef : !llvm.array<4 x !llvm.array<4xf32>> + + // CHECK: insertvalue + // CHECK: insertvalue + // CHECK: extractvalue + %1 = llvm.insertvalue %f0, %0[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> + %2 = llvm.insertvalue %arr, %1[0] : !llvm.array<4 x !llvm.array<4xf32>> + %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4xf32>> + + llvm.return %3 : f32 +}