diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SPIRV/LayoutUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -31,6 +32,13 @@ // Utility functions //===----------------------------------------------------------------------===// +// Returns true if the SPIR-V storage class if for current invokation only. +static bool isCurrentInvocation(spirv::StorageClass storageClass) { + return storageClass == spirv::StorageClass::Input || + storageClass == spirv::StorageClass::Private || + storageClass == spirv::StorageClass::Output; +} + /// Returns true if the given type is a signed integer or vector type. static bool isSignedIntegerOrVector(Type type) { if (type.isSignedInteger()) @@ -179,6 +187,22 @@ return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); } +/// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) +/// offset to LLVM struct. Otherwise, the conversion is not supported. +static Optional +convertStructTypeWithOffset(spirv::StructType type, + LLVMTypeConverter &converter) { + if (type != VulkanLayoutUtils::decorateType(type)) + return llvm::None; + + auto elementsVector = llvm::to_vector<8>( + llvm::map_range(type.getElementTypes(), [&](Type elementType) { + return converter.convertType(elementType).cast(); + })); + return LLVM::LLVMType::getStructTy(type.getContext(), elementsVector, + /*isPacked=*/false); +} + /// Converts SPIR-V struct with no offset to packed LLVM struct. static Type convertStructTypePacked(spirv::StructType type, LLVMTypeConverter &converter) { @@ -223,16 +247,22 @@ // Type conversion //===----------------------------------------------------------------------===// -/// Converts SPIR-V array type to LLVM array. There is no modelling of array -/// stride at the moment. +/// Converts SPIR-V array type to LLVM array. Natural stride (according to +/// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected +/// when converting ops that manipulate array types. static Optional convertArrayType(spirv::ArrayType type, TypeConverter &converter) { - if (type.getArrayStride() != 0) + unsigned stride = type.getArrayStride(); + Type elementType = type.getElementType(); + auto sizeInBytes = elementType.cast().getSizeInBytes(); + if (stride != 0 && + !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride)) return llvm::None; - auto elementType = - converter.convertType(type.getElementType()).cast(); + + auto llvmElementType = + converter.convertType(elementType).cast(); unsigned numElements = type.getNumElements(); - return LLVM::LLVMType::getArrayTy(elementType, numElements); + return LLVM::LLVMType::getArrayTy(llvmElementType, numElements); } /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not @@ -257,13 +287,15 @@ } /// Converts SPIR-V struct to LLVM struct. There is no support of structs with -/// member decorations or with offset. +/// member decorations. Also, only natural offset is supported. static Optional convertStructType(spirv::StructType type, LLVMTypeConverter &converter) { SmallVector memberDecorations; type.getMemberDecorations(memberDecorations); - if (type.hasOffset() || !memberDecorations.empty()) + if (!memberDecorations.empty()) return llvm::None; + if (type.hasOffset()) + return convertStructTypeWithOffset(type, converter); return convertStructTypePacked(type, converter); } @@ -273,6 +305,31 @@ namespace { +class AccessChainPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::AccessChainOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = typeConverter.convertType(op.component_ptr().getType()); + if (!dstType) + return failure(); + // To use GEP we need to add a first 0 index to go through the pointer. + auto indices = llvm::to_vector<4>(op.indices()); + Type indexType = op.indices().front().getType(); + auto llvmIndexType = typeConverter.convertType(indexType); + if (!llvmIndexType) + return failure(); + Value zero = rewriter.create( + op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); + indices.insert(indices.begin(), zero); + rewriter.replaceOpWithNewOp(op, dstType, op.base_ptr(), + indices); + return success(); + } +}; + class AddressOfPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -545,11 +602,12 @@ if (!dstType) return failure(); - // Limit conversion to the current invocation only for now. + // Limit conversion to the current invocation only or `StorageBuffer` + // required by SPIR-V runner. + // This is okay because multiple invocations are not supported yet. auto storageClass = srcType.getStorageClass(); - if (storageClass != spirv::StorageClass::Input && - storageClass != spirv::StorageClass::Private && - storageClass != spirv::StorageClass::Output) { + if (!isCurrentInvocation(storageClass) && + storageClass != spirv::StorageClass::StorageBuffer) { return failure(); } @@ -757,6 +815,20 @@ } }; +/// A template pattern that erases the given `SPIRVOp`. +template +class ErasePattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(SPIRVOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + class ReturnPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -875,18 +947,6 @@ } }; -class MergePattern : public SPIRVToLLVMConversion { -public: - using SPIRVToLLVMConversion::SPIRVToLLVMConversion; - - LogicalResult - matchAndRewrite(spirv::MergeOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.eraseOp(op); - return success(); - } -}; - /// Converts `spv.selection` with `spv.BranchConditional` in its header block. /// All blocks within selection should be reachable for conversion to succeed. class SelectionPattern : public SPIRVToLLVMConversion { @@ -1266,11 +1326,18 @@ ConstantScalarAndVectorPattern, // Control Flow ops - BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern, - SelectionPattern, MergePattern, + BranchConversionPattern, BranchConditionalConversionPattern, + FunctionCallPattern, LoopPattern, SelectionPattern, + ErasePattern, + + // Entry points and execution mode + // Module generated from SPIR-V could have other "internal" functions, so + // having entry point and execution mode metadat can be useful. For now, + // simply remove them. + // TODO: Support EntryPoint/ExecutionMode properly. + ErasePattern, ErasePattern, // Function Call op - FunctionCallPattern, // GLSL extended instruction set ops DirectConversionPattern, @@ -1295,8 +1362,9 @@ NotPattern, // Memory ops - AddressOfPattern, GlobalVariablePattern, LoadStorePattern, - LoadStorePattern, VariablePattern, + AccessChainPattern, AddressOfPattern, GlobalVariablePattern, + LoadStorePattern, LoadStorePattern, + VariablePattern, // Miscellaneous ops DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -1,5 +1,30 @@ // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.AccessChain +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @access_chain +func @access_chain() -> () { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + %0 = spv.constant 1: i32 + %1 = spv.Variable : !spv.ptr>, Function> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr)>>, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm.ptr + %2 = spv.AccessChain %1[%0, %0] : !spv.ptr>, Function>, i32, i32 + return +} + +// CHECK-LABEL: @access_chain_array +func @access_chain_array(%arg0 : i32) -> () { + %0 = spv.Variable : !spv.ptr>, Function> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %{{.*}}] : (!llvm.ptr>>, !llvm.i32, !llvm.i32) -> !llvm.ptr> + %1 = spv.AccessChain %0[%arg0] : !spv.ptr>, Function>, i32 + %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4xf32> + return +} + //===----------------------------------------------------------------------===// // spv.globalVariable and spv._address_of //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -20,6 +20,23 @@ return } +//===----------------------------------------------------------------------===// +// spv.EntryPoint and spv.ExecutionMode +//===----------------------------------------------------------------------===// + +// CHECK: module { +// CHECK: llvm.func @empty +// CHECK: llvm.return +// CHECK: } +// CHECK: } +spv.module Logical GLSL450 { + spv.func @empty() -> () "None" { + spv.Return + } + spv.EntryPoint "GLCompute" @empty + spv.ExecutionMode @empty "LocalSize", 1, 1, 1 +} + //===----------------------------------------------------------------------===// // spv.Undef //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm-invalid.mlir @@ -1,21 +1,14 @@ // RUN: mlir-opt %s -convert-spirv-to-llvm -verify-diagnostics -split-input-file // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} -spv.func @array_with_stride(%arg: !spv.array<4 x f32, stride=4>) -> () "None" { +spv.func @array_with_unnatural_stride(%arg: !spv.array<4 x f32, stride=8>) -> () "None" { spv.Return } // ----- // expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} -spv.func @struct_with_offset1(%arg: !spv.struct) -> () "None" { - spv.Return -} - -// ----- - -// expected-error@+1 {{failed to legalize operation 'spv.func' that was explicitly marked illegal}} -spv.func @struct_with_offset2(%arg: !spv.struct) -> () "None" { +spv.func @struct_with_unnatural_offset(%arg: !spv.struct) -> () "None" { spv.Return } diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-types-to-llvm.mlir @@ -5,7 +5,10 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: @array(!llvm.array<16 x float>, !llvm.array<32 x vec<4 x float>>) -func @array(!spv.array<16xf32>, !spv.array< 32 x vector<4xf32> >) -> () +func @array(!spv.array<16 x f32>, !spv.array< 32 x vector<4xf32> >) -> () + +// CHECK-LABEL: @array_with_natural_stride(!llvm.array<16 x float>) +func @array_with_natural_stride(!spv.array<16 x f32, stride=4>) -> () //===----------------------------------------------------------------------===// // Pointer type @@ -36,3 +39,6 @@ // CHECK-LABEL: @struct_nested(!llvm.struct)>) func @struct_nested(!spv.struct>) + +// CHECK-LABEL: @struct_with_natural_offset(!llvm.struct<(i8, i32)>) +func @struct_with_natural_offset(!spv.struct) -> ()