diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -1,4 +1,4 @@ -//===- SPIRVLowering.cpp - Standard to SPIR-V dialect conversion--===// +//===- SPIRVLowering.cpp - SPIR-V lowering utilities ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include @@ -441,6 +442,66 @@ } } +/// Checks that `candidates` extension requirements are possible to be satisfied +/// with the given `allowedExtensions`. +/// +/// `candidates` is a vector of vector for extension requirements following +/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) +/// convention. +static LogicalResult checkExtensionRequirements( + Operation *op, const llvm::SmallSet &allowedExtensions, + const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { + for (const auto &ors : candidates) { + auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) { + return allowedExtensions.count(ext); + }); + + if (chosen == ors.end()) { + SmallVector extStrings; + for (spirv::Extension ext : ors) + extStrings.push_back(spirv::stringifyExtension(ext)); + + LLVM_DEBUG(llvm::dbgs() << op->getName() + << "illegal: requires at least one extension in [" + << llvm::join(extStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); + } + } + return success(); +} + +/// Checks that `candidates`capability requirements are possible to be satisfied +/// with the given `allowedCapabilities`. +/// +/// `candidates` is a vector of vector for capability requirements following +/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) +/// convention. +static LogicalResult checkCapabilityRequirements( + Operation *op, + const llvm::SmallSet &allowedCapabilities, + const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { + for (const auto &ors : candidates) { + auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) { + return allowedCapabilities.count(cap); + }); + + if (chosen == ors.end()) { + SmallVector capStrings; + for (spirv::Capability cap : ors) + capStrings.push_back(spirv::stringifyCapability(cap)); + + LLVM_DEBUG(llvm::dbgs() + << op->getName() + << "illegal: requires at least one capability in [" + << llvm::join(capStrings, ", ") + << "] but none allowed in target environment\n"); + return failure(); + } + } + return success(); +} + bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all @@ -462,38 +523,47 @@ return false; } - // Make sure this op's required extensions are allowed to use. For each op, - // we return a vector of vector for its extension requirements following - // ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) - // convention. Ops not implementing QueryExtensionInterface do not require - // extensions to be available. - if (auto extensions = dyn_cast(op)) { - auto exts = extensions.getExtensions(); - for (const auto &ors : exts) - if (llvm::all_of(ors, [this](spirv::Extension ext) { - return this->givenExtensions.count(ext) == 0; - })) { - LLVM_DEBUG(llvm::dbgs() << op->getName() - << " illegal: missing required extension\n"); - return false; - } - } + // Make sure this op's required extensions are allowed to use. Ops not + // implementing QueryExtensionInterface do not require extensions to be + // available. + if (auto extensions = dyn_cast(op)) + if (failed(checkExtensionRequirements(op, this->givenExtensions, + extensions.getExtensions()))) + return false; - // Make sure this op's required extensions are allowed to use. For each op, - // we return a vector of vector for its capability requirements following - // ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) - // convention. Ops not implementing QueryExtensionInterface do not require - // extensions to be available. - if (auto capabilities = dyn_cast(op)) { - auto caps = capabilities.getCapabilities(); - for (const auto &ors : caps) - if (llvm::all_of(ors, [this](spirv::Capability cap) { - return this->givenCapabilities.count(cap) == 0; - })) { - LLVM_DEBUG(llvm::dbgs() << op->getName() - << " illegal: missing required capability\n"); - return false; - } + // Make sure this op's required extensions are allowed to use. Ops not + // implementing QueryCapabilityInterface do not require capabilities to be + // available. + if (auto capabilities = dyn_cast(op)) + if (failed(checkCapabilityRequirements(op, this->givenCapabilities, + capabilities.getCapabilities()))) + return false; + + SmallVector valueTypes; + valueTypes.append(op->operand_type_begin(), op->operand_type_end()); + valueTypes.append(op->result_type_begin(), op->result_type_end()); + + // Special treatment for global variables, whose type requirements are + // conveyed by type attributes. + if (auto globalVar = dyn_cast(op)) + valueTypes.push_back(globalVar.type()); + + // Make sure the op's operands/results use types that are allowed by the + // target environment. + SmallVector, 4> typeExtensions; + SmallVector, 8> typeCapabilities; + for (Type valueType : valueTypes) { + typeExtensions.clear(); + valueType.cast().getExtensions(typeExtensions); + if (failed(checkExtensionRequirements(op, this->givenExtensions, + typeExtensions))) + return false; + + typeCapabilities.clear(); + valueType.cast().getCapabilities(typeCapabilities); + if (failed(checkCapabilityRequirements(op, this->givenCapabilities, + typeCapabilities))) + return false; } return true; diff --git a/mlir/test/Conversion/GPUToSPIRV/if.mlir b/mlir/test/Conversion/GPUToSPIRV/if.mlir --- a/mlir/test/Conversion/GPUToSPIRV/if.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/if.mlir @@ -1,6 +1,12 @@ // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s -module attributes {gpu.container_module} { +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { func @main(%arg0 : memref<10xf32>, %arg1 : i1) { %c0 = constant 1 : index "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "kernel_simple_selection", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, i1) -> () diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -1,6 +1,12 @@ // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s -module attributes {gpu.container_module} { +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) { %c0 = constant 0 : index %c12 = constant 12 : index diff --git a/mlir/test/Conversion/GPUToSPIRV/loop.mlir b/mlir/test/Conversion/GPUToSPIRV/loop.mlir --- a/mlir/test/Conversion/GPUToSPIRV/loop.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/loop.mlir @@ -1,6 +1,12 @@ // RUN: mlir-opt -convert-gpu-to-spirv %s -o - | FileCheck %s -module attributes {gpu.container_module} { +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { func @loop(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) { %c0 = constant 1 : index "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0, %arg0, %arg1) { kernel = "loop_kernel", kernel_module = @kernels} : (index, index, index, index, index, index, memref<10xf32>, memref<10xf32>) -> () diff --git a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir @@ -1,5 +1,12 @@ // RUN: mlir-opt -convert-std-to-spirv %s -o - | FileCheck %s +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + //===----------------------------------------------------------------------===// // std binary arithmetic ops //===----------------------------------------------------------------------===// @@ -366,3 +373,5 @@ store %0, %arg1[] : memref return } + +} // end module diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir @@ -4,6 +4,13 @@ // the desired output. Adding all of patterns within a single pass does // not seem to work. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + //===----------------------------------------------------------------------===// // std.subview //===----------------------------------------------------------------------===// @@ -51,3 +58,5 @@ store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> return } + +} // end module