diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::memref; @@ -360,17 +361,74 @@ return success(); } }; + +/// Fold casted alloc/alloca operations so they will have the destination's +/// type, and the memref.cast will not be needed. This is allowed if the cast +/// does not lose information, of rank and size. In some cases the allocated +/// buffer type cannot change, for example if it is used by a func.call. Thus +/// the pattern is limited to alloc/alloca ops that are used by ops from the +/// memref dialect. +template +struct SimplifyCastedAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T alloc, + PatternRewriter &rewriter) const override { + // Valid users are ops of memref dialect. CastOps are allowed only if + // their result type is ranked, has static shape, strides and offsets. + // All memref.cast should have the same type, which will be the alloc's new + // type. + MemRefType dstType; + Dialect *memrefDialect = alloc->getDialect(); + for (auto user : alloc->getUsers()) { + if (user->getDialect() != memrefDialect) + return failure(); + auto castUser = dyn_cast(user); + if (!castUser) + continue; + auto dstBaseMemRefType = + castUser.getDest().getType().template cast(); + if (!dstBaseMemRefType.hasRank()) + return failure(); + if (!dstBaseMemRefType.hasStaticShape()) + return failure(); + auto currDstType = dstBaseMemRefType.template cast(); + auto [strides, offset] = getStridesAndOffset(currDstType); + if (offset == ShapedType::kDynamic || + llvm::is_contained(strides, ShapedType::kDynamic)) + return failure(); + if (!dstType) + dstType = currDstType; + else if (dstType != currDstType) + return failure(); + } + + // If no cast operations were found, the pattern fails. + if (!dstType) + return failure(); + + // If the alloc type should not change, the pattern fails. + if (alloc.getMemref().getType().template cast() == dstType) + return failure(); + + rewriter.replaceOpWithNewOp(alloc, dstType, alloc.getDynamicSizes(), + alloc.getSymbolOperands(), + alloc.getAlignmentAttr()); + return success(); + } +}; } // namespace void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, SimplifyDeadAlloc>(context); + results.add, SimplifyDeadAlloc, + SimplifyCastedAlloc>(context); } void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, SimplifyDeadAlloc>( - context); + results.add, SimplifyDeadAlloc, + SimplifyCastedAlloc>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -940,3 +940,95 @@ %alloc = memref.alloc(%c15, %c-2, %idx1) : memref return %alloc : memref } + +// ----- + +// CHECK-LABEL: func @fold_casted_alloc +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8, #map> +// CHECK-NOT: memref.cast + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @fold_casted_alloc(%arg0: memref<1x15x30x256xsi8, #map>) -> (memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + memref.copy %arg0, %alloc : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<1x15x30x256xsi8, #map> + %transpose = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d3, d1, d2) : memref<1x15x30x256xsi8> to memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> + return %cast, %transpose : memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> +} + +// ----- + +// CHECK-LABEL: func @uncasted_alloc_is_unchanged +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @uncasted_alloc_is_unchanged(%arg0: memref<1x15x30x256xsi8, #map>) -> (memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + memref.copy %arg0, %alloc : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8> + %transpose = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d3, d1, d2) : memref<1x15x30x256xsi8> to memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> + return %transpose : memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> +} + +// ----- + +// CHECK-LABEL: func @non_memref_dialect_users +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @non_memref_dialect_users() -> (memref<1x15x30x256xsi8, #map>, memref<1x15x30x256xsi8>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %other_dialect_op = "my.op"(%alloc) : (memref<1x15x30x256xsi8>) -> (memref<1x15x30x256xsi8>) + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<1x15x30x256xsi8, #map> + return %cast, %other_dialect_op : memref<1x15x30x256xsi8, #map>, memref<1x15x30x256xsi8> +} + +// ----- + +// CHECK-LABEL: func @alloc_casted_to_dynamic_shape +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast + +func.func @alloc_casted_to_dynamic_shape() -> (memref) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref + return %cast : memref +} + +// ----- + +// CHECK-LABEL: func @alloc_casted_to_dynamic_stride +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8, strided<[115200, 1, 7680, 256]>> +// CHECK: memref.cast + +func.func @alloc_casted_to_dynamic_stride() -> (memref<1x15x30x256xsi8, strided<[?, 1, 7680, 256]>>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8, strided<[115200, 1, 7680, 256]>> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8, strided<[115200, 1, 7680, 256]>> to memref<1x15x30x256xsi8, strided<[?, 1, 7680, 256]>> + return %cast : memref<1x15x30x256xsi8, strided<[?, 1, 7680, 256]>> +} + +// ----- + +// CHECK-LABEL: func @casted_alloc_conflicts +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast +// CHECK: memref.cast + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @casted_alloc_conflicts() -> (memref<1x15x30x256xsi8, #map>, memref) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<1x15x30x256xsi8, #map> + %cast_0 = memref.cast %alloc : memref<1x15x30x256xsi8> to memref + return %cast, %cast_0 : memref<1x15x30x256xsi8, #map>, memref +} +// ----- + +// CHECK-LABEL: func @unranked_casted_alloc +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast + +func.func @unranked_casted_alloc() -> (memref<*xsi8>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<*xsi8> + return %cast : memref<*xsi8> +}