diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1833,18 +1833,23 @@ return res; } -/// Infer the canonical type of the result of a subview operation. Returns a -/// type with rank `resultRank` that is either the rank of the rank-reduced -/// type, or the non-rank-reduced type. +/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to +/// deduce the result type for the given `sourceType`. Additionally, reduce the +/// rank of the inferred result type if `currentResultType` is lower rank than +/// `currentSourceType`. Use this signature if `sourceType` is updated together +/// with the result type. In this case, it is important to compute the dropped +/// dimensions using `currentSourceType` whose strides align with +/// `currentResultType`. static MemRefType getCanonicalSubViewResultType( - MemRefType currentResultType, MemRefType sourceType, - ArrayRef mixedOffsets, ArrayRef mixedSizes, - ArrayRef mixedStrides) { + MemRefType currentResultType, MemRefType currentSourceType, + MemRefType sourceType, ArrayRef mixedOffsets, + ArrayRef mixedSizes, ArrayRef mixedStrides) { auto nonRankReducedType = SubViewOp::inferResultType(sourceType, mixedOffsets, mixedSizes, mixedStrides) .cast(); llvm::Optional> unusedDims = - computeMemRefRankReductionMask(sourceType, currentResultType, mixedSizes); + computeMemRefRankReductionMask(currentSourceType, currentResultType, + mixedSizes); // Return nullptr as failure mode. if (!unusedDims) return nullptr; @@ -1861,6 +1866,18 @@ nonRankReducedType.getMemorySpace()); } +/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to +/// deduce the result type. Additionally, reduce the rank of the inferred result +/// type if `currentResultType` is lower rank than `sourceType`. +static MemRefType getCanonicalSubViewResultType( + MemRefType currentResultType, MemRefType sourceType, + ArrayRef mixedOffsets, ArrayRef mixedSizes, + ArrayRef mixedStrides) { + return getCanonicalSubViewResultType(currentResultType, sourceType, + sourceType, mixedOffsets, mixedSizes, + mixedStrides); +} + namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref.cast past its consuming subview when @@ -1897,13 +1914,18 @@ if (!CastOp::canFoldIntoConsumerOp(castOp)) return failure(); - /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on - /// the cast source operand type and the SubViewOp static information. This - /// is the resulting type if the MemRefCastOp were folded. + // Compute the SubViewOp result type after folding the MemRefCastOp. Use the + // MemRefCastOp source operand type to infer the result type and the current + // SubViewOp source operand type to compute the dropped dimensions if the + // operation is rank-reducing. auto resultType = getCanonicalSubViewResultType( - subViewOp.getType(), castOp.source().getType().cast(), + subViewOp.getType(), subViewOp.getSourceType(), + castOp.source().getType().cast(), subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(), subViewOp.getMixedStrides()); + if (!resultType) + return failure(); + Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt %s -canonicalize --split-input-file -allow-unregistered-dialect | FileCheck %s -// CHECK-LABEL: func @subview_of_memcast +// CHECK-LABEL: func @subview_of_size_memcast // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8> -// CHECK: %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> +// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}> // CHECK: %[[M:.+]] = memref.cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}> // CHECK: return %[[M]] : memref<16x32xi8, #{{.*}}> -func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) -> +func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) -> memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{ %0 = memref.cast %arg : memref<4x6x16x32xi8> to memref %1 = memref.subview %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : @@ -16,6 +16,27 @@ // ----- +// CHECK-DAG: #[[MAP0:[0-9a-z]+]] = affine_map<(d0, d1)[s0] -> (d0 * 7 + s0 + d1)> +// CHECK-DAG: #[[MAP1:[0-9a-z]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +#map0 = affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)> +#map1 = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +#map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +// CHECK: func @subview_of_strides_memcast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, #{{.*}}> +// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4] +// CHECK-SAME: to memref<1x4xf32, #[[MAP0]]> +// CHECK: %[[M:.+]] = memref.cast %[[S]] +// CHECK-SAME: to memref<1x4xf32, #[[MAP1]]> +// CHECK: return %[[M]] +func @subview_of_strides_memcast(%arg : memref<1x1x?xf32, #map0>) -> memref<1x4xf32, #map2> { + %0 = memref.cast %arg : memref<1x1x?xf32, #map0> to memref<1x1x?xf32, #map1> + %1 = memref.subview %0[0, 0, 0] [1, 1, 4] [1, 1, 1] : memref<1x1x?xf32, #map1> to memref<1x4xf32, #map2> + return %1 : memref<1x4xf32, #map2> +} + +// ----- + // CHECK-LABEL: func @subview_of_static_full_size // CHECK-SAME: %[[ARG0:.+]]: memref<4x6x16x32xi8> // CHECK-NOT: memref.subview