diff --git a/mlir/include/mlir/Dialect/SPIRV/Passes.h b/mlir/include/mlir/Dialect/SPIRV/Passes.h --- a/mlir/include/mlir/Dialect/SPIRV/Passes.h +++ b/mlir/include/mlir/Dialect/SPIRV/Passes.h @@ -26,12 +26,24 @@ std::unique_ptr> createDecorateSPIRVCompositeTypeLayoutPass(); -/// Creates a module pass that lowers the ABI attributes specified during SPIR-V -/// Lowering. Specifically, -/// 1) Creates the global variables for arguments of entry point function using -/// the specification in the ABI attributes for each argument. -/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point -/// functions using the specification in the EntryPointAttr. +/// Creates an operation pass that deduces and attaches the minimal version/ +/// capabilities/extensions requirements for spv.module ops. +/// For each spv.module op, this pass requires a `spv.target_env` attribute on +/// it or an enclosing module-like op to drive the deduction. The reason is +/// that an op can be enabled by multiple extensions/capabilities. So we need +/// to know which one to pick. `spv.target_env` gives the hard limit as for +/// what the target environment can support; this pass deduces what are +/// actually needed for a specific spv.module op. +std::unique_ptr> +createUpdateVersionCapabilityExtensionPass(); + +/// Creates an operation pass that lowers the ABI attributes specified during +/// SPIR-V Lowering. Specifically, +/// 1. Creates the global variables for arguments of entry point function using +/// the specification in the `spv.interface_var_abi` attribute for each +/// argument. +/// 2. Inserts the EntryPointOp and the ExecutionModeOp for entry point +/// functions using the specification in the `spv.entry_point_abi` attribute. std::unique_ptr> createLowerABIAttributesPass(); } // namespace spirv diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -175,6 +175,10 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context); /// Queries the target environment recursively from enclosing symbol table ops +/// containing the given `op`. +TargetEnvAttr lookupTargetEnv(Operation *op); + +/// Queries the target environment recursively from enclosing symbol table ops /// containing the given `op` or returns the default target environment as /// returned by getDefaultTargetEnv() if not provided. TargetEnvAttr lookupTargetEnvOrDefault(Operation *op); diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -122,6 +122,7 @@ // SPIR-V spirv::createDecorateSPIRVCompositeTypeLayoutPass(); spirv::createLowerABIAttributesPass(); + spirv::createUpdateVersionCapabilityExtensionPass(); createConvertGPUToSPIRVPass(); createConvertStandardToSPIRVPass(); createLegalizeStdOpsForSPIRVLoweringPass(); diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -294,19 +294,25 @@ spirv::getDefaultResourceLimits(context)); } -spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { - Operation *symTable = op; - while (symTable) { - symTable = SymbolTable::getNearestSymbolTable(symTable); - if (!symTable) +spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { + while (op) { + op = SymbolTable::getNearestSymbolTable(op); + if (!op) break; - if (auto attr = symTable->getAttrOfType( + if (auto attr = op->getAttrOfType( spirv::getTargetEnvAttrName())) return attr; - symTable = symTable->getParentOp(); + op = op->getParentOp(); } + return {}; +} + +spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { + if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) + return attr; + return getDefaultTargetEnv(op->getContext()); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRSPIRVTransforms DecorateSPIRVCompositeTypeLayoutPass.cpp LowerABIAttributesPass.cpp + UpdateVCEPass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -0,0 +1,164 @@ +//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to deduce minimal version/extension/capability +// requirements for a spirv::ModuleOp. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/TargetAndABI.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Visitors.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" + +using namespace mlir; + +namespace { +/// Pass to deduce minimal version/extension/capability requirements for a +/// spirv::ModuleOp. +class UpdateVCEPass final + : public OperationPass { +private: + void runOnOperation() override; +}; +} // namespace + +void UpdateVCEPass::runOnOperation() { + spirv::ModuleOp module = getOperation(); + + spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module); + if (!targetEnv) { + module.emitError("missing 'spv.target_env' attribute"); + return signalPassFailure(); + } + + spirv::Version allowedVersion = targetEnv.getVersion(); + + // Build a set for available extensions in the target environment. + llvm::SmallSet allowedExtensions; + for (spirv::Extension ext : targetEnv.getExtensions()) + allowedExtensions.insert(ext); + + // Add extensions implied by the current version. + for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion)) + allowedExtensions.insert(ext); + + // Build a set for available capabilities in the target environment. + llvm::SmallSet allowedCapabilities; + for (spirv::Capability cap : targetEnv.getCapabilities()) { + allowedCapabilities.insert(cap); + + // Add capabilities implied by the current capability. + for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) + allowedCapabilities.insert(c); + } + + spirv::Version deducedVersion = spirv::Version::V_1_0; + llvm::SetVector deducedExtensions; + llvm::SetVector deducedCapabilities; + + // Walk each SPIR-V op to deduce the minimal version/extension/capability + // requirements. + WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { + if (auto minVersion = dyn_cast(op)) { + deducedVersion = std::max(deducedVersion, minVersion.getMinVersion()); + if (deducedVersion > allowedVersion) { + return op->emitError("'") << op->getName() << "' requires min version " + << spirv::stringifyVersion(deducedVersion) + << " but target environment allows up to " + << spirv::stringifyVersion(allowedVersion); + } + } + + // Deduce this op's extension requirement. For each op, the query interfacce + // returns a vector of vector for its extension requirements following + // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) + // convention. Ops not implementing QueryExtensionInterface do not require + // extensions to be available. + if (auto extensions = dyn_cast(op)) { + for (const auto &ors : extensions.getExtensions()) { + bool satisfied = false; // True when at least one extension can be used + for (spirv::Extension ext : ors) { + if (allowedExtensions.count(ext)) { + deducedExtensions.insert(ext); + satisfied = true; + break; + } + } + + if (!satisfied) { + SmallVector extStrings; + for (spirv::Extension ext : ors) + extStrings.push_back(spirv::stringifyExtension(ext)); + + return op->emitError("'") + << op->getName() << "' requires at least one extension in [" + << llvm::join(extStrings, ", ") + << "] but none allowed in target environment"; + } + } + } + + // Deduce this op's capability requirement. For each op, the queryinterface + // returns a vector of vector for its capability requirements following + // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) + // convention. Ops not implementing QueryExtensionInterface do not require + // extensions to be available. + if (auto capabilities = dyn_cast(op)) { + for (const auto &ors : capabilities.getCapabilities()) { + bool satisfied = false; // True when at least one capability can be used + for (spirv::Capability cap : ors) { + if (allowedCapabilities.count(cap)) { + deducedCapabilities.insert(cap); + satisfied = true; + break; + } + } + + if (!satisfied) { + SmallVector capStrings; + for (spirv::Capability cap : ors) + capStrings.push_back(spirv::stringifyCapability(cap)); + + return op->emitError("'") + << op->getName() << "' requires at least one capability in [" + << llvm::join(capStrings, ", ") + << "] but none allowed in target environment"; + } + } + } + + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return signalPassFailure(); + + // TODO(antiagainst): verify that the deduced version is consistent with + // SPIR-V ops' maximal version requirements. + + auto triple = spirv::VerCapExtAttr::get( + deducedVersion, deducedCapabilities.getArrayRef(), + deducedExtensions.getArrayRef(), &getContext()); + module.setAttr("vce_triple", triple); +} + +std::unique_ptr> +mlir::spirv::createUpdateVersionCapabilityExtensionPass() { + return std::make_unique(); +} + +static PassRegistration + pass("spirv-update-vce", + "Deduce and attach minimal (version, capabilities, extensions) " + "requirements to spv.module ops"); diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -0,0 +1,146 @@ +// RUN: mlir-opt -spirv-update-vce %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Version +//===----------------------------------------------------------------------===// + +// Test deducing minimal version. +// spv.IAdd is available from v1.0. + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "GLSL450" { + spv.func @iadd(%val : i32) -> i32 "None" { + %0 = spv.IAdd %val, %val: i32 + spv.ReturnValue %0: i32 + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} + +// Test deducing minimal version. +// spv.GroupNonUniformBallot is available since v1.3. + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "GLSL450" { + spv.func @group_non_uniform_ballot(%predicate : i1) -> vector<4xi32> "None" { + %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32> + spv.ReturnValue %0: vector<4xi32> + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} + +//===----------------------------------------------------------------------===// +// Capability +//===----------------------------------------------------------------------===// + +// Test minimal capabilities. + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "GLSL450" { + spv.func @iadd(%val : i32) -> i32 "None" { + %0 = spv.IAdd %val, %val: i32 + spv.ReturnValue %0: i32 + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} + +// Test deducing implied capability. +// AtomicStorage implies Shader. + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "GLSL450" { + spv.func @iadd(%val : i32) -> i32 "None" { + %0 = spv.IAdd %val, %val: i32 + spv.ReturnValue %0: i32 + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} + +// Test selecting the capability available in the target environment. +// spv.GroupNonUniform op itself can be enabled via any of +// * GroupNonUniformArithmetic +// * GroupNonUniformClustered +// * GroupNonUniformPartitionedNV +// Its 'Reduce' group operation can be enabled via any of +// * Kernel +// * GroupNonUniformArithmetic +// * GroupNonUniformBallot + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "GLSL450" { + spv.func @group_non_uniform_iadd(%val : i32) -> i32 "None" { + %0 = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %val : i32 + spv.ReturnValue %0: i32 + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "GLSL450" { + spv.func @group_non_uniform_iadd(%val : i32) -> i32 "None" { + %0 = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %val : i32 + spv.ReturnValue %0: i32 + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} + +//===----------------------------------------------------------------------===// +// Extension +//===----------------------------------------------------------------------===// + +// Test deducing minimal extensions. +// spv.SubgroupBallotKHR requires the SPV_KHR_shader_ballot extension. + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "GLSL450" { + spv.func @subgroup_ballot(%predicate : i1) -> vector<4xi32> "None" { + %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32> + spv.ReturnValue %0: vector<4xi32> + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} + +// Test deducing implied extension. +// Vulkan memory model requires SPV_KHR_vulkan_memory_model, which is enabled +// implicitly by v1.5. + +// CHECK: vce_triple = #spv.vce +spv.module "Logical" "Vulkan" { + spv.func @iadd(%val : i32) -> i32 "None" { + %0 = spv.IAdd %val, %val: i32 + spv.ReturnValue %0: i32 + } +} attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +}