diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1831,7 +1831,8 @@ def LoadOp : Std_Op<"load", [TypesMatchWith<"result type matches element type of 'memref'", "memref", "result", - "$_self.cast().getElementType()">]> { + "$_self.cast().getElementType()">, + MemRefsNormalizable]> { let summary = "load operation"; let description = [{ The `load` op reads an element from a memref specified by an index list. The @@ -2580,7 +2581,8 @@ def StoreOp : Std_Op<"store", [TypesMatchWith<"type of 'value' matches element type of 'memref'", "memref", "value", - "$_self.cast().getElementType()">]> { + "$_self.cast().getElementType()">, + MemRefsNormalizable]> { let summary = "store operation"; let description = [{ Store a value to a memref location given by indices. The value stored should diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -55,3 +55,37 @@ // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> return } + +// Test with maps in load and store ops. + +#map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)> + +// CHECK-LABEL: test_load_store +// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32> +func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () { + %0 = alloc() : memref<1x16x14x14xf32, #map_tile> + // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x32xf32> + %1 = alloc() : memref<1x16x14x14xf32> + // CHECK: %[[v1:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32> + "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> () + // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> () + %cst = constant 3.0 : f32 + affine.for %i = 0 to 1 { + affine.for %j = 0 to 16 { + affine.for %k = 0 to 14 { + affine.for %l = 0 to 14 { + %2 = load %1[%i, %j, %k, %l] : memref<1x16x14x14xf32> + // CHECK: memref<1x16x14x14xf32> + %3 = addf %2, %cst : f32 + store %3, %arg0[%i, %j, %k, %l] : memref<1x16x14x14xf32> + // CHECK: memref<1x16x14x14xf32> + } + } + } + } + dealloc %0 : memref<1x16x14x14xf32, #map_tile> + // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x32xf32> + dealloc %1 : memref<1x16x14x14xf32> + // CHECK: dealloc %[[v1]] : memref<1x16x14x14xf32> + return +}