diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -145,10 +145,10 @@ /// Check whether all the uses of AllocOps, CallOps and function arguments of a /// function are either of dereferencing type or are uses in: DeallocOp, CallOp /// or ReturnOp. Only if these constraints are satisfied will the function -/// become a candidate for normalization. We follow a conservative approach here -/// wherein even if the non-normalizable memref is not a part of the function's -/// argument or return type, we still label the entire function as -/// non-normalizable. We assume external functions to be normalizable. +/// become a candidate for normalization. When the uses of a memref are +/// non-normalizable and the memref map layout is trivial (identity), we can +/// still label the entire function as normalizable. We assume external +/// functions to be normalizable. bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { // We assume external functions to be normalizable. if (funcOp.isExternal()) @@ -157,7 +157,11 @@ if (funcOp .walk([&](memref::AllocOp allocOp) -> WalkResult { Value oldMemRef = allocOp.getResult(); - if (!isMemRefNormalizable(oldMemRef.getUsers())) + if (!oldMemRef.getType() + .cast() + .getLayout() + .isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); return WalkResult::advance(); }) @@ -170,7 +174,11 @@ llvm::seq(0, callOp.getNumResults())) { Value oldMemRef = callOp.getResult(resIndex); if (oldMemRef.getType().isa()) - if (!isMemRefNormalizable(oldMemRef.getUsers())) + if (!oldMemRef.getType() + .cast() + .getLayout() + .isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); } return WalkResult::advance(); @@ -181,7 +189,8 @@ for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); if (oldMemRef.getType().isa()) - if (!isMemRefNormalizable(oldMemRef.getUsers())) + if (!oldMemRef.getType().cast().getLayout().isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) return false; } diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -41,6 +41,24 @@ return } +// Test with op_nonnorm whose memref map layouts are identity. This op_nonnorm +// does not block the normalization of other operations. + +// CHECK-LABEL: test_nonnorm_identity_layout +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>) +func.func @test_nonnorm_identity_layout(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { + %0 = memref.alloc() : memref<1x16x14x14xf32> + "test.op_nonnorm"(%0, %0) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> () + "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32>) -> () + memref.dealloc %0 : memref<1x16x14x14xf32> + + // CHECK: %[[v0:.*]] = memref.alloc() : memref<1x16x14x14xf32> + // CHECK: "test.op_nonnorm"(%[[v0]], %[[v0]]) : (memref<1x16x14x14xf32>, memref<1x16x14x14xf32>) -> () + // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32>) -> () + // CHECK: memref.dealloc %[[v0]] : memref<1x16x14x14xf32> + return +} + // Test with op_norm, with maps in the operations in the function. // CHECK-LABEL: test_norm_mix