Index: include/llvm/IR/PromoteHalf.h =================================================================== --- include/llvm/IR/PromoteHalf.h +++ include/llvm/IR/PromoteHalf.h @@ -0,0 +1,21 @@ +//===- PromoteHalf.h - LLVM half float transformation ----------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PROMOTE_HALF_H +#define LLVM_IR_PROMOTE_HALF_H + +namespace llvm { + +class ModulePass; + +ModulePass *createPromoteHalfPass(); + +} // End llvm namespace + +#endif Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -201,6 +201,7 @@ void initializeMetaRenamerPass(PassRegistry&); void initializeMergeFunctionsPass(PassRegistry&); void initializeModuleDebugInfoPrinterPass(PassRegistry&); +void initializePromoteHalfPass(PassRegistry&); void initializeNoAAPass(PassRegistry&); void initializeObjCARCAliasAnalysisPass(PassRegistry&); void initializeObjCARCAPElimPass(PassRegistry&); Index: lib/IR/CMakeLists.txt =================================================================== --- lib/IR/CMakeLists.txt +++ lib/IR/CMakeLists.txt @@ -44,6 +44,7 @@ Value.cpp ValueSymbolTable.cpp ValueTypes.cpp + PromoteHalf.cpp Verifier.cpp ) Index: lib/IR/Core.cpp =================================================================== --- lib/IR/Core.cpp +++ lib/IR/Core.cpp @@ -50,6 +50,7 @@ initializePrintFunctionPassWrapperPass(Registry); initializePrintBasicBlockPassPass(Registry); initializeVerifierLegacyPassPass(Registry); + initializePromoteHalfPass(Registry); } void LLVMInitializeCore(LLVMPassRegistryRef R) { Index: lib/IR/PromoteHalf.cpp =================================================================== --- lib/IR/PromoteHalf.cpp +++ lib/IR/PromoteHalf.cpp @@ -0,0 +1,1084 @@ +//===-- PromoteHalf.cpp - Implement the conversion for fp16 operations -----==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// When Opts.NativeHalfType is set to 1, Clang will generate IR for half float +// Currently, most CPUs don't support fp16 register, fp16 should be promoted +// to float for computation, then store back the result to fp16. +// +// This is a module pass to transform operations on fp16 to operations on fp32 +// Backend could add this pass as needed. +//===----------------------------------------------------------------------===// + +#include "LLVMContextImpl.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PromoteHalf.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/PassManager.h" + +using namespace llvm; + +namespace { + +class Fp16InstVisitor : public InstVisitor { + friend class InstVisitor; + + LLVMContext *Ctx; + Module *M; + +public: + explicit Fp16InstVisitor() : Ctx(nullptr) {} + + bool transformFunction(Function &F) { + M = F.getParent(); + Ctx = &M->getContext(); + + for (Function::const_iterator I = F.begin(), E = F.end(); I != E; ++I) { + if (I->empty() || !I->back().isTerminator()) { + return true; + } + } + + visit(const_cast(F)); + return true; + } + +private: + // Below functions convert fp16 vector to fp32/fp64 vector or convert fp32/fp64 vector to fp16 vector. + // if operands are fp16 vector, they need to be converted to fp32/fp64 vector. + // There are effective solution to do such vector conversion between fp16 and fp32/fp64, implement them in future. + Value* convertFp16VToFp32V(Value *src, Instruction *I); + Value* convertFp32VToFp16V(Value *src, Instruction *I); + + Value* convertFp16VToFp64V(Value *src, Instruction *I); + Value* convertFp64VToFp16V(Value *src, Instruction *I); + + // Create a vector of fp16 with specified size + VectorType* getFp16Vector(unsigned size); + + // Below two functions are used to do fp16<--->fp32 conversion + // There are many ways to do such conversions, lib calls or intrinsic calls + Function* getFp16ToFp32() { + return llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::convert_from_fp16, Type::getFloatTy(*Ctx)); + } + Function* getFp32ToFp16() { + return llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::convert_to_fp16, Type::getFloatTy(*Ctx)); + } + +private: + // InstVisitor overrides... + using InstVisitor::visit; + + void visitFunction(Function &F); + void visitBasicBlock(BasicBlock &BB); + void visit(Instruction &I); + + // Terminator Instructions + // 'ret' Instruction + // 'br' Instruction + // 'switch' Instruction + // 'indirectbr' Instruction + // 'invoke' Instruction + // 'resume' Instruction + // 'unreachable' Instruction + void visitReturnInst(ReturnInst &RI); + void visitBranchInst(BranchInst &BI); + void visitSwitchInst(SwitchInst &SI); + void visitIndirectBrInst(IndirectBrInst &BI); + void visitInvokeInst(InvokeInst &II); + + // Binary Operations + // 'add' Instruction + // 'fadd' Instruction + // 'sub' Instruction + // 'fsub' Instruction + // 'mul' Instruction + // 'fmul' Instruction + // 'udiv' Instruction + // 'sdiv' Instruction + // 'fdiv' Instruction + // 'urem' Instruction + // 'srem' Instruction + // 'frem' Instruction + + // Bitwise Binary Operations + // 'shl' Instruction + // 'lshr' Instruction + // 'ashr' Instruction + // 'and' Instruction + // 'or' Instruction + // 'xor' Instruction + void visitBinaryOperator(BinaryOperator &B); + + // Vector Operations + // 'extractelement' Instruction + // 'insertelement' Instruction + // 'shufflevector' Instruction + void visitExtractElementInst(ExtractElementInst &EI); + void visitInsertElementInst(InsertElementInst &EI); + void visitShuffleVectorInst(ShuffleVectorInst &EI); + + // Aggregate Operations + // 'extractvalue' Instruction + // 'insertvalue' Instruction + void visitExtractValueInst(ExtractValueInst &EVI); + void visitInsertValueInst(InsertValueInst &IVI); + + // Memory Access and Addressing Operations + // 'alloca' Instruction + // 'load' Instruction + // 'store' Instruction + // 'fence' Instruction + // 'cmpxchg' Instruction + // 'atomicrmw' Instruction + // 'getelementptr' Instruction + void visitAllocaInst(AllocaInst &AI); + void visitLoadInst(LoadInst &LI); + void visitStoreInst(StoreInst &SI); + void visitFenceInst(FenceInst &FI); + void visitAtomicCmpXchgInst(AtomicCmpXchgInst &CXI); + void visitAtomicRMWInst(AtomicRMWInst &RMWI); + void visitGetElementPtrInst(GetElementPtrInst &GEP); + + // Conversion Operations + // 'trunc .. to' Instruction + // 'zext .. to' Instruction + // 'sext .. to' Instruction + // 'fptrunc .. to' Instruction + // 'fpext .. to' Instruction + // 'fptoui .. to' Instruction + // 'fptosi .. to' Instruction + // 'uitofp .. to' Instruction + // 'sitofp .. to' Instruction + // 'ptrtoint .. to' Instruction + // 'inttoptr .. to' Instruction + // 'bitcast .. to' Instruction + // 'addrspacecast .. to' Instruction + void visitTruncInst(TruncInst &I); + void visitZExtInst(ZExtInst &I); + void visitSExtInst(SExtInst &I); + void visitFPTruncInst(FPTruncInst &I); + void visitFPExtInst(FPExtInst &I); + void visitFPToUIInst(FPToUIInst &I); + void visitFPToSIInst(FPToSIInst &I); + void visitUIToFPInst(UIToFPInst &I); + void visitSIToFPInst(SIToFPInst &I); + void visitIntToPtrInst(IntToPtrInst &I); + void visitPtrToIntInst(PtrToIntInst &I); + void visitBitCastInst(BitCastInst &I); + void visitAddrSpaceCastInst(AddrSpaceCastInst &I); + + // Other Operations + // 'icmp' Instruction + // 'fcmp' Instruction + // 'phi' Instruction + // 'select' Instruction + // 'call' Instruction + // 'va_arg' Instruction + // 'landingpad' Instruction + void visitICmpInst(ICmpInst &IC); + void visitFCmpInst(FCmpInst &FC); + void visitPHINode(PHINode &PN); + void visitSelectInst(SelectInst &SI); + void visitCallInst(CallInst &CI); + void visitVAArgInst(VAArgInst &VAA); + void visitLandingPadInst(LandingPadInst &LPI); +}; +} // End anonymous namespace + +VectorType* Fp16InstVisitor::getFp16Vector(unsigned size) { + LLVMContextImpl *pImpl = Ctx->pImpl; + for (DenseMap, VectorType*>::iterator I = pImpl->VectorTypes.begin(), + E = pImpl->VectorTypes.end(); + I != E; ++I) { + VectorType *vecTy = I->second; + if (I->first.second == size && I->first.first->isHalfTy()) { + return vecTy; + } + } + return VectorType::get(Type::getInt16Ty(*Ctx), size); +} + +// Below functions convert fp16 vector to fp32/fp64 vector or convert fp32/fp64 vector to fp16 vector. +// if operands are fp16 vector, they need to be converted to fp32/fp64 vector. +// There are effective solution to do such vector conversion between fp16 and fp32/fp64, implement them in future. +Value* Fp16InstVisitor::convertFp16VToFp32V(Value *src, Instruction *next) { + VectorType *vTy = dyn_cast(src->getType()); + if (vTy == NULL || !vTy->getElementType()->isIntegerTy(16)) { + return NULL; + } + + unsigned size = vTy->getNumElements(); + Function *func = getFp16ToFp32(); + Value *Undef = UndefValue::get(VectorType::get(Type::getFloatTy(*Ctx), size)); + + // extract element from vector and convert i16 to fp32 + Value *extract, *convert, *insertElem = Undef; + for (unsigned i = 0; i < size; i++) { + // extractelement <4 x i16> src, i32 i + extract = ExtractElementInst::Create(src, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + // call float @llvm.convert.from.fp16(i16 %) + convert = CallInst::Create(func, extract, "", next); + // insertelement % undef, float %3, i32 i + insertElem = InsertElementInst::Create(insertElem, convert, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + } + // return this vector + return insertElem; +} + +Value* Fp16InstVisitor::convertFp32VToFp16V(Value *src, Instruction *next) { + VectorType *vTy = dyn_cast(src->getType()); + if (vTy == NULL || !vTy->getElementType()->isFloatTy()) { + return NULL; + } + + unsigned size = vTy->getNumElements(); + Function *func = getFp32ToFp16(); + Value *Undef = UndefValue::get(getFp16Vector(size)); + + // extract element from vector and convert fp32 to i16 + Value *extract, *convert, *insertElem = Undef; + for (unsigned i = 0; i < size; i++) { + // extractelement <4 x float> src, i32 i + extract = ExtractElementInst::Create(src, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + // call i16 @llvm.convert.to.fp16(float %) + convert = CallInst::Create(func, extract, "", next); + // insertelement % undef, i16 %3, i32 i + insertElem = InsertElementInst::Create(insertElem, convert, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + } + // return this vector + return insertElem; +} + +Value* Fp16InstVisitor::convertFp16VToFp64V(Value *src, Instruction *next) { + VectorType *vTy = dyn_cast(src->getType()); + if (vTy == NULL || !vTy->getElementType()->isIntegerTy(16)) { + return NULL; + } + + unsigned size = vTy->getNumElements(); + Function *func = getFp16ToFp32(); + Value *Undef = UndefValue::get(VectorType::get(Type::getDoubleTy(*Ctx), size)); + + // extract element from vector and convert i16 to fp64 + Value *extract, *convert, *insertElem = Undef; + for (unsigned i = 0; i < size; i++) { + // extractelement <4 x i16> src, i32 i + extract = ExtractElementInst::Create(src, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + + // call float @llvm.convert.from.fp16(i16 %) + convert = CallInst::Create(func, extract, "", next); + + // ext float to double, use double @Intrinsic::convert_from_fp16() in future + Value *extValue = CastInst::CreateFPCast(convert, Type::getDoubleTy(*Ctx), "", next); + + // insertelement % undef, double %3, i32 i + insertElem = InsertElementInst::Create(insertElem, extValue, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + } + // return this vector + return insertElem; +} + +Value* Fp16InstVisitor::convertFp64VToFp16V(Value *src, Instruction *next) { + VectorType *vTy = dyn_cast(src->getType()); + if (vTy == NULL || !vTy->getElementType()->isDoubleTy()) { + return NULL; + } + + unsigned size = vTy->getNumElements(); + Function *func = getFp32ToFp16(); + Value *Undef = UndefValue::get(getFp16Vector(size)); + + // extract element from vector and convert fp64 to i16 + Value *extract, *convert, *insertElem = Undef; + for (unsigned i = 0; i < size; i++) { + // extractelement <4 x double> src, i32 i + extract = ExtractElementInst::Create(src, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + + // trunc double to float first, use Intrinsic::convert_to_fp16(double) in future + Value *trunc = CastInst::CreateTruncOrBitCast(extract, Type::getFloatTy(*Ctx), "", next); + + // call i16 @llvm.convert.to.fp16(float %) + convert = CallInst::Create(func, trunc, "", next); + // insertelement % undef, i16 %3, i32 i + insertElem = InsertElementInst::Create(insertElem, convert, ConstantInt::get(Type::getInt32Ty(*Ctx), i), "", next); + } + // return this vector + return insertElem; +} + +void Fp16InstVisitor::visitFunction(Function &F) { + // Won't use half. +} + +void Fp16InstVisitor::visitBasicBlock(BasicBlock &BB) { + // Won't use half. +} + +void Fp16InstVisitor::visit(Instruction &I) { + InstVisitor::visit(I); +} + +/// ------ Terminator Instructions ------/// + +void Fp16InstVisitor::visitReturnInst(ReturnInst &RI) { + Value *ret = RI.getReturnValue(); + if (ret && ret->getType()->isHalfTy()) { + ret->mutateType(Type::getInt16Ty(*Ctx)); + } +} + +void Fp16InstVisitor::visitBranchInst(BranchInst &BI) { + // Won't use half. +} + +void Fp16InstVisitor::visitSwitchInst(SwitchInst &SI) { + // Won't use half. +} + +void Fp16InstVisitor::visitIndirectBrInst(IndirectBrInst &BI) { + // Won't use half. +} + +void Fp16InstVisitor::visitInvokeInst(InvokeInst &II) { + // Won't use half. +} + +/// ------ Binary Operations ------/// + +// fsub/fadd/fmul/fdiv/frem , +void Fp16InstVisitor::visitBinaryOperator(BinaryOperator &B) { + switch (B.getOpcode()) { + // Won't use half type. + // case Instruction::Add: + // case Instruction::Sub: + // case Instruction::Mul: + // case Instruction::SDiv: + // case Instruction::UDiv: + // case Instruction::SRem: + // case Instruction::URem: + // break; + // case Instruction::And: + // case Instruction::Or: + // case Instruction::Xor: + // break; + // case Instruction::Shl: + // case Instruction::LShr: + // case Instruction::AShr: + // break; + + // Check that floating-point arithmetic operands + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: { + Value *value1 = B.getOperand(0); + Value *value2 = B.getOperand(1); + if (value1->getType()->isHalfTy()) { + value1->mutateType(Type::getInt16Ty(*Ctx)); + } + if (value2->getType()->isHalfTy()) { + value2->mutateType(Type::getInt16Ty(*Ctx)); + } + if (value1->getType()->isIntegerTy(16) && value2->getType()->isIntegerTy(16)) { + B.mutateType(Type::getInt16Ty(*Ctx)); + + // call float @llvm.convert.from.fp16(i16 %) on both operands + Function *func = getFp16ToFp32(); + CallInst *call1 = CallInst::Create(func, value1, "", &B); + CallInst *call2 = CallInst::Create(func, value2, "", &B); + + BinaryOperator *BI = BinaryOperator::Create(B.getOpcode(), call1, call2, "", B.getNextNode()); + // convert float result to fp16 + // call i16 @llvm.convert.to.fp16(float %add) + Function *func2 = getFp32ToFp16(); + CallInst *call = CallInst::Create(func2, BI, "", BI->getNextNode()); + B.replaceAllUsesWith(call); + + B.eraseFromParent(); + } else if (value1->getType()->isVectorTy() && value2->getType()->isVectorTy()) { + VectorType *PTy = dyn_cast(value1->getType()); + if (!PTy->getElementType()->isIntegerTy(16)) return; + VectorType *PTy2 = dyn_cast(value2->getType()); + if (!PTy2->getElementType()->isIntegerTy(16)) return; + + // convert parameters to float vector + Value *v1 = convertFp16VToFp32V(value1, &B); + Value *v2 = convertFp16VToFp32V(value2, &B); + BinaryOperator *BI = BinaryOperator::Create(B.getOpcode(), v1, v2, "", B.getNextNode()); + // convert fp32 vector to fp16 vector + Value *v3 = convertFp32VToFp16V(BI, BI->getNextNode()); + + B.replaceAllUsesWith(v3); + B.eraseFromParent(); + } + break; + } + default: + break; + } +} + +/// ------ Vector Operations ------/// + +// = insertelement <4 x i32> %vec, i32 1, i32 0 +// element type may be any integer or floating point type, or a pointer to these types +void Fp16InstVisitor::visitInsertElementInst(InsertElementInst &IE) { + Value *value2 = IE.getOperand(1); + if (value2->getType()->isHalfTy()) { + value2->mutateType(Type::getInt16Ty(*Ctx)); + } +} + +// = extractelement <4 x i32> %vec, i32 0 +// element type may be any integer or floating point type, or a pointer to these types +void Fp16InstVisitor::visitExtractElementInst(ExtractElementInst &EI) { + // primitive half won't be operand of this inst. +} + +// = shufflevector > , > , +void Fp16InstVisitor::visitShuffleVectorInst(ShuffleVectorInst &SV) { + // primitive half won't be operand of this inst. +} + +/// ------ Aggregate Operations ------/// + +// = extractvalue , {, }* +void Fp16InstVisitor::visitExtractValueInst(ExtractValueInst &EVI) { + // primitive half won't be operand of this inst. +} + +// = insertvalue , , {, }* +void Fp16InstVisitor::visitInsertValueInst(InsertValueInst &IVI) { + Value *value2 = IVI.getOperand(1); + if (value2->getType()->isHalfTy()) { + value2->mutateType(Type::getInt16Ty(*Ctx)); + } +} + +/// ------ Memory Access and Addressing Operations ------/// + +// %x = alloca half, align 2; %v = alloca <4 x half>, align 8 +void Fp16InstVisitor::visitAllocaInst(AllocaInst &AI) { + // This is handled in type transformation +} + +// %5 = load half* %x, align 2 +void Fp16InstVisitor::visitLoadInst(LoadInst &LI) { + // This is handled in type transformation +} + +// store half %4, half* %x, align 2 +void Fp16InstVisitor::visitStoreInst(StoreInst &SI) { + Value *value = SI.getOperand(0); + Value *pointer = SI.getOperand(1); + PointerType *PTy = dyn_cast(pointer->getType()); + + if (value->getType()->isHalfTy() || PTy->getElementType()->isHalfTy()) { + if (value->getType()->isHalfTy()) { + value->mutateType(Type::getInt16Ty(*Ctx)); + } + + if (PTy->getElementType()->isHalfTy()) { + pointer->mutateType(Type::getInt16PtrTy(*Ctx)); + } + } + // store vector is handled in type transformation +} + +void Fp16InstVisitor::visitFenceInst(FenceInst &FI) { + // This instruction won't use half as operands +} + +void Fp16InstVisitor::visitAtomicCmpXchgInst(AtomicCmpXchgInst &CXI) { + // This instruction won't use half as operands +} + +void Fp16InstVisitor::visitAtomicRMWInst(AtomicRMWInst &RMWI) { + // This instruction won't use half as operands +} + +void Fp16InstVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { + // This instruction won't use half as operands +} + +/// ------ Conversion Operations ------/// + +void Fp16InstVisitor::visitTruncInst(TruncInst &I) { + // operands must be of integer types +} + +void Fp16InstVisitor::visitZExtInst(ZExtInst &I) { + // operands must be of integer types +} + +void Fp16InstVisitor::visitSExtInst(SExtInst &I) { + // operands must be of integer types +} + +// fptrunc float %add to half +void Fp16InstVisitor::visitFPTruncInst(FPTruncInst &I) { + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + Type *DestTy = I.getType(); + + if ((DestTy->isHalfTy() || DestTy->isIntegerTy(16)) && + (SrcTy->isFloatTy() || SrcTy->isDoubleTy())) { + if (DestTy->isHalfTy()) { + I.mutateType(Type::getInt16Ty(*Ctx)); + } + llvm::Instruction *next = I.getNextNode(); + + if (SrcTy->isDoubleTy()) { + // trunc double to float first, use Intrinsic::convert_to_fp16(double) in future + Src = CastInst::CreateFPCast(Src, Type::getFloatTy(*Ctx), "", next); + } + + // replace this instruction with a intrinsic call + // call i16 @llvm.convert.to.fp16(float %add) + Function *func = getFp32ToFp16(); + CallInst *call = CallInst::Create(func, Src, "", next); + I.replaceAllUsesWith(call); + + I.eraseFromParent(); + } else if (SrcTy->isVectorTy() && DestTy->isVectorTy()) { + VectorType *PTy = dyn_cast(SrcTy); + if (!PTy->getElementType()->isFloatTy() && !PTy->getElementType()->isDoubleTy()) return; + VectorType *PTy2 = dyn_cast(DestTy); + if (!PTy2->getElementType()->isIntegerTy(16)) return; + + Value *v1 = PTy->getElementType()->isFloatTy() ? convertFp32VToFp16V(Src, &I) : convertFp64VToFp16V(Src, &I); + I.replaceAllUsesWith(v1); + I.eraseFromParent(); + } +} + +// fpext half %0 to float +void Fp16InstVisitor::visitFPExtInst(FPExtInst &I) { + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + Type *DestTy = I.getType(); + if ((SrcTy->isHalfTy() || SrcTy->isIntegerTy(16)) && + (DestTy->isFloatTy() || DestTy->isDoubleTy())) { + if (SrcTy->isHalfTy()) { + Src->mutateType(Type::getInt16Ty(*Ctx)); + } + llvm::Instruction *next = I.getNextNode(); + + // replace this instruction with a intrinsic call + // call float @llvm.convert.from.fp16(i16 %1) + Function *func = getFp16ToFp32(); + CallInst *call = CallInst::Create(func, Src, "", next); + + Value *extValue = call; + if (DestTy->isDoubleTy()) { + // ext float to double, use double @Intrinsic::convert_from_fp16() in future + extValue = CastInst::CreateFPCast(extValue, Type::getDoubleTy(*Ctx), "", next); + } + + I.replaceAllUsesWith(extValue); + + I.eraseFromParent(); + } else if (SrcTy->isVectorTy() && DestTy->isVectorTy()) { + VectorType *PTy = dyn_cast(DestTy); + if (!PTy->getElementType()->isFloatTy() && !PTy->getElementType()->isDoubleTy()) return; + VectorType *PTy2 = dyn_cast(SrcTy); + if (!PTy2->getElementType()->isIntegerTy(16)) return; + Value *v1 = PTy->getElementType()->isFloatTy() ? convertFp16VToFp32V(Src, &I) : convertFp16VToFp64V(Src, &I); + I.replaceAllUsesWith(v1); + I.eraseFromParent(); + } +} + +// = uitofp to +void Fp16InstVisitor::visitUIToFPInst(UIToFPInst &I) { + // uitofp to float first, then call i16 @llvm.convert.to.fp16(float %add) + Type *DestTy = I.getType(); + if (DestTy->isHalfTy() || DestTy->isIntegerTy(16)) { + I.mutateType(Type::getInt16Ty(*Ctx)); + + // call uitofp to float + CastInst *ui2fp = CastInst::Create(Instruction::UIToFP, I.getOperand(0), Type::getFloatTy(*Ctx), "", I.getNextNode()); + + // convert float result to fp16 + // call i16 @llvm.convert.to.fp16(float %add) + Function *func = getFp32ToFp16(); + CallInst *call = CallInst::Create(func, ui2fp, "", ui2fp->getNextNode()); + I.replaceAllUsesWith(call); + + I.eraseFromParent(); + } else if (DestTy->isVectorTy()) { + VectorType *PTy = dyn_cast(DestTy); + if (!PTy->getElementType()->isIntegerTy(16)) return; + + VectorType *vecTy = VectorType::get(Type::getFloatTy(*Ctx), PTy->getNumElements()); + // call uitofp to float vector + CastInst *ui2fp = CastInst::Create(Instruction::UIToFP, I.getOperand(0), vecTy, "", I.getNextNode()); + + // convert float vector to half vector + Value *v1 = convertFp32VToFp16V(ui2fp, ui2fp->getNextNode()); + + I.replaceAllUsesWith(v1); + I.eraseFromParent(); + } +} + +// = sitofp to +void Fp16InstVisitor::visitSIToFPInst(SIToFPInst &I) { + // sitofp to float first, then call i16 @llvm.convert.to.fp16(float %add) + Type *DestTy = I.getType(); + if (DestTy->isHalfTy() || DestTy->isIntegerTy(16)) { + I.mutateType(Type::getInt16Ty(*Ctx)); + + // call sitofp to float + CastInst *si2fp = CastInst::Create(Instruction::SIToFP, I.getOperand(0), Type::getFloatTy(*Ctx), "", I.getNextNode()); + + // convert float result to fp16 + // call i16 @llvm.convert.to.fp16(float %add) + Function *func = getFp32ToFp16(); + CallInst *call = CallInst::Create(func, si2fp, "", si2fp->getNextNode()); + I.replaceAllUsesWith(call); + + I.eraseFromParent(); + } else if (DestTy->isVectorTy()) { + VectorType *PTy = dyn_cast(DestTy); + if (!PTy->getElementType()->isIntegerTy(16)) return; + + VectorType *vecTy = VectorType::get(Type::getFloatTy(*Ctx), PTy->getNumElements()); + + // call sitofp to float vector + CastInst *si2fp = CastInst::Create(Instruction::SIToFP, I.getOperand(0), vecTy, "", I.getNextNode()); + + // convert float vector to half vector + Value *v1 = convertFp32VToFp16V(si2fp, si2fp->getNextNode()); + + I.replaceAllUsesWith(v1); + I.eraseFromParent(); + } +} + +// = fptoui to +void Fp16InstVisitor::visitFPToUIInst(FPToUIInst &I) { + // call float @llvm.convert.from.fp16(i16 %1), then call fptoui + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + if (SrcTy->isHalfTy() || SrcTy->isIntegerTy(16)) { + Src->mutateType(Type::getInt16Ty(*Ctx)); + // convert operator to fp32 + // call float @llvm.convert.from.fp16(i16 %1) + Function *func = getFp16ToFp32(); + CallInst *call = CallInst::Create(func, Src, "", &I); + + // call fptoui for float + I.setOperand(0, call); + } else if (SrcTy->isVectorTy()) { + VectorType *PTy = dyn_cast(SrcTy); + if (!PTy->getElementType()->isIntegerTy(16)) return; + + Value *v1 = convertFp16VToFp32V(Src, &I); + // call fptoui for float vector + I.setOperand(0, v1); + } +} + +// = fptosi to +void Fp16InstVisitor::visitFPToSIInst(FPToSIInst &I) { + // call float @llvm.convert.from.fp16(i16 %1), then call fptoui + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + if (SrcTy->isHalfTy() || SrcTy->isIntegerTy(16)) { + Src->mutateType(Type::getInt16Ty(*Ctx)); + + // convert operator to fp32 + // call float @llvm.convert.from.fp16(i16 %1) + Function *func = getFp16ToFp32(); + CallInst *call = CallInst::Create(func, Src, "", &I); + + // call fptosi for float + I.setOperand(0, call); + } else if (SrcTy->isVectorTy()) { + VectorType *PTy = dyn_cast(SrcTy); + if (!PTy->getElementType()->isIntegerTy(16)) return; + + Value *v1 = convertFp16VToFp32V(Src, &I); + // call fptosi for float vector + I.setOperand(0, v1); + } +} + +// = ptrtoint to +void Fp16InstVisitor::visitPtrToIntInst(PtrToIntInst &I) { + // This is handled in type transformation +} + +// = inttoptr to +void Fp16InstVisitor::visitIntToPtrInst(IntToPtrInst &I) { + // This is handled in type transformation +} + +// = bitcast to +void Fp16InstVisitor::visitBitCastInst(BitCastInst &I) { + // convert source half to i16 or i16* + Value *value1 = I.getOperand(0); + if (value1->getType()->isHalfTy()) { + value1->mutateType(Type::getInt16Ty(*Ctx)); + } else if (value1->getType()->isPointerTy()) { + if (value1->getType()->getPointerElementType()->isHalfTy()) { + value1->mutateType(Type::getInt16PtrTy(*Ctx)); + } + } + + // convert dest to i16 or i16* + if (I.getType()->isHalfTy()) { + I.mutateType(Type::getInt16Ty(*Ctx)); + } else if (I.getType()->isPointerTy()) { + if (I.getType()->getPointerElementType()->isHalfTy()) { + I.mutateType(Type::getInt16PtrTy(*Ctx)); + } + } + // bitcast vector type is handled in type transformation. +} + +void Fp16InstVisitor::visitAddrSpaceCastInst(AddrSpaceCastInst &I) { + // FIXME, it is not clear whether half will be used in this instruction +} + +/// ------ Other Operations ------/// + +void Fp16InstVisitor::visitICmpInst(ICmpInst &IC) { + // Won't use half. +} + +// = fcmp oeq float 4.0, 5.0 ; yields: result=false +void Fp16InstVisitor::visitFCmpInst(FCmpInst &FC) { + Value *value1 = FC.getOperand(0); + Value *value2 = FC.getOperand(1); + if (value1->getType()->isHalfTy()) { + value1->mutateType(Type::getInt16Ty(*Ctx)); + } + if (value2->getType()->isHalfTy()) { + value2->mutateType(Type::getInt16Ty(*Ctx)); + } + if (value1->getType()->isIntegerTy(16) && value2->getType()->isIntegerTy(16)) { + // call float @llvm.convert.from.fp16(i16 %) on both operands + Function *func = getFp16ToFp32(); + CallInst *call1 = CallInst::Create(func, value1, "", &FC); + CallInst *call2 = CallInst::Create(func, value2, "", &FC); + FC.setOperand(0, call1); + FC.setOperand(1, call2); + } else if (value1->getType()->isVectorTy() && value2->getType()->isVectorTy()) { + VectorType *PTy = dyn_cast(value1->getType()); + if (!PTy->getElementType()->isIntegerTy(16)) return; + VectorType *PTy2 = dyn_cast(value2->getType()); + if (!PTy2->getElementType()->isIntegerTy(16)) return; + + // convert parameters to float vector + Value *v1 = convertFp16VToFp32V(value1, &FC); + Value *v2 = convertFp16VToFp32V(value2, &FC); + + FC.setOperand(0, v1); + FC.setOperand(1, v2); + } +} + +void Fp16InstVisitor::visitPHINode(PHINode &PN) { + if (PN.getType()->isHalfTy()) { + PN.mutateType(Type::getInt16Ty(*Ctx)); + } +} + +// = select selty , , +void Fp16InstVisitor::visitSelectInst(SelectInst &SI) { + Value *value1 = SI.getOperand(1); + Value *value2 = SI.getOperand(2); + if (value1->getType()->isHalfTy()) { + value1->mutateType(Type::getInt16Ty(*Ctx)); + } + if (value2->getType()->isHalfTy()) { + value2->mutateType(Type::getInt16Ty(*Ctx)); + } + if (SI.getType()->isHalfTy()) { + SI.mutateType(Type::getInt16Ty(*Ctx)); + } +} + +void Fp16InstVisitor::visitCallInst(CallInst &CI) { + unsigned size = CI.getNumArgOperands(); + for (unsigned i = 0; i < size; i++) { + Value *arg = CI.getArgOperand(i); + if (arg && arg->getType()->isHalfTy()) { + arg->mutateType(Type::getInt16Ty(*Ctx)); + } + } +} + +void Fp16InstVisitor::visitVAArgInst(VAArgInst &VAA) { + // Won't handle it since half can't be a function param. +} + +void Fp16InstVisitor::visitLandingPadInst(LandingPadInst &LPI) { + // Won't handle half as an exception class. +} + + +namespace { + class PromoteHalf : public ModulePass { + Fp16InstVisitor halfConv; + public: + static char ID; + explicit PromoteHalf() : ModulePass(ID), halfConv() { + initializePromoteHalfPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + private: + // convert half to i16 for the subtype, Idx is the index. + void convertContainedType(unsigned Idx, Type *ty); + + // convert half to i16 for all derived types in the module. + void convertDerivedTypes(LLVMContext &Ctx); + + // convert half to i16 in composit types + // return true if element type is changed, else return false + bool convertPtr(PointerType *ptrTy); + bool convertVec(VectorType *vecTy); + bool convertArray(ArrayType *arrTy); + bool convertStruct(StructType *stTy); + }; +} + +char PromoteHalf::ID = 0; +INITIALIZE_PASS(PromoteHalf, "promote_fp16", "Promote fp16 to fp32", false, true) + +ModulePass *llvm::createPromoteHalfPass() { + return new PromoteHalf(); +} + +void PromoteHalf::convertContainedType(unsigned Idx, Type *ty) { + assert(Idx < ty->getNumContainedTypes() && "Index out of range!"); + Type** ContainedTys = const_cast(ty->subtype_begin()); + ContainedTys[Idx] = Type::getInt16Ty(ty->getContext()); +} + +bool PromoteHalf::convertPtr(PointerType *ptrTy) { + Type *ty = ptrTy->getElementType(); + if (ty->isHalfTy()) { + convertContainedType(0, ptrTy); + return true; + } + return false; +} + +bool PromoteHalf::convertVec(VectorType *vecTy) { + Type *ty = vecTy->getElementType(); + if (ty->isHalfTy()) { + convertContainedType(0, vecTy); + return true; + } + return false; +} + +bool PromoteHalf::convertArray(ArrayType *arrTy) { + Type *ty = arrTy->getElementType(); + if (ty->isHalfTy()) { + convertContainedType(0, arrTy); + return true; + } + return false; +} + +bool PromoteHalf::convertStruct(StructType *stTy) { + unsigned size = stTy->getNumElements(); + for (unsigned i = 0; i < size; i++) { + Type *ty = stTy->getElementType(i); + if (ty->isHalfTy()) { + convertContainedType(i, stTy); + return true; + } + } + return false; +} + +void PromoteHalf::convertDerivedTypes(LLVMContext &Ctx) { + LLVMContextImpl *pImpl = Ctx.pImpl; + + // DenseMap, VectorType*> VectorTypes; + SmallVector, 32> VectorTypeStack; + for (DenseMap, VectorType*>::iterator I = pImpl->VectorTypes.begin(), + E = pImpl->VectorTypes.end(); + I != E; ++I) { + VectorType *vecTy = I->second; + Type *ty_old = vecTy->getElementType(); + if (convertVec(vecTy)) { + VectorTypeStack.push_back(std::pair(ty_old, vecTy)); + } + } + // remap std::pair and VectorType* + for (unsigned i = 0; i < VectorTypeStack.size(); i++) { + VectorType *ty = VectorTypeStack[i].second; + // keep 'half vector', in case it will be used. + std::pair key_new(ty->getElementType(), ty->getNumElements()); + pImpl->VectorTypes.insert(std::pair, VectorType*>(key_new, ty)); + } + + // StringMap NamedStructTypes; + for (StringMap::iterator I = pImpl->NamedStructTypes.begin(), + E = pImpl->NamedStructTypes.end(); + I != E; ++I) { + StructType *stTy = I->second; + convertStruct(stTy); + } + + // StructTypeMap AnonStructTypes; + for (LLVMContextImpl::StructTypeMap::iterator I = pImpl->AnonStructTypes.begin(), + E = pImpl->AnonStructTypes.end(); + I != E; ++I) { + StructType *stTy = I->first; + convertStruct(stTy); + } + + // DenseMap, ArrayType*> ArrayTypes; + SmallVector, 32> ArrayTypeStack; + for (DenseMap, ArrayType*>::iterator I = pImpl->ArrayTypes.begin(), + E = pImpl->ArrayTypes.end(); + I != E; ++I) { + ArrayType *arrTy = I->second; + Type *ty_old = arrTy->getElementType(); + if (convertArray(arrTy)) { + ArrayTypeStack.push_back(std::pair(ty_old, arrTy)); + } + } + // remap std::pair and ArrayType* + for (unsigned i = 0; i < ArrayTypeStack.size(); i++) { + ArrayType *ty = ArrayTypeStack[i].second; + std::pair key_old(ArrayTypeStack[i].first, ty->getNumElements()); + pImpl->ArrayTypes.erase(key_old); + + std::pair key_new(ty->getElementType(), ty->getNumElements()); + pImpl->ArrayTypes.insert(std::pair, ArrayType*>(key_new, ty)); + } + + // DenseMap PointerTypes; + SmallVector, 32> PointerTypeStack; + for (DenseMap::iterator I = pImpl->PointerTypes.begin(), + E = pImpl->PointerTypes.end(); + I != E; ++I) { + PointerType *ptrTy = I->second; + Type *ty_old = ptrTy->getElementType(); + if (convertPtr(ptrTy)) { + PointerTypeStack.push_back(std::pair(ty_old, ptrTy)); + } + } + // remap Type* and PointerType* + for (unsigned i = 0; i < PointerTypeStack.size(); i++) { + PointerType *ty = PointerTypeStack[i].second; + pImpl->PointerTypes.erase(PointerTypeStack[i].first); + + pImpl->PointerTypes.insert(std::pair(ty->getElementType(), ty)); + } + + // DenseMap, PointerType*> ASPointerTypes; + SmallVector, 32> ASPointerTypeStack; + for (DenseMap, PointerType*>::iterator I = pImpl->ASPointerTypes.begin(), + E = pImpl->ASPointerTypes.end(); + I != E; ++I) { + PointerType *ptrTy = I->second; + Type *ty_old = ptrTy->getElementType(); + if (convertPtr(ptrTy)) { + ASPointerTypeStack.push_back(std::pair(ty_old, ptrTy)); + } + } + // remap std::pair and PointerType* + for (unsigned i = 0; i < ASPointerTypeStack.size(); i++) { + PointerType *ty = ASPointerTypeStack[i].second; + std::pair key_old(ASPointerTypeStack[i].first, ty->getAddressSpace()); + pImpl->VectorTypes.erase(key_old); + + std::pair key_new(ty->getElementType(), ty->getAddressSpace()); + pImpl->ASPointerTypes.insert(std::pair, PointerType*>(key_new, ty)); + } + + // DenseMap FunctionTypeMap; + for (LLVMContextImpl::FunctionTypeMap::iterator I = pImpl->FunctionTypes.begin(), + E = pImpl->FunctionTypes.end(); + I != E; ++I) { + FunctionType *funcTy = I->first; + Type *ty_ret = funcTy->getReturnType(); + if (ty_ret && ty_ret->isHalfTy()) { + convertContainedType(0, funcTy); + } + for (unsigned i = 0; i < funcTy->getNumParams(); i++) { + Type *ty_param = funcTy->getParamType(i); + if (ty_param && ty_param->isHalfTy()) { + convertContainedType(i + 1, funcTy); + } + } + } +} + +bool PromoteHalf::runOnModule(Module &M) { + LLVMContext &Ctx = M.getContext(); + LLVMContextImpl *pImpl = Ctx.pImpl; + + Fp16InstVisitor halfConv; + + // replace half with i16 in all derived types + // primitive half type is handled in 'instruction and variables visit' + convertDerivedTypes(Ctx); + + // change the type of undef value from half to i16 + for (DenseMap::iterator I = pImpl->UVConstants.begin(), + E = pImpl->UVConstants.end(); + I != E; ++I) { + UndefValue *UV = dyn_cast(I->second); + if (UV != NULL && UV->getType()->isHalfTy()) { + UV->mutateType(Type::getInt16Ty(Ctx)); + } + } + + for (LLVMContextImpl::FPMapTy::iterator I = pImpl->FPConstants.begin(), + E = pImpl->FPConstants.end(); + I != E; ++I) { + ConstantFP *CF = dyn_cast(I->second); + if (CF != NULL && CF->getType()->isHalfTy()) { + CF->mutateType(Type::getInt16Ty(Ctx)); + } + } + + // convert ConstantFP(const half) to ConstantInt(i16) + SmallVector WorkStack; + for (LLVMContextImpl::FPMapTy::iterator I = pImpl->FPConstants.begin(), + E = pImpl->FPConstants.end(); + I != E; ++I) { + ConstantFP *CF = dyn_cast(I->second); + if (CF != NULL && CF->getType()->isIntegerTy(16)) { + APInt apVal = CF->getValueAPF().bitcastToAPInt(); + ConstantInt *vInt = ConstantInt::get(Ctx, apVal); + CF->replaceAllUsesWith(vInt); + WorkStack.push_back(I->first); + } + } + for (unsigned i = 0; i < WorkStack.size(); i++) { + pImpl->FPConstants.erase(WorkStack[i]); + } + + bool ret = false; + // go through function to transform half operands/instructions + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { + if (!I->isDeclaration()) { + ret |= !halfConv.transformFunction(*I); + } + } + return ret; +}