diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -29,6 +29,8 @@ using namespace mlir; +static constexpr const char kSPIRVModule[] = "__spv__"; + //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// @@ -237,6 +239,80 @@ return success(); } +/// Creates a name for the given entry point op or execution mode op and a given +/// module. The name is of the following form: +/// __spv__{module_name}_{fn_name}_entry_point_info +/// This is used to avoid symbolic conflicts, as well as to encode entry-point +/// function name. +template +static std::string getEntryPointInfoName(SPIRVOp op, ModuleOp module) { + std::string moduleName; + if (module.getName().hasValue()) + moduleName = kSPIRVModule + module.getName().getValue().str() + "_"; + else + moduleName = kSPIRVModule; + return moduleName + op.fn().str() + "_entry_point_info"; +} + +/// Creates a global struct containing entry point information. +static LLVM::GlobalOp createEntryPointInfo(Operation *op, std::string name, + ModuleOp module, + PatternRewriter &builder) { + MLIRContext *context = builder.getContext(); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + // Create a struct type, corresponding to the C struct below. + // struct { + // int32_t executionModel; + // int32_t executionMode; + // }; + // TODO: support interface variables and execution mode values. + auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context); + SmallVector fields = {llvmI32Type, llvmI32Type}; + auto structType = LLVM::LLVMType::getStructTy(context, fields); + + // Create `llvm.mlir.global` with initializer region containing one block. + auto dstGlobal = builder.create( + UnknownLoc::get(context), structType, /*isConstant=*/true, + LLVM::Linkage::External, name, Attribute()); + Location loc = dstGlobal.getLoc(); + Region ®ion = dstGlobal.getInitializerRegion(); + Block *block = builder.createBlock(®ion); + + // Initialize the struct and add a terminator to the block. + builder.setInsertionPoint(block, block->begin()); + Value structValue = builder.create(loc, structType); + builder.create(loc, ArrayRef({structValue})); + return dstGlobal; +} + +/// Fills entry point info struct, extracting information from the given op. +static Value setEntryPointInfo(Operation *op, Location loc, Value pred, + PatternRewriter &rewriter) { + MLIRContext *context = rewriter.getContext(); + auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context); + + if (auto entryPointOp = dyn_cast(op)) { + IntegerAttr executionModelAttr = entryPointOp.execution_modelAttr(); + Value executionModel = + rewriter.create(loc, llvmI32Type, executionModelAttr); + return rewriter.create( + loc, pred.getType(), pred, executionModel, + ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, + context)); + } + + auto executionModeOp = cast(op); + IntegerAttr executionModeAttr = executionModeOp.execution_modeAttr(); + Value executionMode = + rewriter.create(loc, llvmI32Type, executionModeAttr); + return rewriter.create( + loc, pred.getType(), pred, executionMode, + ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1)}, + context)); +} + //===----------------------------------------------------------------------===// // Type conversion //===----------------------------------------------------------------------===// @@ -634,6 +710,48 @@ } }; +/// Converts `spv.EntryPoint` and `spv.ExecutionMode` into a global struct +/// constant that holds all entry point information. +template +class EntryPointPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(SPIRVOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // First, check if the global struct has been created. + ModuleOp module = op.template getParentOfType(); + auto entryPointInfoName = getEntryPointInfoName(op, module); + auto entryPointInfo = + module.lookupSymbol(entryPointInfoName); + + // If entry point info struct has not been created, create it. + if (!entryPointInfo) + entryPointInfo = + createEntryPointInfo(op, entryPointInfoName, module, rewriter); + + // Get last two operations (current struct value and the return operation) + // to insert in between them. + Block *block = &entryPointInfo.getInitializerRegion().front(); + auto iter = Block::reverse_iterator(block->back()); + Operation &insertBefore = *iter; + Operation &insertAfter = *std::next(iter); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(&insertBefore); + Location loc = insertBefore.getLoc(); + Value previousStruct = insertAfter.getResult(0); + Value newStruct = setEntryPointInfo(op, loc, previousStruct, rewriter); + + // Update initializer's block terminator with a new struct value. + Operation *terminator = block->getTerminator(); + terminator->replaceUsesOfWith(previousStruct, newStruct); + rewriter.eraseOp(op); + return success(); + } +}; + /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global /// returns a pointer, whereas in LLVM dialect the global holds an actual value. /// This difference is handled by `spv._address_of` and `llvm.mlir.addressof`ops @@ -1385,12 +1503,9 @@ FunctionCallPattern, LoopPattern, SelectionPattern, ErasePattern, - // Entry points and execution mode - // Module generated from SPIR-V could have other "internal" functions, so - // having entry point and execution mode metadat can be useful. For now, - // simply remove them. - // TODO: Support EntryPoint/ExecutionMode properly. - ErasePattern, ErasePattern, + // Entry points and execution mode are handled separately. + EntryPointPattern, + EntryPointPattern, // GLSL extended instruction set ops DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -63,16 +63,24 @@ //===----------------------------------------------------------------------===// // CHECK: module { +// CHECK-NEXT: llvm.mlir.global external constant @{{.*}}() : !llvm.struct<(i32, i32)> { +// CHECK-NEXT: %[[UNDEF:.*]] = llvm.mlir.undef : !llvm.struct<(i32, i32)> +// CHECK-NEXT: %[[VAL1:.*]] = llvm.mlir.constant(6 : i32) : !llvm.i32 +// CHECK-NEXT: %[[T0:.*]] = llvm.insertvalue %[[VAL1]], %[[UNDEF]][0 : i32] : !llvm.struct<(i32, i32)> +// CHECK-NEXT: %[[VAL2:.*]] = llvm.mlir.constant(31 : i32) : !llvm.i32 +// CHECK-NEXT: %[[T1:.*]] = llvm.insertvalue %[[VAL2]], %[[T0]][1 : i32] : !llvm.struct<(i32, i32)> +// CHECK-NEXT: llvm.return %[[T1]] : !llvm.struct<(i32, i32)> +// CHECK-NEXT: } // CHECK-NEXT: llvm.func @empty // CHECK-NEXT: llvm.return // CHECK-NEXT: } // CHECK-NEXT: } -spv.module Logical GLSL450 { +spv.module Logical OpenCL { spv.func @empty() "None" { spv.Return } - spv.EntryPoint "GLCompute" @empty - spv.ExecutionMode @empty "LocalSize", 1, 1, 1 + spv.EntryPoint "Kernel" @empty + spv.ExecutionMode @empty "ContractionOff" } //===----------------------------------------------------------------------===//