diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2527,11 +2527,36 @@ } }; +struct ViewOpMemrefCastFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ViewOp viewOp, + PatternRewriter &rewriter) const override { + Value memrefOperand = viewOp.getOperand(0); + MemRefCastOp memrefCastOp = + dyn_cast_or_null(memrefOperand.getDefiningOp()); + if (!memrefCastOp) { + return matchFailure(); + } + Value allocOperand = memrefCastOp.getOperand(); + AllocOp allocOp = dyn_cast_or_null(allocOperand.getDefiningOp()); + if (!allocOp) + return matchFailure(); + + auto newOperands = {viewOp.getOperands()[1], viewOp.getOperands()[2]}; + + // Replace view op and remove memrefcast values. + rewriter.replaceOpWithNewOp(memrefOperand, viewOp, viewOp.getType(), + allocOperand, newOperands); + return matchSuccess(); + } +}; + } // end anonymous namespace void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -735,6 +735,11 @@ : memref<2048xi8> to memref load %5[%c0, %c0] : memref + // Test: folding static alloc and memref_cast into a view. + // CHECK: std.view %0[][%c15, %c7] : memref<2048xi8> to memref + %6 = memref_cast %0 : memref<2048xi8> to memref + %7 = view %6[%c15][%c7] : memref to memref + load %7[%c0, %c0] : memref return }