diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -632,6 +632,12 @@ } resultIndex++; } + // If the source index is greater that the number of static sizes specified, + // then the semantics of the `SubViewOp` is that the size is same as the + // size of the the source. Do nothing here, and rely on subview op folder to + // fold away the subview itself. + if (sourceIndex >= subview.static_sizes().size()) + return {}; assert(subview.isDynamicSize(sourceIndex) && "expected dynamic subview size"); return subview.getDynamicSize(sourceIndex); @@ -1913,6 +1919,29 @@ return success(); } }; + +/// Fold `memref.subview` with empty offsets, sizes and strides to the `source` +/// (with casts to address type consistency). +class TrivialSubViewOpFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + if (!subViewOp.static_offsets().empty() || + !subViewOp.static_sizes().empty() || + !subViewOp.static_strides().empty()) + return failure(); + Value source = subViewOp.source(); + if (subViewOp.getSourceType() == subViewOp.getType()) { + rewriter.replaceOp(subViewOp, source); + return success(); + } + rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), source); + return success(); + } +}; + } // namespace /// Return the canonical type of the result of a subview. @@ -1938,7 +1967,7 @@ results .add, - SubViewOpMemRefCastFolder>(context); + SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { @@ -1949,6 +1978,10 @@ resultShapedType == sourceShapedType) { return getViewSource(); } + // If the offsets, sizes and strides are empty, the subview is a no-op. + if (static_offsets().empty() && static_sizes().empty() && + static_strides().empty()) + return getViewSource(); return {}; } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -417,3 +417,14 @@ // CHECK: %[[RESULT:.+]] = memref.subview // CHECK-SAME: memref<2x5x7x1xf32> to memref<1x4x1xf32, #{{.+}}> // CHECK: return %[[RESULT]] + +// ----- + +func @no_op_scalar_subview(%arg0 : memref) -> memref (0)>> { + %0 = memref.subview %arg0[] [] [] : memref to memref (0)>> + return %0 : memref (0)>> +} +// CHECK: func @no_op_scalar_subview( +// CHECK-SAME: %[[ARG0:.+]]: memref +// CHECK: %[[CAST:.+]] = memref.cast %[[ARG0]] +// CHECK: return %[[CAST]] diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir @@ -251,3 +251,18 @@ // CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]] // CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[IDX0]], %[[IDX1]], %[[C0]]] // CHECK-SAME: {in_bounds = [true], permutation_map = #[[MAP2]]} : vector<4xf32>, memref (d0 * s1 + s0 + d1)> +func @no_op_subview(%arg0 : memref) -> index { + %c1 = arith.constant 1 : index + %0 = memref.subview %arg0[] [] [] : memref to memref + %1 = memref.dim %0, %c1 : memref + return %1 : index +} +// CHECK: func @no_op_subview( +// CHECK-SAME: %[[ARG0:.+]]: memref