diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1441,12 +1441,18 @@ OpFoldResult DimOp::fold(ArrayRef operands) { auto index = operands[1].dyn_cast_or_null(); + auto argTy = memrefOrTensor().getType(); - // All forms of folding require a known index. - if (!index) - return {}; + // All forms of folding require a known index or a known rank of 1 (in which + // case we can assume the index is 0 if the code is correct). + if (!index) { + auto shapedTy = argTy.dyn_cast(); + if (!shapedTy || !shapedTy.hasRank() || shapedTy.getRank() != 1) + return {}; + + index = IntegerAttr::get(operands[1].getType(), 0); + } - auto argTy = memrefOrTensor().getType(); // Fold if the shape extent along the given index is known. if (auto shapedTy = argTy.dyn_cast()) { // Folding for unranked types (UnrankedMemRefType, UnrankedTensorType) is @@ -1502,6 +1508,10 @@ return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); + if (auto alloca = dyn_cast_or_null(definingOp)) + return *(alloca.getDynamicSizes().begin() + + memrefType.getDynamicDimIndex(unsignedIndex)); + if (auto view = dyn_cast_or_null(definingOp)) return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -134,6 +134,22 @@ // ----- +// Test case: Folding of dim((alloca(rank(%v)), %idx) -> rank(%v) +// CHECK-LABEL: func @dim_of_1d_alloca_with_dynamic_size( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> +// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> +// CHECK-NEXT: return %[[RANK]] : index +func @dim_of_1d_alloca_with_dynamic_size(%arg0: memref<*xf32>) + -> index { + %0 = rank %arg0 : memref<*xf32> + %1 = alloca(%0) : memref + %c0 = constant 0 : index + %2 = dim %1, %c0 : memref + return %2 : index +} + +// ----- + // Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx] // CHECK-LABEL: func @dim_of_memref_reshape( // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,