diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -26,6 +26,49 @@ using namespace mlir; using namespace mlir::memref; +namespace { +/// Idiomatic saturated operations on offsets, sizes and strides. +namespace saturated_arith { +struct Wrapper { + static Wrapper stride(int64_t v) { + return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; + } + static Wrapper offset(int64_t v) { + return (ShapedType::isDynamicStrideOrOffset(v)) ? Wrapper{true, 0} + : Wrapper{false, v}; + } + static Wrapper size(int64_t v) { + return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v}; + } + int64_t asOffset() { + return saturated ? ShapedType::kDynamicStrideOrOffset : v; + } + int64_t asSize() { return saturated ? ShapedType::kDynamicSize : v; } + int64_t asStride() { + return saturated ? ShapedType::kDynamicStrideOrOffset : v; + } + bool operator==(Wrapper other) { + return (saturated && other.saturated) || + (!saturated && !other.saturated && v == other.v); + } + bool operator!=(Wrapper other) { return !(*this == other); } + Wrapper operator+(Wrapper other) { + if (saturated || other.saturated) + return Wrapper{true, 0}; + return Wrapper{false, other.v + v}; + } + Wrapper operator*(Wrapper other) { + if (saturated || other.saturated) + return Wrapper{true, 0}; + return Wrapper{false, other.v * v}; + } + bool saturated; + int64_t v; +}; +} // namespace saturated_arith +} // namespace + /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *MemRefDialect::materializeConstant(OpBuilder &builder, @@ -1558,24 +1601,6 @@ // Reassociative reshape ops //===----------------------------------------------------------------------===// -/// Helper function that computes a stride based on the size/stride of the -/// previous dimension. -/// -/// E.g., memref<20x10x5xf32, offset: 0, strides: [50, 5, 1]> -/// ^^ -/// compute this one -/// prevStride = 5, prevDimSize = 10 -/// nextStride = 5 * 10 = 50 -static int64_t computeNextStride(int64_t prevStride, int64_t prevDimSize) { - if (ShapedType::isDynamicStrideOrOffset(prevStride)) - return ShapedType::kDynamicStrideOrOffset; - - if (ShapedType::isDynamic(prevDimSize)) - return ShapedType::kDynamicStrideOrOffset; - - return prevStride * prevDimSize; -} - /// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp /// result and operand. Layout maps are verified separately. /// @@ -1677,57 +1702,41 @@ static FailureOr computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation) { - SmallVector srcStrides, resultStrides(resultShape.size(), 0); int64_t srcOffset; + SmallVector srcStrides; if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) return failure(); assert(srcStrides.size() == reassociation.size() && "invalid reassociation"); - // Ensure that inner strides are the fastest-varying ones. Other source layout - // maps are currently not supported. - int64_t lastStride = 0; - for (int64_t s : llvm::reverse(srcStrides)) { - if (!ShapedType::isDynamicStrideOrOffset(s)) { - if (s < lastStride) - return failure(); - lastStride = s; - } - } - - // Iterate over all reassociation groups from the back. Example: - // strides = [1000, ?, 2] - // source shape = [20, 10, 5] - // result shape = [ 2, 10, 2, 5, 5] - // reassociation = [[0, 1], [2, 3], [4]] - for (const auto &it : llvm::reverse(llvm::zip(reassociation, srcStrides))) { - ReassociationIndices indices = std::get<0>(it); - int64_t srcGroupStride = std::get<1>(it); - - // The first result dimension (least significant one) in each reassociation - // group has the same stride as the corresponding source dimension. E.g.: - // reassociation = [[0, 1], [2, 3], [4]] - // | | | - // v v v - // 1000 ? 2 - resultStrides[indices.pop_back_val()] = srcGroupStride; - - // Compute the strides for the remaining dims in the reassociation group. - for (int64_t resultDim : llvm::reverse(indices)) { - // E.g.: - // reassociation = [[0, 1], [2, 3], [4]] - // | - // v - // 1000 * 10 = 10000 - // - // If the previous stride or the previous dimension was dynamic, then this - // stride will also be dynamic. - resultStrides[resultDim] = computeNextStride(resultStrides[resultDim + 1], - resultShape[resultDim + 1]); + // 1-1 mapping between srcStrides and reassociation packs. + // Each srcStride starts with the given value and gets expanded according to + // the proper entries in resultShape. + // Example: + // srcStrides = [10000, 1 , 100 ], + // reassociations = [ [0], [1], [2, 3, 4]], + // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]] + // -> For the purpose of stride calculation, the useful sizes are: + // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]]. + // resultStrides = [10000, 1, 600, 200, 100] + // Note that a stride does not get expanded along the first entry of each + // shape pack. + SmallVector reverseResultStrides; + reverseResultStrides.reserve(resultShape.size()); + unsigned shapeIndex = resultShape.size() - 1; + for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) { + ReassociationIndices reassoc = std::get<0>(it); + int64_t currentStrideToExpand = std::get<1>(it); + for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) { + using saturated_arith::Wrapper; + reverseResultStrides.push_back(currentStrideToExpand); + currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) * + Wrapper::size(resultShape[shapeIndex--])) + .asStride(); } } - - return makeStridedLinearLayoutMap(resultStrides, srcOffset, - srcType.getContext()); + return makeStridedLinearLayoutMap( + llvm::to_vector<8>(llvm::reverse(reverseResultStrides)), srcOffset, + srcType.getContext()); } static FailureOr @@ -1804,94 +1813,52 @@ /// not possible to check this by inspecting a MemRefType in the general case. /// But it is assumed. If this is not the case, the behavior is undefined. static FailureOr -computeCollapsedLayoutMap(MemRefType srcType, ArrayRef resultShape, +computeCollapsedLayoutMap(MemRefType srcType, ArrayRef reassociation) { - SmallVector srcStrides, resultStrides; int64_t srcOffset; + SmallVector srcStrides; + auto srcShape = srcType.getShape(); if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) return failure(); - assert(resultShape.size() == reassociation.size() && "invalid reassociation"); - - // Iterate over all reassociation groups from the back. Example: - // source shape = [20, ?, 5, 10, 2] - // source strides = [ ?, ?, 800, 80, 4] - // reassociation = [[0, 1], [2, 3], [4]] - // result shape = [ ?, 50, 2] - // - // Note: The result shape is not needed in this computation. It is just used - // check that the size of the reassociation is correct. - for (ReassociationIndices group : llvm::reverse(reassociation)) { - // A result dim has the same stride as the first dimension (least - // significant one) in the corresponding reassociation group. E.g.: - // reassociation = [[0, 1], [2, 3], [4]] - // | | | - // v v v - // ? 80 4 - int64_t resultStride = srcStrides[group.pop_back_val()]; - - // The following is just a best-effort check for non-contiguous source - // strides within a reassociation group. E.g.: - // reassociation = [[0, 1], [2, 3], [4]] - // ^^^^^^ - // Iteratively compute the next stride within the reassociation group - // one-by-one. Start with the stride computed above. E.g.: - // reassociation = [[0, 1], [2, 3], [4]] - // | - // v - // nextStride = 80 - int64_t nextStride = resultStride; - for (int64_t nextDim : llvm::reverse(group)) { - // Next expected stride is previous stride multiplied by dim size, e.g.: - // reassociation = [[0, 1], [2, 3], [4]] - // | - // v - // nextStride = 80 * 10 = 800 - nextStride = - computeNextStride(nextStride, srcType.getDimSize(nextDim + 1)); - - // Ensure that the source actually has this stride value. E.g.: - // source strides = [ ?, ?, 800, 80, 4] - // | - // v - // same stride, OK - // If strides are dynamic, we cannot verify anything statically. - if (!ShapedType::isDynamicStrideOrOffset(srcStrides[nextDim]) && - !ShapedType::isDynamicStrideOrOffset(nextStride) && - srcStrides[nextDim] != nextStride) { - // Attempting to collapse non-contiguous dimensions. This is forbidden. - // Note: This check does not handle cases where strides and dimension - // sizes are dynamic. Such dims could still turn out to be non- - // contiguous at runtime. This check is only a best effort to catch - // illegal collapses at verification time. + + // The result strides are exactly the strides of the last entry of each + // reassociation. + SmallVector resultStrides; + resultStrides.reserve(reassociation.size()); + for (ReassociationIndices reassoc : reassociation) + resultStrides.push_back(srcStrides[reassoc.back()]); + + // Validate that each reassociation group is contiguous. + unsigned resultStrideIndex = resultStrides.size() - 1; + for (ReassociationIndices reassoc : llvm::reverse(reassociation)) { + auto trailingReassocs = ArrayRef(reassoc).drop_front(); + using saturated_arith::Wrapper; + auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]); + for (int64_t idx : llvm::reverse(trailingReassocs)) { + stride = stride * Wrapper::size(srcShape[idx]); + // Both are either static strides of the same value, or both are dynamic. + // The dynamic case is best effort atm : we can't check it statically. + // One exception to the dynamic check is when the srcShape is `1`, in + // which case it can never produce a non-contiguity. + if (stride != Wrapper::stride(srcStrides[idx - 1]) && srcShape[idx] != 1) return failure(); - } } - - resultStrides.push_back(resultStride); } - - return makeStridedLinearLayoutMap( - llvm::to_vector<8>(llvm::reverse(resultStrides)), srcOffset, - srcType.getContext()); + return makeStridedLinearLayoutMap(resultStrides, srcOffset, + srcType.getContext()); } static MemRefType computeCollapsedType(MemRefType srcType, ArrayRef reassociation) { SmallVector resultShape; + resultShape.reserve(reassociation.size()); for (const ReassociationIndices &group : reassociation) { - int64_t groupSize = 1; - for (int64_t srcDim : group) { - if (srcType.isDynamicDim(srcDim)) { - // Source dim is dynamic, so the collapsed dim is also dynamic. - groupSize = ShapedType::kDynamicSize; - break; - } - - groupSize *= srcType.getDimSize(srcDim); - } - - resultShape.push_back(groupSize); + using saturated_arith::Wrapper; + auto groupSize = Wrapper::size(1); + for (int64_t srcDim : group) + groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim)); + resultShape.push_back(groupSize.asSize()); } if (srcType.getLayout().isIdentity()) { @@ -1906,7 +1873,7 @@ // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. FailureOr computedLayout = - computeCollapsedLayoutMap(srcType, resultShape, reassociation); + computeCollapsedLayoutMap(srcType, reassociation); assert(succeeded(computedLayout) && "invalid source layout map or collapsing non-contiguous dims"); auto computedType = @@ -1948,8 +1915,8 @@ // Source may not be fully contiguous. Compute the layout map. // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. - FailureOr computedLayout = computeCollapsedLayoutMap( - srcType, resultType.getShape(), getReassociationIndices()); + FailureOr computedLayout = + computeCollapsedLayoutMap(srcType, getReassociationIndices()); if (failed(computedLayout)) return emitOpError( "invalid source layout map or collapsing non-contiguous dims"); @@ -2066,29 +2033,6 @@ // SubViewOp //===----------------------------------------------------------------------===// -namespace { -/// Helpers to write more idiomatic operations. -namespace saturated_arith { -struct Wrapper { - explicit Wrapper(int64_t v) : v(v) {} - operator int64_t() { return v; } - int64_t v; -}; -Wrapper operator+(Wrapper a, int64_t b) { - if (ShapedType::isDynamicStrideOrOffset(a) || - ShapedType::isDynamicStrideOrOffset(b)) - return Wrapper(ShapedType::kDynamicStrideOrOffset); - return Wrapper(a.v + b); -} -Wrapper operator*(Wrapper a, int64_t b) { - if (ShapedType::isDynamicStrideOrOffset(a) || - ShapedType::isDynamicStrideOrOffset(b)) - return Wrapper(ShapedType::kDynamicStrideOrOffset); - return Wrapper(a.v * b); -} -} // namespace saturated_arith -} // namespace - /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. @@ -2114,8 +2058,11 @@ int64_t targetOffset = sourceOffset; for (auto it : llvm::zip(staticOffsets, sourceStrides)) { auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); - using namespace saturated_arith; - targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; + using saturated_arith::Wrapper; + targetOffset = + (Wrapper::offset(targetOffset) + + Wrapper::offset(staticOffset) * Wrapper::stride(targetStride)) + .asOffset(); } // Compute target stride whose value is: @@ -2124,8 +2071,10 @@ targetStrides.reserve(staticOffsets.size()); for (auto it : llvm::zip(sourceStrides, staticStrides)) { auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); - using namespace saturated_arith; - targetStrides.push_back(Wrapper(sourceStride) * staticStride); + using saturated_arith::Wrapper; + targetStrides.push_back( + (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride)) + .asStride()); } // The type is now known. @@ -2305,8 +2254,8 @@ /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } -/// Return true if t1 and t2 have equal offsets (both dynamic or of same static -/// value). +/// Return true if t1 and t2 have equal offsets (both dynamic or of same +/// static value). static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { AffineExpr t1Offset, t2Offset; SmallVector t1Strides, t2Strides; @@ -2431,12 +2380,12 @@ return res; } -/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to -/// deduce the result type for the given `sourceType`. Additionally, reduce the -/// rank of the inferred result type if `currentResultType` is lower rank than -/// `currentSourceType`. Use this signature if `sourceType` is updated together -/// with the result type. In this case, it is important to compute the dropped -/// dimensions using `currentSourceType` whose strides align with +/// Compute the canonical result type of a SubViewOp. Call `inferResultType` +/// to deduce the result type for the given `sourceType`. Additionally, reduce +/// the rank of the inferred result type if `currentResultType` is lower rank +/// than `currentSourceType`. Use this signature if `sourceType` is updated +/// together with the result type. In this case, it is important to compute +/// the dropped dimensions using `currentSourceType` whose strides align with /// `currentResultType`. static MemRefType getCanonicalSubViewResultType( MemRefType currentResultType, MemRefType currentSourceType, @@ -2464,9 +2413,9 @@ nonRankReducedType.getMemorySpace()); } -/// Compute the canonical result type of a SubViewOp. Call `inferResultType` to -/// deduce the result type. Additionally, reduce the rank of the inferred result -/// type if `currentResultType` is lower rank than `sourceType`. +/// Compute the canonical result type of a SubViewOp. Call `inferResultType` +/// to deduce the result type. Additionally, reduce the rank of the inferred +/// result type if `currentResultType` is lower rank than `sourceType`. static MemRefType getCanonicalSubViewResultType( MemRefType currentResultType, MemRefType sourceType, ArrayRef mixedOffsets, ArrayRef mixedSizes, @@ -2478,8 +2427,8 @@ /// Helper method to check if a `subview` operation is trivially a no-op. This /// is the case if the all offsets are zero, all strides are 1, and the source -/// shape is same as the size of the subview. In such cases, the subview can be -/// folded into its source. +/// shape is same as the size of the subview. In such cases, the subview can +/// be folded into its source. static bool isTrivialSubViewOp(SubViewOp subViewOp) { if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank()) return false; @@ -2536,7 +2485,8 @@ LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { - // Any constant operand, just return to let SubViewOpConstantFolder kick in. + // Any constant operand, just return to let SubViewOpConstantFolder kick + // in. if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); })) @@ -2549,10 +2499,10 @@ if (!CastOp::canFoldIntoConsumerOp(castOp)) return failure(); - // Compute the SubViewOp result type after folding the MemRefCastOp. Use the - // MemRefCastOp source operand type to infer the result type and the current - // SubViewOp source operand type to compute the dropped dimensions if the - // operation is rank-reducing. + // Compute the SubViewOp result type after folding the MemRefCastOp. Use + // the MemRefCastOp source operand type to infer the result type and the + // current SubViewOp source operand type to compute the dropped dimensions + // if the operation is rank-reducing. auto resultType = getCanonicalSubViewResultType( subViewOp.getType(), subViewOp.getSourceType(), castOp.source().getType().cast(), @@ -2571,8 +2521,8 @@ } }; -/// Canonicalize subview ops that are no-ops. When the source shape is not same -/// as a result shape due to use of `affine_map`. +/// Canonicalize subview ops that are no-ops. When the source shape is not +/// same as a result shape due to use of `affine_map`. class TrivialSubViewOpFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -517,18 +517,6 @@ // ----- -func @expand_shape_unsupported_src_layout( - %arg0 : memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]>) - -> memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]> { - // expected-error @+1 {{invalid source layout map}} - %0 = memref.expand_shape %arg0 [[0], [1], [2, 3], [4]] : - memref<20x2x10x5xf32, offset: 0, strides: [100, 10, 50, 1]> - into memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]> - return %0 : memref<20x2x2x5x5xf32, offset : 0, strides : [100, 10, 250, 50, 1]> -} - -// ----- - func @expand_shape_invalid_static_dim_size(%arg0 : memref) -> memref { // expected-error @+1 {{collapsed dim size (21) must equal reassociation group size (20)}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -281,6 +281,29 @@ // ----- +// CHECK-LABEL: func @expand_collapse_shape_transposed_layout +func @expand_collapse_shape_transposed_layout( + %m0: memref, + %m1: memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]>) { + + %r0 = memref.expand_shape %m0 [[0], [1, 2]] : + memref into + memref + %rr0 = memref.collapse_shape %r0 [[0], [1, 2]] : + memref into + memref + + %r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] : + memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]> into + memref<2x2x5x2x3xf32, offset : 0, strides : [2, 1, ?, 3000, 1000]> + %rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] : + memref<2x2x5x2x3xf32, offset : 0, strides : [2, 1, ?, 3000, 1000]> into + memref<4x5x6xf32, offset : 0, strides : [1, ?, 1000]> + return +} + +// ----- + func @rank(%t : memref<4x4x?xf32>) { // CHECK: %{{.*}} = memref.rank %{{.*}} : memref<4x4x?xf32> %0 = "memref.rank"(%t) : (memref<4x4x?xf32>) -> index diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -5,6 +5,7 @@ // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)> // CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)> +// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor, @@ -337,6 +338,17 @@ return %1 : tensor } +// CHECK-LABEL: func @tensor.expand_shape_of_slice2( +// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> +func @tensor.expand_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> { + // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]> + %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> + // CHECK: memref.collapse_shape %{{.*}} [ + // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32> + %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> + return %1 : tensor<1xf32> +} + // CHECK-LABEL: func @tensor.collapse_shape( // CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor {