diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1155,7 +1155,8 @@ let arguments = (ins Arg:$memref, - Variadic:$indices); + Variadic:$indices, + DefaultValuedOptionalAttr:$nontemporal); let results = (outs AnyType:$result); let extraClassDeclaration = [{ @@ -1690,7 +1691,8 @@ let arguments = (ins AnyType:$value, Arg:$memref, - Variadic:$indices); + Variadic:$indices, + DefaultValuedOptionalAttr:$nontemporal); let builders = [ OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{ diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -731,7 +731,8 @@ Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); - rewriter.replaceOpWithNewOp(loadOp, dataPtr); + rewriter.replaceOpWithNewOp(loadOp, dataPtr, 0, false, + loadOp.getNontemporal()); return success(); } }; @@ -748,7 +749,8 @@ Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), adaptor.getIndices(), rewriter); - rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr); + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr, + 0, false, op.getNontemporal()); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -66,7 +66,8 @@ op.getMemRefType())); rewriter.replaceOpWithNewOp( - op, newResTy, adaptor.getMemref(), adaptor.getIndices()); + op, newResTy, adaptor.getMemref(), adaptor.getIndices(), + op.getNontemporal()); return success(); } }; @@ -88,7 +89,8 @@ op.getMemRefType())); rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices()); + op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), + op.getNontemporal()); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -384,10 +384,14 @@ return failure(); llvm::TypeSwitch(loadOp) - .Case([&](auto op) { - rewriter.replaceOpWithNewOp(loadOp, subViewOp.getSource(), + .Case([&](AffineLoadOp op) { + rewriter.replaceOpWithNewOp(loadOp, subViewOp.getSource(), sourceIndices); }) + .Case([&](memref::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); + }) .Case([&](vector::TransferReadOp transferReadOp) { rewriter.replaceOpWithNewOp( transferReadOp, transferReadOp.getVectorType(), @@ -490,10 +494,15 @@ return failure(); llvm::TypeSwitch(storeOp) - .Case([&](auto op) { - rewriter.replaceOpWithNewOp( + .Case([&](AffineStoreOp op) { + rewriter.replaceOpWithNewOp( storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices); }) + .Case([&](memref::StoreOp op) { + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices, + op.getNontemporal()); + }) .Case([&](vector::TransferWriteOp op) { rewriter.replaceOpWithNewOp( op, op.getValue(), subViewOp.getSource(), sourceIndices, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -537,3 +537,24 @@ return } + +// ----- + +// CHECK-LABEL: func @load_non_temporal( +func.func @load_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>) { + %1 = arith.constant 7 : index + // CHECK: llvm.load %{{.*}} {nontemporal} : !llvm.ptr + %2 = memref.load %arg0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + func.return +} + +// ----- + +// CHECK-LABEL: func @store_non_temporal( +func.func @store_non_temporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) { + %1 = arith.constant 7 : index + %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + // CHECK: llvm.store %{{.*}}, %{{.*}} {nontemporal} : !llvm.ptr + memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + func.return +} diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -894,3 +894,15 @@ to memref> return %1 : memref> } + +// ----- + +// CHECK-LABEL: func @load_store_nontemporal( +func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) { + %1 = arith.constant 7 : index + // CHECK: memref.load %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> + %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> + memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + func.return +} diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir --- a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir +++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir @@ -44,3 +44,19 @@ memref.store %c1, %m[%c0] : memref<4xi64, 1> return } + + +// CHECK-LABEL: func @alloc_load_store_i64_nontemporal +// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32> +// CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: memref.store [[C1]], [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: return +func.func @alloc_load_store_i64_nontemporal() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i64 + %m = memref.alloc() : memref<4xi64, 1> + %v = memref.load %m[%c0] {nontemporal = true} : memref<4xi64, 1> + memref.store %c1, %m[%c0] {nontemporal = true} : memref<4xi64, 1> + return +} diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -502,3 +502,25 @@ to memref> return %1 : memref> } + +// ----- + +// CHECK-LABEL: func @fold_load_keep_nontemporal( +// CHECK: memref.load %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} +func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> + %1 = memref.load %0[%arg3, %arg4] {nontemporal = true }: memref<4x4xf32, strided<[64, 3], offset: ?>> + return %1 : f32 +} + + +// ----- + +// CHECK-LABEL: func @fold_store_keep_nontemporal( +// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32> +func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) { + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : + memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> + memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>> + return +}