diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -9,6 +9,7 @@ add_llvm_target(DirectXCodeGen DirectXSubtarget.cpp DirectXTargetMachine.cpp + DXILOpLowering.cpp DXILPointerType.cpp DXILPrepare.cpp PointerTypeAnalysis.cpp diff --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILConstants.h @@ -0,0 +1,29 @@ +//===- DXILConstants.h - Essential DXIL constants -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains essential DXIL constants. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H +#define LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H + +namespace llvm { +namespace DXIL { +// Enumeration for operations specified by DXIL +enum class OpCode : unsigned { + Sin = 13, // returns sine(theta) for theta in radians. +}; +// Groups for DXIL operations with equivalent function templates +enum class OpCodeClass : unsigned { + Unary, +}; + +} // namespace DXIL +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -0,0 +1,279 @@ +//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains passes and utilities to lower llvm intrinsic call +/// to DXILOp function call. +//===----------------------------------------------------------------------===// + +#include "DXILConstants.h" +#include "DirectX.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "dxil-op-lower" + +using namespace llvm; +using namespace llvm::DXIL; + +constexpr StringLiteral DXILOpNamePrefix = "dx.op."; + +enum OverloadKind : uint16_t { + VOID = 1, + HALF = 1 << 1, + FLOAT = 1 << 2, + DOUBLE = 1 << 3, + I1 = 1 << 4, + I8 = 1 << 5, + I16 = 1 << 6, + I32 = 1 << 7, + I64 = 1 << 8, + UserDefineType = 1 << 9, + ObjectType = 1 << 10, +}; + +static const char *getOverloadTypeName(OverloadKind Kind) { + switch (Kind) { + case OverloadKind::HALF: + return "f16"; + case OverloadKind::FLOAT: + return "f32"; + case OverloadKind::DOUBLE: + return "f64"; + case OverloadKind::I1: + return "i1"; + case OverloadKind::I8: + return "i8"; + case OverloadKind::I16: + return "i16"; + case OverloadKind::I32: + return "i32"; + case OverloadKind::I64: + return "i64"; + case OverloadKind::VOID: + case OverloadKind::ObjectType: + case OverloadKind::UserDefineType: + llvm_unreachable("invalid overload type for name"); + break; + } +} + +static OverloadKind getOverloadKind(Type *Ty) { + Type::TypeID T = Ty->getTypeID(); + switch (T) { + case Type::VoidTyID: + return OverloadKind::VOID; + case Type::HalfTyID: + return OverloadKind::HALF; + case Type::FloatTyID: + return OverloadKind::FLOAT; + case Type::DoubleTyID: + return OverloadKind::DOUBLE; + case Type::IntegerTyID: { + IntegerType *ITy = cast(Ty); + unsigned Bits = ITy->getBitWidth(); + switch (Bits) { + case 1: + return OverloadKind::I1; + case 8: + return OverloadKind::I8; + case 16: + return OverloadKind::I16; + case 32: + return OverloadKind::I32; + case 64: + return OverloadKind::I64; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } + } + case Type::PointerTyID: + return OverloadKind::UserDefineType; + case Type::StructTyID: + return OverloadKind::ObjectType; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } +} + +static std::string getTypeName(OverloadKind Kind, Type *Ty) { + if (Kind < OverloadKind::UserDefineType) { + return getOverloadTypeName(Kind); + } else if (Kind == OverloadKind::UserDefineType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else if (Kind == OverloadKind::ObjectType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else { + std::string Str; + raw_string_ostream OS(Str); + Ty->print(OS); + return OS.str(); + } +} + +// Static properties. +struct OpCodeProperty { + DXIL::OpCode OpCode; + // FIXME: change OpCodeName into index to a large string constant when move to + // tableGen. + const char *OpCodeName; + DXIL::OpCodeClass OpCodeClass; + uint16_t OverloadTys; + llvm::Attribute::AttrKind FuncAttr; +}; + +static const char *getOpCodeClassName(const OpCodeProperty &Prop) { + // FIXME: generate this table with tableGen. + static const char *OpCodeClassNames[] = { + "unary", + }; + unsigned Index = static_cast(Prop.OpCodeClass); + assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) && + "Out of bound OpCodeClass"); + return OpCodeClassNames[Index]; +} + +static std::string constructOverloadName(OverloadKind Kind, Type *Ty, + const OpCodeProperty &Prop) { + if (Kind == OverloadKind::VOID) { + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); + } + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + + getTypeName(Kind, Ty)) + .str(); +} + +static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) { + // FIXME: generate this table with tableGen. + static const OpCodeProperty OpCodeProps[] = { + {DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary, + OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone}, + }; + // FIXME: change search to indexing with + // DXILOp once all DXIL op is added. + OpCodeProperty TmpProp; + TmpProp.OpCode = DXILOp; + const OpCodeProperty *Prop = + llvm::lower_bound(OpCodeProps, TmpProp, + [](const OpCodeProperty &A, const OpCodeProperty &B) { + return A.OpCode < B.OpCode; + }); + return Prop; +} + +static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, + Module &M) { + const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); + + // Get return type as overload type for DXILOp. + // Only simple mapping case here, so return type is good enough. + Type *OverloadTy = F.getReturnType(); + + OverloadKind Kind = getOverloadKind(OverloadTy); + // FIXME: find the issue and report error in clang instead of check it in + // backend. + if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { + llvm_unreachable("invalid overload"); + } + + std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); + assert(!M.getFunction(FnName) && "Function already exists"); + + auto &Ctx = M.getContext(); + Type *OpCodeTy = Type::getInt32Ty(Ctx); + + SmallVector ArgTypes; + // DXIL has i32 opcode as first arg. + ArgTypes.emplace_back(OpCodeTy); + FunctionType *FT = F.getFunctionType(); + ArgTypes.append(FT->param_begin(), FT->param_end()); + FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); + return M.getOrInsertFunction(FnName, DXILOpFT); +} + +static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { + auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); + IRBuilder<> B(M.getContext()); + Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); + for (User *U : make_early_inc_range(F.users())) { + CallInst *CI = dyn_cast(U); + if (!CI) + continue; + + SmallVector Args; + Args.emplace_back(DXILOpArg); + Args.append(CI->arg_begin(), CI->arg_end()); + B.SetInsertPoint(CI); + CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); + CI->replaceAllUsesWith(DXILCI); + CI->eraseFromParent(); + } + if (F.user_empty()) + F.eraseFromParent(); +} + +static bool lowerIntrinsics(Module &M) { + bool Updated = false; + static SmallDenseMap LowerMap = { + {Intrinsic::sin, DXIL::OpCode::Sin}}; + for (Function &F : make_early_inc_range(M.functions())) { + if (!F.isDeclaration()) + continue; + Intrinsic::ID ID = F.getIntrinsicID(); + auto LowerIt = LowerMap.find(ID); + if (LowerIt == LowerMap.end()) + continue; + lowerIntrinsic(LowerIt->second, F, M); + Updated = true; + } + return Updated; +} + +namespace { +/// A pass that transforms external global definitions into declarations. +class DXILOpLowering : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + if (lowerIntrinsics(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); + } +}; +} // namespace + +namespace { +class DXILOpLoweringLegacy : public ModulePass { +public: + bool runOnModule(Module &M) override { return lowerIntrinsics(M); } + StringRef getPassName() const override { return "DXIL Op Lowering"; } + DXILOpLoweringLegacy() : ModulePass(ID) {} + + static char ID; // Pass identification. +}; +char DXILOpLoweringLegacy::ID = 0; + +} // end anonymous namespace + +INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", + false, false) +INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, + false) + +ModulePass *llvm::createDXILOpLoweringLegacyPass() { + return new DXILOpLoweringLegacy(); +} diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -23,6 +23,13 @@ /// Pass to convert modules into DXIL-compatable modules ModulePass *createDXILPrepareModulePass(); + +/// Initializer for DXILOpLowering +void initializeDXILOpLoweringLegacyPass(PassRegistry &); + +/// Pass to lowering LLVM intrinsic call to DXIL op function call. +ModulePass *createDXILOpLoweringLegacyPass(); + } // namespace llvm #endif // LLVM_LIB_TARGET_DIRECTX_DIRECTX_H diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -34,6 +34,7 @@ RegisterTargetMachine X(getTheDirectXTarget()); auto *PR = PassRegistry::getPassRegistry(); initializeDXILPrepareModulePass(*PR); + initializeDXILOpLoweringLegacyPass(*PR); } class DXILTargetObjectFile : public TargetLoweringObjectFile { @@ -84,6 +85,7 @@ PassManagerBase &PM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut, CodeGenFileType FileType, bool DisableVerify, MachineModuleInfoWrapperPass *MMIWP) { + PM.add(createDXILOpLoweringLegacyPass()); PM.add(createDXILPrepareModulePass()); switch (FileType) { case CGFT_AssemblyFile: diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/sin.ll @@ -0,0 +1,43 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for sin are generated for float and half. +; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}}) +; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}}) + +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-pc-shadermodel6.7-library" + +; Function Attrs: noinline nounwind optnone +define noundef float @_Z3foof(float noundef %a) #0 { +entry: + %a.addr = alloca float, align 4 + store float %a, ptr %a.addr, align 4 + %0 = load float, ptr %a.addr, align 4 + %1 = call float @llvm.sin.f32(float %0) + ret float %1 +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare float @llvm.sin.f32(float) #1 + +; Function Attrs: noinline nounwind optnone +define noundef half @_Z3barDh(half noundef %a) #0 { +entry: + %a.addr = alloca half, align 2 + store half %a, ptr %a.addr, align 2 + %0 = load half, ptr %a.addr, align 2 + %1 = call half @llvm.sin.f16(half %0) + ret half %1 +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare half @llvm.sin.f16(half) #1 + +attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"} diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -476,7 +476,7 @@ "x86-", "xcore-", "wasm-", "systemz-", "ppc-", "nvvm-", "nvptx-", "mips-", "lanai-", "hexagon-", "bpf-", "avr-", "thumb2-", "arm-", "si-", "gcn-", "amdgpu-", "aarch64-", - "amdgcn-", "polly-", "riscv-"}; + "amdgcn-", "polly-", "riscv-", "dxil-"}; std::vector PassNameContain = {"ehprepare"}; std::vector PassNameExact = { "safe-stack", "cost-model",