diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3494,6 +3494,12 @@ source().getType().cast().getRank() == 0) return getViewSource(); + if (source().getType().cast().hasStaticShape() && + getResult().getType().cast().hasStaticShape( + source().getType().cast().getShape())) { + return getViewSource(); + } + return {}; } diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -204,6 +204,17 @@ // ----- +// CHECK-LABEL: func @subview_of_full_size +// CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> +// CHECK-NOT: subview +// CHECK: return %[[ARG0]] : memref<4x6x16x32xi8> +func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4x6x16x32xi8> { + %0 = subview %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<4x6x16x32xi8> + return %0 : memref<4x6x16x32xi8> +} + +// ----- + // CHECK-LABEL: func @trivial_subtensor // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK-NOT: subtensor