diff --git a/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h b/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h --- a/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/NaryReassociate.h @@ -149,6 +149,10 @@ // Tries to match Op1 and Op2 by using V. bool matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, Value *&Op2); + // Tries to match X, C1 and C2 by using LHS and RHS. + bool matchConstLShiftOp(BinaryOperator *I, Value *LHS, Value *RHS, Value *&X, + const APInt *&C1, const APInt *&C2); + // Gets SCEV for (LHS op RHS). const SCEV *getBinarySCEV(BinaryOperator *I, const SCEV *LHS, const SCEV *RHS); diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -477,10 +477,12 @@ Instruction *NaryReassociatePass::tryReassociateBinaryOp(Value *LHS, Value *RHS, BinaryOperator *I) { + // To be conservative, we reassociate I only when it is the only user of LHS. + if (!LHS->hasOneUse()) + return nullptr; + Value *A = nullptr, *B = nullptr; - // To be conservative, we reassociate I only when it is the only user of (A op - // B). - if (LHS->hasOneUse() && matchTernaryOp(I, LHS, A, B)) { + if (matchTernaryOp(I, LHS, A, B)) { // I = (A op B) op RHS // = (A op RHS) op B or (B op RHS) op A const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); @@ -495,6 +497,19 @@ tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I)) return NewI; } + return nullptr; + } + + // Handle special case: I = (X << C1) + C2, where C1 and C2 are constant + // values such that I can be represented as (X + C3) << C1, C2 = C3 << C1. + // Then (X + C3) expression can be CSE'd later. + Value *X = nullptr; + const APInt *C1 = nullptr, *C2 = nullptr; + if (matchConstLShiftOp(I, LHS, RHS, X, C1, C2)) { + auto *C3Val = ConstantInt::get(I->getType(), C2->lshr(*C1)); + auto *Add = BinaryOperator::CreateAdd(X, C3Val, "add.nary", I); + auto *C1Val = ConstantInt::get(I->getType(), *C1); + return BinaryOperator::CreateShl(Add, C1Val, "shl.nary", I); } return nullptr; } @@ -523,6 +538,20 @@ return NewI; } +// Match if I = (X << C1) + C2, such that C2 = C3 << C1, +// where C1, C2 and C3 are constant values. +bool NaryReassociatePass::matchConstLShiftOp(BinaryOperator *I, Value *LHS, + Value *RHS, Value *&X, + const APInt *&C1, + const APInt *&C2) { + if (I->getOpcode() == Instruction::Add) { + if (match(LHS, m_Shl(m_Value(X), m_APInt(C1))) && match(RHS, m_APInt(C2)) && + (C2->countTrailingZeros() >= C1->getZExtValue())) + return true; + } + return false; +} + bool NaryReassociatePass::matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, Value *&Op2) { switch (I->getOpcode()) { diff --git a/llvm/test/Transforms/NaryReassociate/NVPTX/const-shl.ll b/llvm/test/Transforms/NaryReassociate/NVPTX/const-shl.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/NaryReassociate/NVPTX/const-shl.ll @@ -0,0 +1,63 @@ +; RUN: opt < %s -passes=instcombine,nary-reassociate -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefix=INSTCOMB + +define i32 @foo(i32 %x) { +; CHECK-LABEL: @foo +; CHECK-NEXT: %add.nary = add i32 %x, 787 +; CHECK-NEXT: %shl.nary = shl i32 %add.nary, 22 +; CHECK-NEXT: %add.nary1 = add i32 %x, 787 +; CHECK-NEXT: %shl.nary2 = shl i32 %add.nary1, 11 +; CHECK-NEXT: %add.nary3 = add i32 %x, 787 +; CHECK-NEXT: %shl.nary4 = shl i32 %add.nary3, 3 +; CHECK-NEXT: %add.nary5 = add i32 %x, 787 +; CHECK-NEXT: %shl.nary6 = shl i32 %add.nary5, 13 +; CHECK-NEXT: %or1 = or i32 %shl.nary, %shl.nary2 +; CHECK-NEXT: %or2 = or i32 %shl.nary4, %shl.nary6 +; CHECK-NEXT: %or3 = or i32 %or1, %or2 +; CHECK-NEXT: ret i32 %or3 + + %add1 = shl i32 %x, 22 + %shl1 = add i32 %add1, -994050048 + %add2 = shl i32 %x, 11 + %shl2 = add i32 %add2, 1611776 + %add3 = shl i32 %x, 3 + %shl3 = add i32 %add3, 6296 + %add4 = shl i32 %x, 13 + %shl4 = add i32 %add4, 6447104 + %or1 = or i32 %shl1, %shl2 + %or2 = or i32 %shl3, %shl4 + %or3 = or i32 %or1, %or2 + ret i32 %or3 +} + +; An example of how instcombine performs transformations +; that block GVN/CSE from happening. + +define i32 @bar(i32 %x) { +; INSTCOMB-LABEL: @bar +; INSTCOMB-NEXT: %add.nary = shl i32 %x, 22 +; INSTCOMB-NEXT: %shl.nary = add i32 %add.nary, -994050048 +; INSTCOMB-NEXT: %add.nary1 = shl i32 %x, 11 +; INSTCOMB-NEXT: %shl.nary2 = add i32 %add.nary1, 1611776 +; INSTCOMB-NEXT: %add.nary3 = shl i32 %x, 3 +; INSTCOMB-NEXT: %shl.nary4 = add i32 %add.nary3, 6296 +; INSTCOMB-NEXT: %add.nary5 = shl i32 %x, 13 +; INSTCOMB-NEXT: %shl.nary6 = add i32 %add.nary5, 6447104 +; INSTCOMB-NEXT: %or1 = or i32 %shl.nary, %shl.nary2 +; INSTCOMB-NEXT: %or2 = or i32 %shl.nary4, %shl.nary6 +; INSTCOMB-NEXT: %or3 = or i32 %or1, %or2 +; INSTCOMB-NEXT: ret i32 %or3 + + %add.nary = add i32 %x, 787 + %shl.nary = shl i32 %add.nary, 22 + %add.nary1 = add i32 %x, 787 + %shl.nary2 = shl i32 %add.nary1, 11 + %add.nary3 = add i32 %x, 787 + %shl.nary4 = shl i32 %add.nary3, 3 + %add.nary5 = add i32 %x, 787 + %shl.nary6 = shl i32 %add.nary5, 13 + %or1 = or i32 %shl.nary, %shl.nary2 + %or2 = or i32 %shl.nary4, %shl.nary6 + %or3 = or i32 %or1, %or2 + ret i32 %or3 +}