diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -31,23 +31,17 @@ namespace saturated_arith { struct Wrapper { static Wrapper stride(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} - : Wrapper{false, v}; + return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; } static Wrapper offset(int64_t v) { - return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} - : Wrapper{false, v}; + return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; } static Wrapper size(int64_t v) { return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; } - int64_t asOffset() { - return saturated ? ShapedType::kDynamic : v; - } + int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; } int64_t asSize() { return saturated ? ShapedType::kDynamic : v; } - int64_t asStride() { - return saturated ? ShapedType::kDynamic : v; - } + int64_t asStride() { return saturated ? ShapedType::kDynamic : v; } bool operator==(Wrapper other) { return (saturated && other.saturated) || (!saturated && !other.saturated && v == other.v); @@ -731,8 +725,7 @@ for (auto it : llvm::zip(sourceStrides, resultStrides)) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) - if (ShapedType::isDynamic(ss) && - !ShapedType::isDynamic(st)) + if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) return false; } @@ -765,8 +758,7 @@ // same. They are also compatible if either one is dynamic (see // description of MemRefCastOp for details). auto checkCompatible = [](int64_t a, int64_t b) { - return (ShapedType::isDynamic(a) || - ShapedType::isDynamic(b) || a == b); + return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b); }; if (!checkCompatible(aOffset, bOffset)) return false; @@ -1889,8 +1881,7 @@ // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = getStaticOffsets().front(); if (!ShapedType::isDynamic(resultOffset) && - !ShapedType::isDynamic(expectedOffset) && - resultOffset != expectedOffset) + !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset) return emitError("expected result type with offset = ") << expectedOffset << " instead of " << resultOffset; @@ -2944,18 +2935,6 @@ nonRankReducedType.getMemorySpace()); } -/// Compute the canonical result type of a SubViewOp. Call `inferResultType` -/// to deduce the result type. Additionally, reduce the rank of the inferred -/// result type if `currentResultType` is lower rank than `sourceType`. -static MemRefType getCanonicalSubViewResultType( - MemRefType currentResultType, MemRefType sourceType, - ArrayRef mixedOffsets, ArrayRef mixedSizes, - ArrayRef mixedStrides) { - return getCanonicalSubViewResultType(currentResultType, sourceType, - sourceType, mixedOffsets, mixedSizes, - mixedStrides); -} - Value mlir::memref::createCanonicalRankReducingSubViewOp( OpBuilder &b, Location loc, Value memref, ArrayRef targetShape) { auto memrefType = llvm::cast(memref.getType()); @@ -3108,9 +3087,32 @@ MemRefType operator()(SubViewOp op, ArrayRef mixedOffsets, ArrayRef mixedSizes, ArrayRef mixedStrides) { - return getCanonicalSubViewResultType(op.getType(), op.getSourceType(), - mixedOffsets, mixedSizes, - mixedStrides); + // Infer a memref type without taking into account any rank reductions. + MemRefType nonReducedType = cast(SubViewOp::inferResultType( + op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides)); + + // Directly return the non-rank reduced type if there are no dropped dims. + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + if (droppedDims.empty()) + return nonReducedType; + + // Take the strides and offset from the non-rank reduced type. + auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType); + + // Drop dims from shape and strides. + SmallVector targetShape; + SmallVector targetStrides; + for (int64_t i = 0; i < static_cast(mixedSizes.size()); ++i) { + if (droppedDims.test(i)) + continue; + targetStrides.push_back(nonReducedStrides[i]); + targetShape.push_back(nonReducedType.getDimSize(i)); + } + + return MemRefType::get(targetShape, nonReducedType.getElementType(), + StridedLayoutAttr::get(nonReducedType.getContext(), + offset, targetStrides), + nonReducedType.getMemorySpace()); } }; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -931,7 +931,7 @@ // ----- -// CHECK-lABEL: func @ub_negative_alloc_size +// CHECK-LABEL: func private @ub_negative_alloc_size func.func private @ub_negative_alloc_size() -> memref { %idx1 = index.constant 1 %c-2 = arith.constant -2 : index @@ -940,3 +940,18 @@ %alloc = memref.alloc(%c15, %c-2, %idx1) : memref return %alloc : memref } + +// ----- + +// CHECK-LABEL: func @subview_rank_reduction( +// CHECK-SAME: %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index +func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index) + -> memref> { + %c1 = arith.constant 1 : index + // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>> + // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref> + %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1] + : memref<1x384x384xf32> to memref> + // CHECK: return %[[cast]] + return %0 : memref> +}