diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2699,7 +2699,7 @@ OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { // gep %x:T, 0 -> %x if (getBase().getType() == getType() && getIndices().size() == 1 && - matchPattern(getIndices()[0], m_Zero())) + getStructIndices().size() == 1 && matchPattern(getIndices()[0], m_Zero())) return getBase(); return {}; } diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -100,6 +100,17 @@ llvm.return %c : !llvm.ptr } +// CHECK-LABEL: fold_gep_neg +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: %[[C:.*]] = arith.constant 0 +// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][%[[C]], 1] +// CHECK-NEXT: llvm.return %[[RES]] +llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr { + %c0 = arith.constant 0 : i32 + %0 = llvm.getelementptr %x[%c0, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32, i32)> + llvm.return %0 : !llvm.ptr +} + // ----- // Check that LLVM constants participate in cross-dialect constant folding. The