diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1753,6 +1753,7 @@ Optional getConstantIndex(); }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1531,6 +1531,34 @@ return {}; } +namespace { +/// Fold dim of a memref reshape operation to a load into the reshape's shape +/// operand. +struct DimOfMemRefReshape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dim, + PatternRewriter &rewriter) const override { + auto reshape = dim.memrefOrTensor().getDefiningOp(); + + if (!reshape) + return failure(); + + // Place the load directly after the reshape to ensure that the shape memref + // was not mutated. + rewriter.setInsertionPointAfter(reshape); + rewriter.replaceOpWithNewOp(dim, reshape.shape(), + llvm::makeArrayRef({dim.index()})); + return success(); + } +}; +} // end anonymous namespace. + +void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -59,3 +59,23 @@ %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> return %1 : index } + +// Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref +// CHECK-NEXT: %[[IDX:.*]] = constant 3 +// CHECK-NEXT: %[[DIM:.*]] = load %[[SHP]][%[[IDX]]] +// CHECK-NEXT: store +// CHECK-NOT: dim +// CHECK: return %[[DIM]] : index +func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref) + -> index { + %c3 = constant 3 : index + %0 = memref_reshape %arg0(%arg1) + : (memref<*xf32>, memref) -> memref<*xf32> + // Update the shape to test that he load ends up in the right place. + store %c3, %arg1[%c3] : memref + %1 = dim %0, %c3 : memref<*xf32> + return %1 : index +}