diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2944,18 +2944,6 @@ nonRankReducedType.getMemorySpace()); } -/// Compute the canonical result type of a SubViewOp. Call `inferResultType` -/// to deduce the result type. Additionally, reduce the rank of the inferred -/// result type if `currentResultType` is lower rank than `sourceType`. -static MemRefType getCanonicalSubViewResultType( - MemRefType currentResultType, MemRefType sourceType, - ArrayRef mixedOffsets, ArrayRef mixedSizes, - ArrayRef mixedStrides) { - return getCanonicalSubViewResultType(currentResultType, sourceType, - sourceType, mixedOffsets, mixedSizes, - mixedStrides); -} - Value mlir::memref::createCanonicalRankReducingSubViewOp( OpBuilder &b, Location loc, Value memref, ArrayRef targetShape) { auto memrefType = llvm::cast(memref.getType()); @@ -3108,9 +3096,32 @@ MemRefType operator()(SubViewOp op, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { - return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), - mixedOffsets, mixedSizes, - mixedStrides); + // Infer a memref type without taking into account any rank reductions. + MemRefType nonReducedType = cast(SubViewOp::inferResultType( + op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides)); + + // Directly return the non-rank reduced type if there are no dropped dims. + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + if (droppedDims.empty()) + return nonReducedType; + + // Take the strides and offset from the non-rank reduced type. + auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType); + + // Drop dims from shape and strides. + SmallVector targetShape; + SmallVector targetStrides; + for (int64_t i = 0; i < mixedSizes.size(); ++i) { + if (droppedDims.test(i)) + continue; + targetStrides.push_back(nonReducedStrides[i]); + targetShape.push_back(nonReducedType.getDimSize(i)); + } + + return MemRefType::get(targetShape, nonReducedType.getElementType(), + StridedLayoutAttr::get(nonReducedType.getContext(), + offset, targetStrides), + nonReducedType.getMemorySpace()); } }; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -931,7 +931,7 @@ // ----- -// CHECK-lABEL: func @ub_negative_alloc_size +// CHECK-LABEL: func private @ub_negative_alloc_size func.func private @ub_negative_alloc_size() -> memref { %idx1 = index.constant 1 %c-2 = arith.constant -2 : index @@ -940,3 +940,18 @@ %alloc = memref.alloc(%c15, %c-2, %idx1) : memref return %alloc : memref } + +// ----- + +// CHECK-LABEL: func @subview_rank_reduction( +// CHECK-SAME: %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index +func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index) + -> memref> { + %c1 = arith.constant 1 : index + // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>> + // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref> + %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1] + : memref<1x384x384xf32> to memref> + // CHECK: return %[[cast]] + return %0 : memref> +}