diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td @@ -46,6 +46,14 @@ def SPIRV_CapabilityArrayAttr : TypedArrayAttrBase< SPIRV_CapabilityAttr, "SPIR-V capability array attribute">; +def SPIRV_LinkageAttributesAttr : SPIRV_Attr<"LinkageAttributes", "linkage_attributes"> { + let parameters = (ins + "std::string":$linkage_name, + "mlir::spirv::LinkageTypeAttr":$linkage_type + ); + let assemblyFormat = "`<` struct(params) `>`"; +} + // Description of cooperative matrix operations supported on the // target. Represents `VkCooperativeMatrixPropertiesNV`. See // https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkCooperativeMatrixPropertiesNV.html diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -15,6 +15,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS #define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS +include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.td" include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/FunctionInterfaces.td" @@ -294,7 +295,8 @@ OptionalAttr:$arg_attrs, OptionalAttr:$res_attrs, StrAttr:$sym_name, - SPIRV_FunctionControlAttr:$function_control + SPIRV_FunctionControlAttr:$function_control, + OptionalAttr:$linkage_attributes ); let results = (outs); @@ -385,7 +387,8 @@ OptionalAttr:$location, OptionalAttr:$binding, OptionalAttr:$descriptor_set, - OptionalAttr:$builtin + OptionalAttr:$builtin, + OptionalAttr:$linkage_attributes ); let results = (outs); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3522,8 +3522,17 @@ } entryPoints[key] = entryPointOp; } else if (auto funcOp = dyn_cast(op)) { - if (funcOp.isExternal()) - return op.emitError("'spirv.module' cannot contain external functions"); + // If the function is external and does not have 'Import' + // linkage_attributes(LinkageAttributes), throw an error. 'Import' + // LinkageAttributes is used to import external functions. + auto linkageAttr = funcOp.getLinkageAttributes(); + auto hasImportLinkage = + linkageAttr && (linkageAttr.value().getLinkageType().getValue() == + spirv::LinkageType::Import); + if (funcOp.isExternal() && !hasImportLinkage) + return op.emitError( + "'spirv.module' cannot contain external functions " + "without 'Import' linkage_attributes (LinkageAttributes)"); // TODO: move this check to spirv.func. for (auto &block : funcOp) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -267,6 +267,27 @@ } typeDecorations[words[0]] = words[2]; break; + case spirv::Decoration::LinkageAttributes: { + if (words.size() < 4) { + return emitError(unknownLoc, "OpDecorate with ") + << decorationName + << " needs at least 1 string and 1 integer literal"; + } + // LinkageAttributes has two parameters ["linkageName", linkageType] + // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import + // "linkageName" is a stringliteral encoded as uint32_t, + // hence the size of name is variable length which results in words.size() + // being variable length, words.size() = 3 + strlen(name)/4 + 1 or + // 3 + ceildiv(strlen(name), 4). + unsigned wordIndex = 2; + auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str(); + auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>( + static_cast<::mlir::spirv::LinkageType>(words[wordIndex++])); + auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>( + linkageName, linkageTypeAttr); + decorations[words[0]].set(symbol, linkageAttr.dyn_cast()); + break; + } case spirv::Decoration::Aliased: case spirv::Decoration::Block: case spirv::Decoration::BufferBlock: @@ -380,6 +401,12 @@ std::string fnName = getFunctionSymbol(fnID); auto funcOp = opBuilder.create( unknownLoc, fnName, functionType, fnControl.value()); + // Processing other function attributes. + if (decorations.count(fnID)) { + for (auto attr : decorations[fnID].getAttrs()) { + funcOp->setAttr(attr.getName(), attr.getValue()); + } + } curFunction = funcMap[fnID] = funcOp; auto *entryBlock = funcOp.addEntryBlock(); LLVM_DEBUG({ @@ -430,6 +457,16 @@ } } + // entryBlock is needed to access the arguments, Once that is done, we can + // erase the block for functions with 'Import' LinkageAttributes, since these + // are essentially function declarations, so they have no body. + auto linkageAttr = funcOp.getLinkageAttributes(); + auto hasImportLinkage = + linkageAttr && (linkageAttr.value().getLinkageType().getValue() == + spirv::LinkageType::Import); + if (hasImportLinkage) + funcOp.eraseBody(); + // RAII guard to reset the insertion point to the module's region after // deserializing the body of this function. OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -18,6 +18,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "spirv-serialization" @@ -208,52 +209,101 @@ if (failed(processName(funcID, op.getName()))) { return failure(); } + // Handle external functions with linkage_attributes(LinkageAttributes) + // differently. + auto linkageAttr = op.getLinkageAttributes(); + auto hasImportLinkage = + linkageAttr && (linkageAttr.value().getLinkageType().getValue() == + spirv::LinkageType::Import); + if (op.isExternal() && !hasImportLinkage) { + return op.emitError( + "'spirv.module' cannot contain external functions " + "without 'Import' linkage_attributes (LinkageAttributes)"); + } else if (op.isExternal() && hasImportLinkage) { + // Add an entry block to set up the block arguments + // to match the signature of the function. + // This is to generate OpFunctionParameter for functions with + // LinkageAttributes. + // WARNING: This operation has side-effect, it essentially adds a body + // to the func. Hence, making it not external anymore (isExternal() + // is going to return false for this function from now on) + // Hence, we'll remove the body once we are done with the serialization. + op.addEntryBlock(); + for (auto arg : op.getArguments()) { + uint32_t argTypeID = 0; + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + auto argValueID = getNextID(); + valueIDMap[arg] = argValueID; + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, + {argTypeID, argValueID}); + } + // Don't need to process the added block, there is nothing to process, + // the fake body was added just to get the arguments, remove the body, + // since it's use is done. + op.eraseBody(); + } else { + // Declare the parameters. + for (auto arg : op.getArguments()) { + uint32_t argTypeID = 0; + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + auto argValueID = getNextID(); + valueIDMap[arg] = argValueID; + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, + {argTypeID, argValueID}); + } - // Declare the parameters. - for (auto arg : op.getArguments()) { - uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + // Some instructions (e.g., OpVariable) in a function must be in the first + // block in the function. These instructions will be put in + // functionHeader. Thus, we put the label in functionHeader first, and + // omit it from the first block. OpLabel only needs to be added for + // functions with body (including empty body). Since, we added a fake body + // for functions with 'Import' Linkage attributes, these functions are + // essentially function delcaration, so they should not have OpLabel and a + // terminating instruction. That's why we skipped it for those functions. + encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, + {getOrCreateBlockID(&op.front())}); + if (failed(processBlock(&op.front(), /*omitLabel=*/true))) + return failure(); + if (failed(visitInPrettyBlockOrder( + &op.front(), [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true))) { return failure(); } - auto argValueID = getNextID(); - valueIDMap[arg] = argValueID; - encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, - {argTypeID, argValueID}); - } - - // Process the body. - if (op.isExternal()) { - return op.emitError("external function is unhandled"); - } - - // Some instructions (e.g., OpVariable) in a function must be in the first - // block in the function. These instructions will be put in functionHeader. - // Thus, we put the label in functionHeader first, and omit it from the first - // block. - encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, - {getOrCreateBlockID(&op.front())}); - if (failed(processBlock(&op.front(), /*omitLabel=*/true))) - return failure(); - if (failed(visitInPrettyBlockOrder( - &op.front(), [&](Block *block) { return processBlock(block); }, - /*skipHeader=*/true))) { - return failure(); - } - // There might be OpPhi instructions who have value references needing to fix. - for (const auto &deferredValue : deferredPhiValues) { - Value value = deferredValue.first; - uint32_t id = getValueID(value); - LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value - << " to id = " << id << '\n'); - assert(id && "OpPhi references undefined value!"); - for (size_t offset : deferredValue.second) - functionBody[offset] = id; + // There might be OpPhi instructions who have value references needing to + // fix. + for (const auto &deferredValue : deferredPhiValues) { + Value value = deferredValue.first; + uint32_t id = getValueID(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value + << " to id = " << id << '\n'); + assert(id && "OpPhi references undefined value!"); + for (size_t offset : deferredValue.second) + functionBody[offset] = id; + } + deferredPhiValues.clear(); } - deferredPhiValues.clear(); - LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() << "' --\n"); + // Insert Decorations based on Function Attributes. + // Only attributes we should be considering for decoration are the + // ::mlir::spirv::Decoration attributes. + + for (auto attr : op->getAttrs()) { + // Only generate OpDecorate op for spirv::Decoration attributes. + auto isValidDecoration = mlir::spirv::symbolizeEnum( + llvm::convertToCamelFromSnakeCase(attr.getName().strref(), + /*capitalizeFirst=*/true)); + if (isValidDecoration != std::nullopt) { + if (failed(processDecoration(op.getLoc(), funcID, attr))) { + return failure(); + } + } + } // Insert OpFunctionEnd. encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -208,7 +208,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { auto attrName = attr.getName().strref(); - auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); + auto decorationName = + llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { return emitError( @@ -218,6 +219,18 @@ } SmallVector args; switch (*decoration) { + case spirv::Decoration::LinkageAttributes: { + // Get the value of the Linkage Attributes + // e.g., LinkageAttributes=["linkageName", linkageType]. + auto linkageAttr = attr.getValue().dyn_cast(); + auto linkageName = linkageAttr.getLinkageName(); + auto linkageType = linkageAttr.getLinkageType().getValue(); + // Encode the Linkage Name (string literal to uint32_t). + spirv::encodeStringLiteralInto(args, linkageName); + // Encode LinkageType & Add the Linkagetype to the args. + args.push_back(static_cast(linkageType)); + break; + } case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: diff --git a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} { + %uchar_0 = spirv.Constant 0 : i8 + %ushort_1 = spirv.Constant 1 : i16 + %uint_0 = spirv.Constant 0 : i32 + spirv.FunctionCall @outside.func.with.linkage(%uchar_0):(i8) -> () + spirv.Return + } + // CHECK: linkage_attributes = #spirv.linkage_attributes> + spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" attributes { + linkage_attributes=#spirv.linkage_attributes< + linkage_name="outside.func", + linkage_type= + > + } + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +} diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -270,6 +270,26 @@ // ----- +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: linkage_attributes = #spirv.linkage_attributes> + spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" attributes { + linkage_attributes=#spirv.linkage_attributes< + linkage_name="outside.func", + linkage_type= + > + } + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +} +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + // expected-error @+1 {{'spirv.module' cannot contain external functions without 'Import' linkage_attributes (LinkageAttributes)}} + spirv.func @outside.func.without.linkage(%arg0 : i8) -> () "Pure" + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +} + +// ----- + // expected-error @+1 {{expected function_control attribute specified as string}} spirv.func @missing_function_control() { spirv.Return } @@ -360,6 +380,19 @@ spirv.GlobalVariable @var0 : !spirv.ptr } +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: linkage_attributes = #spirv.linkage_attributes> + spirv.GlobalVariable @var1 { + linkage_attributes=#spirv.linkage_attributes< + linkage_name="outSideGlobalVar1", + linkage_type= + > + } : !spirv.ptr +} + + // ----- spirv.module Logical GLSL450 { diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -55,4 +55,14 @@ // CHECK: relaxed_precision spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr, Output> } +// ----- +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: linkage_attributes = #spirv.linkage_attributes> + spirv.GlobalVariable @var1 { + linkage_attributes=#spirv.linkage_attributes< + linkage_name="outSideGlobalVar1", + linkage_type= + > + } : !spirv.ptr +} diff --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/SPIRV/function-decorations.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} { + %uchar_0 = spirv.Constant 0 : i8 + %ushort_1 = spirv.Constant 1 : i16 + %uint_0 = spirv.Constant 0 : i32 + spirv.FunctionCall @outside.func.with.linkage(%uchar_0):(i8) -> () + spirv.Return + } + // CHECK: linkage_attributes = #spirv.linkage_attributes> + spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" attributes { + linkage_attributes=#spirv.linkage_attributes< + linkage_name="outside.func", + linkage_type= + > + } + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +} diff --git a/mlir/test/Target/SPIRV/global-variable.mlir b/mlir/test/Target/SPIRV/global-variable.mlir --- a/mlir/test/Target/SPIRV/global-variable.mlir +++ b/mlir/test/Target/SPIRV/global-variable.mlir @@ -34,3 +34,15 @@ spirv.Return } } + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: linkage_attributes = #spirv.linkage_attributes> + spirv.GlobalVariable @var1 { + linkage_attributes=#spirv.linkage_attributes< + linkage_name="outSideGlobalVar1", + linkage_type= + > + } : !spirv.ptr +}