diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -892,6 +892,8 @@ /// The result of a memref_cast is always a memref. Type getType() { return getResult().getType(); } }]; + + let hasCanonicalizer = 1; } def MulFOp : FloatArithmeticOp<"mulf"> { diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1744,6 +1744,103 @@ return false; } +namespace { + +/// Canonicalize either: +/// +/// ```mlir +/// ... = memref_cast ... : memref<8x16xf32> to memref +/// ... = memref_cast ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> +// to memref +/// ``` +/// +/// into +/// +/// ```mlir +/// ... = memref_cast ... : ... to memref<8x16xf32> +/// ``` +/// +/// One big caveat is that this canonicalization changes the result type and +/// cannot work in a generic fashion at the moment. As a consequence, this +/// inspects uses and bails out if it hits an op it does not know about. +// TODO(ntv, rriddle): this may be a good use case for an OpInterface. +class MemRefCastToDynamicFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(MemRefCastOp castOp, + PatternRewriter &rewriter) const override { + MemRefType sourceType = castOp.source().getType().dyn_cast(); + AffineMap sourceMap = (sourceType.getAffineMaps().empty()) + ? AffineMap() + : sourceType.getAffineMaps().front(); + MemRefType resultType = castOp.getType().dyn_cast(); + MemRefType desiredResultType = + MemRefType::get(sourceType.getShape(), sourceType.getElementType(), {}); + + // If we don't have MemRefType as source and destination, bail out. + if (!sourceType || !resultType) + return matchFailure(); + + // If we're already in canonical form all is good. + if (resultType == desiredResultType) + return matchFailure(); + + // If resultType has a map, it needs to be the same as the source type to + // canonicalize. + if (!resultType.getAffineMaps().empty() && sourceType != resultType) + return matchFailure(); + + // Inspect uses and bails out if it hits an op it does not know about. + // TODO(ntv, rriddle): this may be a good use case for an OpInterface. + for (auto &u : castOp.getResult().getUses()) { + if (isa(u.getOwner()) || isa(u.getOwner())) + continue; + return matchFailure(); + } + + // Ensure that: + // 1. source is static + // 2. source and target have the same rank (will be extended when needed) + // 3. if result is partially static, ensure sizes match. + if (!sourceType.hasStaticShape() || + sourceType.getRank() != resultType.getRank()) + return matchFailure(); + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto sourceSize = std::get<0>(it); + auto resultSize = std::get<1>(it); + if (ShapedType::isDynamic(resultSize)) + continue; + if (sourceSize != resultSize) + return matchFailure(); + } + + if (sourceMap) { + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(sourceType, strides, offset); + (void)res; + assert(succeeded(res)); + auto stridedMap = + makeStridedLinearLayoutMap(strides, offset, sourceMap.getContext()); + if (sourceMap != stridedMap) + return matchFailure(); + } + + rewriter.replaceOpWithNewOp(castOp, castOp.source(), + desiredResultType); + + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void MemRefCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + OpFoldResult MemRefCastOp::fold(ArrayRef operands) { return impl::foldCastOp(*this); } diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -878,3 +878,15 @@ // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index } + +// CHECK-LABEL: func @memref_cast_canonicalize +func @memref_cast_canonicalize(%arg0: memref<4x8xf32>) -> f32 { + // CHECK: constant 0 : index + %c0 = constant 0 : index + // Goes away + %0 = memref_cast %arg0 : memref<4x8xf32> to memref + + // CHECK-NEXT: load {{.*}} : memref<4x8xf32> + %1 = load %0[%c0, %c0] : memref + return %1 : f32 +}