diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -634,6 +634,82 @@ } }; +/// Converts `spv.ExecutionMode` into a global struct constant that holds +/// execution mode information. +class ExecutionModePattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // First, create the global struct's name that would be associated with + // this entry point's execution mode. We set it to be: + // __spv__{SPIR-V module name}_{function name}_execution_mode_info + ModuleOp module = op.getParentOfType(); + std::string moduleName; + if (module.getName().hasValue()) + moduleName = "_" + module.getName().getValue().str() + "_"; + else + moduleName = "_"; + std::string executionModeInfoName = llvm::formatv( + "__spv_{0}{1}_execution_mode_info", moduleName, op.fn().str()); + + MLIRContext *context = rewriter.getContext(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + // Create a struct type, corresponding to the C struct below. + // struct { + // int32_t executionMode; + // int32_t values[]; // optional values + // }; + auto llvmI32Type = LLVM::LLVMType::getInt32Ty(context); + SmallVector fields; + fields.push_back(llvmI32Type); + ArrayAttr values = op.values(); + if (!values.empty()) { + auto arrayType = LLVM::LLVMType::getArrayTy(llvmI32Type, values.size()); + fields.push_back(arrayType); + } + auto structType = LLVM::LLVMType::getStructTy(context, fields); + + // Create `llvm.mlir.global` with initializer region containing one block. + auto global = rewriter.create( + UnknownLoc::get(context), structType, /*isConstant=*/true, + LLVM::Linkage::External, executionModeInfoName, Attribute()); + Location loc = global.getLoc(); + Region ®ion = global.getInitializerRegion(); + Block *block = rewriter.createBlock(®ion); + + // Initialize the struct and set the execution mode value. + rewriter.setInsertionPoint(block, block->begin()); + Value structValue = rewriter.create(loc, structType); + IntegerAttr executionModeAttr = op.execution_modeAttr(); + Value executionMode = + rewriter.create(loc, llvmI32Type, executionModeAttr); + structValue = rewriter.create( + loc, structType, structValue, executionMode, + ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, + context)); + + // Insert extra operands if they exist into execution mode info struct. + for (unsigned i = 0, e = values.size(); i < e; ++i) { + auto attr = values.getValue()[i]; + Value entry = rewriter.create(loc, llvmI32Type, attr); + structValue = rewriter.create( + loc, structType, structValue, entry, + ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1), + rewriter.getIntegerAttr(rewriter.getI32Type(), i)}, + context)); + } + rewriter.create(loc, ArrayRef({structValue})); + rewriter.eraseOp(op); + return success(); + } +}; + /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global /// returns a pointer, whereas in LLVM dialect the global holds an actual value. /// This difference is handled by `spv._address_of` and `llvm.mlir.addressof`ops @@ -1385,12 +1461,8 @@ FunctionCallPattern, LoopPattern, SelectionPattern, ErasePattern, - // Entry points and execution mode - // Module generated from SPIR-V could have other "internal" functions, so - // having entry point and execution mode metadat can be useful. For now, - // simply remove them. - // TODO: Support EntryPoint/ExecutionMode properly. - ErasePattern, ErasePattern, + // Entry points and execution mode are handled separately. + ErasePattern, ExecutionModePattern, // GLSL extended instruction set ops DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -63,16 +63,47 @@ //===----------------------------------------------------------------------===// // CHECK: module { +// CHECK-NEXT: llvm.mlir.global external constant @{{.*}}() : !llvm.struct<(i32)> { +// CHECK-NEXT: %[[UNDEF:.*]] = llvm.mlir.undef : !llvm.struct<(i32)> +// CHECK-NEXT: %[[VAL:.*]] = llvm.mlir.constant(31 : i32) : !llvm.i32 +// CHECK-NEXT: %[[RET:.*]] = llvm.insertvalue %[[VAL]], %[[UNDEF]][0 : i32] : !llvm.struct<(i32)> +// CHECK-NEXT: llvm.return %[[RET]] : !llvm.struct<(i32)> +// CHECK-NEXT: } // CHECK-NEXT: llvm.func @empty // CHECK-NEXT: llvm.return // CHECK-NEXT: } // CHECK-NEXT: } -spv.module Logical GLSL450 { +spv.module Logical OpenCL { spv.func @empty() "None" { spv.Return } - spv.EntryPoint "GLCompute" @empty - spv.ExecutionMode @empty "LocalSize", 1, 1, 1 + spv.EntryPoint "Kernel" @empty + spv.ExecutionMode @empty "ContractionOff" +} + +// CHECK: module { +// CHECK-NEXT: llvm.mlir.global external constant @{{.*}}() : !llvm.struct<(i32, array<3 x i32>)> { +// CHECK-NEXT: %[[UNDEF:.*]] = llvm.mlir.undef : !llvm.struct<(i32, array<3 x i32>)> +// CHECK-NEXT: %[[EM:.*]] = llvm.mlir.constant(18 : i32) : !llvm.i32 +// CHECK-NEXT: %[[T0:.*]] = llvm.insertvalue %[[EM]], %[[UNDEF]][0 : i32] : !llvm.struct<(i32, array<3 x i32>)> +// CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(32 : i32) : !llvm.i32 +// CHECK-NEXT: %[[T1:.*]] = llvm.insertvalue %[[C0]], %[[T0]][1 : i32, 0 : i32] : !llvm.struct<(i32, array<3 x i32>)> +// CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK-NEXT: %[[T2:.*]] = llvm.insertvalue %[[C1]], %[[T1]][1 : i32, 1 : i32] : !llvm.struct<(i32, array<3 x i32>)> +// CHECK-NEXT: %[[C2:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 +// CHECK-NEXT: %[[RET:.*]] = llvm.insertvalue %[[C2]], %[[T2]][1 : i32, 2 : i32] : !llvm.struct<(i32, array<3 x i32>)> +// CHECK-NEXT: llvm.return %[[RET]] : !llvm.struct<(i32, array<3 x i32>)> +// CHECK-NEXT: } +// CHECK-NEXT: llvm.func @bar +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } +// CHECK-NEXT: } +spv.module Logical OpenCL { + spv.func @bar() "None" { + spv.Return + } + spv.EntryPoint "Kernel" @bar + spv.ExecutionMode @bar "LocalSizeHint", 32, 1, 1 } //===----------------------------------------------------------------------===//