diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -379,12 +379,25 @@ let arguments = (ins TypeAttr:$type, StrAttr:$sym_name, - OptionalAttr:$initializer + OptionalAttr:$initializer, + OptionalAttr:$location, + OptionalAttr:$binding, + OptionalAttr:$descriptorSet, + OptionalAttr:$builtin ); let results = (outs); let builders = [ + OpBuilder<(ins "TypeAttr":$type, + "StringAttr":$sym_name, + CArg<"FlatSymbolRefAttr", "nullptr">:$initializer), + [{ + $_state.addAttribute("type", type); + $_state.addAttribute(sym_nameAttrName($_state.name), sym_name); + if (initializer) + $_state.addAttribute(initializerAttrName($_state.name), initializer); + }]>, OpBuilder<(ins "TypeAttr":$type, "ArrayRef":$namedAttrs), [{ $_state.addAttribute("type", type); @@ -393,7 +406,16 @@ OpBuilder<(ins "Type":$type, "StringRef":$name, "unsigned":$descriptorSet, "unsigned":$binding)>, OpBuilder<(ins "Type":$type, "StringRef":$name, - "spirv::BuiltIn":$builtin)> + "spirv::BuiltIn":$builtin)>, + OpBuilder<(ins "Type":$type, + "StringRef":$sym_name, + CArg<"FlatSymbolRefAttr", "{}">:$initializer), + [{ + $_state.addAttribute("type", TypeAttr::get(type)); + $_state.addAttribute(sym_nameAttrName($_state.name), $_builder.getStringAttr(sym_name)); + if (initializer) + $_state.addAttribute(initializerAttrName($_state.name), initializer); + }]> ]; let hasOpcode = 0; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -733,17 +733,22 @@ // required by SPIR-V runner. // This is okay because multiple invocations are not supported yet. auto storageClass = srcType.getStorageClass(); - if (storageClass != spirv::StorageClass::Input && - storageClass != spirv::StorageClass::Private && - storageClass != spirv::StorageClass::Output && - storageClass != spirv::StorageClass::StorageBuffer) { + switch (storageClass) { + case spirv::StorageClass::Input: + case spirv::StorageClass::Private: + case spirv::StorageClass::Output: + case spirv::StorageClass::StorageBuffer: + case spirv::StorageClass::UniformConstant: + break; + default: return failure(); } // LLVM dialect spec: "If the global value is a constant, storing into it is - // not allowed.". This corresponds to SPIR-V 'Input' storage class that is - // read-only. - bool isConstant = storageClass == spirv::StorageClass::Input; + // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' + // storage class that is read-only. + bool isConstant = (storageClass == spirv::StorageClass::Input) || + (storageClass == spirv::StorageClass::UniformConstant); // SPIR-V spec: "By default, functions and global variables are private to a // module and cannot be accessed by other modules. However, a module may be // written to export or import functions and global (module scope) @@ -752,9 +757,14 @@ auto linkage = storageClass == spirv::StorageClass::Private ? LLVM::Linkage::Private : LLVM::Linkage::External; - rewriter.replaceOpWithNewOp( + auto newGlobalOp = rewriter.replaceOpWithNewOp( op, dstType, isConstant, linkage, op.sym_name(), Attribute(), /*alignment=*/0); + + // Attach location attribute if applicable + if (op.locationAttr()) + newGlobalOp->setAttr(op.locationAttrName(), op.locationAttr()); + return success(); } }; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -92,7 +92,12 @@ if (!integerValueAttr) { return failure(); } - value = integerValueAttr.getInt(); + + if (integerValueAttr.getType().isSignlessInteger()) + value = integerValueAttr.getInt(); + else + value = integerValueAttr.getSInt(); + return success(); } @@ -2066,8 +2071,7 @@ void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, Type type, StringRef name, unsigned descriptorSet, unsigned binding) { - build(builder, state, TypeAttr::get(type), builder.getStringAttr(name), - nullptr); + build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), builder.getI32IntegerAttr(descriptorSet)); @@ -2079,8 +2083,7 @@ void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, Type type, StringRef name, spirv::BuiltIn builtin) { - build(builder, state, TypeAttr::get(type), builder.getStringAttr(name), - nullptr); + build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn), builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -262,6 +262,7 @@ case spirv::Decoration::NonWritable: case spirv::Decoration::NoPerspective: case spirv::Decoration::Restrict: + case spirv::Decoration::RelaxedPrecision: if (words.size() != 2) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target "; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -241,6 +241,7 @@ case spirv::Decoration::NonWritable: case spirv::Decoration::NoPerspective: case spirv::Decoration::Restrict: + case spirv::Decoration::RelaxedPrecision: // For unit attributes, the args list has no values so we do nothing if (auto unitAttr = attr.second.dyn_cast()) break; diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -67,6 +67,26 @@ } } +spv.module Logical GLSL450 { + // CHECK: llvm.mlir.global external @bar() {location = 1 : i32} : i32 + // CHECK-LABEL: @foo + spv.GlobalVariable @bar {location = 1 : i32} : !spv.ptr + spv.func @foo() "None" { + %0 = spv.mlir.addressof @bar : !spv.ptr + spv.Return + } +} + +spv.module Logical GLSL450 { + // CHECK: llvm.mlir.global external constant @bar() {location = 3 : i32} : f32 + // CHECK-LABEL: @foo + spv.GlobalVariable @bar {descriptor_set = 0 : i32, location = 3 : i32} : !spv.ptr + spv.func @foo() "None" { + %0 = spv.mlir.addressof @bar : !spv.ptr + spv.Return + } +} + //===----------------------------------------------------------------------===// // spv.Load //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -49,3 +49,10 @@ spv.GlobalVariable @var bind(0, 0) {restrict} : !spv.ptr[0])>, StorageBuffer> } +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK: relaxed_precision + spv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spv.ptr, Output> +} +