diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -189,18 +189,20 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, - unsigned alignment) { + unsigned alignment, bool isVolatile, + bool isNoneTemporal) { if (auto loadOp = dyn_cast(op)) { auto dstType = typeConverter.convertType(loadOp.getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp(loadOp, dstType, loadOp.ptr(), - alignment); + rewriter.replaceOpWithNewOp( + loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNoneTemporal); return success(); } auto storeOp = cast(op); rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), - storeOp.ptr(), alignment); + storeOp.ptr(), alignment, + isVolatile, isNoneTemporal); return success(); } @@ -594,19 +596,31 @@ LogicalResult matchAndRewrite(SPIRVop op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (op.memory_access().hasValue() && - op.memory_access().getValue() != spirv::MemoryAccess::None) { - auto memoryAccess = op.memory_access().getValue(); - if (memoryAccess == spirv::MemoryAccess::Aligned) { - unsigned alignment = op.alignment().getValue().getZExtValue(); - replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment); - return success(); - } + + if (!op.memory_access().hasValue()) { + replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0, + /*isVolatile=*/false, /*isNonTemporal*/ false); + return success(); + } + auto memoryAccess = op.memory_access().getValue(); + switch (memoryAccess) { + case spirv::MemoryAccess::Aligned: + case spirv::MemoryAccess::None: + case spirv::MemoryAccess::Nontemporal: + case spirv::MemoryAccess::Volatile: { + unsigned alignment = memoryAccess == spirv::MemoryAccess::Aligned + ? op.alignment().getValue().getZExtValue() + : 0; + bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; + bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; + replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment, + isVolatile, isNonTemporal); + return success(); + } + default: // There is no support of other memory access attributes. return failure(); } - replaceWithLoadOrStore(op, rewriter, this->typeConverter, 0); - return success(); } }; diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -25,6 +25,20 @@ return } +func @load_volatile() { + %0 = spv.Variable : !spv.ptr + // CHECK: %{{.*}} = llvm.load volatile %{{.*}} : !llvm<"float*"> + %1 = spv.Load "Function" %0 ["Volatile"] : f32 + return +} + +func @load_nontemporal() { + %0 = spv.Variable : !spv.ptr + // CHECK: %{{.*}} = llvm.load %{{.*}} {nontemporal} : !llvm<"float*"> + %1 = spv.Load "Function" %0 ["Nontemporal"] : f32 + return +} + //===----------------------------------------------------------------------===// // spv.Store //===----------------------------------------------------------------------===// @@ -50,6 +64,20 @@ return } +func @store_volatile(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: llvm.store volatile %{{.*}}, %{{.*}} : !llvm<"float*"> + spv.Store "Function" %0, %arg0 ["Volatile"] : f32 + return +} + +func @store_nontemporal(%arg0 : f32) -> () { + %0 = spv.Variable : !spv.ptr + // CHECK: llvm.store %{{.*}}, %{{.*}} {nontemporal} : !llvm<"float*"> + spv.Store "Function" %0, %arg0 ["Nontemporal"] : f32 + return +} + //===----------------------------------------------------------------------===// // spv.Variable //===----------------------------------------------------------------------===//