diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1520,6 +1520,30 @@ } namespace { +template +struct DimOf1DAllocLikeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dim, + PatternRewriter &rewriter) const override { + auto alloc = dim.memrefOrTensor().getDefiningOp(); + + if (!alloc) + return failure(); + + auto type = alloc.getType().template cast(); + if (!type || type.getRank() != 1) + return failure(); + + if (alloc.dynamicSizes().size() == 1) + rewriter.replaceOp(dim, {alloc.dynamicSizes()[0]}); + else + rewriter.replaceOpWithNewOp(dim, type.getShape()[0]); + + return success(); + } +}; + /// Fold dim of a memref reshape operation to a load into the reshape's shape /// operand. struct DimOfMemRefReshape : public OpRewritePattern { @@ -1560,7 +1584,8 @@ void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert, + results.insert, DimOf1DAllocLikeOp, + DimOfMemRefReshape, DimOfCastOp, DimOfCastOp>(context); } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -134,6 +134,35 @@ // ----- +// Test case: Folding of dim((alloc(%idx) -> rank(%v) +// CHECK-LABEL: func @dim_of_1d_alloc_with_static_shape +// CHECK-NEXT: %[[C10:.*]] = constant 10 : index +// CHECK-NEXT: return %[[C10]] : index +func @dim_of_1d_alloc_with_static_shape() -> index { + %0 = alloc() : memref<10xindex> + %c0 = constant 0 : index + %1 = dim %0, %c0 : memref<10xindex> + return %1 : index +} + +// ----- + +// Test case: Folding of dim((alloca(rank(%v)), %idx) -> rank(%v) +// CHECK-LABEL: func @dim_of_1d_alloca_from_rank( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> +// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> +// CHECK-NEXT: return %[[RANK]] : index +func @dim_of_1d_alloca_from_rank(%arg0: memref<*xf32>) + -> index { + %0 = rank %arg0 : memref<*xf32> + %1 = alloca(%0) : memref + %c0 = constant 0 : index + %2 = dim %1, %c0 : memref + return %2 : index +} + +// ----- + // Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx] // CHECK-LABEL: func @dim_of_memref_reshape( // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,