diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -80,6 +80,15 @@ RegImmPair(Register Reg, int64_t Imm) : Reg(Reg), Imm(Imm) {} }; +/// Used to describe addressing mode similar to ExtAddrMode in CodeGenPrepare. +/// It holds the register values, the scale value and the displacement. +struct ExtAddrMode { + Register BaseReg; + Register ScaledReg; + int64_t Scale; + int64_t Displacement; +}; + //--------------------------------------------------------------------------- /// /// TargetInstrInfo - Interface to description of machine instruction set @@ -968,6 +977,15 @@ return None; } + /// Returns true if MI is an instruction that defines Reg to have a constant + /// value and the value is recorded in ImmVal. The ImmVal is a result that + /// should be interpreted as modulo size of Reg. + virtual bool getConstValDefinedInReg(const MachineInstr &MI, + const Register Reg, + int64_t &ImmVal) const { + return false; + } + /// Store the specified register of the given register class to the specified /// stack frame index. The store instruction is to be added to the given /// machine basic block before the specified machine instruction. If isKill @@ -1270,6 +1288,16 @@ return false; } + /// Target dependent implementation to get the values constituting the address + /// MachineInstr that is accessing memory. These values are returned as a + /// struct ExtAddrMode which contains all relevant information to make up the + /// address. + virtual Optional + getAddrModeFromMemoryOp(const MachineInstr &MemI, + const TargetRegisterInfo *TRI) const { + return None; + } + /// Returns true if MI's Def is NullValueReg, and the MI /// does not change the Zero value. i.e. cases such as rax = shr rax, X where /// NullValueReg = rax. Note that if the NullValueReg is non-zero, this diff --git a/llvm/lib/CodeGen/ImplicitNullChecks.cpp b/llvm/lib/CodeGen/ImplicitNullChecks.cpp --- a/llvm/lib/CodeGen/ImplicitNullChecks.cpp +++ b/llvm/lib/CodeGen/ImplicitNullChecks.cpp @@ -378,26 +378,100 @@ if (MI.getDesc().getNumDefs() > 1) return SR_Unsuitable; - // FIXME: This handles only simple addressing mode. - if (!TII->getMemOperandWithOffset(MI, BaseOp, Offset, OffsetIsScalable, TRI)) + if (!MI.mayLoadOrStore() || MI.isPredicable()) + return SR_Unsuitable; + auto AM = TII->getAddrModeFromMemoryOp(MI, TRI); + if (!AM) return SR_Unsuitable; + auto AddrMode = *AM; + const Register BaseReg = AddrMode.BaseReg, ScaledReg = AddrMode.ScaledReg; + int64_t Displacement = AddrMode.Displacement; // We need the base of the memory instruction to be same as the register // where the null check is performed (i.e. PointerReg). - if (!BaseOp->isReg() || BaseOp->getReg() != PointerReg) + if (BaseReg != PointerReg && ScaledReg != PointerReg) return SR_Unsuitable; - - // Scalable offsets are a part of scalable vectors (SVE for AArch64). That - // target is in-practice unsupported for ImplicitNullChecks. - if (OffsetIsScalable) + const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); + unsigned PointerRegSizeInBits = TRI->getRegSizeInBits(PointerReg, MRI); + // Bail out of the sizes of BaseReg, ScaledReg and PointerReg are not the + // same. + if ((BaseReg && + TRI->getRegSizeInBits(BaseReg, MRI) != PointerRegSizeInBits) || + (ScaledReg && + TRI->getRegSizeInBits(ScaledReg, MRI) != PointerRegSizeInBits)) return SR_Unsuitable; - if (!MI.mayLoadOrStore() || MI.isPredicable()) + // Returns true if RegUsedInAddr is used for calculating the displacement + // depending on addressing mode. Also calculates the Displacement. + auto CalculateDisplacementFromAddrMode = [&](Register RegUsedInAddr, + int64_t Multiplier) { + // The register can be NoRegister, which is defined as zero for all targets. + // Consider instruction of interest as `movq 8(,%rdi,8), %rax`. Here the + // ScaledReg is %rdi, while there is no BaseReg. + if (!RegUsedInAddr) + return false; + assert(Multiplier && "expected to be non-zero!"); + MachineInstr *ModifyingMI = nullptr; + for (auto It = std::next(MachineBasicBlock::const_reverse_iterator(&MI)); + It != MI.getParent()->rend(); It++) { + const MachineInstr *CurrMI = &*It; + if (CurrMI->modifiesRegister(RegUsedInAddr, TRI)) { + ModifyingMI = const_cast(CurrMI); + break; + } + } + if (!ModifyingMI) + return false; + // Check for the const value defined in register by ModifyingMI. This means + // all other previous values for that register has been invalidated. + int64_t ImmVal; + if (!TII->getConstValDefinedInReg(*ModifyingMI, RegUsedInAddr, ImmVal)) + return false; + // Calculate the reg size in bits, since this is needed for bailing out in + // case of overflow. + int32_t RegSizeInBits = TRI->getRegSizeInBits(RegUsedInAddr, MRI); + APInt ImmValC(RegSizeInBits, ImmVal, true /*IsSigned*/); + APInt MultiplierC(RegSizeInBits, Multiplier); + assert(MultiplierC.isStrictlyPositive() && + "expected to be a positive value!"); + bool IsOverflow; + // Sign of the product depends on the sign of the ImmVal, since Multiplier + // is always positive. + APInt Product = ImmValC.smul_ov(MultiplierC, IsOverflow); + if (IsOverflow) + return false; + APInt DisplacementC(64, Displacement, true /*isSigned*/); + DisplacementC = Product.sadd_ov(DisplacementC, IsOverflow); + if (IsOverflow) + return false; + + // We only handle diplacements upto 64 bits wide. + if (DisplacementC.getActiveBits() > 64) + return false; + Displacement = DisplacementC.getSExtValue(); + return true; + }; + + // If a register used in the address is constant, fold it's effect into the + // displacement for ease of analysis. + bool BaseRegIsConstVal = false, ScaledRegIsConstVal = false; + if (CalculateDisplacementFromAddrMode(BaseReg, 1)) + BaseRegIsConstVal = true; + if (CalculateDisplacementFromAddrMode(ScaledReg, AddrMode.Scale)) + ScaledRegIsConstVal = true; + + // The register which is not null checked should be part of the Displacement + // calculation, otherwise we do not know whether the Displacement is made up + // by some symbolic values. + // This matters because we do not want to incorrectly assume that load from + // falls in the zeroth faulting page in the "sane offset check" below. + if ((BaseReg && BaseReg != PointerReg && !BaseRegIsConstVal) || + (ScaledReg && ScaledReg != PointerReg && !ScaledRegIsConstVal)) return SR_Unsuitable; // We want the mem access to be issued at a sane offset from PointerReg, // so that if PointerReg is null then the access reliably page faults. - if (!(-PageSize < Offset && Offset < PageSize)) + if (!(-PageSize < Displacement && Displacement < PageSize)) return SR_Unsuitable; // Finally, check whether the current memory access aliases with previous one. diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h @@ -113,6 +113,10 @@ /// Hint that pairing the given load or store is unprofitable. static void suppressLdStPair(MachineInstr &MI); + Optional + getAddrModeFromMemoryOp(const MachineInstr &MemI, + const TargetRegisterInfo *TRI) const override; + bool getMemOperandsWithOffsetWidth( const MachineInstr &MI, SmallVectorImpl &BaseOps, int64_t &Offset, bool &OffsetIsScalable, unsigned &Width, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2144,6 +2144,24 @@ return true; } +Optional +AArch64InstrInfo::getAddrModeFromMemoryOp(const MachineInstr &MemI, + const TargetRegisterInfo *TRI) const { + const MachineOperand *Base; // Filled with the base operand of MI. + int64_t Offset; // Filled with the offset of MI. + bool OffsetIsScalable; + if (!getMemOperandWithOffset(MemI, Base, Offset, OffsetIsScalable, TRI)) + return None; + + if (!Base->isReg()) + return None; + ExtAddrMode AM; + AM.BaseReg = Base->getReg(); + AM.Displacement = Offset; + AM.ScaledReg = 0; + return AM; +} + bool AArch64InstrInfo::getMemOperandWithOffsetWidth( const MachineInstr &LdSt, const MachineOperand *&BaseOp, int64_t &Offset, bool &OffsetIsScalable, unsigned &Width, diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -317,6 +317,13 @@ SmallVectorImpl &Cond, bool AllowModify) const override; + Optional + getAddrModeFromMemoryOp(const MachineInstr &MemI, + const TargetRegisterInfo *TRI) const override; + + bool getConstValDefinedInReg(const MachineInstr &MI, const Register Reg, + int64_t &ImmVal) const override; + bool preservesZeroValueInReg(const MachineInstr *MI, const Register NullValueReg, const TargetRegisterInfo *TRI) const override; diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -3663,6 +3663,45 @@ } } +Optional +X86InstrInfo::getAddrModeFromMemoryOp(const MachineInstr &MemI, + const TargetRegisterInfo *TRI) const { + const MCInstrDesc &Desc = MemI.getDesc(); + int MemRefBegin = X86II::getMemoryOperandNo(Desc.TSFlags); + if (MemRefBegin < 0) + return None; + + MemRefBegin += X86II::getOperandBias(Desc); + + auto &BaseOp = MemI.getOperand(MemRefBegin + X86::AddrBaseReg); + if (!BaseOp.isReg()) // Can be an MO_FrameIndex + return None; + + const MachineOperand &DispMO = MemI.getOperand(MemRefBegin + X86::AddrDisp); + // Displacement can be symbolic + if (!DispMO.isImm()) + return None; + + ExtAddrMode AM; + AM.BaseReg = BaseOp.getReg(); + AM.ScaledReg = MemI.getOperand(MemRefBegin + X86::AddrIndexReg).getReg(); + AM.Scale = MemI.getOperand(MemRefBegin + X86::AddrScaleAmt).getImm(); + AM.Displacement = DispMO.getImm(); + return AM; +} + +bool X86InstrInfo::getConstValDefinedInReg(const MachineInstr &MI, + const Register Reg, + int64_t &ImmVal) const { + if (MI.getOpcode() != X86::MOV32ri && MI.getOpcode() != X86::MOV64ri) + return false; + // Mov Src can be a global address. + if (!MI.getOperand(1).isImm() || MI.getOperand(0).getReg() != Reg) + return false; + ImmVal = MI.getOperand(1).getImm(); + return true; +} + bool X86InstrInfo::preservesZeroValueInReg( const MachineInstr *MI, const Register NullValueReg, const TargetRegisterInfo *TRI) const { diff --git a/llvm/test/CodeGen/X86/implicit-null-check-negative.ll b/llvm/test/CodeGen/X86/implicit-null-check-negative.ll --- a/llvm/test/CodeGen/X86/implicit-null-check-negative.ll +++ b/llvm/test/CodeGen/X86/implicit-null-check-negative.ll @@ -129,4 +129,24 @@ %t = load i64, i64* %x.loc ret i64 %t } + +; the memory op is not within faulting page. +define i64 @imp_null_check_load_addr_outside_faulting_page(i64* %x) { + entry: + %c = icmp eq i64* %x, null + br i1 %c, label %is_null, label %not_null, !make.implicit !0 + + is_null: + ret i64 42 + + not_null: + %y = ptrtoint i64* %x to i64 + %shry = shl i64 %y, 3 + %shry.add = add i64 %shry, 68719472640 + %y.ptr = inttoptr i64 %shry.add to i64* + %x.loc = getelementptr i64, i64* %y.ptr, i64 1 + %t = load i64, i64* %x.loc + ret i64 %t +} + !0 = !{} diff --git a/llvm/test/CodeGen/X86/implicit-null-check.ll b/llvm/test/CodeGen/X86/implicit-null-check.ll --- a/llvm/test/CodeGen/X86/implicit-null-check.ll +++ b/llvm/test/CodeGen/X86/implicit-null-check.ll @@ -1,4 +1,3 @@ -; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -verify-machineinstrs -O3 -mtriple=x86_64-apple-macosx -enable-implicit-null-checks < %s | FileCheck %s define i32 @imp_null_check_load(i32* %x) { @@ -593,14 +592,12 @@ ; Same as imp_null_check_load_shift_addr but shift is by 3 and this is now ; converted into complex addressing. -; TODO: Can be converted into implicit null check define i64 @imp_null_check_load_shift_by_3_addr(i64* %x) { ; CHECK-LABEL: imp_null_check_load_shift_by_3_addr: ; CHECK: ## %bb.0: ## %entry -; CHECK-NEXT: testq %rdi, %rdi -; CHECK-NEXT: je LBB22_1 +; CHECK-NEXT: Ltmp18: +; CHECK-NEXT: movq 8(,%rdi,8), %rax ## on-fault: LBB22_1 ; CHECK-NEXT: ## %bb.2: ## %not_null -; CHECK-NEXT: movq 8(,%rdi,8), %rax ; CHECK-NEXT: retq ; CHECK-NEXT: LBB22_1: ## %is_null ; CHECK-NEXT: movl $42, %eax @@ -621,4 +618,31 @@ %t = load i64, i64* %x.loc ret i64 %t } + +define i64 @imp_null_check_load_shift_add_addr(i64* %x) { +; CHECK-LABEL: imp_null_check_load_shift_add_addr: +; CHECK: ## %bb.0: ## %entry +; CHECK: movq 3526(,%rdi,8), %rax ## on-fault: LBB23_1 +; CHECK-NEXT: ## %bb.2: ## %not_null +; CHECK-NEXT: retq +; CHECK-NEXT: LBB23_1: ## %is_null +; CHECK-NEXT: movl $42, %eax +; CHECK-NEXT: retq + + entry: + %c = icmp eq i64* %x, null + br i1 %c, label %is_null, label %not_null, !make.implicit !0 + + is_null: + ret i64 42 + + not_null: + %y = ptrtoint i64* %x to i64 + %shry = shl i64 %y, 3 + %shry.add = add i64 %shry, 3518 + %y.ptr = inttoptr i64 %shry.add to i64* + %x.loc = getelementptr i64, i64* %y.ptr, i64 1 + %t = load i64, i64* %x.loc + ret i64 %t +} !0 = !{} diff --git a/llvm/test/CodeGen/X86/implicit-null-checks.mir b/llvm/test/CodeGen/X86/implicit-null-checks.mir --- a/llvm/test/CodeGen/X86/implicit-null-checks.mir +++ b/llvm/test/CodeGen/X86/implicit-null-checks.mir @@ -377,6 +377,22 @@ ret i32 undef } + define i32 @imp_null_check_address_mul_overflow(i32* %x, i32 %a) { + entry: + %c = icmp eq i32* %x, null + br i1 %c, label %is_null, label %not_null, !make.implicit !0 + + is_null: ; preds = %entry + ret i32 42 + + not_null: ; preds = %entry + %y = ptrtoint i32* %x to i32 + %y64 = zext i32 %y to i64 + %b = mul i64 %y64, 9223372036854775807 ; 0X0FFFF.. i.e. 2^63 - 1 + %z = trunc i64 %b to i32 + ret i32 %z + } + attributes #0 = { "target-features"="+bmi,+bmi2" } !0 = !{} @@ -1316,3 +1332,32 @@ RETQ $eax ... +--- +name: imp_null_check_address_mul_overflow +# CHECK-LABEL: name: imp_null_check_address_mul_overflow +# CHECK: bb.0.entry: +# CHECK-NOT: FAULTING_OP +alignment: 16 +tracksRegLiveness: true +liveins: + - { reg: '$rdi' } + - { reg: '$rsi' } +body: | + bb.0.entry: + liveins: $rsi, $rdi + + TEST64rr $rdi, $rdi, implicit-def $eflags + JCC_1 %bb.1, 4, implicit $eflags + + bb.2.not_null: + liveins: $rdi, $rsi + + $rcx = MOV64ri -9223372036854775808 + $eax = MOV32rm killed $rdi, 2, $rcx, 0, $noreg, implicit-def $rax + RETQ $eax + + bb.1.is_null: + $eax = MOV32ri 42 + RETQ $eax + +...