diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3514,8 +3514,18 @@ } entryPoints[key] = entryPointOp; } else if (auto funcOp = dyn_cast(op)) { - if (funcOp.isExternal()) - return op.emitError("'spirv.module' cannot contain external functions"); + // If the function is external and does not have 'Import' + // linkage_attributes(LinkageAttributes), throw an error; 'Import' + // LinkageAttributes is used to import external functions + auto linkageAttr = funcOp->getAttr("linkage_attributes"); + if (funcOp.isExternal() && + !(linkageAttr && spirv::symbolizeEnum( + linkageAttr.dyn_cast()[1] + .dyn_cast() + .strref()) == spirv::LinkageType::Import)) + return op.emitError( + "'spirv.module' cannot contain external functions " + "without 'Import' linkage_attributes(LinkageAttributes)"); // TODO: move this check to spirv.func. for (auto &block : funcOp) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -267,6 +267,28 @@ } typeDecorations[words[0]] = words[2]; break; + case spirv::Decoration::LinkageAttributes: { + if (words.size() < 4) { + return emitError(unknownLoc, "OpDecorate with ") + << decorationName + << " needs at least 1 string and 1 integer literal"; + } + // LinkageAttributes has two parameters ["Name", "LinkageType"] + // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import + // "name" is a stringliteral encoded as uint32_t, + // hence the size of name is variable length which results in words.size() + // being variable length, words.size() = 3 + strlen(name)/4 + 1 + // uint32_t linkageNameLen = words.size() - 3; + unsigned wordIndex = 2; + StringRef linkageName = spirv::decodeStringLiteral(words, wordIndex); + std::vector attrElements; + attrElements.push_back(opBuilder.getStringAttr(linkageName)); + attrElements.push_back(opBuilder.getStringAttr(stringifyLinkageType( + static_cast(words[words.size() - 1])))); + ArrayAttr linkageAttrVal = opBuilder.getArrayAttr(attrElements); + decorations[words[0]].set(symbol, linkageAttrVal.dyn_cast()); + break; + } case spirv::Decoration::Aliased: case spirv::Decoration::Block: case spirv::Decoration::BufferBlock: @@ -380,6 +402,20 @@ std::string fnName = getFunctionSymbol(fnID); auto funcOp = opBuilder.create( unknownLoc, fnName, functionType, fnControl.value()); + bool isImportedFunc = false; + // Processing other function attributes + if (decorations.count(fnID)) { + for (auto attr : decorations[fnID].getAttrs()) { + funcOp->setAttr(attr.getName(), attr.getValue()); + if (attr.getName() == "linkage_attributes" && + spirv::symbolizeEnum(attr.getValue() + .dyn_cast()[1] + .dyn_cast() + .strref()) == + spirv::LinkageType::Import) + isImportedFunc = true; + } + } curFunction = funcMap[fnID] = funcOp; auto *entryBlock = funcOp.addEntryBlock(); LLVM_DEBUG({ @@ -430,6 +466,11 @@ } } + // entryBlock is needed to access the arguments, Once that is done, we can + // erase the block for functions with 'Import' LinkageAttributes, since these + // are essentially function declarations, so they have no body + if (isImportedFunc) + funcOp.eraseBody(); // RAII guard to reset the insertion point to the module's region after // deserializing the body of this function. OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -18,6 +18,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "spirv-serialization" @@ -208,52 +209,102 @@ if (failed(processName(funcID, op.getName()))) { return failure(); } + // Handle external functions with linkage_attributes(LinkageAttributes) + // differently + if (op.isExternal()) { + auto linkageAttr = op->getAttr("linkage_attributes"); + if (linkageAttr && spirv::symbolizeEnum( + linkageAttr.dyn_cast()[1] + .dyn_cast() + .strref()) == spirv::LinkageType::Import) { + // Add an entry block to set up the block arguments + // to match the signature of the function. + // This is to generate OpFunctionParameter for functions with + // LinkageAttributes + // WARNING: This operation has side-effect, it essentially adds a body + // to the func. Hence, making it not external anymore (isExternal() + // is going to return false for this function from now on) + // Hence, we'll remove the body once we are done with the serialization + op.addEntryBlock(); + for (auto arg : op.getArguments()) { + uint32_t argTypeID = 0; + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + auto argValueID = getNextID(); + valueIDMap[arg] = argValueID; + encodeInstructionInto(functionHeader, + spirv::Opcode::OpFunctionParameter, + {argTypeID, argValueID}); + } + // Don't need to process the added block, there is nothing to process, the + // fake body was added just to get the arguments, remove the body, since + // it's use is done + op.eraseBody(); + } else + return op.emitError( + "'spirv.module' cannot contain external functions " + "without 'Import' linkage_attributes(LinkageAttributes)"); + } else { + // Declare the parameters. + for (auto arg : op.getArguments()) { + uint32_t argTypeID = 0; + if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + return failure(); + } + auto argValueID = getNextID(); + valueIDMap[arg] = argValueID; + encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, + {argTypeID, argValueID}); + } - // Declare the parameters. - for (auto arg : op.getArguments()) { - uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) { + // Some instructions (e.g., OpVariable) in a function must be in the first + // block in the function. These instructions will be put in functionHeader. + // Thus, we put the label in functionHeader first, and omit it from the + // first block. + // OpLabel only needs to be added for functions with body (including empty + // body) Since, we added a fake body for functions with Import Linkage + // attributes, These functions are essentially function delcaration, so they + // should not have OpLabel and a terminating instruction + encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, + {getOrCreateBlockID(&op.front())}); + if (failed(processBlock(&op.front(), /*omitLabel=*/true))) + return failure(); + if (failed(visitInPrettyBlockOrder( + &op.front(), [&](Block *block) { return processBlock(block); }, + /*skipHeader=*/true))) { return failure(); } - auto argValueID = getNextID(); - valueIDMap[arg] = argValueID; - encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter, - {argTypeID, argValueID}); - } - - // Process the body. - if (op.isExternal()) { - return op.emitError("external function is unhandled"); - } - // Some instructions (e.g., OpVariable) in a function must be in the first - // block in the function. These instructions will be put in functionHeader. - // Thus, we put the label in functionHeader first, and omit it from the first - // block. - encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel, - {getOrCreateBlockID(&op.front())}); - if (failed(processBlock(&op.front(), /*omitLabel=*/true))) - return failure(); - if (failed(visitInPrettyBlockOrder( - &op.front(), [&](Block *block) { return processBlock(block); }, - /*skipHeader=*/true))) { - return failure(); - } - - // There might be OpPhi instructions who have value references needing to fix. - for (const auto &deferredValue : deferredPhiValues) { - Value value = deferredValue.first; - uint32_t id = getValueID(value); - LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value - << " to id = " << id << '\n'); - assert(id && "OpPhi references undefined value!"); - for (size_t offset : deferredValue.second) - functionBody[offset] = id; + // There might be OpPhi instructions who have value references needing to + // fix. + for (const auto &deferredValue : deferredPhiValues) { + Value value = deferredValue.first; + uint32_t id = getValueID(value); + LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value + << " to id = " << id << '\n'); + assert(id && "OpPhi references undefined value!"); + for (size_t offset : deferredValue.second) + functionBody[offset] = id; + } + deferredPhiValues.clear(); } - deferredPhiValues.clear(); - LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName() << "' --\n"); + // Insert Decorations based on Function Attributes + // Only attributes we should be considering for decoration are the + // ::mlir::spirv::Decoration attributes + for (auto attr : op->getAttrs()) { + // Only generate OpDecorate op for spirv::Decoration attributes + if (mlir::spirv::symbolizeEnum( + llvm::convertToCamelFromSnakeCase(attr.getName().strref(), + /*capitalizeFirst=*/true)) != + std::nullopt) { + if (failed(processDecoration(op.getLoc(), funcID, attr))) { + return failure(); + } + } + } // Insert OpFunctionEnd. encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {}); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -208,7 +208,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { auto attrName = attr.getName().strref(); - auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); + auto decorationName = + llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { return emitError( @@ -218,6 +219,27 @@ } SmallVector args; switch (*decoration) { + case spirv::Decoration::LinkageAttributes: { + // Get the value of the Linkage Attributes + // e.g., LinkageAttributes=["Name", "LinkageType"] + // TODO: check if attribute values are passed in the following format + // LinkageAttributes=["Name", "LinkageType"] + // At this point we assume, they are passed in this format + auto arrayAttrVal = attr.getValue().dyn_cast(); + if (arrayAttrVal.size() != 2) + return emitError(loc, "attribute must have 2 values ") << attrName; + // Encode the Linkage Name (string literal to uint32_t) + spirv::encodeStringLiteralInto( + args, arrayAttrVal[0].dyn_cast().strref()); + // Encode LinkageType + // Add the Linkagetype to the args + auto linkageType = static_cast( + spirv::symbolizeEnum( + arrayAttrVal[1].dyn_cast().strref()) + .value()); + args.push_back(linkageType); + break; + } case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: diff --git a/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/function-decorations.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} { + %uchar_0 = spirv.Constant 0 : i8 + %ushort_1 = spirv.Constant 1 : i16 + %uint_0 = spirv.Constant 0 : i32 + spirv.FunctionCall @outside.func.with.linkage(%uchar_0):(i8) -> () + spirv.Return + } + // CHECK: linkage_attributes = ["outSideGlobalVar1", "Import"] + spirv.GlobalVariable @var1 {linkage_attributes=["outSideGlobalVar1", "Import"]} : !spirv.ptr + + // CHECK: linkage_attributes = ["outside.func", "Import"] + spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" attributes {linkage_attributes=["outside.func", "Import"]} + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +} diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -270,6 +270,21 @@ // ----- +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: linkage_attributes = ["outside.func", "Import"] + spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" attributes {linkage_attributes=["outside.func", "Import"]} + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +} +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + // expected-error @+1 {{'spirv.module' cannot contain external functions without 'Import' linkage_attributes(LinkageAttributes)}} + spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +} + +// ----- + // expected-error @+1 {{expected function_control attribute specified as string}} spirv.func @missing_function_control() { spirv.Return } @@ -362,6 +377,13 @@ // ----- +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: linkage_attributes = ["outSideGlobalVar1", "Import"] + spirv.GlobalVariable @var1 {linkage_attributes=["outSideGlobalVar1", "Import"]} : !spirv.ptr +} + +// ----- + spirv.module Logical GLSL450 { // expected-error @+1 {{expected spirv.ptr type}} spirv.GlobalVariable @var0 : f32 @@ -405,6 +427,7 @@ } } + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -55,4 +55,9 @@ // CHECK: relaxed_precision spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr, Output> } +// ----- +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK: linkage_attributes = ["outSideGlobalVar1", "Import"] + spirv.GlobalVariable @var1 {linkage_attributes=["outSideGlobalVar1", "Import"]} : !spirv.ptr +} diff --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/SPIRV/function-decorations.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} { + %uchar_0 = spirv.Constant 0 : i8 + %ushort_1 = spirv.Constant 1 : i16 + %uint_0 = spirv.Constant 0 : i32 + spirv.FunctionCall @outside.func.with.linkage(%uchar_0):(i8) -> () + spirv.Return + } + // CHECK: linkage_attributes = ["outside.func", "Import"] + spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" attributes {linkage_attributes=["outside.func", "Import"]} + spirv.func @inside.func() -> () "Pure" attributes {} {spirv.Return} +}