diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1155,7 +1155,8 @@ let arguments = (ins Arg:$memref, - Variadic:$indices); + Variadic:$indices, + DefaultValuedOptionalAttr:$nontemporal); let results = (outs AnyType:$result); let extraClassDeclaration = [{ @@ -1690,7 +1691,8 @@ let arguments = (ins AnyType:$value, Arg:$memref, - Variadic:$indices); + Variadic:$indices, + DefaultValuedOptionalAttr:$nontemporal); let builders = [ OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{ diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -731,7 +731,8 @@ Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); - rewriter.replaceOpWithNewOp(loadOp, dataPtr); + rewriter.replaceOpWithNewOp(loadOp, dataPtr, 0, false, + loadOp.getNontemporal()); return success(); } }; @@ -748,7 +749,8 @@ Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); - rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr); + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr, + 0, false, op.getNontemporal()); return success(); } }; diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -537,3 +537,25 @@ return } + +// ----- + +// CHECK-LABEL: func @load_non_temporal( +func.func @load_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>) { + %0 = memref.alloc() : memref<32xf32, affine_map<(d0) -> (d0)>> + %1 = arith.constant 7 : index + // CHECK: llvm.load %{{.*}} {nontemporal} : !llvm.ptr + %2 = memref.load %0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + func.return +} + +// ----- + +// CHECK-LABEL: func @store_non_temporal( +func.func @store_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>, %data : f32) { + %0 = memref.alloc() : memref<32xf32, affine_map<(d0) -> (d0)>> + %1 = arith.constant 7 : index + // CHECK: llvm.store %{{.*}}, %{{.*}} {nontemporal} : !llvm.ptr + memref.store %data, %0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + func.return +}