diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1520,6 +1520,29 @@ } namespace { +struct DimOf1DAllocaFromRank : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dim, + PatternRewriter &rewriter) const override { + auto alloca = dim.memrefOrTensor().getDefiningOp(); + + if (!alloca) + return failure(); + + auto type = alloca.getType().dyn_cast(); + if (!type || type.getRank() != 1) + return failure(); + + auto rank = alloca.getOperand(0).getDefiningOp(); + if (!rank) + return failure(); + + rewriter.replaceOp(dim, {rank}); + return success(); + } +}; + /// Fold dim of a memref reshape operation to a load into the reshape's shape /// operand. struct DimOfMemRefReshape : public OpRewritePattern { @@ -1560,8 +1583,9 @@ void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert, - DimOfCastOp>(context); + results.insert, DimOfCastOp>( + context); } // --------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -134,6 +134,22 @@ // ----- +// Test case: Folding of dim((alloca(rank(%v)), %idx) -> rank(%v) +// CHECK-LABEL: func @dim_of_1d_alloca_from_rank( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> +// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> +// CHECK-NEXT: return %[[RANK]] : index +func @dim_of_1d_alloca_from_rank(%arg0: memref<*xf32>) + -> index { + %0 = rank %arg0 : memref<*xf32> + %1 = alloca(%0) : memref + %c0 = constant 0 : index + %2 = dim %1, %c0 : memref + return %2 : index +} + +// ----- + // Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx] // CHECK-LABEL: func @dim_of_memref_reshape( // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,