diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -83,8 +83,10 @@ llvm::SmallSet givenCapabilities; /// Allowed capabilities }; -/// Returns a value that represents a builtin variable value within the SPIR-V -/// module. +/// Returns the value for the given `builtin` variable. This function gets or +/// inserts the global variable associated for the builtin within the nearest +/// enclosing op that has a symbol table. Returns null Value if such an +/// enclosing op cannot be found. Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, OpBuilder &builder); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -221,11 +221,11 @@ // Builtin Variables //===----------------------------------------------------------------------===// -/// Look through all global variables in `moduleOp` and check if there is a -/// spv.globalVariable that has the same `builtin` attribute. -static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp, +static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin) { - for (auto varOp : moduleOp.getBlock().getOps()) { + // Look through all global variables in the given `body` block and check if + // there is a spv.globalVariable that has the same `builtin` attribute. + for (auto varOp : body.getOps()) { if (auto builtinAttr = varOp.getAttrOfType( spirv::SPIRVDialect::getAttributeName( spirv::Decoration::BuiltIn))) { @@ -243,16 +243,16 @@ return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__"; } -/// Gets or inserts a global variable for a builtin within a module. +/// Gets or inserts a global variable for a builtin within `body` block. static spirv::GlobalVariableOp -getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc, - spirv::BuiltIn builtin, OpBuilder &builder) { - if (auto varOp = getBuiltinVariable(moduleOp, builtin)) { +getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, + OpBuilder &builder) { + if (auto varOp = getBuiltinVariable(body, builtin)) return varOp; - } - auto ip = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(&moduleOp.getBlock()); - auto name = getBuiltinVarName(builtin); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&body); + spirv::GlobalVariableOp newVarOp; switch (builtin) { case spirv::BuiltIn::NumWorkgroups: @@ -263,6 +263,7 @@ auto ptrType = spirv::PointerType::get( VectorType::get({3}, builder.getIntegerType(32)), spirv::StorageClass::Input); + std::string name = getBuiltinVarName(builtin); newVarOp = builder.create(loc, ptrType, name, builtin); break; @@ -271,22 +272,22 @@ emitError(loc, "unimplemented builtin variable generation for ") << stringifyBuiltIn(builtin); } - builder.restoreInsertionPoint(ip); return newVarOp; } -/// Gets the global variable associated with a builtin and add -/// it if it doesn't exist. Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, OpBuilder &builder) { - auto moduleOp = op->getParentOfType(); - if (!moduleOp) { - op->emitError("expected operation to be within a SPIR-V module"); + Operation *parent = op->getParentOp(); + while (parent && !parent->hasTrait()) + parent = parent->getParentOp(); + if (!parent) { + op->emitError("expected operation to be within a module-like op"); return nullptr; } - spirv::GlobalVariableOp varOp = - getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, builder); + + spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable( + *parent->getRegion(0).begin(), op->getLoc(), builtin, builder); Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr); }