diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -157,10 +157,7 @@ if (funcOp .walk([&](memref::AllocOp allocOp) -> WalkResult { Value oldMemRef = allocOp.getResult(); - if (!oldMemRef.getType() - .cast() - .getLayout() - .isIdentity() && + if (!allocOp.getType().getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); return WalkResult::advance(); @@ -173,11 +170,9 @@ for (unsigned resIndex : llvm::seq(0, callOp.getNumResults())) { Value oldMemRef = callOp.getResult(resIndex); - if (oldMemRef.getType().isa()) - if (!oldMemRef.getType() - .cast() - .getLayout() - .isIdentity() && + if (auto oldMemRefType = + oldMemRef.getType().dyn_cast()) + if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); } @@ -188,8 +183,8 @@ for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); - if (oldMemRef.getType().isa()) - if (!oldMemRef.getType().cast().getLayout().isIdentity() && + if (auto oldMemRefType = oldMemRef.getType().dyn_cast()) + if (!oldMemRefType.getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return false; }