diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3496,9 +3496,13 @@ } OpFoldResult SubViewOp::fold(ArrayRef operands) { - if (getResult().getType().cast().getRank() == 0 && - source().getType().cast().getRank() == 0) + auto resultShapedType = getResult().getType().cast(); + auto sourceShapedType = source().getType().cast(); + + if (resultShapedType.hasStaticShape() && + resultShapedType == sourceShapedType) { return getViewSource(); + } return {}; } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -204,6 +204,17 @@ // ----- +// CHECK-LABEL: func @subview_of_static_full_size +// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> +// CHECK-NOT: subview +// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> +func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { + %0 = subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> + return %0 : memref<4x6x16x32xi8> +} + +// ----- + // CHECK-LABEL: func @trivial_subtensor // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK-NOT: subtensor