Index: lib/Transforms/Scalar/StraightLineStrengthReduce.cpp =================================================================== --- lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -15,37 +15,33 @@ // // There are many optimizations we can perform in the domain of SLSR. This file // for now contains only an initial step. Specifically, we look for strength -// reduction candidates in two forms: +// reduction candidates in the following forms: // -// Form 1: (B + i) * S -// Form 2: &B[i * S] +// Form 1: B + i * S +// Form 2: (B + i) * S +// Form 3: &B[i * S] // // where S is an integer variable, and i is a constant integer. If we found two -// candidates +// candidates S1 and S2 in the same form and S1 dominates S2, we may be able to +// rewrite S2 in a simpler way with respect to S1. For example, // -// S1: X = (B + i) * S -// S2: Y = (B + i') * S +// S1: X = B + i * S +// S2: Y = B + i' * S => X + (i' - i) * S // -// or +// S1: X = (B + i) * S +// S2: Y = (B + i') * S => X + (i' - i) * S // // S1: X = &B[i * S] -// S2: Y = &B[i' * S] -// -// and S1 dominates S2, we call S1 a basis of S2, and can replace S2 with +// S2: Y = &B[i' * S] => &X[(i' - i) * S] // -// Y = X + (i' - i) * S +// Note: (i' - i) * S is folded to the extend possible. // -// or -// -// Y = &X[(i' - i) * S] -// -// where (i' - i) * S is folded to the extent possible. When S2 has multiple -// bases, we pick the one that is closest to S2, or S2's "immediate" basis. +// When such rewriting is possible, we call S1 a "basis" of S2. When S2 has +// multiple bases, we choose to rewrite S2 with respect to the closest basis or +// the "immediate" basis. // // TODO: // -// - Handle candidates in the form of B + i * S -// // - Floating point arithmetics when fast math is enabled. // // - SLSR may decrease ILP at the architecture level. Targets that are very @@ -79,6 +75,7 @@ struct Candidate : public ilist_node { enum Kind { Invalid, // reserved for the default constructor + Add, // B + i * S Mul, // (B + i) * S GEP, // &B[..][i * S][..] }; @@ -146,6 +143,12 @@ // Checks whether I is in a candidate form. If so, adds all the matching forms // to Candidates, and tries to find the immediate basis for each of them. void allocateCandidateAndFindBasis(Instruction *I); + // Allocate candidates and find bases for Add instructions. + void allocateCandidateAndFindBasisForAdd(Instruction *I); + // Given I = LHS + RHS, factors RHS into i * S and makes (LHS + i * S) a + // candidate. + void allocateCandidateAndFindBasisForAdd(Value *LHS, Value *RHS, + Instruction *I); // Allocate candidates and find bases for Mul instructions. void allocateCandidateAndFindBasisForMul(Instruction *I); // Splits LHS into Base + Index and, if succeeds, calls @@ -289,6 +292,9 @@ void StraightLineStrengthReduce::allocateCandidateAndFindBasis(Instruction *I) { switch (I->getOpcode()) { + case Instruction::Add: + allocateCandidateAndFindBasisForAdd(I); + break; case Instruction::Mul: allocateCandidateAndFindBasisForMul(I); break; @@ -298,6 +304,37 @@ } } +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForAdd( + Instruction *I) { + // Try matching B + i * S. + if (!isa(I->getType())) + return; + + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + allocateCandidateAndFindBasisForAdd(LHS, RHS, I); + if (LHS != RHS) + allocateCandidateAndFindBasisForAdd(RHS, LHS, I); +} + +void StraightLineStrengthReduce::allocateCandidateAndFindBasisForAdd( + Value *LHS, Value *RHS, Instruction *I) { + Value *S = nullptr; + ConstantInt *Idx = nullptr; + if (match(RHS, m_Mul(m_Value(S), m_ConstantInt(Idx)))) { + // I = LHS + RHS = LHS + Idx * S + allocateCandidateAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I); + } else if (match(RHS, m_Shl(m_Value(S), m_ConstantInt(Idx)))) { + // I = LHS + RHS = LHS + (S << Idx) = LHS + S * (1 << Idx) + APInt One(Idx->getBitWidth(), 1); + Idx = ConstantInt::get(Idx->getContext(), One << Idx->getValue()); + allocateCandidateAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I); + } else { + // At least, I = LHS + 1 * RHS + ConstantInt *One = ConstantInt::get(cast(I->getType()), 1); + allocateCandidateAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS, I); + } +} + void StraightLineStrengthReduce::allocateCandidateAndFindBasisForMul( Value *LHS, Value *RHS, Instruction *I) { Value *B = nullptr; @@ -442,6 +479,7 @@ else BumpWithUglyGEP = true; } + // Compute Bump = C - Basis = (i' - i) * S. // Common case 1: if (i' - i) is 1, Bump = S. if (IndexOffset.getSExtValue() == 1) @@ -449,9 +487,18 @@ // Common case 2: if (i' - i) is -1, Bump = -S. if (IndexOffset.getSExtValue() == -1) return Builder.CreateNeg(C.Stride); - // Otherwise, Bump = (i' - i) * sext/trunc(S). - ConstantInt *Delta = ConstantInt::get(Basis.Ins->getContext(), IndexOffset); - Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, Delta->getType()); + + // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may + // have different bit widths. + IntegerType *DeltaType = + IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth()); + Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType); + if (IndexOffset.isPowerOf2()) { + // If (i' - i) is a power of 2, Bump = sext/trunc(S) << log(i' - i). + ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2()); + return Builder.CreateShl(ExtendedStride, Exponent); + } + Constant *Delta = ConstantInt::get(DeltaType, IndexOffset); return Builder.CreateMul(ExtendedStride, Delta); } @@ -472,6 +519,7 @@ Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP); Value *Reduced = nullptr; // equivalent to but weaker than C.Ins switch (C.CandidateKind) { + case Candidate::Add: case Candidate::Mul: Reduced = Builder.CreateAdd(Basis.Ins, Bump); break; Index: test/Transforms/StraightLineStrengthReduce/slsr-add.ll =================================================================== --- /dev/null +++ test/Transforms/StraightLineStrengthReduce/slsr-add.ll @@ -0,0 +1,52 @@ +; RUN: opt < %s -slsr -gvn -dce -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + +define void @shl(i32 %b, i32 %s) { +; CHECK-LABEL: @shl( + %1 = add i32 %b, %s +; [[BASIS:%[a-zA-Z0-9]+]] = add i32 %b, %s + call void @foo(i32 %1) + %s2 = shl i32 %s, 1 + %2 = add i32 %b, %s2 +; add i32 [[BASIS]], %s + call void @foo(i32 %2) + ret void +} + +define void @stride_is_2s(i32 %b, i32 %s) { +; CHECK-LABEL: @stride_is_2s( + %s2 = shl i32 %s, 1 +; CHECK: %s2 = shl i32 %s, 1 + %1 = add i32 %b, %s2 +; CHECK: [[t1:%[a-zA-Z0-9]+]] = add i32 %b, %s2 + call void @foo(i32 %1) + %s4 = shl i32 %s, 2 + %2 = add i32 %b, %s4 +; CHECK: [[t2:%[a-zA-Z0-9]+]] = add i32 [[t1]], %s2 + call void @foo(i32 %2) + %s6 = mul i32 %s, 6 + %3 = add i32 %b, %s6 +; CHECK: add i32 [[t2]], %s2 + call void @foo(i32 %3) + ret void +} + +define void @stride_is_3s(i32 %b, i32 %s) { +; CHECK-LABEL: @stride_is_3s( + %1 = add i32 %s, %b +; CHECK: [[t1:%[a-zA-Z0-9]+]] = add i32 %s, %b + call void @foo(i32 %1) + %s4 = shl i32 %s, 2 + %2 = add i32 %s4, %b +; CHECK: [[bump:%[a-zA-Z0-9]+]] = mul i32 %s, 3 +; CHECK: [[t2:%[a-zA-Z0-9]+]] = add i32 [[t1]], [[bump]] + call void @foo(i32 %2) + %s7 = mul i32 %s, 7 + %3 = add i32 %s7, %b +; CHECK: add i32 [[t2]], [[bump]] + call void @foo(i32 %3) + ret void +} + +declare void @foo(i32 %a) Index: test/Transforms/StraightLineStrengthReduce/slsr-mul.ll =================================================================== --- test/Transforms/StraightLineStrengthReduce/slsr-mul.ll +++ test/Transforms/StraightLineStrengthReduce/slsr-mul.ll @@ -79,7 +79,7 @@ %b1 = add i32 %b, 2 %mul1 = mul i32 %b1, %s -; CHECK: [[BUMP:%[a-zA-Z0-9]+]] = mul i32 %s, 2 +; CHECK: [[BUMP:%[a-zA-Z0-9]+]] = shl i32 %s, 1 ; CHECK: %mul1 = add i32 %mul0, [[BUMP]] %v1 = call i32 @foo(i32 %mul1)