diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -9,6 +9,7 @@ add_llvm_target(DirectXCodeGen DirectXSubtarget.cpp DirectXTargetMachine.cpp + DXILOpLowering.cpp DXILPointerType.cpp DXILPrepare.cpp PointerTypeAnalysis.cpp diff --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILConstants.h @@ -0,0 +1,29 @@ +//===- DXILConstants.h - Essential DXIL constants -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains essential DXIL constants. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H +#define LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H + +namespace llvm { +namespace DXIL { +// Enumeration for operations specified by DXIL +enum class OpCode : unsigned { + Sin = 13, // returns sine(theta) for theta in radians. +}; +// Groups for DXIL operations with equivalent function templates +enum class OpCodeClass : unsigned { + Unary, +}; + +} // namespace DXIL +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -0,0 +1,265 @@ +//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains passes and utilities to lower llvm intrinsic call +/// to DXILOp function call. +//===----------------------------------------------------------------------===// + +#include "DXILConstants.h" +#include "DirectX.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "dxil-op-lower" + +using namespace llvm; + +namespace { + +const char *DXILOpNamePrefix = "dx.op."; + +DenseMap LowerMap = { + {Intrinsic::sin, DXIL::OpCode::Sin}}; + +enum OverloadKind : uint16_t { + VOID = 1, + HALF = 1 << 1, + FLOAT = 1 << 2, + DOUBLE = 1 << 3, + I1 = 1 << 4, + I8 = 1 << 5, + I16 = 1 << 6, + I32 = 1 << 7, + I64 = 1 << 8, + UserDefineType = 1 << 9, + ObjectType = 1 << 10, +}; + +const char *getOverloadTypeName(OverloadKind Kind) { + switch (Kind) { + case OverloadKind::HALF: + return "f16"; + case OverloadKind::FLOAT: + return "f32"; + case OverloadKind::DOUBLE: + return "f64"; + case OverloadKind::I1: + return "i1"; + case OverloadKind::I8: + return "i8"; + case OverloadKind::I16: + return "i16"; + case OverloadKind::I32: + return "i32"; + case OverloadKind::I64: + return "i64"; + default: + llvm_unreachable("invalid overload type for name"); + break; + } +} + +OverloadKind getOverloadKind(Type *pType) { + Type::TypeID T = pType->getTypeID(); + switch (T) { + case Type::VoidTyID: + return OverloadKind::VOID; + case Type::HalfTyID: + return OverloadKind::HALF; + case Type::FloatTyID: + return OverloadKind::FLOAT; + case Type::DoubleTyID: + return OverloadKind::DOUBLE; + case Type::IntegerTyID: { + IntegerType *pIT = dyn_cast(pType); + unsigned Bits = pIT->getBitWidth(); + switch (Bits) { + case 1: + return OverloadKind::I1; + case 8: + return OverloadKind::I8; + case 16: + return OverloadKind::I16; + case 32: + return OverloadKind::I32; + case 64: + return OverloadKind::I64; + } + } + case Type::PointerTyID: + return OverloadKind::UserDefineType; + case Type::StructTyID: + return OverloadKind::ObjectType; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } +} + +std::string getTypeName(OverloadKind Kind, Type *Ty) { + if (Kind < OverloadKind::UserDefineType) { + return getOverloadTypeName(Kind); + } else if (Kind == OverloadKind::UserDefineType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else if (Kind == OverloadKind::ObjectType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else { + std::string str; + raw_string_ostream os(str); + Ty->print(os); + os.flush(); + return str; + } +} + +// FIXME: generate this table with tableGen. +const char *OpCodeClassNames[] = { + "unary", +}; + +// Static properties. +struct OpCodeProperty { + DXIL::OpCode OpCode; + const char *OpCodeName; + DXIL::OpCodeClass OpCodeClass; + uint16_t OverloadTys; + llvm::Attribute::AttrKind FuncAttr; +}; + +const char *getOpCodeClassName(const OpCodeProperty &Prop) { + return OpCodeClassNames[static_cast(Prop.OpCodeClass)]; +} + +std::string constructOverloadName(OverloadKind Kind, Type *Ty, + const OpCodeProperty &Prop) { + if (Kind == OverloadKind::VOID) { + return (Twine(DXILOpNamePrefix) + Twine(getOpCodeClassName(Prop))).str(); + } else { + return (Twine(DXILOpNamePrefix) + Twine(getOpCodeClassName(Prop)) + "." + + getTypeName(Kind, Ty)) + .str(); + } +} + +using OC = DXIL::OpCode; +using OCC = DXIL::OpCodeClass; +using OK = OverloadKind; +using AK = Attribute::AttrKind; + +// FIXME: generate this table with tableGen. +const OpCodeProperty OpCodeProps[] = { + {OC::Sin, "Sin", OCC::Unary, OK::FLOAT | OK::HALF, AK::ReadNone}, +}; + +FunctionCallee +createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, + Module &M) { // FIXME: change search to indexing with + // DXILOp once all DXIL op is added. + OpCodeProperty TmpProp; + TmpProp.OpCode = DXILOp; + auto Pos = std::lower_bound( + &OpCodeProps[0], + &OpCodeProps[sizeof(OpCodeProps) / sizeof(OpCodeProperty)], TmpProp, + [](const OpCodeProperty &a, const OpCodeProperty &b) { + return a.OpCode < b.OpCode; + }); + unsigned Index = std::distance(OpCodeProps, Pos); + + const OpCodeProperty &Prop = OpCodeProps[Index]; + + // Get return type as overload type for DXILOp. + // Only simple mapping case here, so return type is good enough. + Type *OverloadTy = F.getReturnType(); + + OverloadKind Kind = getOverloadKind(OverloadTy); + // FIXME: find the issue and report error in clang instead of check it in + // backend. + if (!(Prop.OverloadTys & Kind)) { + llvm_unreachable("invalid overload"); + } + + std::string FnName = constructOverloadName(Kind, OverloadTy, Prop); + assert(M.getFunction(FnName) == nullptr && "function already exist"); + + auto &Ctx = M.getContext(); + Type *OpCodeTy = Type::getInt32Ty(Ctx); + + SmallVector ArgTypes; + // DXIL has i32 opcode as first arg. + ArgTypes.emplace_back(OpCodeTy); + FunctionType *FT = F.getFunctionType(); + ArgTypes.append(FT->param_begin(), FT->param_end()); + FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); + return M.getOrInsertFunction(FnName, DXILOpFT); +} + +void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { + auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); + IRBuilder<> B(M.getContext()); + Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); + for (auto It = F.user_begin(); It != F.user_end();) { + User *U = *(It++); + CallInst *CI = dyn_cast(U); + if (CI == nullptr) { + continue; + } + IRBuilder<> B(CI); + SmallVector Args; + Args.emplace_back(DXILOpArg); + Args.append(CI->arg_begin(), CI->arg_end()); + CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); + CI->replaceAllUsesWith(DXILCI); + CI->eraseFromParent(); + } + if (F.user_empty()) + F.eraseFromParent(); +} + +} // namespace + +namespace { +class DXILOpLowering : public ModulePass { +public: + bool runOnModule(Module &M) override { + bool Updated = false; + for (auto It = M.functions().begin(); It != M.functions().end();) { + Function &F = *(It++); + if (!F.isDeclaration()) + continue; + Intrinsic::ID ID = F.getIntrinsicID(); + auto LowerIt = LowerMap.find(ID); + if (LowerIt == LowerMap.end()) + continue; + lowerIntrinsic(LowerIt->second, F, M); + Updated = true; + } + return Updated; + } + + DXILOpLowering() : ModulePass(ID) {} + + static char ID; // Pass identification. +}; +char DXILOpLowering::ID = 0; + +} // end anonymous namespace + +INITIALIZE_PASS_BEGIN(DXILOpLowering, DEBUG_TYPE, "DXIL Op Lowering", false, + false) +INITIALIZE_PASS_END(DXILOpLowering, DEBUG_TYPE, "DXIL Op Lowering", false, + false) + +ModulePass *llvm::createDXILOpLoweringPass() { return new DXILOpLowering(); } diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -23,6 +23,13 @@ /// Pass to convert modules into DXIL-compatable modules ModulePass *createDXILPrepareModulePass(); + +/// Initializer for DXILOpLowering +void initializeDXILOpLoweringPass(PassRegistry &); + +/// Pass to lowering LLVM intrinsic call to DXIL op function call. +ModulePass *createDXILOpLoweringPass(); + } // namespace llvm #endif // LLVM_LIB_TARGET_DIRECTX_DIRECTX_H diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -34,6 +34,7 @@ RegisterTargetMachine X(getTheDirectXTarget()); auto *PR = PassRegistry::getPassRegistry(); initializeDXILPrepareModulePass(*PR); + initializeDXILOpLoweringPass(*PR); } class DXILTargetObjectFile : public TargetLoweringObjectFile { @@ -84,6 +85,7 @@ PassManagerBase &PM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut, CodeGenFileType FileType, bool DisableVerify, MachineModuleInfoWrapperPass *MMIWP) { + PM.add(createDXILOpLoweringPass()); PM.add(createDXILPrepareModulePass()); switch (FileType) { case CGFT_AssemblyFile: diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/sin.ll @@ -0,0 +1,43 @@ +; RUN: llc %s --filetype=asm -o - | FileCheck %s + +; Make sure dxil operation function calls for sin are generated for float and half. +; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}}) +; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}}) + +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-pc-shadermodel6.7-library" + +; Function Attrs: noinline nounwind optnone +define noundef float @_Z3foof(float noundef %a) #0 { +entry: + %a.addr = alloca float, align 4 + store float %a, ptr %a.addr, align 4 + %0 = load float, ptr %a.addr, align 4 + %1 = call float @llvm.sin.f32(float %0) + ret float %1 +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare float @llvm.sin.f32(float) #1 + +; Function Attrs: noinline nounwind optnone +define noundef half @_Z3barDh(half noundef %a) #0 { +entry: + %a.addr = alloca half, align 2 + store half %a, ptr %a.addr, align 2 + %0 = load half, ptr %a.addr, align 2 + %1 = call half @llvm.sin.f16(half %0) + ret half %1 +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare half @llvm.sin.f16(half) #1 + +attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"}