diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1561,23 +1561,29 @@ } } + Operation *definingOp = memrefOrTensor().getDefiningOp(); + // dim(tensor_load(memref)) -> dim(memref) + if (auto tensorLoadOp = dyn_cast_or_null(definingOp)) { + setOperand(0, tensorLoadOp.memref()); + return getResult(); + } + // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. auto memrefType = argTy.dyn_cast(); if (!memrefType) return {}; // The size at the given index is now known to be a dynamic size of a memref. - auto *memref = memrefOrTensor().getDefiningOp(); unsigned unsignedIndex = index.getValue().getZExtValue(); - if (auto alloc = dyn_cast_or_null(memref)) + if (auto alloc = dyn_cast_or_null(definingOp)) return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); - if (auto view = dyn_cast_or_null(memref)) + if (auto view = dyn_cast_or_null(definingOp)) return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); - if (auto subview = dyn_cast_or_null(memref)) { + if (auto subview = dyn_cast_or_null(definingOp)) { assert(subview.isDynamicSize(unsignedIndex) && "Expected dynamic subview size"); return subview.getDynamicSize(unsignedIndex); diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -31,3 +31,16 @@ %1 = tensor_to_memref %0 : memref return %1 : memref } + +// Test case: Basic folding of dim(tensor_load(m)) -> dim(m). +// CHECK-LABEL: func @dim_of_tensor_load( +// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref +// CHECK: %[[C0:.*]] = constant 0 +// CHECK: %[[D:.*]] = dim %[[MEMREF]], %[[C0]] +// CHECK: return %[[D]] : index +func @dim_of_tensor_load(%arg0: memref) -> index { + %c0 = constant 0 : index + %0 = tensor_load %arg0 : memref + %1 = dim %0, %c0 : tensor + return %1 : index +}