diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -725,6 +725,28 @@ } ``` +## Target environment + +SPIR-V aims to support multiple execution environments as specified by client +APIs. These execution environments affect the availability of certain SPIR-V +features. For example, a [Vulkan 1.1][VulkanSpirv] implementation must support +the 1.0, 1.1, 1.2, and 1.3 versions of SPIR-V and the 1.0 version of the SPIR-V +extended instructions for GLSL. Further Vulkan extensions may enable more SPIR-V +instructions. + +SPIR-V compilation should also take into consideration of the execution +environment, so we generate SPIR-V modules valid for the target environment. +This is conveyed by the `spv.target_env` attribute. It is a triple of + +* `version`: a 32-bit integer indicating the target SPIR-V version. +* `extensions`: a string array attribute containing allowed extensions. +* `capabilities`: a 32-bit integer array attribute containing allowed + capabilities. + +Dialect conversion framework will utilize the information in `spv.target_env` +to properly filter out patterns and ops not available in the target execution +environment. + ## Shader interface (ABI) SPIR-V itself is just expressing computation happening on GPU device. SPIR-V @@ -852,12 +874,18 @@ additional rules are imposed by [Vulkan execution environment][VulkanSpirv]. The lowering described below implements both these requirements.) +### `SPIRVConversionTarget` + +The `mlir::spirv::SPIRVConversionTarget` class derives from the +`mlir::ConversionTarget` class and serves as a utility to define a conversion +target satisfying a given [`spv.target_env`](#target-environment). It registers +proper hooks to check the dynamic legality of SPIR-V ops. Users can further +register other legality constraints into the returned `SPIRVConversionTarget`. -### SPIRVTypeConverter +### `SPIRVTypeConverter` -The `mlir::spirv::SPIRVTypeConverter` derives from -`mlir::TypeConverter` and provides type conversion for standard -types to SPIR-V types: +The `mlir::SPIRVTypeConverter` derives from `mlir::TypeConverter` and provides +type conversion for standard types to SPIR-V types: * [Standard Integer][MlirIntegerType] -> Standard Integer * [Standard Float][MlirFloatType] -> Standard Float @@ -874,11 +902,11 @@ (TODO: Allow for configuring the integer width to use for `index` types in the SPIR-V dialect) -### SPIRVOpLowering +### `SPIRVOpLowering` -`mlir::spirv::SPIRVOpLowering` is a base class that can be used to define the -patterns used for implementing the lowering. For now this only provides derived -classes access to an instance of `mlir::spirv::SPIRVTypeLowering` class. +`mlir::SPIRVOpLowering` is a base class that can be used to define the patterns +used for implementing the lowering. For now this only provides derived classes +access to an instance of `mlir::SPIRVTypeLowering` class. ### Utility functions for lowering diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -13,8 +13,10 @@ #ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H #define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/TargetAndABI.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallSet.h" namespace mlir { @@ -48,7 +50,30 @@ }; namespace spirv { -enum class BuiltIn : uint32_t; +class SPIRVConversionTarget : public ConversionTarget { +public: + /// Creates a SPIR-V conversion target for the given target environment. + static std::unique_ptr get(TargetEnvAttr targetEnv, + MLIRContext *context); + +private: + SPIRVConversionTarget(TargetEnvAttr targetEnv, MLIRContext *context); + + // Be explicit that instance of this class cannot be copied or moved: there + // are lambdas capturing fields of the instance. + SPIRVConversionTarget(const SPIRVConversionTarget &) = delete; + SPIRVConversionTarget(SPIRVConversionTarget &&) = delete; + SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete; + SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete; + + /// Returns true if the given `op` is legal to use under the current target + /// environment. + bool isLegalOp(Operation *op); + + Version givenVersion; /// SPIR-V version to target + llvm::SmallSet givenExtensions; /// Allowed extensions + llvm::SmallSet givenCapabilities; /// Allowed capabilities +}; /// Returns a value that represents a builtin variable value within the SPIR-V /// module. diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -27,21 +27,33 @@ namespace spirv { enum class StorageClass : uint32_t; -/// Attribute name for specifying argument ABI information. +/// Returns the attribute name for specifying argument ABI information. StringRef getInterfaceVarABIAttrName(); -/// Get the InterfaceVarABIAttr given its fields. +/// Gets the InterfaceVarABIAttr given its fields. InterfaceVarABIAttr getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, StorageClass storageClass, MLIRContext *context); -/// Attribute name for specifying entry point information. +/// Returns the attribute name for specifying entry point information. StringRef getEntryPointABIAttrName(); -/// Get the EntryPointABIAttr given its fields. +/// Gets the EntryPointABIAttr given its fields. EntryPointABIAttr getEntryPointABIAttr(ArrayRef localSize, MLIRContext *context); + +/// Returns the attribute name for specifying SPIR-V target environment. +StringRef getTargetEnvAttrName(); + +/// Returns the default target environment: SPIR-V 1.0 with Shader capability +/// and no extra extensions. +TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); + +/// Queries the target environment from the given `op` or returns the default +/// target environment (SPIR-V 1.0 with Shader capability and no extra +/// extensions) if not provided. +TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.td @@ -1,4 +1,4 @@ -//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===// +//===- TargetAndABI.td - SPIR-V Target and ABI definitions -*- tablegen -*-===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,41 +6,54 @@ // //===----------------------------------------------------------------------===// // -// This is the base file for supporting lowering to SPIR-V dialect. This -// file defines SPIR-V attributes used for specifying the shader -// interface or ABI. This is because SPIR-V module is expected to work in -// an execution environment as specified by a client API. A SPIR-V module -// needs to "link" correctly with the execution environment regarding the -// resources that are used in the SPIR-V module and get populated with -// data via the client API. The shader interface (or ABI) is passed into -// SPIR-V lowering path via attributes defined in this file. A -// compilation flow targeting SPIR-V is expected to attach such +// This is the base file for supporting lowering to SPIR-V dialect. This file +// defines SPIR-V attributes used for specifying the shader interface or ABI. +// This is because SPIR-V module is expected to work in an execution environment +// as specified by a client API. A SPIR-V module needs to "link" correctly with +// the execution environment regarding the resources that are used in the SPIR-V +// module and get populated with data via the client API. The shader interface +// (or ABI) is passed into SPIR-V lowering path via attributes defined in this +// file. A compilation flow targeting SPIR-V is expected to attach such // attributes to resources and other suitable places. // //===----------------------------------------------------------------------===// -#ifndef SPIRV_LOWERING -#define SPIRV_LOWERING +#ifndef SPIRV_TARGET_AND_ABI +#define SPIRV_TARGET_AND_ABI include "mlir/Dialect/SPIRV/SPIRVBase.td" // For arguments that eventually map to spv.globalVariable for the // shader interface, this attribute specifies the information regarding -// the global variable : +// the global variable: // 1) Descriptor Set. // 2) Binding number. // 3) Storage class. -def SPV_InterfaceVarABIAttr: - StructAttr<"InterfaceVarABIAttr", SPV_Dialect, - [StructFieldAttr<"descriptor_set", I32Attr>, - StructFieldAttr<"binding", I32Attr>, - StructFieldAttr<"storage_class", SPV_StorageClassAttr>]>; +def SPV_InterfaceVarABIAttr : StructAttr<"InterfaceVarABIAttr", SPV_Dialect, [ + StructFieldAttr<"descriptor_set", I32Attr>, + StructFieldAttr<"binding", I32Attr>, + StructFieldAttr<"storage_class", SPV_StorageClassAttr> +]>; // For entry functions, this attribute specifies information related to entry // points in the generated SPIR-V module: // 1) WorkGroup Size. -def SPV_EntryPointABIAttr: - StructAttr<"EntryPointABIAttr", SPV_Dialect, - [StructFieldAttr<"local_size", I32ElementsAttr>]>; +def SPV_EntryPointABIAttr : StructAttr<"EntryPointABIAttr", SPV_Dialect, [ + StructFieldAttr<"local_size", I32ElementsAttr> +]>; -#endif // SPIRV_LOWERING +def SPV_ExtensionArrayAttr : TypedArrayAttrBase< + SPV_ExtensionAttr, "SPIR-V extension array attribute">; + +def SPV_CapabilityArrayAttr : TypedArrayAttrBase< + SPV_CapabilityAttr, "SPIR-V capability array attribute">; + +// For the generated SPIR-V module, this attribute specifies the target version, +// allowed extensions and capabilities. +def SPV_TargetEnvAttr : StructAttr<"TargetEnvAttr", SPV_Dialect, [ + StructFieldAttr<"version", SPV_VersionAttr>, + StructFieldAttr<"extensions", SPV_ExtensionArrayAttr>, + StructFieldAttr<"capabilities", SPV_CapabilityArrayAttr> +]>; + +#endif // SPIRV_TARGET_AND_ABI diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -55,8 +55,8 @@ } // namespace void GPUToSPIRVPass::runOnModule() { - auto context = &getContext(); - auto module = getModule(); + MLIRContext *context = &getContext(); + ModuleOp module = getModule(); SmallVector kernelModules; OpBuilder builder(context); @@ -76,12 +76,12 @@ populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize); populateStandardToSPIRVPatterns(context, typeConverter, patterns); - ConversionTarget target(*context); - target.addLegalDialect(); - target.addDynamicallyLegalOp( + std::unique_ptr target = spirv::SPIRVConversionTarget::get( + spirv::lookupTargetEnvOrDefault(module), context); + target->addDynamicallyLegalOp( [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); - if (failed(applyFullConversion(kernelModules, target, patterns, + if (failed(applyFullConversion(kernelModules, *target, patterns, &typeConverter))) { return signalPassFailure(); } diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp @@ -64,19 +64,20 @@ } void ConvertStandardToSPIRVPass::runOnModule() { - OwningRewritePatternList patterns; - auto context = &getContext(); - auto module = getModule(); + MLIRContext *context = &getContext(); + ModuleOp module = getModule(); SPIRVTypeConverter typeConverter; + OwningRewritePatternList patterns; populateStandardToSPIRVPatterns(context, typeConverter, patterns); patterns.insert(context, typeConverter); - ConversionTarget target(*(module.getContext())); - target.addLegalDialect(); - target.addDynamicallyLegalOp( + + std::unique_ptr target = spirv::SPIRVConversionTarget::get( + spirv::lookupTargetEnvOrDefault(module), context); + target->addDynamicallyLegalOp( [&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); }); - if (failed(applyPartialConversion(module, target, patterns))) { + if (failed(applyPartialConversion(module, *target, patterns))) { return signalPassFailure(); } } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -648,15 +648,26 @@ StringRef symbol = attribute.first.strref(); Attribute attr = attribute.second; - if (symbol != spirv::getEntryPointABIAttrName()) + // TODO(antiagainst): figure out a way to generate the description from the + // StructAttr definition. + if (symbol == spirv::getEntryPointABIAttrName()) { + if (!attr.isa()) + return op->emitError("'") + << symbol + << "' attribute must be a dictionary attribute containing one " + "32-bit integer elements attribute: 'local_size'"; + } else if (symbol == spirv::getTargetEnvAttrName()) { + if (!attr.isa()) + return op->emitError("'") + << symbol + << "' must be a dictionary attribute containing one 32-bit " + "integer attribute 'version', one string array attribute " + "'extensions', and one 32-bit integer array attribute " + "'capabilities'"; + } else { return op->emitError("found unsupported '") << symbol << "' attribute on operation"; - - if (!spirv::EntryPointABIAttr::classof(attr)) - return op->emitError("'") - << symbol - << "' attribute must be a dictionary attribute containing one " - "integer elements attribute: 'local_size'"; + } return success(); } @@ -673,11 +684,11 @@ << symbol << "' attribute on region " << (forArg ? "argument" : "result"); - if (!spirv::InterfaceVarABIAttr::classof(attr)) + if (!attr.isa()) return emitError(loc, "'") << symbol << "' attribute must be a dictionary attribute containing three " - "integer attributes: 'descriptor_set', 'binding', and " + "32-bit integer attributes: 'descriptor_set', 'binding', and " "'storage_class'"; return success(); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -15,6 +15,11 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "llvm/ADT/Sequence.h" +#include "llvm/Support/Debug.h" + +#include + +#define DEBUG_TYPE "mlir-spirv-lowering" using namespace mlir; @@ -214,3 +219,93 @@ funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); return success(); } + +//===----------------------------------------------------------------------===// +// SPIR-V ConversionTarget +//===----------------------------------------------------------------------===// + +std::unique_ptr +spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv, + MLIRContext *context) { + std::unique_ptr target( + // std::make_unique does not work here because the constructor is private. + new SPIRVConversionTarget(targetEnv, context)); + SPIRVConversionTarget *targetPtr = target.get(); + target->addDynamicallyLegalDialect( + Optional( + // We need to capture the raw pointer here because it is stable: + // target will be destroyed once this function is returned. + [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); })); + return target; +} + +spirv::SPIRVConversionTarget::SPIRVConversionTarget( + spirv::TargetEnvAttr targetEnv, MLIRContext *context) + : ConversionTarget(*context), + givenVersion(static_cast(targetEnv.version().getInt())) { + for (Attribute extAttr : targetEnv.extensions()) + givenExtensions.insert( + *spirv::symbolizeExtension(extAttr.cast().getValue())); + + for (Attribute capAttr : targetEnv.capabilities()) + givenCapabilities.insert( + static_cast(capAttr.cast().getInt())); +} + +bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { + // Make sure this op is available at the given version. Ops not implementing + // QueryMinVersionInterface/QueryMaxVersionInterface are available to all + // SPIR-V versions. + if (auto minVersion = dyn_cast(op)) + if (minVersion.getMinVersion() > givenVersion) { + LLVM_DEBUG(llvm::dbgs() + << op->getName() << " illegal: requiring min version " + << spirv::stringifyVersion(minVersion.getMinVersion()) + << "\n"); + return false; + } + if (auto maxVersion = dyn_cast(op)) + if (maxVersion.getMaxVersion() < givenVersion) { + LLVM_DEBUG(llvm::dbgs() + << op->getName() << " illegal: requiring max version " + << spirv::stringifyVersion(maxVersion.getMaxVersion()) + << "\n"); + return false; + } + + // Make sure this op's required extensions are allowed to use. For each op, + // we return a vector of vector for its extension requirements following + // ((Extension::A OR Extenion::B) AND (Extension::C OR Extension::D)) + // convention. Ops not implementing QueryExtensionInterface do not require + // extensions to be available. + if (auto extensions = dyn_cast(op)) { + auto exts = extensions.getExtensions(); + for (const auto &ors : exts) + if (llvm::all_of(ors, [this](spirv::Extension ext) { + return this->givenExtensions.count(ext) == 0; + })) { + LLVM_DEBUG(llvm::dbgs() << op->getName() + << " illegal: missing required extension\n"); + return false; + } + } + + // Make sure this op's required extensions are allowed to use. For each op, + // we return a vector of vector for its capability requirements following + // ((Capability::A OR Extenion::B) AND (Capability::C OR Capability::D)) + // convention. Ops not implementing QueryExtensionInterface do not require + // extensions to be available. + if (auto capabilities = dyn_cast(op)) { + auto caps = capabilities.getCapabilities(); + for (const auto &ors : caps) + if (llvm::all_of(ors, [this](spirv::Capability cap) { + return this->givenCapabilities.count(cap) == 0; + })) { + LLVM_DEBUG(llvm::dbgs() << op->getName() + << " illegal: missing required capability\n"); + return false; + } + } + + return true; +}; diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -1,4 +1,4 @@ -//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===// +//===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,6 +8,7 @@ #include "mlir/Dialect/SPIRV/TargetAndABI.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" using namespace mlir; @@ -16,32 +17,48 @@ #include "mlir/Dialect/SPIRV/TargetAndABI.cpp.inc" } -StringRef mlir::spirv::getInterfaceVarABIAttrName() { +StringRef spirv::getInterfaceVarABIAttrName() { return "spv.interface_var_abi"; } -mlir::spirv::InterfaceVarABIAttr -mlir::spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, - spirv::StorageClass storageClass, - MLIRContext *context) { +spirv::InterfaceVarABIAttr +spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, + spirv::StorageClass storageClass, + MLIRContext *context) { Type i32Type = IntegerType::get(32, context); - return mlir::spirv::InterfaceVarABIAttr::get( + return spirv::InterfaceVarABIAttr::get( IntegerAttr::get(i32Type, descriptorSet), IntegerAttr::get(i32Type, binding), IntegerAttr::get(i32Type, static_cast(storageClass)), context); } -StringRef mlir::spirv::getEntryPointABIAttrName() { - return "spv.entry_point_abi"; -} +StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; } -mlir::spirv::EntryPointABIAttr -mlir::spirv::getEntryPointABIAttr(ArrayRef localSize, - MLIRContext *context) { +spirv::EntryPointABIAttr +spirv::getEntryPointABIAttr(ArrayRef localSize, MLIRContext *context) { assert(localSize.size() == 3); - return mlir::spirv::EntryPointABIAttr::get( + return spirv::EntryPointABIAttr::get( DenseElementsAttr::get( VectorType::get(3, IntegerType::get(32, context)), localSize) .cast(), context); } + +StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; } + +spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) { + Builder builder(context); + return spirv::TargetEnvAttr::get( + builder.getI32IntegerAttr(static_cast(spirv::Version::V_1_0)), + builder.getI32ArrayAttr({}), + builder.getI32ArrayAttr( + {static_cast(spirv::Capability::Shader)}), + context); +} + +spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { + if (auto attr = op->getAttrOfType( + spirv::getTargetEnvAttrName())) + return attr; + return getDefaultTargetEnv(op->getContext()); +} diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -217,16 +217,16 @@ OwningRewritePatternList patterns; patterns.insert(context, typeConverter); - ConversionTarget target(*context); - target.addLegalDialect(); + std::unique_ptr target = spirv::SPIRVConversionTarget::get( + spirv::lookupTargetEnvOrDefault(module), context); auto entryPointAttrName = spirv::getEntryPointABIAttrName(); - target.addDynamicallyLegalOp([&](FuncOp op) { + target->addDynamicallyLegalOp([&](FuncOp op) { return op.getAttrOfType(entryPointAttrName) && op.getNumResults() == 0 && op.getNumArguments() == 0; }); - target.addLegalOp(); + target->addLegalOp(); if (failed( - applyPartialConversion(module, target, patterns, &typeConverter))) { + applyPartialConversion(module, *target, patterns, &typeConverter))) { return signalPassFailure(); } diff --git a/mlir/test/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/Dialect/SPIRV/TestAvailability.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Function.h" @@ -13,14 +14,18 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Printing op availability pass +//===----------------------------------------------------------------------===// + namespace { /// A pass for testing SPIR-V op availability. -struct TestAvailability : public FunctionPass { +struct PrintOpAvailability : public FunctionPass { void runOnFunction() override; }; } // end anonymous namespace -void TestAvailability::runOnFunction() { +void PrintOpAvailability::runOnFunction() { auto f = getFunction(); llvm::outs() << f.getName() << "\n"; @@ -70,5 +75,105 @@ }); } -static PassRegistration pass("test-spirv-op-availability", - "Test SPIR-V op availability"); +static PassRegistration + printOpAvailabilityPass("test-spirv-op-availability", + "Test SPIR-V op availability"); + +//===----------------------------------------------------------------------===// +// Converting target environment pass +//===----------------------------------------------------------------------===// + +namespace { +/// A pass for testing SPIR-V op availability. +struct ConvertToTargetEnv : public FunctionPass { + void runOnFunction() override; +}; + +struct ConvertToAtomCmpExchangeWeak : public RewritePattern { + ConvertToAtomCmpExchangeWeak(MLIRContext *context); + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +struct ConvertToGroupNonUniformBallot : public RewritePattern { + ConvertToGroupNonUniformBallot(MLIRContext *context); + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +struct ConvertToSubgroupBallot : public RewritePattern { + ConvertToSubgroupBallot(MLIRContext *context); + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; +} // end anonymous namespace + +void ConvertToTargetEnv::runOnFunction() { + MLIRContext *context = &getContext(); + FuncOp fn = getFunction(); + + auto targetEnv = fn.getOperation() + ->getAttr(spirv::getTargetEnvAttrName()) + .cast(); + auto target = spirv::SPIRVConversionTarget::get(targetEnv, context); + + OwningRewritePatternList patterns; + patterns.insert(context); + + if (failed(applyPartialConversion(fn, *target, patterns))) + return signalPassFailure(); +} + +ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) + : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", + {"spv.AtomicCompareExchangeWeak"}, 1, context) {} + +PatternMatchResult +ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + Value ptr = op->getOperand(0); + Value value = op->getOperand(1); + Value comparator = op->getOperand(2); + + // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in + // memory semantics to additionally require AtomicStorage capability. + rewriter.replaceOpWithNewOp( + op, value->getType(), ptr, spirv::Scope::Workgroup, + spirv::MemorySemantics::AcquireRelease | + spirv::MemorySemantics::AtomicCounterMemory, + spirv::MemorySemantics::Acquire, value, comparator); + return matchSuccess(); +} + +ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( + MLIRContext *context) + : RewritePattern("test.convert_to_group_non_uniform_ballot_op", + {"spv.GroupNonUniformBallot"}, 1, context) {} + +PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + Value predicate = op->getOperand(0); + + rewriter.replaceOpWithNewOp( + op, op->getResult(0)->getType(), spirv::Scope::Workgroup, predicate); + return matchSuccess(); +} + +ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) + : RewritePattern("test.convert_to_subgroup_ballot_op", + {"spv.SubgroupBallotKHR"}, 1, context) {} + +PatternMatchResult +ConvertToSubgroupBallot::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + Value predicate = op->getOperand(0); + + rewriter.replaceOpWithNewOp( + op, op->getResult(0)->getType(), predicate); + return matchSuccess(); +} + +static PassRegistration + convertToTargetEnvPass("test-spirv-target-env", + "Test SPIR-V target environment"); diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir --- a/mlir/test/Dialect/SPIRV/target-and-abi.mlir +++ b/mlir/test/Dialect/SPIRV/target-and-abi.mlir @@ -26,14 +26,14 @@ // spv.entry_point_abi //===----------------------------------------------------------------------===// -// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}} +// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one 32-bit integer elements attribute: 'local_size'}} func @spv_entry_point() attributes { spv.entry_point_abi = 64 } { return } // ----- -// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}} +// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one 32-bit integer elements attribute: 'local_size'}} func @spv_entry_point() attributes { spv.entry_point_abi = {local_size = 64} } { return } @@ -51,14 +51,14 @@ // spv.interface_var_abi //===----------------------------------------------------------------------===// -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} +// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} func @interface_var( %arg0 : f32 {spv.interface_var_abi = 64} ) { return } // ----- -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} +// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} func @interface_var( %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}} ) { return } @@ -74,7 +74,7 @@ // ----- -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} +// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} func @interface_var() -> (f32 {spv.interface_var_abi = 64}) { %0 = constant 10.0 : f32 @@ -83,7 +83,7 @@ // ----- -// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} +// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three 32-bit integer attributes: 'descriptor_set', 'binding', and 'storage_class'}} func @interface_var() -> (f32 {spv.interface_var_abi = {binding = 0: i32}}) { %0 = constant 10.0 : f32 @@ -99,3 +99,49 @@ %0 = constant 10.0 : f32 return %0: f32 } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.target_env +//===----------------------------------------------------------------------===// + +// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}} +func @target_env_wrong_type() attributes { + spv.target_env = 64 +} { return } + +// ----- + +// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}} +func @target_env_missing_fields() attributes { + spv.target_env = {version = 0: i32} +} { return } + +// ----- + +// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}} +func @target_env_wrong_extension_type() attributes { + spv.target_env = {version = 0: i32, extensions = [32: i32], capabilities = [1: i32]} +} { return } + +// ----- + +// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}} +func @target_env_wrong_extension() attributes { + spv.target_env = {version = 0: i32, extensions = ["SPV_Something"], capabilities = [1: i32]} +} { return } + +// ----- + +func @target_env() attributes { + // CHECK: spv.target_env = {capabilities = [1 : i32], extensions = ["SPV_KHR_storage_buffer_storage_class"], version = 0 : i32} + spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_storage_buffer_storage_class"], capabilities = [1: i32]} +} { return } + +// ----- + +// expected-error @+1 {{'spv.target_env' must be a dictionary attribute containing one 32-bit integer attribute 'version', one string array attribute 'extensions', and one 32-bit integer array attribute 'capabilities'}} +func @target_env_extra_fields() attributes { + spv.target_env = {version = 0: i32, extensions = ["SPV_KHR_storage_buffer_storage_class"], capabilities = [1: i32], extra = 32} +} { return } diff --git a/mlir/test/Dialect/SPIRV/target-env.mlir b/mlir/test/Dialect/SPIRV/target-env.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/target-env.mlir @@ -0,0 +1,120 @@ +// RUN: mlir-opt -disable-pass-threading -test-spirv-target-env %s | FileCheck %s + +// Note: The following tests check that a spv.target_env can properly control +// the conversion target and filter unavailable ops during the conversion. +// We don't care about the op argument consistency too much; so certain enum +// values for enum attributes may not make much sense for the test op. + +// spv.AtomicCompareExchangeWeak is available from SPIR-V 1.0 to 1.3 under +// Kernel capability. +// spv.AtomicCompareExchangeWeak has two memory semantics enum attribute, +// whose value, if containing AtomicCounterMemory bit, additionally requires +// AtomicStorage capability. + +// spv.GroupNonUniformBallot is available starting from SPIR-V 1.3 under +// GroupNonUniform capability. + +// spv.SubgroupBallotKHR is available under in all SPIR-V versions under +// SubgroupBallotKHR capability and SPV_KHR_shader_ballot extension. + +// Enum case symbol (value) map: +// Version: 1.0 (0), 1.1 (1), 1.2 (2), 1.3 (3), 1.4 (4) +// Capability: Kernel (6), AtomicStorage (21), GroupNonUniformBallot (64), +// SubgroupBallotKHR (4423) + +//===----------------------------------------------------------------------===// +// MaxVersion +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @cmp_exchange_weak_suitable_version_capabilities +func @cmp_exchange_weak_suitable_version_capabilities(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { + spv.target_env = {version = 1: i32, extensions = [], capabilities = [6: i32, 21: i32]} +} { + // CHECK: spv.AtomicCompareExchangeWeak "Workgroup" "AcquireRelease|AtomicCounterMemory" "Acquire" + %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @cmp_exchange_weak_unsupported_version +func @cmp_exchange_weak_unsupported_version(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { + spv.target_env = {version = 4: i32, extensions = [], capabilities = [6: i32, 21: i32]} +} { + // CHECK: test.convert_to_atomic_compare_exchange_weak_op + %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) + return %0: i32 +} + +//===----------------------------------------------------------------------===// +// MinVersion +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @group_non_uniform_ballot_suitable_version +func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes { + spv.target_env = {version = 4: i32, extensions = [], capabilities = [64: i32]} +} { + // CHECK: spv.GroupNonUniformBallot "Workgroup" + %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>) + return %0: vector<4xi32> +} + +// CHECK-LABEL: @group_non_uniform_ballot_unsupported_version +func @group_non_uniform_ballot_unsupported_version(%predicate: i1) -> vector<4xi32> attributes { + spv.target_env = {version = 1: i32, extensions = [], capabilities = [64: i32]} +} { + // CHECK: test.convert_to_group_non_uniform_ballot_op + %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>) + return %0: vector<4xi32> +} + +//===----------------------------------------------------------------------===// +// Capability +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @cmp_exchange_weak_missing_capability_kernel +func @cmp_exchange_weak_missing_capability_kernel(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { + spv.target_env = {version = 3: i32, extensions = [], capabilities = [21: i32]} +} { + // CHECK: test.convert_to_atomic_compare_exchange_weak_op + %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @cmp_exchange_weak_missing_capability_atomic_storage +func @cmp_exchange_weak_missing_capability_atomic_storage(%ptr: !spv.ptr, %value: i32, %comparator: i32) -> i32 attributes { + spv.target_env = {version = 3: i32, extensions = [], capabilities = [6: i32]} +} { + // CHECK: test.convert_to_atomic_compare_exchange_weak_op + %0 = "test.convert_to_atomic_compare_exchange_weak_op"(%ptr, %value, %comparator): (!spv.ptr, i32, i32) -> (i32) + return %0: i32 +} + +// CHECK-LABEL: @subgroup_ballot_missing_capability +func @subgroup_ballot_missing_capability(%predicate: i1) -> vector<4xi32> attributes { + spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = []} +} { + // CHECK: test.convert_to_subgroup_ballot_op + %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>) + return %0: vector<4xi32> +} + +//===----------------------------------------------------------------------===// +// Extension +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @subgroup_ballot_suitable_extension +func @subgroup_ballot_suitable_extension(%predicate: i1) -> vector<4xi32> attributes { + spv.target_env = {version = 4: i32, extensions = ["SPV_KHR_shader_ballot"], capabilities = [4423: i32]} +} { + // CHECK: spv.SubgroupBallotKHR + %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>) + return %0: vector<4xi32> +} + +// CHECK-LABEL: @subgroup_ballot_missing_extension +func @subgroup_ballot_missing_extension(%predicate: i1) -> vector<4xi32> attributes { + spv.target_env = {version = 4: i32, extensions = [], capabilities = [4423: i32]} +} { + // CHECK: test.convert_to_subgroup_ballot_op + %0 = "test.convert_to_subgroup_ballot_op"(%predicate): (i1) -> (vector<4xi32>) + return %0: vector<4xi32> +}