diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -140,8 +140,13 @@ /// Returns the value for the given `builtin` variable. This function gets or /// inserts the global variable associated for the builtin within the nearest /// symbol table enclosing `op`. Returns null Value on error. +/// +/// The global name being generated will be mangled using `preffix` and +/// `suffix`. Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, - OpBuilder &builder); + OpBuilder &builder, + StringRef prefix = "__builtin__", + StringRef suffix = "__"); /// Gets the value at the given `offset` of the push constant storage with a /// total of `elementCount` `integerType` integers. A global variable will be diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -702,14 +702,16 @@ } /// Gets name of global variable for a builtin. -static std::string getBuiltinVarName(spirv::BuiltIn builtin) { - return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; +static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, + StringRef suffix) { + return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str(); } /// Gets or inserts a global variable for a builtin within `body` block. static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, - Type integerType, OpBuilder &builder) { + Type integerType, OpBuilder &builder, + StringRef prefix, StringRef suffix) { if (auto varOp = getBuiltinVariable(body, builtin)) return varOp; @@ -725,7 +727,7 @@ case spirv::BuiltIn::GlobalInvocationId: { auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), spirv::StorageClass::Input); - std::string name = getBuiltinVarName(builtin); + std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = builder.create(loc, ptrType, name, builtin); break; @@ -735,7 +737,7 @@ case spirv::BuiltIn::SubgroupSize: { auto ptrType = spirv::PointerType::get(integerType, spirv::StorageClass::Input); - std::string name = getBuiltinVarName(builtin); + std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = builder.create(loc, ptrType, name, builtin); break; @@ -749,8 +751,8 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, - Type integerType, - OpBuilder &builder) { + Type integerType, OpBuilder &builder, + StringRef prefix, StringRef suffix) { Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { op->emitError("expected operation to be within a module-like op"); @@ -759,7 +761,7 @@ spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), - builtin, integerType, builder); + builtin, integerType, builder, prefix, suffix); Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr); }