diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -631,6 +631,34 @@ } }; +class StorePattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::StoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (storeOp.memory_access().hasValue() && + storeOp.memory_access().getValue() != spirv::MemoryAccess::None) { + auto memoryAccess = storeOp.memory_access().getValue(); + if (memoryAccess == spirv::MemoryAccess::Aligned) { + // In SPIR-V dialect, alignment is parsed as 32 bit integer, therefore + // extend to 64 bits. + auto alignment = rewriter.getI64IntegerAttr( + storeOp.alignment().getValue().getZExtValue()); + rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), + storeOp.ptr(), alignment); + return success(); + } + // There is no support of other memory access attributes. + return failure(); + } + rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), + storeOp.ptr()); + return success(); + } +}; + class VariablePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -854,7 +882,7 @@ NotPattern, // Memory ops - VariablePattern, + StorePattern, VariablePattern, // Miscellaneous ops DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -1,5 +1,30 @@ // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.StoreOp +//===----------------------------------------------------------------------===// + +func @store(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm<"float*"> + spv.Store "Function" %0, %arg0 : f32 + return +} + +func @store_composite(%arg0 : !spv.struct) -> () { + %0 = spv.Variable : !spv.ptr, Function> + // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm<"<{ double }>*"> + spv.Store "Function" %0, %arg0 : !spv.struct + return +} + +func @store_with_alignment(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : !llvm<"float*"> + spv.Store "Function" %0, %arg0 ["Aligned", 4] : f32 + return +} + //===----------------------------------------------------------------------===// // spv.Variable //===----------------------------------------------------------------------===//