Changeset View
Changeset View
Standalone View
Standalone View
llvm/lib/Target/X86/X86PartialReduction.cpp
Show All 13 Lines | |||||
#include "X86.h" | #include "X86.h" | ||||
#include "X86TargetMachine.h" | #include "X86TargetMachine.h" | ||||
#include "llvm/Analysis/ValueTracking.h" | #include "llvm/Analysis/ValueTracking.h" | ||||
#include "llvm/CodeGen/TargetPassConfig.h" | #include "llvm/CodeGen/TargetPassConfig.h" | ||||
#include "llvm/IR/Constants.h" | #include "llvm/IR/Constants.h" | ||||
#include "llvm/IR/IRBuilder.h" | #include "llvm/IR/IRBuilder.h" | ||||
#include "llvm/IR/Instructions.h" | #include "llvm/IR/Instructions.h" | ||||
#include "llvm/IR/IntrinsicInst.h" | |||||
#include "llvm/IR/IntrinsicsX86.h" | #include "llvm/IR/IntrinsicsX86.h" | ||||
#include "llvm/IR/Operator.h" | #include "llvm/IR/Operator.h" | ||||
#include "llvm/IR/PatternMatch.h" | |||||
#include "llvm/Pass.h" | #include "llvm/Pass.h" | ||||
#include "llvm/Support/KnownBits.h" | #include "llvm/Support/KnownBits.h" | ||||
using namespace llvm; | using namespace llvm; | ||||
#define DEBUG_TYPE "x86-partial-reduction" | #define DEBUG_TYPE "x86-partial-reduction" | ||||
namespace { | namespace { | ||||
▲ Show 20 Lines • Show All 182 Lines • ▼ Show 20 Lines | |||||
bool X86PartialReduction::trySADReplacement(Instruction *Op) { | bool X86PartialReduction::trySADReplacement(Instruction *Op) { | ||||
if (!ST->hasSSE2()) | if (!ST->hasSSE2()) | ||||
return false; | return false; | ||||
// TODO: There's nothing special about i32, any integer type above i16 should | // TODO: There's nothing special about i32, any integer type above i16 should | ||||
// work just as well. | // work just as well. | ||||
if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) | if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) | ||||
return false; | return false; | ||||
Value *LHS; | |||||
craig.topper: Sink RHS into the `else` so it doens't get accidentally used. | |||||
if (match(Op, PatternMatch::m_Intrinsic<Intrinsic::abs>())) { | |||||
LHS = Op->getOperand(0); | |||||
} else { | |||||
// Operand should be a select. | // Operand should be a select. | ||||
auto *SI = dyn_cast<SelectInst>(Op); | auto *SI = dyn_cast<SelectInst>(Op); | ||||
if (!SI) | if (!SI) | ||||
return false; | return false; | ||||
Value *RHS; | |||||
// Select needs to implement absolute value. | // Select needs to implement absolute value. | ||||
Value *LHS, *RHS; | |||||
auto SPR = matchSelectPattern(SI, LHS, RHS); | auto SPR = matchSelectPattern(SI, LHS, RHS); | ||||
if (SPR.Flavor != SPF_ABS) | if (SPR.Flavor != SPF_ABS) | ||||
return false; | return false; | ||||
} | |||||
Add curly braces to be consistent with the else. See the 4th example if here https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements craig.topper: Add curly braces to be consistent with the else. See the 4th example if here https://llvm. | |||||
// Need a subtract of two values. | // Need a subtract of two values. | ||||
auto *Sub = dyn_cast<BinaryOperator>(LHS); | auto *Sub = dyn_cast<BinaryOperator>(LHS); | ||||
if (!Sub || Sub->getOpcode() != Instruction::Sub) | if (!Sub || Sub->getOpcode() != Instruction::Sub) | ||||
return false; | return false; | ||||
// Look for zero extend from i8. | // Look for zero extend from i8. | ||||
auto getZeroExtendedVal = [](Value *Op) -> Value * { | auto getZeroExtendedVal = [](Value *Op) -> Value * { | ||||
if (auto *ZExt = dyn_cast<ZExtInst>(Op)) | if (auto *ZExt = dyn_cast<ZExtInst>(Op)) | ||||
if (cast<VectorType>(ZExt->getOperand(0)->getType()) | if (cast<VectorType>(ZExt->getOperand(0)->getType()) | ||||
->getElementType() | ->getElementType() | ||||
->isIntegerTy(8)) | ->isIntegerTy(8)) | ||||
return ZExt->getOperand(0); | return ZExt->getOperand(0); | ||||
return nullptr; | return nullptr; | ||||
}; | }; | ||||
// Both operands of the subtract should be extends from vXi8. | // Both operands of the subtract should be extends from vXi8. | ||||
Value *Op0 = getZeroExtendedVal(Sub->getOperand(0)); | Value *Op0 = getZeroExtendedVal(Sub->getOperand(0)); | ||||
Value *Op1 = getZeroExtendedVal(Sub->getOperand(1)); | Value *Op1 = getZeroExtendedVal(Sub->getOperand(1)); | ||||
if (!Op0 || !Op1) | if (!Op0 || !Op1) | ||||
return false; | return false; | ||||
IRBuilder<> Builder(SI); | IRBuilder<> Builder(Op); | ||||
auto *OpTy = cast<FixedVectorType>(Op->getType()); | auto *OpTy = cast<FixedVectorType>(Op->getType()); | ||||
unsigned NumElts = OpTy->getNumElements(); | unsigned NumElts = OpTy->getNumElements(); | ||||
unsigned IntrinsicNumElts; | unsigned IntrinsicNumElts; | ||||
Intrinsic::ID IID; | Intrinsic::ID IID; | ||||
if (ST->hasBWI() && NumElts >= 64) { | if (ST->hasBWI() && NumElts >= 64) { | ||||
IID = Intrinsic::x86_avx512_psad_bw_512; | IID = Intrinsic::x86_avx512_psad_bw_512; | ||||
IntrinsicNumElts = 64; | IntrinsicNumElts = 64; | ||||
} else if (ST->hasAVX2() && NumElts >= 32) { | } else if (ST->hasAVX2() && NumElts >= 32) { | ||||
IID = Intrinsic::x86_avx2_psad_bw; | IID = Intrinsic::x86_avx2_psad_bw; | ||||
IntrinsicNumElts = 32; | IntrinsicNumElts = 32; | ||||
} else { | } else { | ||||
IID = Intrinsic::x86_sse2_psad_bw; | IID = Intrinsic::x86_sse2_psad_bw; | ||||
IntrinsicNumElts = 16; | IntrinsicNumElts = 16; | ||||
} | } | ||||
Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID); | Function *PSADBWFn = Intrinsic::getDeclaration(Op->getModule(), IID); | ||||
if (NumElts < 16) { | if (NumElts < 16) { | ||||
// Pad input with zeroes. | // Pad input with zeroes. | ||||
SmallVector<int, 32> ConcatMask(16); | SmallVector<int, 32> ConcatMask(16); | ||||
for (unsigned i = 0; i != NumElts; ++i) | for (unsigned i = 0; i != NumElts; ++i) | ||||
ConcatMask[i] = i; | ConcatMask[i] = i; | ||||
for (unsigned i = NumElts; i != 16; ++i) | for (unsigned i = NumElts; i != 16; ++i) | ||||
ConcatMask[i] = (i % NumElts) + NumElts; | ConcatMask[i] = (i % NumElts) + NumElts; | ||||
▲ Show 20 Lines • Show All 48 Lines • ▼ Show 20 Lines | for (unsigned i = 0; i != SubElts; ++i) | ||||
ConcatMask[i] = i; | ConcatMask[i] = i; | ||||
for (unsigned i = SubElts; i != NumElts; ++i) | for (unsigned i = SubElts; i != NumElts; ++i) | ||||
ConcatMask[i] = (i % SubElts) + SubElts; | ConcatMask[i] = (i % SubElts) + SubElts; | ||||
Value *Zero = Constant::getNullValue(Ops[0]->getType()); | Value *Zero = Constant::getNullValue(Ops[0]->getType()); | ||||
Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); | Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); | ||||
} | } | ||||
SI->replaceAllUsesWith(Ops[0]); | Op->replaceAllUsesWith(Ops[0]); | ||||
SI->eraseFromParent(); | Op->eraseFromParent(); | ||||
return true; | return true; | ||||
} | } | ||||
// Walk backwards from the ExtractElementInst and determine if it is the end of | // Walk backwards from the ExtractElementInst and determine if it is the end of | ||||
// a horizontal reduction. Return the input to the reduction if we find one. | // a horizontal reduction. Return the input to the reduction if we find one. | ||||
static Value *matchAddReduction(const ExtractElementInst &EE, | static Value *matchAddReduction(const ExtractElementInst &EE, | ||||
bool &ReduceInOneBB) { | bool &ReduceInOneBB) { | ||||
▲ Show 20 Lines • Show All 190 Lines • Show Last 20 Lines |
Sink RHS into the else so it doens't get accidentally used.