diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -258,6 +258,22 @@ namespace { +class AddressOfPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::AddressOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = typeConverter.convertType(op.pointer().getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp( + op, dstType.cast(), op.variable()); + return success(); + } +}; + class BitFieldInsertPattern : public SPIRVToLLVMConversion { public: @@ -487,6 +503,55 @@ } }; +/// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global +/// returns a pointer, whereas in LLVM dialect the global holds an actual value. +/// This difference is handled by `spv._address_of` and `llvm.mlir.addressof`ops +/// that both return a pointer. +class GlobalVariablePattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Currently, there is no support of initialization with a constant value in + // SPIR-V dialect. Specialization constants are not considered as well. + if (op.initializer()) + return failure(); + + auto srcType = op.type().cast(); + auto dstType = typeConverter.convertType(srcType.getPointeeType()); + if (!dstType) + return failure(); + + // Limit conversion to the current invocation only for now. + auto storageClass = srcType.getStorageClass(); + if (storageClass != spirv::StorageClass::Input && + storageClass != spirv::StorageClass::Private && + storageClass != spirv::StorageClass::Output) { + return failure(); + } + + // LLVM dialect spec: "If the global value is a constant, storing into it is + // not allowed.". This corresponds to SPIR-V 'Input' storage class that is + // read-only. + bool isConstant = storageClass == spirv::StorageClass::Input; + // SPIR-V spec: "By default, functions and global variables are private to a + // module and cannot be accessed by other modules. However, a module may be + // written to export or import functions and global (module scope) + // variables.". Therefore, map 'Private' storage class to private linkage, + // 'Input' and 'Output' to external linkage. + auto linkage = storageClass == spirv::StorageClass::Private + ? LLVM::Linkage::Private + : LLVM::Linkage::External; + rewriter.replaceOpWithNewOp( + op, dstType.cast(), isConstant, linkage, op.sym_name(), + Attribute()); + return success(); + } +}; + /// Converts SPIR-V cast ops that do not have straightforward LLVM /// equivalent in LLVM dialect. template @@ -1017,8 +1082,8 @@ NotPattern, // Memory ops - LoadStorePattern, LoadStorePattern, - VariablePattern, + AddressOfPattern, GlobalVariablePattern, LoadStorePattern, + LoadStorePattern, VariablePattern, // Miscellaneous ops DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -1,5 +1,24 @@ // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.globalVariable and spv._address_of +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + // CHECK: llvm.mlir.global external constant @var() : !llvm.float + spv.globalVariable @var : !spv.ptr +} + +spv.module Logical GLSL450 { + // CHECK: llvm.mlir.global private @struct() : !llvm<"<{ float, [10 x float] }>"> + spv.globalVariable @struct : !spv.ptr>, Private> + spv.func @func() -> () "None" { + // CHECK: %{{.*}} = llvm.mlir.addressof @struct : !llvm<"<{ float, [10 x float] }>*"> + %0 = spv._address_of @struct : !spv.ptr>, Private> + spv.Return + } +} + //===----------------------------------------------------------------------===// // spv.Load //===----------------------------------------------------------------------===//