diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -892,6 +892,8 @@ /// The result of a memref_cast is always a memref. Type getType() { return getResult().getType(); } }]; + + let hasCanonicalizer = 1; } def MulFOp : FloatArithmeticOp<"mulf"> { diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -1627,6 +1627,9 @@ (a.isa() && b.isIndex()); } +void IndexCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) {} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// @@ -1744,6 +1747,77 @@ return false; } +namespace { + +/// Canonicalize either: +/// +/// ```mlir +/// ... = memref_cast ... : memref<8x16xf32> to memref +/// ... = memref_cast ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>> +// to memref +/// ``` +/// +/// into +/// +/// ```mlir +/// ... = memref_cast ... : ... to memref<8x16xf32> +/// ``` +class MemRefCastToDynamicFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(MemRefCastOp castOp, + PatternRewriter &rewriter) const override { + MemRefType sourceType = castOp.source().getType().cast(); + AffineMap sourceMap = (sourceType.getAffineMaps().empty()) + ? AffineMap() + : sourceType.getAffineMaps().front(); + MemRefType resultType = castOp.getType().cast(); + MemRefType desiredResultType = + MemRefType::get(sourceType.getShape(), sourceType.getElementType(), {}); + + // If we're already in canonical form all is good. + if (resultType == desiredResultType) + return matchFailure(); + + // If resultType has a map, it needs to be the same as the source type to + // canonicalize. + if (!resultType.getAffineMaps().empty() && sourceType != resultType) + return matchFailure(); + + // Ensure that: + // 1. source is static + // 2. source and target have the same rank (will be extended when needed) + // 3. if result is partially static, ensure sizes match. + if (!sourceType.hasStaticShape() || + sourceType.getRank() != resultType.getRank()) + return matchFailure(); + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto sourceSize = std::get<0>(it); + auto resultSize = std::get<1>(it); + if (ShapedType::isDynamic(resultSize)) + continue; + if (sourceSize != resultSize) + return matchFailure(); + } + + if (sourceMap) { + } + + rewriter.replaceOpWithNewOp(castOp, castOp.source(), + desiredResultType); + + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void MemRefCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + OpFoldResult MemRefCastOp::fold(ArrayRef operands) { return impl::foldCastOp(*this); } @@ -1970,6 +2044,9 @@ return a.isa() && b.isa(); } +void SIToFPOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) {} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// @@ -2234,6 +2311,9 @@ return impl::foldCastOp(*this); } +void TensorCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) {} + //===----------------------------------------------------------------------===// // Helpers for Tensor[Load|Store]Op //===----------------------------------------------------------------------===// @@ -2987,6 +3067,9 @@ return false; } +void FPExtOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) {} + //===----------------------------------------------------------------------===// // FPTruncOp //===----------------------------------------------------------------------===// @@ -2998,6 +3081,9 @@ return false; } +void FPTruncOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) {} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -323,14 +323,16 @@ } // CHECK-LABEL: func @alloc_const_fold -func @alloc_const_fold() -> memref { - // CHECK-NEXT: %0 = alloc() : memref<4xf32> +func @alloc_const_fold() -> f32 { + // CHECK: constant 0 : index + %c0 = constant 0 : index %c4 = constant 4 : index + // CHECK-NEXT: %0 = alloc() : memref<4xf32> %a = alloc(%c4) : memref - // CHECK-NEXT: %1 = memref_cast %0 : memref<4xf32> to memref - // CHECK-NEXT: return %1 : memref - return %a : memref + // CHECK-NEXT: load {{.*}} : memref<4xf32> + %1 = load %a[%c0] : memref + return %1 : f32 } // CHECK-LABEL: func @dead_alloc_fold @@ -384,7 +386,7 @@ } // CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index) -func @dyn_shape_fold(%L : index, %M : index) -> (memref, memref) { +func @dyn_shape_fold(%L : index, %M : index) -> (i32, f32) { // CHECK: %c0 = constant 0 : index %zero = constant 0 : index // The constants below disappear after they propagate into shapes. @@ -415,7 +417,12 @@ // CHECK: alloc() : memref<9x9xf32> %d = alloc(%nine, %nine) : memref - return %c, %d : memref, memref + // Avoid DCE. + %c0 = constant 0 : index + %vc = load %c[%c0, %c0] : memref + %vd = load %d[%c0, %c0] : memref + + return %vc, %vd : i32, f32 } #map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> @@ -496,7 +503,7 @@ } // CHECK-LABEL: func @const_fold_propagate -func @const_fold_propagate() -> memref { +func @const_fold_propagate() -> f32 { %VT_i = constant 512 : index %VT_i_s = affine.apply affine_map<(d0) -> (d0 floordiv 8)> (%VT_i) @@ -504,7 +511,12 @@ // CHECK: = alloc() : memref<64x32xf32> %Av = alloc(%VT_i_s, %VT_k_l) : memref - return %Av : memref + + // Avoid DCE. + %c0 = constant 0 : index + %res = load %Av[%c0, %c0] : memref + + return %res : f32 } // CHECK-LABEL: func @br_folding @@ -878,3 +890,16 @@ // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index } + +// CHECK-LABEL: func @memref_cast_canonicalize +func @memref_cast_canonicalize(%arg0: memref<4x8xf32>) -> f32 { + // CHECK: constant 0 : index + %c0 = constant 0 : index + // Goes away + %0 = memref_cast %arg0 : memref<4x8xf32> to memref + + // CHECK-NEXT: load {{.*}} : memref<4x8xf32> + %1 = load %0[%c0, %c0] : memref + return %1 : f32 +} +