diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H #define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H +#include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/TargetAndABI.h" #include "mlir/Transforms/DialectConversion.h" @@ -27,7 +28,7 @@ /// pointers to structs. class SPIRVTypeConverter : public TypeConverter { public: - SPIRVTypeConverter(); + explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr); /// Gets the SPIR-V correspondence for the standard index type. static Type getIndexType(MLIRContext *context); @@ -40,6 +41,9 @@ /// llvm::None if the memory space does not map to any SPIR-V storage class. static Optional getStorageClassForMemorySpace(unsigned space); + +private: + spirv::TargetEnv targetEnv; }; /// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V. @@ -70,11 +74,10 @@ class SPIRVConversionTarget : public ConversionTarget { public: /// Creates a SPIR-V conversion target for the given target environment. - static std::unique_ptr get(TargetEnvAttr targetEnv, - MLIRContext *context); + static std::unique_ptr get(TargetEnvAttr targetAttr); private: - SPIRVConversionTarget(TargetEnvAttr targetEnv, MLIRContext *context); + explicit SPIRVConversionTarget(TargetEnvAttr targetAttr); // Be explicit that instance of this class cannot be copied or moved: there // are lambdas capturing fields of the instance. @@ -87,9 +90,7 @@ /// environment. bool isLegalOp(Operation *op); - Version givenVersion; /// SPIR-V version to target - llvm::SmallSet givenExtensions; /// Allowed extensions - llvm::SmallSet givenCapabilities; /// Allowed capabilities + TargetEnv targetEnv; }; /// Returns the value for the given `builtin` variable. This function gets or diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallSet.h" namespace mlir { class Operation; @@ -22,6 +23,38 @@ namespace spirv { enum class StorageClass : uint32_t; +/// A wrapper class around a spirv::TargetEnvAttr to provide query methods for +/// allowed version/capabilities/extensions. +class TargetEnv { +public: + explicit TargetEnv(TargetEnvAttr targetAttr); + + Version getVersion(); + + /// Returns true if the given capability is allowed. + bool allows(Capability) const; + /// Returns the first allowed one if any of the given capabilities is allowed. + /// Returns llvm::None otherwise. + Optional allows(ArrayRef) const; + + /// Returns true if the given extension is allowed. + bool allows(Extension) const; + /// Returns the first allowed one if any of the given extensions is allowed. + /// Returns llvm::None otherwise. + Optional allows(ArrayRef) const; + + /// Returns the MLIRContext. + MLIRContext *getContext(); + + /// Allows implicity converting to the underlying spirv::TargetEnvAttr. + operator TargetEnvAttr() const { return targetAttr; } + +private: + TargetEnvAttr targetAttr; + llvm::SmallSet givenExtensions; /// Allowed extensions + llvm::SmallSet givenCapabilities; /// Allowed capabilities +}; + /// Returns the attribute name for specifying argument ABI information. StringRef getInterfaceVarABIAttrName(); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -52,14 +52,15 @@ kernelModules.push_back(builder.clone(*moduleOp.getOperation())); }); - SPIRVTypeConverter typeConverter; + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + spirv::SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); OwningRewritePatternList patterns; populateGPUToSPIRVPatterns(context, typeConverter, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); - std::unique_ptr target = spirv::SPIRVConversionTarget::get( - spirv::lookupTargetEnvOrDefault(module), context); - if (failed(applyFullConversion(kernelModules, *target, patterns, &typeConverter))) { return signalPassFailure(); diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -25,15 +25,15 @@ MLIRContext *context = &getContext(); ModuleOp module = getModule(); - SPIRVTypeConverter typeConverter; + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + spirv::SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); OwningRewritePatternList patterns; populateLinalgToSPIRVPatterns(context, typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); - auto targetEnv = spirv::lookupTargetEnvOrDefault(module); - std::unique_ptr target = - spirv::SPIRVConversionTarget::get(targetEnv, context); - // Allow builtin ops. target->addLegalOp(); target->addDynamicallyLegalOp( diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -31,14 +31,15 @@ MLIRContext *context = &getContext(); ModuleOp module = getModule(); - SPIRVTypeConverter typeConverter; + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); + std::unique_ptr target = + spirv::SPIRVConversionTarget::get(targetAttr); + + SPIRVTypeConverter typeConverter(targetAttr); OwningRewritePatternList patterns; populateStandardToSPIRVPatterns(context, typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); - std::unique_ptr target = spirv::SPIRVConversionTarget::get( - spirv::lookupTargetEnvOrDefault(module), context); - if (failed(applyPartialConversion(module, *target, patterns))) { return signalPassFailure(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -157,7 +157,8 @@ return llvm::None; } -SPIRVTypeConverter::SPIRVTypeConverter() { +SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr) + : targetEnv(targetAttr) { addConversion([](Type type) -> Optional { // If the type is already valid in SPIR-V, directly return. return spirv::SPIRVDialect::isValidType(type) ? type : Optional(); @@ -409,11 +410,10 @@ //===----------------------------------------------------------------------===// std::unique_ptr -spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv, - MLIRContext *context) { +spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { std::unique_ptr target( // std::make_unique does not work here because the constructor is private. - new SPIRVConversionTarget(targetEnv, context)); + new SPIRVConversionTarget(targetAttr)); SPIRVConversionTarget *targetPtr = target.get(); target->addDynamicallyLegalDialect( Optional( @@ -424,80 +424,57 @@ } spirv::SPIRVConversionTarget::SPIRVConversionTarget( - spirv::TargetEnvAttr targetEnv, MLIRContext *context) - : ConversionTarget(*context), givenVersion(targetEnv.getVersion()) { - for (spirv::Extension ext : targetEnv.getExtensions()) - givenExtensions.insert(ext); - - // Add extensions implied by the current version. - for (spirv::Extension ext : spirv::getImpliedExtensions(givenVersion)) - givenExtensions.insert(ext); - - for (spirv::Capability cap : targetEnv.getCapabilities()) { - givenCapabilities.insert(cap); - - // Add capabilities implied by the current capability. - for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) - givenCapabilities.insert(c); - } -} + spirv::TargetEnvAttr targetAttr) + : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} /// Checks that `candidates` extension requirements are possible to be satisfied -/// with the given `allowedExtensions`. +/// with the given `targetEnv`. /// /// `candidates` is a vector of vector for extension requirements following /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) /// convention. static LogicalResult checkExtensionRequirements( - Operation *op, const llvm::SmallSet &allowedExtensions, + Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { for (const auto &ors : candidates) { - auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) { - return allowedExtensions.count(ext); - }); - - if (chosen == ors.end()) { - SmallVector extStrings; - for (spirv::Extension ext : ors) - extStrings.push_back(spirv::stringifyExtension(ext)); - - LLVM_DEBUG(llvm::dbgs() << op->getName() - << "illegal: requires at least one extension in [" - << llvm::join(extStrings, ", ") - << "] but none allowed in target environment\n"); - return failure(); - } + if (targetEnv.allows(ors)) + continue; + + SmallVector extStrings; + for (spirv::Extension ext : ors) + extStrings.push_back(spirv::stringifyExtension(ext)); + + LLVM_DEBUG(llvm::dbgs() << op->getName() + << " illegal: requires at least one extension in [" + << llvm::join(extStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); } return success(); } /// Checks that `candidates`capability requirements are possible to be satisfied -/// with the given `allowedCapabilities`. +/// with the given `isAllowedFn`. /// /// `candidates` is a vector of vector for capability requirements following /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) /// convention. static LogicalResult checkCapabilityRequirements( - Operation *op, - const llvm::SmallSet &allowedCapabilities, + Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { for (const auto &ors : candidates) { - auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) { - return allowedCapabilities.count(cap); - }); - - if (chosen == ors.end()) { - SmallVector capStrings; - for (spirv::Capability cap : ors) - capStrings.push_back(spirv::stringifyCapability(cap)); - - LLVM_DEBUG(llvm::dbgs() - << op->getName() - << "illegal: requires at least one capability in [" - << llvm::join(capStrings, ", ") - << "] but none allowed in target environment\n"); - return failure(); - } + if (targetEnv.allows(ors)) + continue; + + SmallVector capStrings; + for (spirv::Capability cap : ors) + capStrings.push_back(spirv::stringifyCapability(cap)); + + LLVM_DEBUG(llvm::dbgs() << op->getName() + << " illegal: requires at least one capability in [" + << llvm::join(capStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); } return success(); } @@ -507,7 +484,7 @@ // QueryMinVersionInterface/QueryMaxVersionInterface are available to all // SPIR-V versions. if (auto minVersion = dyn_cast(op)) - if (minVersion.getMinVersion() > givenVersion) { + if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring min version " << spirv::stringifyVersion(minVersion.getMinVersion()) @@ -515,7 +492,7 @@ return false; } if (auto maxVersion = dyn_cast(op)) - if (maxVersion.getMaxVersion() < givenVersion) { + if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring max version " << spirv::stringifyVersion(maxVersion.getMaxVersion()) @@ -527,7 +504,7 @@ // implementing QueryExtensionInterface do not require extensions to be // available. if (auto extensions = dyn_cast(op)) - if (failed(checkExtensionRequirements(op, this->givenExtensions, + if (failed(checkExtensionRequirements(op, this->targetEnv, extensions.getExtensions()))) return false; @@ -535,7 +512,7 @@ // implementing QueryCapabilityInterface do not require capabilities to be // available. if (auto capabilities = dyn_cast(op)) - if (failed(checkCapabilityRequirements(op, this->givenCapabilities, + if (failed(checkCapabilityRequirements(op, this->targetEnv, capabilities.getCapabilities()))) return false; @@ -555,14 +532,13 @@ for (Type valueType : valueTypes) { typeExtensions.clear(); valueType.cast().getExtensions(typeExtensions); - if (failed(checkExtensionRequirements(op, this->givenExtensions, - typeExtensions))) + if (failed(checkExtensionRequirements(op, this->targetEnv, typeExtensions))) return false; typeCapabilities.clear(); valueType.cast().getCapabilities(typeCapabilities); - if (failed(checkCapabilityRequirements(op, this->givenCapabilities, - typeCapabilities))) + if (failed( + checkCapabilityRequirements(op, this->targetEnv, typeCapabilities))) return false; } diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -16,6 +16,67 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// TargetEnv +//===----------------------------------------------------------------------===// + +spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr) + : targetAttr(targetAttr) { + for (spirv::Extension ext : targetAttr.getExtensions()) + givenExtensions.insert(ext); + + // Add extensions implied by the current version. + for (spirv::Extension ext : + spirv::getImpliedExtensions(targetAttr.getVersion())) + givenExtensions.insert(ext); + + for (spirv::Capability cap : targetAttr.getCapabilities()) { + givenCapabilities.insert(cap); + + // Add capabilities implied by the current capability. + for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) + givenCapabilities.insert(c); + } +} + +spirv::Version spirv::TargetEnv::getVersion() { + return targetAttr.getVersion(); +} + +bool spirv::TargetEnv::allows(spirv::Capability capability) const { + return givenCapabilities.count(capability); +} + +Optional +spirv::TargetEnv::allows(ArrayRef caps) const { + auto chosen = llvm::find_if(caps, [this](spirv::Capability cap) { + return givenCapabilities.count(cap); + }); + if (chosen != caps.end()) + return *chosen; + return llvm::None; +} + +bool spirv::TargetEnv::allows(spirv::Extension extension) const { + return givenExtensions.count(extension); +} + +Optional +spirv::TargetEnv::allows(ArrayRef exts) const { + auto chosen = llvm::find_if(exts, [this](spirv::Extension ext) { + return givenExtensions.count(ext); + }); + if (chosen != exts.end()) + return *chosen; + return llvm::None; +} + +MLIRContext *spirv::TargetEnv::getContext() { return targetAttr.getContext(); } + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + StringRef spirv::getInterfaceVarABIAttrName() { return "spv.interface_var_abi"; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -224,7 +224,9 @@ spirv::ModuleOp module = getOperation(); MLIRContext *context = &getContext(); - SPIRVTypeConverter typeConverter; + spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module)); + + SPIRVTypeConverter typeConverter(targetEnv); OwningRewritePatternList patterns; patterns.insert(context, typeConverter); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -34,22 +34,18 @@ } // namespace /// Checks that `candidates` extension requirements are possible to be satisfied -/// with the given `allowedExtensions` and updates `deducedExtensions` if so. -/// Emits errors attaching to the given `op` on failures. +/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits +/// errors attaching to the given `op` on failures. /// /// `candidates` is a vector of vector for extension requirements following /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) /// convention. static LogicalResult checkAndUpdateExtensionRequirements( - Operation *op, const llvm::SmallSet &allowedExtensions, + Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates, llvm::SetVector &deducedExtensions) { for (const auto &ors : candidates) { - auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) { - return allowedExtensions.count(ext); - }); - - if (chosen != ors.end()) { + if (Optional chosen = targetEnv.allows(ors)) { deducedExtensions.insert(*chosen); } else { SmallVector extStrings; @@ -66,23 +62,18 @@ } /// Checks that `candidates`capability requirements are possible to be satisfied -/// with the given `allowedCapabilities` and updates `deducedCapabilities` if -/// so. Emits errors attaching to the given `op` on failures. +/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits +/// errors attaching to the given `op` on failures. /// /// `candidates` is a vector of vector for capability requirements following /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) /// convention. static LogicalResult checkAndUpdateCapabilityRequirements( - Operation *op, - const llvm::SmallSet &allowedCapabilities, + Operation *op, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates, llvm::SetVector &deducedCapabilities) { for (const auto &ors : candidates) { - auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) { - return allowedCapabilities.count(cap); - }); - - if (chosen != ors.end()) { + if (Optional chosen = targetEnv.allows(ors)) { deducedCapabilities.insert(*chosen); } else { SmallVector capStrings; @@ -101,32 +92,14 @@ void UpdateVCEPass::runOnOperation() { spirv::ModuleOp module = getOperation(); - spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module); - if (!targetEnv) { + spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module); + if (!targetAttr) { module.emitError("missing 'spv.target_env' attribute"); return signalPassFailure(); } - spirv::Version allowedVersion = targetEnv.getVersion(); - - // Build a set for available extensions in the target environment. - llvm::SmallSet allowedExtensions; - for (spirv::Extension ext : targetEnv.getExtensions()) - allowedExtensions.insert(ext); - - // Add extensions implied by the current version. - for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion)) - allowedExtensions.insert(ext); - - // Build a set for available capabilities in the target environment. - llvm::SmallSet allowedCapabilities; - for (spirv::Capability cap : targetEnv.getCapabilities()) { - allowedCapabilities.insert(cap); - - // Add capabilities implied by the current capability. - for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) - allowedCapabilities.insert(c); - } + spirv::TargetEnv targetEnv(targetAttr); + spirv::Version allowedVersion = targetAttr.getVersion(); spirv::Version deducedVersion = spirv::Version::V_1_0; llvm::SetVector deducedExtensions; @@ -148,15 +121,14 @@ // Op extension requirements if (auto extensions = dyn_cast(op)) - if (failed(checkAndUpdateExtensionRequirements(op, allowedExtensions, - extensions.getExtensions(), - deducedExtensions))) + if (failed(checkAndUpdateExtensionRequirements( + op, targetEnv, extensions.getExtensions(), deducedExtensions))) return WalkResult::interrupt(); // Op capability requirements if (auto capabilities = dyn_cast(op)) if (failed(checkAndUpdateCapabilityRequirements( - op, allowedCapabilities, capabilities.getCapabilities(), + op, targetEnv, capabilities.getCapabilities(), deducedCapabilities))) return WalkResult::interrupt(); @@ -176,13 +148,13 @@ typeExtensions.clear(); valueType.cast().getExtensions(typeExtensions); if (failed(checkAndUpdateExtensionRequirements( - op, allowedExtensions, typeExtensions, deducedExtensions))) + op, targetEnv, typeExtensions, deducedExtensions))) return WalkResult::interrupt(); typeCapabilities.clear(); valueType.cast().getCapabilities(typeCapabilities); if (failed(checkAndUpdateCapabilityRequirements( - op, allowedCapabilities, typeCapabilities, deducedCapabilities))) + op, targetEnv, typeCapabilities, deducedCapabilities))) return WalkResult::interrupt(); } diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-interface.mlir @@ -1,5 +1,12 @@ // RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + // CHECK-LABEL: spv.module spv.module Logical GLSL450 { // CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr, StorageBuffer> @@ -24,4 +31,6 @@ } // CHECK: spv.EntryPoint "GLCompute" [[FN]] // CHECK: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 -} +} // end spv.module + +} // end module diff --git a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir @@ -1,5 +1,12 @@ // RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + // CHECK-LABEL: spv.module spv.module Logical GLSL450 { // CHECK-DAG: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize") @@ -119,4 +126,6 @@ } // CHECK: spv.EntryPoint "GLCompute" [[FN]], [[WORKGROUPID]], [[LOCALINVOCATIONID]], [[NUMWORKGROUPS]], [[WORKGROUPSIZE]] // CHECK-NEXT: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1 -} +} // end spv.module + +} // end module diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -130,7 +130,12 @@ auto targetEnv = fn.getOperation() ->getAttr(spirv::getTargetEnvAttrName()) .cast(); - auto target = spirv::SPIRVConversionTarget::get(targetEnv, context); + if (!targetEnv) { + fn.emitError("missing 'spv.target_env' attribute"); + return signalPassFailure(); + } + + auto target = spirv::SPIRVConversionTarget::get(targetEnv); OwningRewritePatternList patterns; patterns.insert