diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -185,6 +185,25 @@ rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } +/// Utility for `spv.Load` and `spv.Store` conversion. +static LogicalResult replaceWithLoadOrStore(Operation *op, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + unsigned alignment) { + if (auto loadOp = dyn_cast(op)) { + auto dstType = typeConverter.convertType(loadOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(loadOp, dstType, loadOp.ptr(), + alignment); + return success(); + } + auto storeOp = cast(op); + rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), + storeOp.ptr(), alignment); + return success(); +} + //===----------------------------------------------------------------------===// // Type conversion //===----------------------------------------------------------------------===// @@ -566,6 +585,31 @@ } }; +/// Converts `spv.Load` and `spv.Store` to LLVM dialect. +template +class LoadStorePattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(SPIRVop op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (op.memory_access().hasValue() && + op.memory_access().getValue() != spirv::MemoryAccess::None) { + auto memoryAccess = op.memory_access().getValue(); + if (memoryAccess == spirv::MemoryAccess::Aligned) { + unsigned alignment = op.alignment().getValue().getZExtValue(); + replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment); + return success(); + } + // There is no support of other memory access attributes. + return failure(); + } + replaceWithLoadOrStore(op, rewriter, this->typeConverter, 0); + return success(); + } +}; + /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. template class NotPattern : public SPIRVToLLVMConversion { @@ -973,6 +1017,7 @@ NotPattern, // Memory ops + LoadStorePattern, LoadStorePattern, VariablePattern, // Miscellaneous ops diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -1,5 +1,55 @@ // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.Load +//===----------------------------------------------------------------------===// + +func @load() { + %0 = spv.Variable : !spv.ptr + // CHECK: %{{.*}} = llvm.load %{{.*}} : !llvm<"float*"> + %1 = spv.Load "Function" %0 : f32 + return +} + +func @load_none() { + %0 = spv.Variable : !spv.ptr + // CHECK: %{{.*}} = llvm.load %{{.*}} : !llvm<"float*"> + %1 = spv.Load "Function" %0 ["None"] : f32 + return +} + +func @load_with_alignment() { + %0 = spv.Variable : !spv.ptr + // CHECK: %{{.*}} = llvm.load %{{.*}} {alignment = 4 : i64} : !llvm<"float*"> + %1 = spv.Load "Function" %0 ["Aligned", 4] : f32 + return +} + +//===----------------------------------------------------------------------===// +// spv.Store +//===----------------------------------------------------------------------===// + +func @store(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm<"float*"> + spv.Store "Function" %0, %arg0 : f32 + return +} + +func @store_composite(%arg0 : !spv.struct) -> () { + %0 = spv.Variable : !spv.ptr, Function> + // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm<"<{ double }>*"> + spv.Store "Function" %0, %arg0 : !spv.struct + return +} + +func @store_with_alignment(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 4 : i64} : !llvm<"float*"> + spv.Store "Function" %0, %arg0 ["Aligned", 4] : f32 + return +} + //===----------------------------------------------------------------------===// // spv.Variable //===----------------------------------------------------------------------===//