Index: lib/Transforms/Scalar/NaryReassociate.cpp =================================================================== --- lib/Transforms/Scalar/NaryReassociate.cpp +++ lib/Transforms/Scalar/NaryReassociate.cpp @@ -71,8 +71,8 @@ // // Limitations and TODO items: // -// 1) We only considers n-ary adds for now. This should be extended and -// generalized. +// 1) We only considers n-ary adds and muls for now. This should be extended +// and generalized. // //===----------------------------------------------------------------------===// @@ -145,12 +145,14 @@ unsigned I, Value *LHS, Value *RHS, Type *IndexedType); - // Reassociate Add for better CSE. - Instruction *tryReassociateAdd(BinaryOperator *I); - // A helper function for tryReassociateAdd. LHS and RHS are explicitly passed. - Instruction *tryReassociateAdd(Value *LHS, Value *RHS, Instruction *I); - // Rewrites I to LHS + RHS if LHS is computed already. - Instruction *tryReassociatedAdd(const SCEV *LHS, Value *RHS, Instruction *I); + // Reassociate binary operators for better CSE. + Instruction *tryReassociateBinaryOp(BinaryOperator *I); + // A helper function for tryReassociateBinaryOp. LHS and RHS are explicitly + // passed. + Instruction *tryReassociateBinaryOp(Value *LHS, Value *RHS, Instruction *I); + // Rewrites I to (LHS op RHS) if LHS is computed already. + Instruction *tryReassociatedBinaryOp(const SCEV *LHS, Value *RHS, + Instruction *I); // Returns the closest dominator of \c Dominatee that computes // \c CandidateExpr. Returns null if not found. @@ -219,6 +221,7 @@ switch (I->getOpcode()) { case Instruction::Add: case Instruction::GetElementPtr: + case Instruction::Mul: return true; default: return false; @@ -276,7 +279,8 @@ Instruction *NaryReassociate::tryReassociate(Instruction *I) { switch (I->getOpcode()) { case Instruction::Add: - return tryReassociateAdd(cast(I)); + case Instruction::Mul: + return tryReassociateBinaryOp(cast(I)); case Instruction::GetElementPtr: return tryReassociateGEP(cast(I)); default: @@ -453,45 +457,68 @@ return NewGEP; } -Instruction *NaryReassociate::tryReassociateAdd(BinaryOperator *I) { +Instruction *NaryReassociate::tryReassociateBinaryOp(BinaryOperator *I) { Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); - if (auto *NewI = tryReassociateAdd(LHS, RHS, I)) + if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I)) return NewI; - if (auto *NewI = tryReassociateAdd(RHS, LHS, I)) + if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I)) return NewI; return nullptr; } -Instruction *NaryReassociate::tryReassociateAdd(Value *LHS, Value *RHS, - Instruction *I) { +Instruction *NaryReassociate::tryReassociateBinaryOp(Value *LHS, Value *RHS, + Instruction *I) { Value *A = nullptr, *B = nullptr; - // To be conservative, we reassociate I only when it is the only user of A+B. - if (LHS->hasOneUse() && match(LHS, m_Add(m_Value(A), m_Value(B)))) { - // I = (A + B) + RHS - // = (A + RHS) + B or (B + RHS) + A + unsigned Opcode = I->getOpcode(); + bool Matched = + (Opcode == Instruction::Add && + match(LHS, m_Add(m_Value(A), m_Value(B)))) || + (Opcode == Instruction::Mul && match(LHS, m_Mul(m_Value(A), m_Value(B)))); + // To be conservative, we reassociate I only when it is the only user of (A op + // B). + if (LHS->hasOneUse() && Matched) { + // I = (A op B) op RHS + // = (A op RHS) op B or (B op RHS) op A const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); const SCEV *RHSExpr = SE->getSCEV(RHS); if (BExpr != RHSExpr) { - if (auto *NewI = tryReassociatedAdd(SE->getAddExpr(AExpr, RHSExpr), B, I)) + const SCEV *Expr = Opcode == Instruction::Add + ? SE->getAddExpr(AExpr, RHSExpr) + : SE->getMulExpr(AExpr, RHSExpr); + if (auto *NewI = tryReassociatedBinaryOp(Expr, B, I)) return NewI; } if (AExpr != RHSExpr) { - if (auto *NewI = tryReassociatedAdd(SE->getAddExpr(BExpr, RHSExpr), A, I)) + const SCEV *Expr = Opcode == Instruction::Add + ? SE->getAddExpr(BExpr, RHSExpr) + : SE->getMulExpr(BExpr, RHSExpr); + if (auto *NewI = tryReassociatedBinaryOp(Expr, A, I)) return NewI; } } return nullptr; } -Instruction *NaryReassociate::tryReassociatedAdd(const SCEV *LHSExpr, - Value *RHS, Instruction *I) { +Instruction *NaryReassociate::tryReassociatedBinaryOp(const SCEV *LHSExpr, + Value *RHS, + Instruction *I) { // Look for the closest dominator LHS of I that computes LHSExpr, and replace - // I with LHS + RHS. + // I with LHS op RHS. auto *LHS = findClosestMatchingDominator(LHSExpr, I); if (LHS == nullptr) return nullptr; - Instruction *NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I); + Instruction *NewI = nullptr; + switch (I->getOpcode()) { + case Instruction::Add: + NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I); + break; + case Instruction::Mul: + NewI = BinaryOperator::CreateMul(LHS, RHS, "", I); + break; + default: + llvm_unreachable("Unexpected instruction."); + } NewI->takeName(I); return NewI; } Index: test/Transforms/NaryReassociate/nary-mul.ll =================================================================== --- /dev/null +++ test/Transforms/NaryReassociate/nary-mul.ll @@ -0,0 +1,18 @@ +; RUN: opt < %s -nary-reassociate -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + +declare void @foo(i32) + +; CHECK-LABEL: @bar( +define void @bar(i32 %a, i32 %b, i32 %c) { + %1 = mul i32 %a, %c +; CHECK: [[BASE:%[a-zA-Z0-9]+]] = mul i32 %a, %c + call void @foo(i32 %1) + %2 = mul i32 %a, %b + %3 = mul i32 %2, %c +; CHECK: mul i32 [[BASE]], %b + call void @foo(i32 %3) + ret void +} +