diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -2489,34 +2489,60 @@ // match the shift as a scale factor. if (AM.IndexReg.getNode() != nullptr || AM.Scale != 1) break; - if (N.getOperand(0).getOpcode() != ISD::SHL || !N.getOperand(0).hasOneUse()) + + // Peek through mask: zext(and(shl(x,c1),c2)) + SDValue Src = N.getOperand(0); + APInt Mask = APInt::getAllOnes(Src.getScalarValueSizeInBits()); + if (Src.getOpcode() == ISD::AND && Src.hasOneUse()) + if (auto *MaskC = dyn_cast(Src.getOperand(1))) { + Mask = MaskC->getAPIntValue(); + Src = Src.getOperand(0); + } + + if (Src.getOpcode() != ISD::SHL || !Src.hasOneUse()) break; // Give up if the shift is not a valid scale factor [1,2,3]. - SDValue Shl = N.getOperand(0); - auto *ShAmtC = dyn_cast(Shl.getOperand(1)); - if (!ShAmtC || ShAmtC->getZExtValue() > 3) + SDValue ShlSrc = Src.getOperand(0); + SDValue ShlAmt = Src.getOperand(1); + auto *ShAmtC = dyn_cast(ShlAmt); + if (!ShAmtC) + break; + unsigned ShAmtV = ShAmtC->getZExtValue(); + if (ShAmtV > 3) break; // The narrow shift must only shift out zero bits (it must be 'nuw'). // That makes it safe to widen to the destination type. - APInt HighZeros = APInt::getHighBitsSet(Shl.getValueSizeInBits(), - ShAmtC->getZExtValue()); - if (!CurDAG->MaskedValueIsZero(Shl.getOperand(0), HighZeros)) + APInt HighZeros = + APInt::getHighBitsSet(ShlSrc.getValueSizeInBits(), ShAmtV); + if (!CurDAG->MaskedValueIsZero(ShlSrc, HighZeros & Mask)) break; - // zext (shl nuw i8 %x, C) to i32 --> shl (zext i8 %x to i32), (zext C) + // zext (shl nuw i8 %x, C1) to i32 + // --> shl (zext i8 %x to i32), (zext C1) + // zext (and (shl nuw i8 %x, C1), C2) to i32 + // --> shl (zext i8 (and %x, C2 >> C1) to i32), (zext C1) + MVT SrcVT = ShlSrc.getSimpleValueType(); MVT VT = N.getSimpleValueType(); SDLoc DL(N); - SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Shl.getOperand(0)); - SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, Shl.getOperand(1)); + + SDValue Res = ShlSrc; + if (!Mask.isAllOnes()) { + Res = CurDAG->getConstant(Mask.lshr(ShAmtV), DL, SrcVT); + insertDAGNode(*CurDAG, N, Res); + Res = CurDAG->getNode(ISD::AND, DL, SrcVT, ShlSrc, Res); + insertDAGNode(*CurDAG, N, Res); + } + SDValue Zext = CurDAG->getNode(ISD::ZERO_EXTEND, DL, VT, Res); + insertDAGNode(*CurDAG, N, Zext); + SDValue NewShl = CurDAG->getNode(ISD::SHL, DL, VT, Zext, ShlAmt); + insertDAGNode(*CurDAG, N, NewShl); // Convert the shift to scale factor. - AM.Scale = 1 << ShAmtC->getZExtValue(); + AM.Scale = 1 << ShAmtV; AM.IndexReg = Zext; - insertDAGNode(*CurDAG, N, Zext); - insertDAGNode(*CurDAG, N, NewShl); CurDAG->ReplaceAllUsesWith(N, NewShl); CurDAG->RemoveDeadNode(N.getNode()); return false; diff --git a/llvm/test/CodeGen/X86/lea-dagdag.ll b/llvm/test/CodeGen/X86/lea-dagdag.ll --- a/llvm/test/CodeGen/X86/lea-dagdag.ll +++ b/llvm/test/CodeGen/X86/lea-dagdag.ll @@ -153,10 +153,9 @@ define i64 @shl_and_i8_zext_add_i64(i64 %t0, i8 %t1) { ; CHECK-LABEL: shl_and_i8_zext_add_i64: ; CHECK: # %bb.0: -; CHECK-NEXT: shlb $2, %sil -; CHECK-NEXT: andb $60, %sil +; CHECK-NEXT: andb $15, %sil ; CHECK-NEXT: movzbl %sil, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: leaq (%rdi,%rax,4), %rax ; CHECK-NEXT: retq %s = shl i8 %t1, 2 %m = and i8 %s, 60 @@ -169,9 +168,8 @@ ; CHECK-LABEL: shl_and_i16_zext_add_i64: ; CHECK: # %bb.0: ; CHECK-NEXT: # kill: def $esi killed $esi def $rsi -; CHECK-NEXT: leal (%rsi,%rsi), %eax -; CHECK-NEXT: andl $16, %eax -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: andl $8, %esi +; CHECK-NEXT: leaq (%rdi,%rsi,2), %rax ; CHECK-NEXT: retq %s = shl i16 %t1, 1 %m = and i16 %s, 17 @@ -184,9 +182,8 @@ ; CHECK-LABEL: shl_and_i32_zext_add_i64: ; CHECK: # %bb.0: ; CHECK-NEXT: # kill: def $esi killed $esi def $rsi -; CHECK-NEXT: leal (,%rsi,8), %eax -; CHECK-NEXT: andl $5992, %eax # imm = 0x1768 -; CHECK-NEXT: addq %rdi, %rax +; CHECK-NEXT: andl $749, %esi # imm = 0x2ED +; CHECK-NEXT: leaq (%rdi,%rsi,8), %rax ; CHECK-NEXT: retq %s = shl i32 %t1, 3 %m = and i32 %s, 5999