diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -301,6 +301,44 @@ raw_ostream &operator<<(raw_ostream &os, SubViewOp::Range &range); +/// Determines whether MemRefCastOp casts to a more dynamic version of the +/// source memref. This is useful to to fold a memref_cast into a consuming op +/// and implement canonicalization patterns for ops in different dialects that +/// may consume the results of memref_cast operations. Such foldable memref_cast +/// operations are typically inserted as `view` and `subview` ops are +/// canonicalized, to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked memrefs with strided semantics and same +/// element type and rank. +/// 2. each of the source's size, offset or stride has more static information +/// than the corresponding result's size, offset or stride. +/// +/// Example 1: +/// ```mlir +/// %1 = memref_cast %0 : memref<8x16xf32> to memref +/// %2 = consumer %1 ... : memref ... +/// ``` +/// +/// may fold into: +/// +/// ```mlir +/// %2 = consumer %0 ... : memref<8x16xf32> ... +/// ``` +/// +/// Example 2: +/// ``` +/// %1 = memref_cast %0 : memref(16 * i + j)>> +/// to memref +/// consumer %1 : memref ... +/// ``` +/// +/// may fold into: +/// +/// ``` +/// consumer %0 ... : memref(16 * i + j)>> +/// ``` +bool canFoldIntoConsumerOp(MemRefCastOp castOp); } // end namespace mlir #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2606,6 +2606,7 @@ }]; let hasCanonicalizer = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -44,82 +44,16 @@ template static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); -/// Determines whether it is possible to fold it away in the parent Linalg op: -/// -/// ```mlir -/// %1 = memref_cast %0 : memref<8x16xf32> to memref -/// %2 = linalg.slice %1 ... : memref ... -/// // or -/// %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> -/// to memref -/// linalg.generic(%1 ...) : memref ... -/// ``` -/// -/// into -/// -/// ```mlir -/// %2 = linalg.slice %0 ... : memref<8x16xf32> ... -/// // or -/// linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> -/// ``` -/// -static bool canFold(MemRefCastOp castOp) { - MemRefType sourceType = castOp.source().getType().dyn_cast(); - MemRefType resultType = castOp.getType().dyn_cast(); - - // If we don't have MemRefType as source and destination, bail out. - if (!sourceType || !resultType) - return false; - - // If resultType has a map, it needs to be the same as the source type to - // canonicalize. - if (!resultType.getAffineMaps().empty() && - sourceType.getAffineMaps() != resultType.getAffineMaps()) - return false; - - // Ensure that: - // 1. source is static - // 2. source and target have the same rank (will be extended when needed) - // 3. if result is partially static, ensure sizes match. - if (!sourceType.hasStaticShape() || - sourceType.getRank() != resultType.getRank()) - return false; - - for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { - auto sourceSize = std::get<0>(it); - auto resultSize = std::get<1>(it); - if (ShapedType::isDynamic(resultSize)) - continue; - if (sourceSize != resultSize) - return false; - } - - // If source has a map, it can only canonicalize if it is the canonical - // strided layout map. - if (sourceType.getAffineMaps().empty()) - return true; - - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(sourceType, strides, offset); - (void)res; - assert(succeeded(res)); - auto stridedMap = - makeStridedLinearLayoutMap(strides, offset, castOp.getContext()); - AffineMap sourceMap = sourceType.getAffineMaps().front(); - return sourceMap == stridedMap; -} - /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast) -> someop /// ``` -/// It folds the source of any memref_cast into the root operation directly. +/// It folds the source of the memref_cast into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto castOp = dyn_cast_or_null(operand.get().getDefiningOp()); - if (castOp && canFold(castOp)) { + if (castOp && canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); folded = true; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2519,6 +2519,111 @@ } // end anonymous namespace +/// Determines whether MemRefCastOp casts to a more dynamic version of the +/// source memref. This is useful to to fold a memref_cast into a consuming op +/// and implement canonicalization patterns for ops in different dialects that +/// may consume the results of memref_cast operations. Such foldable memref_cast +/// operations are typically inserted as `view` and `subview` ops are +/// canonicalized, to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked memrefs with strided semantics and same +/// element type and rank. +/// 2. each of the source's size, offset or stride has more static information +/// than the corresponding result's size, offset or stride. +/// +/// Example 1: +/// ```mlir +/// %1 = memref_cast %0 : memref<8x16xf32> to memref +/// %2 = consumer %1 ... : memref ... +/// ``` +/// +/// may fold into: +/// +/// ```mlir +/// %2 = consumer %0 ... : memref<8x16xf32> ... +/// ``` +/// +/// Example 2: +/// ``` +/// %1 = memref_cast %0 : memref(16 * i + j)>> +/// to memref +/// consumer %1 : memref ... +/// ``` +/// +/// may fold into: +/// +/// ``` +/// consumer %0 ... : memref(16 * i + j)>> +/// ``` +bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) { + MemRefType sourceType = castOp.source().getType().dyn_cast(); + MemRefType resultType = castOp.getType().dyn_cast(); + + // Requires ranked MemRefType. + if (!sourceType || !resultType) + return false; + + // Requires same elemental type. + if (sourceType.getElementType() != resultType.getElementType()) + return false; + + // Requires same rank. + if (sourceType.getRank() != resultType.getRank()) + return false; + + // Only fold casts between strided memref forms. + int64_t sourceOffset, resultOffset; + SmallVector sourceStrides, resultStrides; + if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || + failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return false; + + // If cast is towards more static sizes along any dimension, don't fold. + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto ss = std::get<0>(it), st = std::get<1>(it); + if (ss != st) + if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) + return false; + } + + // If cast is towards more static offset along any dimension, don't fold. + if (sourceOffset != resultOffset) + if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && + !MemRefType::isDynamicStrideOrOffset(resultOffset)) + return false; + + // If cast is towards more static strides along any dimension, don't fold. + for (auto it : llvm::zip(sourceStrides, resultStrides)) { + auto ss = std::get<0>(it), st = std::get<1>(it); + if (ss != st) + if (MemRefType::isDynamicStrideOrOffset(ss) && + !MemRefType::isDynamicStrideOrOffset(st)) + return false; + } + + return true; +} + +OpFoldResult SubViewOp::fold(ArrayRef) { + auto folds = [](Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = + dyn_cast_or_null(operand.get().getDefiningOp()); + if (castOp && canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return folded ? success() : failure(); + }; + + if (succeeded(folds(*this))) + return getResult(); + return {}; +} + void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert } + +// ----- + +// CHECK-LABEL: func @memref_cast_folding_subview +func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref) { + %0 = memref_cast %arg0 : memref<4x5xf32> to memref + // CHECK-NEXT: subview %{{.*}}: memref<4x5xf32> + %1 = subview %0[][%i,%i][]: memref to memref + // CHECK-NEXT: return %{{.*}} + return %1: memref +} +