diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -617,6 +617,10 @@ return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); + if (auto alloca = dyn_cast_or_null(definingOp)) + return *(alloca.getDynamicSizes().begin() + + memrefType.getDynamicDimIndex(unsignedIndex)); + if (auto view = dyn_cast_or_null(definingOp)) return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -134,6 +134,34 @@ // ----- +// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size +// CHECK-LABEL: func @dim_of_alloca( +// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index +// CHECK-NEXT: return %[[SIZE]] : index +func @dim_of_alloca(%size: index) -> index { + %0 = memref.alloca(%size) : memref + %c0 = constant 0 : index + %1 = memref.dim %0, %c0 : memref + return %1 : index +} + +// ----- + +// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v) +// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> +// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> +// CHECK-NEXT: return %[[RANK]] : index +func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { + %0 = rank %arg0 : memref<*xf32> + %1 = memref.alloca(%0) : memref + %c0 = constant 0 : index + %2 = memref.dim %1, %c0 : memref + return %2 : index +} + +// ----- + // Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] // CHECK-LABEL: func @dim_of_memref_reshape( // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,