diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td @@ -16,9 +16,14 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" + //===----------------------------------------------------------------------===// -// GPU target attribute interface. +// GPU target attribute. //===----------------------------------------------------------------------===// +def GPUTargetAttrTrait : NativeTrait<"TargetAttrTrait", ""> { + let cppNamespace = "::mlir::gpu"; +} + def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> { let description = [{ Interface for GPU target attributes. Attributes implementing this interface @@ -42,7 +47,19 @@ ]; } -def GPUTargetArrayAttr : TypedArrayAttrBase()">, + "with the GPU `TargetAttrTrait` trait." +>; + +def GPUTargetAttr : ConfinedAttr { + let description = [{ + Generic GPU target attribute. These attributes must implement the GPU + `TargetAttrInterface` interface or promise the interface. + }]; +} + +def GPUTargetArrayAttr : TypedArrayAttrBase; def GPUNonEmptyTargetArrayAttr : diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationAttrs.td @@ -32,8 +32,9 @@ #gpu.object<#nvvm.target, "..."> ``` }]; - let parameters = (ins "TargetAttrInterface":$target, "StringAttr":$object); + let parameters = (ins "Attribute":$target, "StringAttr":$object); let assemblyFormat = [{`<` $target `,` $object `>`}]; + let genVerifyDecl = 1; } def GPUObjectArrayAttr : diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h --- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h +++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h @@ -24,6 +24,15 @@ class ModuleTranslation; } namespace gpu { +/// This class indicates that the attribute associated with this trait is a GPU +/// target attribute. These kinds of attributes must implement an interface for +/// handling the serialization of GPU Modules into strings. +template +class TargetAttrTrait + : public AttributeTrait::TraitBase { + // TODO: Verify the attribute promises or implements the interface. +}; + /// This class indicates that the attribute associated with this trait is a GPU /// offloading translation attribute. These kinds of attributes must implement /// an interface for handling the translation of GPU offloading operations like diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1619,7 +1619,7 @@ // NVVM target attribute. //===----------------------------------------------------------------------===// -def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> { +def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target", [GPUTargetAttrTrait]> { let description = [{ GPU target attribute for controlling compilation of NVIDIA targets. All parameters decay into default values if not present. diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -404,7 +404,7 @@ //===----------------------------------------------------------------------===// def ROCDL_TargettAttr : - ROCDL_Attr<"ROCDLTarget", "target"> { + ROCDL_Attr<"ROCDLTarget", "target", [GPUTargetAttrTrait]> { let description = [{ ROCDL target attribute for controlling compilation of AMDGPU targets. All parameters decay into default values if not present. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1954,6 +1954,19 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// GPU object attribute +//===----------------------------------------------------------------------===// + +LogicalResult +gpu::ObjectAttr::verify(function_ref emitError, + Attribute target, StringAttr object) { + if (target && !target.hasTrait()) + return emitError() + << "attribute failed to have the `TargetAttrTrait` trait."; + return success(); +} + //===----------------------------------------------------------------------===// // GPU select object attribute //===----------------------------------------------------------------------===// @@ -1967,7 +1980,7 @@ if (intAttr.getInt() < 0) { return emitError() << "The object index must be positive."; } - } else if (!(::mlir::isa(target))) { + } else if (!target.hasTrait()) { return emitError() << "The target attribute must be a GPU Target attribute."; } diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -57,8 +57,6 @@ MLIROpenACCToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation - MLIRNVVMTarget - MLIRROCDLTarget ) add_mlir_translation_library(MLIRTargetLLVMIRImport diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -13,8 +13,6 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Target/LLVM/NVVM/Target.h" -#include "mlir/Target/LLVM/ROCDL/Target.h" #include "mlir/Target/LLVMIR/Dialect/All.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Tools/mlir-translate/Translation.h" @@ -38,8 +36,6 @@ }, [](DialectRegistry ®istry) { registry.insert(); - registerNVVMTarget(registry); - registerROCDLTarget(registry); registerAllToLLVMIRTranslations(registry); }); }