diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -66,7 +66,8 @@ op.getMemRefType())); rewriter.replaceOpWithNewOp( - op, newResTy, adaptor.getMemref(), adaptor.getIndices()); + op, newResTy, adaptor.getMemref(), adaptor.getIndices(), + op.getNontemporal()); return success(); } }; @@ -88,7 +89,8 @@ op.getMemRefType())); rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices()); + op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), + op.getNontemporal()); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -384,10 +384,14 @@ return failure(); llvm::TypeSwitch(loadOp) - .Case([&](auto op) { - rewriter.replaceOpWithNewOp(loadOp, subViewOp.getSource(), + .Case([&](AffineLoadOp op) { + rewriter.replaceOpWithNewOp(loadOp, subViewOp.getSource(), sourceIndices); }) + .Case([&](memref::LoadOp op) { + rewriter.replaceOpWithNewOp( + loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); + }) .Case([&](vector::TransferReadOp transferReadOp) { rewriter.replaceOpWithNewOp( transferReadOp, transferReadOp.getVectorType(), @@ -490,10 +494,15 @@ return failure(); llvm::TypeSwitch(storeOp) - .Case([&](auto op) { - rewriter.replaceOpWithNewOp( + .Case([&](AffineStoreOp op) { + rewriter.replaceOpWithNewOp( storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices); }) + .Case([&](memref::StoreOp op) { + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices, + op.getNontemporal()); + }) .Case([&](vector::TransferWriteOp op) { rewriter.replaceOpWithNewOp( op, op.getValue(), subViewOp.getSource(), sourceIndices, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -541,21 +541,20 @@ // ----- // CHECK-LABEL: func @load_non_temporal( -func.func @load_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>) { - %0 = memref.alloc() : memref<32xf32, affine_map<(d0) -> (d0)>> +func.func @load_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>) { %1 = arith.constant 7 : index // CHECK: llvm.load %{{.*}} {nontemporal} : !llvm.ptr - %2 = memref.load %0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + %2 = memref.load %arg0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> func.return } // ----- // CHECK-LABEL: func @store_non_temporal( -func.func @store_non_temporal(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>>, %data : f32) { - %0 = memref.alloc() : memref<32xf32, affine_map<(d0) -> (d0)>> +func.func @store_non_temporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) { %1 = arith.constant 7 : index + %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> // CHECK: llvm.store %{{.*}}, %{{.*}} {nontemporal} : !llvm.ptr - memref.store %data, %0[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> func.return } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -894,3 +894,15 @@ to memref> return %1 : memref> } + +// ----- + +// CHECK-LABEL: func @load_store_nontemporal( +func.func @load_store_nontemporal(%input : memref<32xf32, affine_map<(d0) -> (d0)>>, %output : memref<32xf32, affine_map<(d0) -> (d0)>>) { + %1 = arith.constant 7 : index + // CHECK: memref.load %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> + %2 = memref.load %input[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + // CHECK: memref.store %{{.*}}, %{{.*}}[%{{.*}}] {nontemporal = true} : memref<32xf32> + memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> + func.return +} diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir --- a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir +++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir @@ -44,3 +44,19 @@ memref.store %c1, %m[%c0] : memref<4xi64, 1> return } + + +// CHECK-LABEL: func @alloc_load_store_i64_nontemporal +// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32> +// CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: memref.store [[C1]], [[M]][{{%.+}}] {nontemporal = true} : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: return +func.func @alloc_load_store_i64_nontemporal() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i64 + %m = memref.alloc() : memref<4xi64, 1> + %v = memref.load %m[%c0] {nontemporal = true} : memref<4xi64, 1> + memref.store %c1, %m[%c0] {nontemporal = true} : memref<4xi64, 1> + return +} diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -502,3 +502,25 @@ to memref> return %1 : memref> } + +// ----- + +// CHECK-LABEL: func @fold_load_keep_nontemporal( +// CHECK: memref.load %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} +func.func @fold_load_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> + %1 = memref.load %0[%arg3, %arg4] {nontemporal = true }: memref<4x4xf32, strided<[64, 3], offset: ?>> + return %1 : f32 +} + + +// ----- + +// CHECK-LABEL: func @fold_store_keep_nontemporal( +// CHECK: memref.store %{{.+}}, %{{.+}}[%{{.+}}, %{{.+}}] {nontemporal = true} : memref<12x32xf32> +func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) { + %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : + memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> + memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>> + return +} \ No newline at end of file