Index: lib/Transforms/Scalar/StraightLineStrengthReduce.cpp =================================================================== --- lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -15,37 +15,33 @@ // // There are many optimizations we can perform in the domain of SLSR. This file // for now contains only an initial step. Specifically, we look for strength -// reduction candidates in two forms: +// reduction candidates in the following forms: // -// Form 1: (B + i) * S -// Form 2: &B[i * S] +// Form 1: B + i * S +// Form 2: (B + i) * S +// Form 3: &B[i * S] // // where S is an integer variable, and i is a constant integer. If we found two -// candidates +// candidates S1 and S2 in the same form and S1 dominates S2, we may be able to +// rewrite S2 in a simpler way with respect to S1. For example, // -// S1: X = (B + i) * S -// S2: Y = (B + i') * S +// S1: X = B + i * S +// S2: Y = B + i' * S => X + (i' - i) * S // -// or +// S1: X = (B + i) * S +// S2: Y = (B + i') * S => X + (i' - i) * S // // S1: X = &B[i * S] -// S2: Y = &B[i' * S] -// -// and S1 dominates S2, we call S1 a basis of S2, and can replace S2 with -// -// Y = X + (i' - i) * S +// S2: Y = &B[i' * S] => &X[(i' - i) * S] // -// or +// Note: (i' - i) * S is folded to the extend possible. // -// Y = &X[(i' - i) * S] -// -// where (i' - i) * S is folded to the extent possible. When S2 has multiple -// bases, we pick the one that is closest to S2, or S2's "immediate" basis. +// When such rewriting is possible, we call S1 a "basis" of S2. When S2 has +// multiple bases, we choose to rewrite S2 with respect to the closest basis or +// the "immediate" basis. // // TODO: // -// - Handle candidates in the form of B + i * S -// // - Floating point arithmetics when fast math is enabled. // // - SLSR may decrease ILP at the architecture level. Targets that are very @@ -79,6 +75,7 @@ struct Candidate : public ilist_node { enum Kind { Invalid, // reserved for the default constructor + Add, // B + i * S Mul, // (B + i) * S GEP, // &B[..][i * S][..] }; @@ -146,6 +143,12 @@ // Checks whether I is in a candidate form. If so, adds all the matching forms // to Candidates, and tries to find the immediate basis for each of them. void allocateCandidateAndFindBasis(Instruction *I); + // Allocate candidates and find bases for Add instructions. + void allocateCandidateAndFindBasisForAdd(Instruction *I); + // Given I = LHS + RHS, factors RHS into i * S and makes (LHS + i * S) a + // candidate. + void allocateCandidateAndFindBasisForAdd(Value *LHS, Value *RHS, + Instruction *I); // Allocate candidates and find bases for Mul instructions. void allocateCandidateAndFindBasisForMul(Instruction *I); // Splits LHS into Base + Index and, if succeeds, calls @@ -215,9 +218,9 @@ Basis.CandidateKind == C.CandidateKind); } -static bool isCompletelyFoldable(GetElementPtrInst *GEP, - const TargetTransformInfo *TTI, - const DataLayout *DL) { +static bool isGEPFoldable(GetElementPtrInst *GEP, + const TargetTransformInfo *TTI, + const DataLayout *DL) { GlobalVariable *BaseGV = nullptr; int64_t BaseOffset = 0; bool HasBaseReg = false; @@ -252,6 +255,13 @@ BaseOffset, HasBaseReg, Scale); } +// Returns whether (Base + Index * Stride) can be folded to an addressing mode. +static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, + TargetTransformInfo *TTI) { + return TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true, + Index->getSExtValue()); +} + // TODO: We currently implement an algorithm whose time complexity is linear to // the number of existing candidates. However, a better algorithm exists. We // could depth-first search the dominator tree, and maintain a hash table that @@ -265,7 +275,12 @@ if (GetElementPtrInst *GEP = dyn_cast(I)) { // If &B[Idx * S] fits into an addressing mode, do not turn it into // non-free computation. - if (isCompletelyFoldable(GEP, TTI, DL)) + if (isGEPFoldable(GEP, TTI, DL)) + return; + } + if (CT == Candidate::Add) { + // Similarly, bail out if (B + Idx * S) is free. + if (isAddFoldable(B, Idx, S, TTI)) return; } @@ -289,6 +304,9 @@ void StraightLineStrengthReduce::allocateCandidateAndFindBasis(Instruction *I) { switch (I->getOpcode()) { + case Instruction::Add: + allocateCandidateAndFindBasisForAdd(I); + break; case Instruction::Mul: allocateCandidateAndFindBasisForMul(I); break; @@ -298,6 +316,37 @@ } } +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForAdd( + Instruction *I) { + // Try matching B + i * S. + if (!isa(I->getType())) + return; + + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + allocateCandidateAndFindBasisForAdd(LHS, RHS, I); + if (LHS != RHS) + allocateCandidateAndFindBasisForAdd(RHS, LHS, I); +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForAdd( + Value *LHS, Value *RHS, Instruction *I) { + Value *S = nullptr; + ConstantInt *Idx = nullptr; + if (match(RHS, m_Mul(m_Value(S), m_ConstantInt(Idx)))) { + // I = LHS + RHS = LHS + Idx * S + allocateCandidateAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I); + } else if (match(RHS, m_Shl(m_Value(S), m_ConstantInt(Idx)))) { + // I = LHS + RHS = LHS + (S << Idx) = LHS + S * (1 << Idx) + APInt One(Idx->getBitWidth(), 1); + Idx = ConstantInt::get(Idx->getContext(), One << Idx->getValue()); + allocateCandidateAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I); + } else { + // At least, I = LHS + 1 * RHS + ConstantInt *One = ConstantInt::get(cast(I->getType()), 1); + allocateCandidateAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS, I); + } +} + void StraightLineStrengthReduce::allocateCandidateAndFindBasisForMul( Value *LHS, Value *RHS, Instruction *I) { Value *B = nullptr; @@ -442,6 +491,7 @@ else BumpWithUglyGEP = true; } + // Compute Bump = C - Basis = (i' - i) * S. // Common case 1: if (i' - i) is 1, Bump = S. if (IndexOffset.getSExtValue() == 1) @@ -449,9 +499,18 @@ // Common case 2: if (i' - i) is -1, Bump = -S. if (IndexOffset.getSExtValue() == -1) return Builder.CreateNeg(C.Stride); - // Otherwise, Bump = (i' - i) * sext/trunc(S). - ConstantInt *Delta = ConstantInt::get(Basis.Ins->getContext(), IndexOffset); - Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, Delta->getType()); + + // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may + // have different bit widths. + IntegerType *DeltaType = + IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth()); + Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType); + if (IndexOffset.isPowerOf2()) { + // If (i' - i) is a power of 2, Bump = sext/trunc(S) << log(i' - i). + ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2()); + return Builder.CreateShl(ExtendedStride, Exponent); + } + Constant *Delta = ConstantInt::get(DeltaType, IndexOffset); return Builder.CreateMul(ExtendedStride, Delta); } @@ -472,6 +531,7 @@ Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP); Value *Reduced = nullptr; // equivalent to but weaker than C.Ins switch (C.CandidateKind) { + case Candidate::Add: case Candidate::Mul: Reduced = Builder.CreateAdd(Basis.Ins, Bump); break; Index: test/Transforms/StraightLineStrengthReduce/X86/no-slsr.ll =================================================================== --- test/Transforms/StraightLineStrengthReduce/X86/no-slsr.ll +++ test/Transforms/StraightLineStrengthReduce/X86/no-slsr.ll @@ -5,8 +5,8 @@ ; Do not perform SLSR on &input[s] and &input[s * 2] which fit into addressing ; modes of X86. -define i32 @slsr_gep(i32* %input, i64 %s) { -; CHECK-LABEL: @slsr_gep( +define i32 @no_slsr_gep(i32* %input, i64 %s) { +; CHECK-LABEL: @no_slsr_gep( ; v0 = input[0]; %p0 = getelementptr inbounds i32, i32* %input, i64 0 %v0 = load i32, i32* %p0 @@ -28,3 +28,17 @@ ret i32 %2 } +define void @no_slsr_add(i32 %b, i32 %s) { +; CHECK-LABEL: @no_slsr_add( + %1 = add i32 %b, %s +; CHECK: add i32 %b, %s + call void @foo(i32 %1) + %s2 = mul i32 %s, 2 +; CHECK: %s2 = mul i32 %s, 2 + %2 = add i32 %b, %s2 +; CHECK: add i32 %b, %s2 + call void @foo(i32 %2) + ret void +} + +declare void @foo(i32 %a) Index: test/Transforms/StraightLineStrengthReduce/slsr-add.ll =================================================================== --- /dev/null +++ test/Transforms/StraightLineStrengthReduce/slsr-add.ll @@ -0,0 +1,52 @@ +; RUN: opt < %s -slsr -gvn -dce -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + +define void @shl(i32 %b, i32 %s) { +; CHECK-LABEL: @shl( + %1 = add i32 %b, %s +; [[BASIS:%[a-zA-Z0-9]+]] = add i32 %b, %s + call void @foo(i32 %1) + %s2 = shl i32 %s, 1 + %2 = add i32 %b, %s2 +; add i32 [[BASIS]], %s + call void @foo(i32 %2) + ret void +} + +define void @stride_is_2s(i32 %b, i32 %s) { +; CHECK-LABEL: @stride_is_2s( + %s2 = shl i32 %s, 1 +; CHECK: %s2 = shl i32 %s, 1 + %1 = add i32 %b, %s2 +; CHECK: [[t1:%[a-zA-Z0-9]+]] = add i32 %b, %s2 + call void @foo(i32 %1) + %s4 = shl i32 %s, 2 + %2 = add i32 %b, %s4 +; CHECK: [[t2:%[a-zA-Z0-9]+]] = add i32 [[t1]], %s2 + call void @foo(i32 %2) + %s6 = mul i32 %s, 6 + %3 = add i32 %b, %s6 +; CHECK: add i32 [[t2]], %s2 + call void @foo(i32 %3) + ret void +} + +define void @stride_is_3s(i32 %b, i32 %s) { +; CHECK-LABEL: @stride_is_3s( + %1 = add i32 %s, %b +; CHECK: [[t1:%[a-zA-Z0-9]+]] = add i32 %s, %b + call void @foo(i32 %1) + %s4 = shl i32 %s, 2 + %2 = add i32 %s4, %b +; CHECK: [[bump:%[a-zA-Z0-9]+]] = mul i32 %s, 3 +; CHECK: [[t2:%[a-zA-Z0-9]+]] = add i32 [[t1]], [[bump]] + call void @foo(i32 %2) + %s7 = mul i32 %s, 7 + %3 = add i32 %s7, %b +; CHECK: add i32 [[t2]], [[bump]] + call void @foo(i32 %3) + ret void +} + +declare void @foo(i32 %a) Index: test/Transforms/StraightLineStrengthReduce/slsr-mul.ll =================================================================== --- test/Transforms/StraightLineStrengthReduce/slsr-mul.ll +++ test/Transforms/StraightLineStrengthReduce/slsr-mul.ll @@ -79,7 +79,7 @@ %b1 = add i32 %b, 2 %mul1 = mul i32 %b1, %s -; CHECK: [[BUMP:%[a-zA-Z0-9]+]] = mul i32 %s, 2 +; CHECK: [[BUMP:%[a-zA-Z0-9]+]] = shl i32 %s, 1 ; CHECK: %mul1 = add i32 %mul0, [[BUMP]] %v1 = call i32 @foo(i32 %mul1)