diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -20,8 +20,10 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/Casting.h" using namespace mlir; using namespace mlir::memref; @@ -354,17 +356,81 @@ return success(); } }; + +/// Fold casted alloc/alloca operations so they will have the destination's +/// type, and the memref.cast will not be needed. This is allowed if the cast +/// does not lose information, of rank and size. In some cases the allocated +/// buffer type cannot change, for example if it is used by a func.call. Thus +/// the pattern is limited to alloc/alloca ops that are used by ops from the +/// memref dialect. Note that the pattern is applied only on alloc and cast ops +/// that are in the same block, to avoid folding a cast that appears inside a +/// control flow. +template +struct SimplifyCastedAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T alloc, + PatternRewriter &rewriter) const override { + // Valid users are ops of memref dialect. Folding the castOp is allowed only + // if result types of both alloc and cast are ranked, have static shape, + // strides and offsets. All memref.cast should have the same type, which + // will be the alloc's new type. + auto srcBaseMemRefType = + alloc.getMemref().getType().template cast(); + MemRefType dstType; + Dialect *memrefDialect = alloc->getDialect(); + for (auto user : alloc->getUsers()) { + if (user->getDialect() != memrefDialect) + return failure(); + auto castUser = dyn_cast(user); + if (!castUser) + continue; + if (castUser->getBlock() != alloc->getBlock()) + return failure(); + auto dstBaseMemRefType = + castUser.getDest().getType().template cast(); + if (!srcBaseMemRefType.hasRank() || !dstBaseMemRefType.hasRank()) + return failure(); + if (!srcBaseMemRefType.hasStaticShape() || + !dstBaseMemRefType.hasStaticShape()) + return failure(); + auto currDstType = dstBaseMemRefType.template cast(); + auto [strides, offset] = getStridesAndOffset(currDstType); + if (offset == ShapedType::kDynamic || + llvm::is_contained(strides, ShapedType::kDynamic)) + return failure(); + if (!dstType) + dstType = currDstType; + else if (dstType != currDstType) + return failure(); + } + + // If no cast operations were found, the pattern fails. + if (!dstType) + return failure(); + + // If the alloc type should not change, the pattern fails. + if (alloc.getMemref().getType().template cast() == dstType) + return failure(); + + rewriter.replaceOpWithNewOp(alloc, dstType, alloc.getDynamicSizes(), + alloc.getSymbolOperands(), + alloc.getAlignmentAttr()); + return success(); + } +}; } // namespace void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, SimplifyDeadAlloc>(context); + results.add, SimplifyDeadAlloc, + SimplifyCastedAlloc>(context); } void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, SimplifyDeadAlloc>( - context); + results.add, SimplifyDeadAlloc, + SimplifyCastedAlloc>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -965,3 +965,129 @@ // CHECK: return %[[cast]] return %0 : memref> } + +// ----- + +// CHECK-LABEL: func @fold_casted_alloc +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8, #map> +// CHECK-NOT: memref.cast + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @fold_casted_alloc(%arg0: memref<1x15x30x256xsi8, #map>) -> (memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + memref.copy %arg0, %alloc : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<1x15x30x256xsi8, #map> + %transpose = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d3, d1, d2) : memref<1x15x30x256xsi8> to memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> + return %cast, %transpose : memref<1x15x30x256xsi8, #map>, memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> +} + +// ----- + +// CHECK-LABEL: func @uncasted_alloc_is_unchanged +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @uncasted_alloc_is_unchanged(%arg0: memref<1x15x30x256xsi8, #map>) -> (memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + memref.copy %arg0, %alloc : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8> + %transpose = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d3, d1, d2) : memref<1x15x30x256xsi8> to memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> + return %transpose : memref<1x256x15x30xsi8, strided<[115200, 1, 7680, 256]>> +} + +// ----- + +// CHECK-LABEL: func @non_memref_dialect_users +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @non_memref_dialect_users() -> (memref<1x15x30x256xsi8, #map>, memref<1x15x30x256xsi8>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %other_dialect_op = "my.op"(%alloc) : (memref<1x15x30x256xsi8>) -> (memref<1x15x30x256xsi8>) + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<1x15x30x256xsi8, #map> + return %cast, %other_dialect_op : memref<1x15x30x256xsi8, #map>, memref<1x15x30x256xsi8> +} + +// ----- + +// CHECK-LABEL: func @alloc_casted_to_dynamic_shape +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast + +func.func @alloc_casted_to_dynamic_shape() -> (memref) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref + return %cast : memref +} + +// ----- + +// CHECK-LABEL: func @alloc_casted_to_dynamic_stride +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8, strided<[115200, 1, 7680, 256]>> +// CHECK: memref.cast + +func.func @alloc_casted_to_dynamic_stride() -> (memref<1x15x30x256xsi8, strided<[?, 1, 7680, 256]>>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8, strided<[115200, 1, 7680, 256]>> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8, strided<[115200, 1, 7680, 256]>> to memref<1x15x30x256xsi8, strided<[?, 1, 7680, 256]>> + return %cast : memref<1x15x30x256xsi8, strided<[?, 1, 7680, 256]>> +} + +// ----- + +// CHECK-LABEL: func @casted_alloc_conflicts +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast +// CHECK: memref.cast + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @casted_alloc_conflicts() -> (memref<1x15x30x256xsi8, #map>, memref) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<1x15x30x256xsi8, #map> + %cast_0 = memref.cast %alloc : memref<1x15x30x256xsi8> to memref + return %cast, %cast_0 : memref<1x15x30x256xsi8, #map>, memref +} +// ----- + +// CHECK-LABEL: func @unranked_casted_alloc +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<1x15x30x256xsi8> +// CHECK: memref.cast + +func.func @unranked_casted_alloc() -> (memref<*xsi8>) { + %alloc = memref.alloc() : memref<1x15x30x256xsi8> + %cast = memref.cast %alloc : memref<1x15x30x256xsi8> to memref<*xsi8> + return %cast : memref<*xsi8> +} + +// ----- + + +// CHECK-LABEL: func @cast_and_alloc_in_different_blocks +// CHECK: %[[VAL_0:.*]] = memref.alloc() : memref<1x15x30x256xsi8, #map> +// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<1x15x30x256xsi8, #map> + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @cast_and_alloc_in_different_blocks(%cond: i1) -> () { + %alloc = memref.alloc() : memref<1x15x30x256xsi8, #map> + %alloc_1 = memref.alloc() : memref<1x15x30x256xsi8, #map> + memref.copy %alloc, %alloc_1 : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8, #map> + scf.if %cond { + %cast = memref.cast %alloc : memref<1x15x30x256xsi8, #map> to memref<1x15x30x256xsi8> + } + return +} + +// ----- + + +// CHECK-LABEL: func @do_not_change_dynamic_alloc( +// CHECK-SAME: %{{.*}} %[[VAL_1:.*]]: index) +// CHECK: %[[alloc:.*]] = memref.alloc(%[[VAL_1]]) : memref + +#map = affine_map<(d0, d1, d2, d3) -> (d0 * 115200 + d1 * 7680 + d2 * 256 + d3)> +func.func @do_not_change_dynamic_alloc(%arg0: memref, %arg1: index) -> (memref<1x15x30x256xsi8, #map>, memref>) { + %alloc = memref.alloc(%arg1) : memref + memref.copy %arg0, %alloc : memref to memref + %cast = memref.cast %alloc : memref to memref<1x15x30x256xsi8, #map> + %transpose = memref.transpose %alloc (d0, d1, d2, d3) -> (d0, d3, d1, d2) : memref to memref> + return %cast, %transpose : memref<1x15x30x256xsi8, #map>, memref> +}