diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1861,7 +1861,6 @@ return success(); } }; -} // namespace /// Return the canonical type of the result of a subview. struct SubViewReturnTypeCanonicalizer { @@ -1881,12 +1880,82 @@ } }; +static bool isSameRank(Type type1, Type type2) { + auto shaped1 = type1.dyn_cast(); + if (!shaped1) + return false; + + auto shaped2 = type2.dyn_cast(); + if (!shaped2) + return false; + + if (!shaped1.hasRank() || !shaped2.hasRank()) + return false; + + return shaped1.getRank() == shaped2.getRank(); +} + +static bool isMixedValuesEqual(ArrayRef values, + int64_t expectedVal) { + for (auto val : values) { + auto intVal = getConstantIntValue(val); + if (!intVal || *intVal != expectedVal) + return false; + } + return true; +} + +struct SubviewLoadPropagate : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LoadOp op, + PatternRewriter &rewriter) const override { + auto src = op.memref().getDefiningOp(); + if (!src) + return failure(); + + if (!isSameRank(src.source().getType(), src.getType())) + return failure(); + + if (!isMixedValuesEqual(src.getMixedOffsets(), 0) || + !isMixedValuesEqual(src.getMixedStrides(), 1)) + return failure(); + + rewriter.replaceOpWithNewOp(op, src.source(), op.indices()); + return success(); + } +}; + +struct SubviewStorePropagate : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(StoreOp op, + PatternRewriter &rewriter) const override { + auto src = op.memref().getDefiningOp(); + if (!src) + return failure(); + + if (!isSameRank(src.source().getType(), src.getType())) + return failure(); + + if (!isMixedValuesEqual(src.getMixedOffsets(), 0) || + !isMixedValuesEqual(src.getMixedStrides(), 1)) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.value(), src.source(), + op.indices()); + return success(); + } +}; +} // namespace + void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results .add, - SubViewOpMemRefCastFolder>(context); + SubViewOpMemRefCastFolder, SubviewLoadPropagate, + SubviewStorePropagate>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -395,3 +395,43 @@ memref.store %0, %arg0[] : memref> return } + +// ----- + +#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +func @subview_load_canonicalize(%arg0 : memref, %arg1 : index, %arg2 : index) -> f32 +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %c0, %c0] [%c4, %c1, %arg1] [%c1, %c1, %c1] : memref to memref + %1 = memref.load %0[%arg2, %c0, %c1] : memref + return %1 : f32 +} +// CHECK-LABEL: func @subview_load_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK: %[[RESULT:.+]] = memref.load %[[ARG0]][%[[ARG2]], %[[C0]], %[[C1]]] +// CHEKC: return %[[RESULT]] + +// ----- + +#map0 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +func @subview_store_canonicalize(%arg0 : memref, %arg1 : index, %arg2 : index) +{ + %cst42 = constant 0.0 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c4 = constant 4 : index + %0 = memref.subview %arg0[%c0, %c0, %c0] [%c4, %c1, %arg1] [%c1, %c1, %c1] : memref to memref + memref.store %cst42, %0[%arg2, %c0, %c1] : memref + return +} +// CHECK-LABEL: func @subview_store_canonicalize +// CHECK-SAME: %[[ARG0:.+]]: memref, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index +// CHECK-DAG: %[[VAL:.+]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK: memref.store %[[VAL]], %[[ARG0]][%[[ARG2]], %[[C0]], %[[C1]]] +// CHEKC: return diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -721,6 +721,9 @@ // ----- +// Function to avoid load/store folding to subview source +func private @break_chain(%arg0: memref) -> memref + // CHECK-DAG: #[[$BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> // CHECK-DAG: #[[$SUBVIEW_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 64 + s0 + d1 * 4 + d2)> // CHECK-DAG: #[[$SUBVIEW_MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 79)> @@ -763,7 +766,8 @@ %1 = memref.subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to memref - %v0 = memref.load %1[%c0, %c0, %c0] : memref + %temp1 = call @break_chain(%1) : (memref) -> (memref) + %v0 = memref.load %temp1[%c0, %c0, %c0] : memref // Test: subview with one dynamic operand can also be folded. // CHECK: memref.subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] : @@ -772,7 +776,8 @@ %2 = memref.subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to memref - memref.store %v0, %2[%c0, %c0, %c0] : memref + %temp2 = call @break_chain(%2) : (memref) -> (memref) + memref.store %v0, %temp2[%c0, %c0, %c0] : memref // CHECK: %[[ALLOC1:.*]] = memref.alloc(%[[ARG0]]) %3 = memref.alloc(%arg0) : memref @@ -783,7 +788,8 @@ %4 = memref.subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] : memref to memref - memref.store %v0, %4[%c0, %c0, %c0] : memref + %temp4 = call @break_chain(%4) : (memref) -> (memref) + memref.store %v0, %temp4[%c0, %c0, %c0] : memref // Test: subview offset operands are folded correctly w.r.t. base strides. // CHECK: memref.subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] :