diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -309,6 +309,33 @@ } namespace { +// Create a memref_cast when needed to convert from newType to the type of +// oldOp. Called when canonilizing a view or subview op that changed the type of +// the memref (from a dynamic-sized memref to a statically-sized memref). Only +// inserts the memref_cast when the corresponding memref is used as argument for +// a function call or is returned directly. +void replaceOpInsertMemrefCastWhenNeeded(PatternRewriter &rewriter, + Value oldValue, Value newValue) { + bool canSkipMemrefCast = true; + auto oldOp = oldValue.getDefiningOp(); + auto uses = oldOp->getUses(); + for (auto it = uses.begin(); it != uses.end(); it++) { + auto *user = it.getUser(); + canSkipMemrefCast &= !(isa(user) + || isa(user) + || isa(user)); + } + if (canSkipMemrefCast) { + for (auto it = uses.begin(); it != uses.end(); it++) + it->get().setType(newValue.getType()); + rewriter.replaceOp(oldOp, {newValue}); + } else { + auto castOp = rewriter.create(oldValue.getLoc(), newValue, + oldValue.getType()); + rewriter.replaceOp(oldOp, {castOp}); + } +} + /// Fold constant dimensions into an alloc operation. struct SimplifyAllocConst : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -359,11 +386,7 @@ // Create and insert the alloc op for the new memref. auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, newOperands, IntegerAttr()); - // Insert a cast so we have the same type as the old alloc. - auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, - alloc.getType()); - - rewriter.replaceOp(alloc, {resultCast}); + replaceOpInsertMemrefCastWhenNeeded(rewriter, alloc, newAlloc); return success(); } }; @@ -2025,9 +2048,7 @@ auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), ArrayRef(), subViewOp.strides(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); + replaceOpInsertMemrefCastWhenNeeded(rewriter, subViewOp, newSubViewOp); return success(); } }; @@ -2074,9 +2095,7 @@ auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), subViewOp.sizes(), ArrayRef(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); + replaceOpInsertMemrefCastWhenNeeded(rewriter, subViewOp, newSubViewOp); return success(); } }; @@ -2125,9 +2144,7 @@ auto newSubViewOp = rewriter.create( subViewOp.getLoc(), subViewOp.source(), ArrayRef(), subViewOp.sizes(), subViewOp.strides(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); + replaceOpInsertMemrefCastWhenNeeded(rewriter, subViewOp, newSubViewOp); return success(); } }; @@ -2442,8 +2459,7 @@ auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), newOperands); // Insert a cast so we have the same type as the old memref type. - rewriter.replaceOpWithNewOp(viewOp, newViewOp, - viewOp.getType()); + replaceOpInsertMemrefCastWhenNeeded(rewriter, viewOp, newViewOp); return success(); } };