diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -932,9 +932,15 @@ * Binding number for the corresponding resource variable. * Storage class for the corresponding resource variable. -The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] for -consuming these attributes and create SPIR-V module complying with the -interface. +The SPIR-V dialect provides a [`LowerABIAttributesPass`][MlirSpirvPasses] that +uses this information to lower the entry point function and its ABI consistent +with the Vulkan validation rules. Specifically, + +* Creates `spv.GlobalVariable`s for the arguments, and replaces all uses of + the argument with this variable. The SSA value used for replacement is + obtained using the `spv.mlir.addressof` operation. +* Adds the `spv.EntryPoint` and `spv.ExecutionMode` operations into the + `spv.module` for the entry function. ## Serialization and deserialization @@ -1052,29 +1058,8 @@ the pointer type are derived from the memref's memory space with `SPIRVTypeConverter::getStorageClassForMemorySpace()`. -### `SPIRVOpLowering` - -`mlir::SPIRVOpLowering` is a base class that can be used to define the patterns -used for implementing the lowering. For now this only provides derived classes -access to an instance of `mlir::SPIRVTypeLowering` class. - ### Utility functions for lowering -#### Setting shader interface - -The method `mlir::spirv::setABIAttrs` allows setting the [shader interface -attributes](#shader-interface-abi) for a function that is to be an entry -point function within the `spv.module` on lowering. A later pass -`mlir::spirv::LowerABIAttributesPass` uses this information to lower the entry -point function and its ABI consistent with the Vulkan validation -rules. Specifically, - -* Creates `spv.GlobalVariable`s for the arguments, and replaces all uses of - the argument with this variable. The SSA value used for replacement is - obtained using the `spv.mlir.addressof` operation. -* Adds the `spv.EntryPoint` and `spv.ExecutionMode` operations into the - `spv.module` for the entry function. - #### Setting layout for shader interface variables SPIR-V validation rules for shaders require composite objects to be explicitly diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -21,6 +21,10 @@ namespace mlir { +//===----------------------------------------------------------------------===// +// Type Converter +//===----------------------------------------------------------------------===// + /// Type conversion from builtin types to SPIR-V types for shader interface. /// /// Non-32-bit scalar types require special hardware support that may not exist @@ -63,24 +67,22 @@ spirv::TargetEnv targetEnv; }; -/// Appends to a pattern list additional patterns for translating the builtin -/// `func` op to the SPIR-V dialect. These patterns do not handle shader -/// interface/ABI; they convert function parameters to be of SPIR-V allowed -/// types. -void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, - RewritePatternSet &patterns); - -namespace spirv { -class AccessChainOp; -class FuncOp; +//===----------------------------------------------------------------------===// +// Conversion Target +//===----------------------------------------------------------------------===// +// The default SPIR-V conversion target. +// +// It takes a SPIR-V target environment and controls operation legality based on +// the their availability in the target environment. class SPIRVConversionTarget : public ConversionTarget { public: /// Creates a SPIR-V conversion target for the given target environment. - static std::unique_ptr get(TargetEnvAttr targetAttr); + static std::unique_ptr + get(spirv::TargetEnvAttr targetAttr); private: - explicit SPIRVConversionTarget(TargetEnvAttr targetAttr); + explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr); // Be explicit that instance of this class cannot be copied or moved: there // are lambdas capturing fields of the instance. @@ -93,16 +95,37 @@ /// environment. bool isLegalOp(Operation *op); - TargetEnv targetEnv; + spirv::TargetEnv targetEnv; }; +//===----------------------------------------------------------------------===// +// Patterns and Utility Functions +//===----------------------------------------------------------------------===// + +/// Appends to a pattern list additional patterns for translating the builtin +/// `func` op to the SPIR-V dialect. These patterns do not handle shader +/// interface/ABI; they convert function parameters to be of SPIR-V allowed +/// types. +void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); + +namespace spirv { +class AccessChainOp; + /// Returns the value for the given `builtin` variable. This function gets or /// inserts the global variable associated for the builtin within the nearest -/// enclosing op that has a symbol table. Returns null Value if such an -/// enclosing op cannot be found. +/// symbol table enclosing `op`. Returns null Value on error. Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, OpBuilder &builder); +/// Gets the value at the given `offset` of the push constant storage with a +/// total of `elementCount` 32-bit integers. A global variable will be created +/// in the nearest symbol table enclosing `op` for the push constant storage if +/// not existing. Load ops will be created via the given `builder` to load +/// values from the push constant. Returns null Value on error. +Value getPushConstantValue(Operation *op, unsigned elementCount, + unsigned offset, OpBuilder &builder); + /// Generates IR to perform index linearization with the given `indices` and /// their corresponding `strides`, adding an initial `offset`. Value linearizeIndex(ValueRange indices, ArrayRef strides, @@ -118,11 +141,6 @@ ValueRange indices, Location loc, OpBuilder &builder); -/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its -/// arguments. -LogicalResult setABIAttrs(spirv::FuncOp funcOp, - EntryPointABIAttr entryPointInfo, - ArrayRef argABIInfo); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -208,8 +208,13 @@ return nullptr; rewriter.eraseOp(funcOp); - if (failed(spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo))) - return nullptr; + // Set the attributes for argument and the function. + StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); + for (auto argIndex : llvm::seq(0, argABIInfo.size())) { + newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); + } + newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); + return newFuncOp; } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -54,7 +54,7 @@ auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = - spirv::SPIRVConversionTarget::get(targetAttr); + SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp @@ -27,7 +27,7 @@ auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = - spirv::SPIRVConversionTarget::get(targetAttr); + SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -33,7 +33,7 @@ auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = - spirv::SPIRVConversionTarget::get(targetAttr); + SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); ScfToSPIRVContext scfContext; diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp @@ -32,7 +32,7 @@ auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = - spirv::SPIRVConversionTarget::get(targetAttr); + SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -34,7 +34,7 @@ auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = - spirv::SPIRVConversionTarget::get(targetAttr); + SPIRVConversionTarget::get(targetAttr); SPIRVTypeConverter typeConverter(targetAttr); RewritePatternSet patterns(context); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -171,12 +171,14 @@ } return bitWidth / 8; } + if (auto vecType = t.dyn_cast()) { auto elementSize = getTypeNumBytes(vecType.getElementType()); if (!elementSize) return llvm::None; return vecType.getNumElements() * *elementSize; } + if (auto memRefType = t.dyn_cast()) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. @@ -207,7 +209,9 @@ memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); } return (offset + memrefSize) * elementSize.getValue(); - } else if (auto tensorType = t.dyn_cast()) { + } + + if (auto tensorType = t.dyn_cast()) { if (!tensorType.hasStaticShape()) { return llvm::None; } @@ -221,6 +225,7 @@ } return size; } + // TODO: Add size computation for other types. return llvm::None; } @@ -602,6 +607,80 @@ return builder.create(op->getLoc(), ptr); } +//===----------------------------------------------------------------------===// +// Push constant storage +//===----------------------------------------------------------------------===// + +/// Returns the pointer type for the push constant storage containing +/// `elementCount` 32-bit integer values. +static spirv::PointerType getPushConstantStorageType(unsigned elementCount, + Builder &builder) { + auto arrayType = spirv::ArrayType::get( + SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount, + /*stride=*/4); + auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); + return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); +} + +/// Returns the push constant varible containing `elementCount` 32-bit integer +/// values in `body`. Returns null op if such an op does not exit. +static spirv::GlobalVariableOp getPushConstantVariable(Block &body, + unsigned elementCount) { + for (auto varOp : body.getOps()) { + auto ptrType = varOp.type().cast(); + // Note that Vulkan requires "There must be no more than one push constant + // block statically used per shader entry point." So we should always reuse + // the existing one. + if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { + auto numElements = ptrType.getPointeeType() + .cast() + .getElementType(0) + .cast() + .getNumElements(); + if (numElements == elementCount) + return varOp; + } + } + return nullptr; +} + +/// Gets or inserts a global variable for push constant storage containing +/// `elementCount` 32-bit integer values in `block`. +static spirv::GlobalVariableOp +getOrInsertPushConstantVariable(Location loc, Block &block, + unsigned elementCount, OpBuilder &b) { + if (auto varOp = getPushConstantVariable(block, elementCount)) + return varOp; + + auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); + auto type = getPushConstantStorageType(elementCount, builder); + const char *name = "__push_constant_var__"; + return builder.create(loc, type, name, + /*initializer=*/nullptr); +} + +Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, + unsigned offset, OpBuilder &builder) { + Location loc = op->getLoc(); + Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); + if (!parent) { + op->emitError("expected operation to be within a module-like op"); + return nullptr; + } + + spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( + loc, parent->getRegion(0).front(), elementCount, builder); + + auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext()); + Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder); + Value offsetOp = builder.create( + loc, i32Type, builder.getI32IntegerAttr(offset)); + auto addrOp = builder.create(loc, varOp); + auto acOp = builder.create( + loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp})); + return builder.create(loc, acOp); +} + //===----------------------------------------------------------------------===// // Index calculation //===----------------------------------------------------------------------===// @@ -661,45 +740,27 @@ return builder.create(loc, basePtr, linearizedIndices); } -//===----------------------------------------------------------------------===// -// Set ABI attributes for lowering entry functions. -//===----------------------------------------------------------------------===// - -LogicalResult -mlir::spirv::setABIAttrs(spirv::FuncOp funcOp, - spirv::EntryPointABIAttr entryPointInfo, - ArrayRef argABIInfo) { - // Set the attributes for argument and the function. - StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName(); - for (auto argIndex : llvm::seq(0, argABIInfo.size())) { - funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]); - } - funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo); - return success(); -} - //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// -std::unique_ptr -spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { +std::unique_ptr +SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { std::unique_ptr target( // std::make_unique does not work here because the constructor is private. new SPIRVConversionTarget(targetAttr)); SPIRVConversionTarget *targetPtr = target.get(); - target->addDynamicallyLegalDialect( + target->addDynamicallyLegalDialect( // We need to capture the raw pointer here because it is stable: // target will be destroyed once this function is returned. [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); return target; } -spirv::SPIRVConversionTarget::SPIRVConversionTarget( - spirv::TargetEnvAttr targetAttr) +SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} -bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) { +bool SPIRVConversionTarget::isLegalOp(Operation *op) { // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all // SPIR-V versions. diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -137,7 +137,7 @@ return signalPassFailure(); } - auto target = spirv::SPIRVConversionTarget::get(targetEnv); + auto target = SPIRVConversionTarget::get(targetEnv); RewritePatternSet patterns(context); patterns.add