diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; @@ -679,10 +680,11 @@ return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); - if (auto subview = dyn_cast_or_null(definingOp)) { - assert(subview.isDynamicSize(unsignedIndex) && + if (auto sizeInterface = + dyn_cast_or_null(definingOp)) { + assert(sizeInterface.isDynamicSize(unsignedIndex) && "Expected dynamic subview size"); - return subview.getDynamicSize(unsignedIndex); + return sizeInterface.getDynamicSize(unsignedIndex); } // dim(memrefcast) -> dim diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -192,3 +192,18 @@ memref.dealloc %1 : memref<32xf32> return } + +// ----- + +// CHECK-LABEL: func @dim_of_sized_view +// CHECK-SAME: %{{[a-z0-9A-Z_]+}}: memref +// CHECK-SAME: %[[SIZE:.[a-z0-9A-Z_]+]]: index +// CHECK: return %[[SIZE]] : index +func @dim_of_sized_view(%arg : memref, %size: index) -> index { + %c0 = constant 0 : index + %0 = memref.reinterpret_cast %arg to offset: [0], sizes: [%size], strides: [0] : memref to memref + %1 = memref.dim %0, %c0 : memref + return %1 : index +} + +