diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1827,14 +1827,16 @@ /// type with rank `resultRank` that is either the rank of the rank-reduced /// type, or the non-rank-reduced type. static MemRefType getCanonicalSubViewResultType( - MemRefType currentResultType, MemRefType sourceType, - ArrayRef mixedOffsets, ArrayRef mixedSizes, - ArrayRef mixedStrides) { - auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, - mixedSizes, mixedStrides) - .cast(); + MemRefType currentResultType, MemRefType currentSourceType, + MemRefType newSourceType, ArrayRef mixedOffsets, + ArrayRef mixedSizes, ArrayRef mixedStrides) { + auto nonRankReducedType = + SubViewOp::inferResultType(newSourceType, mixedOffsets, mixedSizes, + mixedStrides) + .cast(); llvm::Optional> unusedDims = - computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes); + computeMemRefRankReductionMask(currentSourceType, currentResultType, + mixedSizes); // Return nullptr as failure mode. if (!unusedDims) return nullptr; @@ -1891,9 +1893,13 @@ /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. auto resultType = getCanonicalSubViewResultType( - subViewOp.getType(), castOp.source().getType().cast(), + subViewOp.getType(), subViewOp.getSourceType(), + castOp.source().getType().cast(), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), subViewOp.getMixedStrides()); + if (!resultType) + return failure(); + Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), @@ -1911,8 +1917,8 @@ ArrayRef mixedSizes, ArrayRef mixedStrides) { return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), - mixedOffsets, mixedSizes, - mixedStrides); + op.getSourceType(), mixedOffsets, + mixedSizes, mixedStrides); } }; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s -// CHECK-LABEL: func @subview_of_memcast +// CHECK-LABEL: func @subview_of_size_memcast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> // CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> // CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> -func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> +func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) -> memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : @@ -16,6 +16,27 @@ // ----- +// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0, d1)[s0] -> (d0 * 7 + s0 + d1)> +// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)> +#map1 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +#map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +// CHECK: func @subview_of_strides_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, #{{.*}}> +// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4] +// CHECK-SAME: to memref<1x4xf32, #[[MAP0]]> +// CHECK: %[[M:.+]] = memref.cast %[[S]] +// CHECK-SAME: to memref<1x4xf32, #[[MAP1]]> +// CHECK: return %[[M]] +func @subview_of_strides_memcast(%arg : memref<1x1x?xf32, #map0>) -> memref<1x4xf32, #map2> { + %0 = memref.cast %arg : memref<1x1x?xf32, #map0> to memref<1x1x?xf32, #map1> + %1 = memref.subview %0[0, 0, 0] [1, 1, 4] [1, 1, 1] : memref<1x1x?xf32, #map1> to memref<1x4xf32, #map2> + return %1 : memref<1x4xf32, #map2> +} + +// ----- + // CHECK-LABEL: func @subview_of_static_full_size // CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> // CHECK-NOT: memref.subview