diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td @@ -60,12 +60,15 @@ : Availability { let interfaceName = name; - let queryFnRetType = scheme.returnType; + let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; let queryFnName = "getMinVersion"; - let mergeAction = "$overall = static_cast<" # scheme.returnType # ">(" - "std::max($overall, $instance))"; - let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))"; + let mergeAction = "{ " + "if ($overall.hasValue()) { " + "$overall = static_cast<" # scheme.returnType # ">(" + "std::max(*$overall, $instance)); " + "} else { $overall = $instance; }}"; + let initializer = "::llvm::None"; let instanceType = scheme.cppNamespace # "::" # scheme.className; let instance = scheme.cppNamespace # "::" # scheme.className # "::" # @@ -76,12 +79,15 @@ : Availability { let interfaceName = name; - let queryFnRetType = scheme.returnType; + let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; let queryFnName = "getMaxVersion"; - let mergeAction = "$overall = static_cast<" # scheme.returnType # ">(" - "std::min($overall, $instance))"; - let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))"; + let mergeAction = "{ " + "if ($overall.hasValue()) { " + "$overall = static_cast<" # scheme.returnType # ">(" + "std::min(*$overall, $instance)); " + "} else { $overall = $instance; }}"; + let initializer = "::llvm::None"; let instanceType = scheme.cppNamespace # "::" # scheme.className; let instance = scheme.cppNamespace # "::" # scheme.className # "::" # diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -239,10 +239,12 @@ // TODO: the following interfaces definitions are duplicating with the above. // Remove them once we are able to support dialect-specific contents in ODS. def QueryMinVersionInterface : SPIRVOpInterface<"QueryMinVersionInterface"> { - let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">]; + let methods = [InterfaceMethod< + "", "::llvm::Optional<::mlir::spirv::Version>", "getMinVersion">]; } def QueryMaxVersionInterface : SPIRVOpInterface<"QueryMaxVersionInterface"> { - let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">]; + let methods = [InterfaceMethod< + "", "::llvm::Optional<::mlir::spirv::Version>", "getMaxVersion">]; } def QueryExtensionInterface : SPIRVOpInterface<"QueryExtensionInterface"> { let methods = [InterfaceMethod< diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -843,22 +843,24 @@ // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all // SPIR-V versions. - if (auto minVersion = dyn_cast(op)) - if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { + if (auto minVersionIfx = dyn_cast(op)) { + Optional minVersion = minVersionIfx.getMinVersion(); + if (minVersion && *minVersion > this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring min version " - << spirv::stringifyVersion(minVersion.getMinVersion()) - << "\n"); + << spirv::stringifyVersion(*minVersion) << "\n"); return false; } - if (auto maxVersion = dyn_cast(op)) - if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { + } + if (auto maxVersionIfx = dyn_cast(op)) { + Optional maxVersion = maxVersionIfx.getMaxVersion(); + if (maxVersion && *maxVersion < this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring max version " - << spirv::stringifyVersion(maxVersion.getMaxVersion()) - << "\n"); + << spirv::stringifyVersion(*maxVersion) << "\n"); return false; } + } // Make sure this op's required extensions are allowed to use. Ops not // implementing QueryExtensionInterface do not require extensions to be diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -109,13 +109,17 @@ // requirements. WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { // Op min version requirements - if (auto minVersion = dyn_cast(op)) { - deducedVersion = std::max(deducedVersion, minVersion.getMinVersion()); - if (deducedVersion > allowedVersion) { - return op->emitError("'") << op->getName() << "' requires min version " - << spirv::stringifyVersion(deducedVersion) - << " but target environment allows up to " - << spirv::stringifyVersion(allowedVersion); + if (auto minVersionIfx = dyn_cast(op)) { + Optional minVersion = minVersionIfx.getMinVersion(); + if (minVersion) { + deducedVersion = std::max(deducedVersion, *minVersion); + if (deducedVersion > allowedVersion) { + return op->emitError("'") + << op->getName() << "' requires min version " + << spirv::stringifyVersion(deducedVersion) + << " but target environment allows up to " + << spirv::stringifyVersion(allowedVersion); + } } } diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -43,13 +43,23 @@ auto opName = op->getName(); auto &os = llvm::outs(); - if (auto minVersion = dyn_cast(op)) - os << opName << " min version: " - << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; + if (auto minVersionIfx = dyn_cast(op)) { + Optional minVersion = minVersionIfx.getMinVersion(); + os << opName << " min version: "; + if (minVersion) + os << spirv::stringifyVersion(*minVersion) << "\n"; + else + os << "None\n"; + } - if (auto maxVersion = dyn_cast(op)) - os << opName << " max version: " - << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; + if (auto maxVersionIfx = dyn_cast(op)) { + Optional maxVersion = maxVersionIfx.getMaxVersion(); + os << opName << " max version: "; + if (maxVersion) + os << spirv::stringifyVersion(*maxVersion) << "\n"; + else + os << "None\n"; + } if (auto extension = dyn_cast(op)) { os << opName << " extensions: ["; @@ -81,7 +91,7 @@ } namespace mlir { -void registerPrintOpAvailabilityPass() { +void registerPrintSpirvAvailabilityPass() { PassRegistration(); } } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -31,7 +31,7 @@ namespace mlir { void registerConvertToTargetEnvPass(); void registerPassManagerTestPass(); -void registerPrintOpAvailabilityPass(); +void registerPrintSpirvAvailabilityPass(); void registerShapeFunctionTestPasses(); void registerSideEffectTestPasses(); void registerSliceAnalysisTestPass(); @@ -119,7 +119,7 @@ void registerTestPasses() { registerConvertToTargetEnvPass(); registerPassManagerTestPass(); - registerPrintOpAvailabilityPass(); + registerPrintSpirvAvailabilityPass(); registerShapeFunctionTestPasses(); registerSideEffectTestPasses(); registerSliceAnalysisTestPass();