diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -157,10 +157,7 @@ if (funcOp .walk([&](memref::AllocOp allocOp) -> WalkResult { Value oldMemRef = allocOp.getResult(); - if (!oldMemRef.getType() - .cast() - .getLayout() - .isIdentity() && + if (!allocOp.getType().getLayout().isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return WalkResult::interrupt(); return WalkResult::advance(); @@ -175,7 +172,7 @@ Value oldMemRef = callOp.getResult(resIndex); if (oldMemRef.getType().isa()) if (!oldMemRef.getType() - .cast() + .dyn_cast() .getLayout() .isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) @@ -189,7 +186,10 @@ for (unsigned argIndex : llvm::seq(0, funcOp.getNumArguments())) { BlockArgument oldMemRef = funcOp.getArgument(argIndex); if (oldMemRef.getType().isa()) - if (!oldMemRef.getType().cast().getLayout().isIdentity() && + if (!oldMemRef.getType() + .dyn_cast() + .getLayout() + .isIdentity() && !isMemRefNormalizable(oldMemRef.getUsers())) return false; }