diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -360,17 +360,74 @@ return success(); } }; + +/// Fold casted alloc/alloca operations so they will have the destination's +/// type, and the memref.cast will not be needed. This is allowed if the cast +/// does not lose information, of rank and size. In some cases the allocated +/// buffer type cannot change, for example if it is used by a func.call. Thus +/// the pattern is limited to alloc/alloca ops that are used by ops from the +/// memref dialect. +template +struct SimplifyCastedAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T alloc, + PatternRewriter &rewriter) const override { + if (llvm::none_of(alloc->getUsers(), + [&](Operation *op) { return isa(op); })) + return failure(); + auto isValidUser = [&](Operation *op) -> bool { + if (!isa(op->getDialect())) + return false; + auto castUser = dyn_cast(op); + if (!castUser) + return true; + MemRefType currDstMemRefType = + cast(op).getDest().getType().dyn_cast(); + if (!currDstMemRefType) + return false; + if (!currDstMemRefType.hasStaticShape()) + return false; + auto [strides, offset] = getStridesAndOffset(currDstMemRefType); + return offset != ShapedType::kDynamic && + !llvm::is_contained(strides, ShapedType::kDynamic); + }; + if (!llvm::all_of(alloc->getUsers(), isValidUser)) + return failure(); + // All memref.cast should have the same type, which will be the alloc's new + // type. + MemRefType dstType; + for (auto user : alloc->getUsers()) { + if (!isa(user)) + continue; + MemRefType currDstType = + cast(user).getDest().getType().template cast(); + if (!dstType) + dstType = currDstType; + else if (dstType != currDstType) + return failure(); + } + if (alloc.getMemref().getType().template cast() == dstType) + return failure(); + rewriter.replaceOpWithNewOp(alloc, dstType, alloc.getDynamicSizes(), + alloc.getSymbolOperands(), + alloc.getAlignmentAttr()); + return success(); + } +}; } // namespace void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, SimplifyDeadAlloc>(context); + results.add, SimplifyDeadAlloc, + SimplifyCastedAlloc>(context); } void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, SimplifyDeadAlloc>( - context); + results.add, SimplifyDeadAlloc, + SimplifyCastedAlloc>(context); + } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -940,3 +940,22 @@ %alloc = memref.alloc(%c15, %c-2, %idx1) : memref return %alloc : memref } + +// ----- + +// CHECK-LABEL: func @fold_casted_alloc( +// CHECK-SAME: %[[arg:.*]]: memref<1x15x30x256xsi8, #map>) -> (memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>>) { +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8, #map> +// CHECK: memref.copy %[[arg]], %[[alloc]] : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8, #map> +// CHECK-NOT: memref.cast +// CHECK: %[[transpose:.*]] = memref.transpose %[[alloc]] (d0, d1, d2, d3) -> (d0, d3, d1, d2) : memref<1x15x30x256xsi8, #map> to memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> +// CHECK: return %[[alloc]], %[[transpose]] : memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @fold_casted_alloc(%arg0: memref<1x15x30x256xsi8, #map>) -> (memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + memref.copy %arg0, %alloc : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<1x15x30x256xsi8, #map> + %transpose = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d3, d1, d2) : memref<1x15x30x256xsi8> to memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> + return %cast, %transpose : memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> +}