diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1303,35 +1303,44 @@ let summary = "operation to produce a memref with a smaller rank."; let description = [{ The `memref.collapse_shape` op produces a new view with a smaller rank - whose sizes are a reassociation of the original `view`. Depending on - whether or not the reassociated MemRefType is contiguous, the resulting - memref may require explicit alloc and copies. + whose sizes are a reassociation of the original `view`. The operation is + limited to such reassociations, where subsequent, contiguous dimensions are + collapsed into a single dimension. Such reassociations never require + additional allocs or copies. + + Collapsing non-contiguous dimensions is undefined behavior. When a group of + dimensions can be statically proven to be non-contiguous, collapses of such + groups are rejected in the verifier on a best-effort basis. In the general + case, collapses of dynamically-sized dims with dynamic strides cannot be + proven to be contiguous or non-contiguous due to limitations in the memref + type. A reassociation is defined as a continuous grouping of dimensions and is represented with an array of I64ArrayAttr attribute. - For now, it is assumed that either: - 1. a reassociation produces and consumes contiguous MemRefType or, - 2. the reshape op will be folded into its consumers (by changing the shape - of the computations). - All other cases are undefined behavior and a reshape op may not lower to - LLVM if it cannot be proven statically that it does not require alloc+copy. + Note: Only the dimensions within a reassociation group must be contiguous. + The remaining dimensions may be non-contiguous. - The result memref type of a reshape can be zero-ranked if the operand - memref type is statically shaped with all dimensions being unit extent. In - such case the reassociation map is empty. - - The verification rule is that the reassociation maps are applied to the - operand memref with the larger rank to obtain the result memref with the - smaller rank. + The result memref type can be zero-ranked if the source memref type is + statically shaped with all dimensions being unit extent. In such a case, the + reassociation indices must be empty. Examples: ```mlir // Dimension collapse (i, j) -> i' and k -> k' %1 = memref.collapse_shape %0 [[0, 1], [2]] : - memref into memref + memref into memref ``` + + For simplicity, this op may not be used to cast dynamicity of dimension + sizes and/or strides. I.e., a result dimension must be dynamic if and only + if at least one dimension in the corresponding reassociation group is + dynamic. Similarly, the stride of a result dimension must be dynamic if and + only if the corresponding start dimension in the source type is dynamic. + + Note: This op currently assumes that the inner strides are of the + source/result layout map are the faster-varying ones. }]; let builders = [ // Builders for a contracting reshape whose result type is computed from diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1701,105 +1701,6 @@ return true; } -/// Compute the MemRefType obtained by applying the `reassociation` (which is -/// expected to be valid) to `type`. -/// If `type` is Contiguous MemRefType, this always produce a contiguous -/// MemRefType. -static MemRefType -computeReshapeCollapsedType(MemRefType type, - ArrayRef reassociation) { - auto sizes = type.getShape(); - AffineExpr offset; - SmallVector strides; - auto status = getStridesAndOffset(type, strides, offset); - auto isIdentityLayout = type.getLayout().isIdentity(); - (void)status; - assert(succeeded(status) && "expected strided memref"); - - SmallVector newSizes; - newSizes.reserve(reassociation.size()); - SmallVector newStrides; - newStrides.reserve(reassociation.size()); - - // Use the fact that reassociation is valid to simplify the logic: only use - // each map's rank. - assert(isReassociationValid(reassociation) && "invalid reassociation"); - unsigned currentDim = 0; - for (AffineMap m : reassociation) { - unsigned dim = m.getNumResults(); - int64_t size = 1; - AffineExpr stride = strides[currentDim + dim - 1]; - if (isIdentityLayout || - isReshapableDimBand(currentDim, dim, sizes, strides)) { - for (unsigned d = 0; d < dim; ++d) { - int64_t currentSize = sizes[currentDim + d]; - if (ShapedType::isDynamic(currentSize)) { - size = ShapedType::kDynamicSize; - break; - } - size *= currentSize; - } - } else { - size = ShapedType::kDynamicSize; - stride = AffineExpr(); - } - newSizes.push_back(size); - newStrides.push_back(stride); - currentDim += dim; - } - - // Early-exit: if `type` is contiguous, the result must be contiguous. - if (canonicalizeStridedLayout(type).getLayout().isIdentity()) - return MemRefType::Builder(type).setShape(newSizes).setLayout({}); - - // Convert back to int64_t because we don't have enough information to create - // new strided layouts from AffineExpr only. This corresponds to a case where - // copies may be necessary. - int64_t intOffset = ShapedType::kDynamicStrideOrOffset; - if (auto o = offset.dyn_cast()) - intOffset = o.getValue(); - SmallVector intStrides; - intStrides.reserve(strides.size()); - for (auto stride : newStrides) { - if (auto cst = stride.dyn_cast_or_null()) - intStrides.push_back(cst.getValue()); - else - intStrides.push_back(ShapedType::kDynamicStrideOrOffset); - } - auto layout = - makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); - return canonicalizeStridedLayout( - MemRefType::Builder(type).setShape(newSizes).setLayout( - AffineMapAttr::get(layout))); -} - -void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, - ArrayRef reassociation, - ArrayRef attrs) { - auto memRefType = src.getType().cast(); - auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( - b.getContext(), reassociation))); - build(b, result, resultType, src, attrs); - result.addAttribute(getReassociationAttrName(), - getReassociationIndicesAttribute(b, reassociation)); -} - -template ::value> -static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, - MemRefType collapsedType) { - if (failed( - verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) - return failure(); - auto maps = op.getReassociationMaps(); - MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); - if (collapsedType != expectedType) - return op.emitOpError("expected collapsed type to be ") - << expectedType << ", but got " << collapsedType; - return success(); -} - /// Compute the layout map after expanding a given source MemRef type with the /// specified reassociation indices. static FailureOr @@ -1925,8 +1826,174 @@ CollapseMixedReshapeOps>(context); } +/// Compute the layout map after collapsing a given source MemRef type with the +/// specified reassociation indices. +/// +/// Note: All collapsed dims in a reassociation group must be contiguous. It is +/// not possible to check this by inspecting a MemRefType in the general case. +/// But it is assumed. If this is not the case, the behavior is undefined. +static FailureOr +computeCollapsedLayoutMap(MemRefType srcType, ArrayRef resultShape, + ArrayRef reassociation) { + SmallVector srcStrides, resultStrides; + int64_t srcOffset; + if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) + return failure(); + assert(resultShape.size() == reassociation.size() && "invalid reassociation"); + + // Iterate over all reassociation groups from the back. Example: + // source shape = [20, ?, 5, 10, 2] + // source strides = [ ?, ?, 800, 80, 4] + // reassociation = [[0, 1], [2, 3], [4]] + // result shape = [ ?, 50, 2] + // + // Note: The result shape is not needed in this computation. It is just used + // check that the size of the reassociation is correct. + for (ReassociationIndices group : llvm::reverse(reassociation)) { + // A result dim has the same stride as the first dimension (least + // significant one) in the corresponding reassociation group. E.g.: + // reassociation = [[0, 1], [2, 3], [4]] + // | | | + // v v v + // ? 80 4 + int64_t resultStride = srcStrides[group.pop_back_val()]; + + // The following is just a best-effort check for non-contiguous source + // strides within a reassociation group. E.g.: + // reassociation = [[0, 1], [2, 3], [4]] + // ^^^^^^ + // Iteratively compute the next stride within the reassociation group + // one-by-one. Start with the stride computed above. E.g.: + // reassociation = [[0, 1], [2, 3], [4]] + // | + // v + // nextStride = 80 + int64_t nextStride = resultStride; + for (int64_t nextDim : llvm::reverse(group)) { + // Next expected stride is previous stride multiplied by dim size, e.g.: + // reassociation = [[0, 1], [2, 3], [4]] + // | + // v + // nextStride = 80 * 10 = 800 + nextStride = + computeNextStride(nextStride, srcType.getDimSize(nextDim + 1)); + + // Ensure that the source actually has this stride value. E.g.: + // source strides = [ ?, ?, 800, 80, 4] + // | + // v + // same stride, OK + // If strides are dynamic, we cannot verify anything statically. + if (!ShapedType::isDynamicStrideOrOffset(srcStrides[nextDim]) && + !ShapedType::isDynamicStrideOrOffset(nextStride) && + srcStrides[nextDim] != nextStride) { + // Attempting to collapse non-contiguous dimensions. This is forbidden. + // Note: This check does not handle cases where strides and dimension + // sizes are dynamic. Such dims could still turn out to be non- + // contiguous at runtime. This check is only a best effort to catch + // illegal collapses at verification time. + return failure(); + } + } + + resultStrides.push_back(resultStride); + } + + return makeStridedLinearLayoutMap( + llvm::to_vector<8>(llvm::reverse(resultStrides)), srcOffset, + srcType.getContext()); +} + +static MemRefType +computeCollapsedType(MemRefType srcType, + ArrayRef reassociation) { + SmallVector resultShape; + for (const ReassociationIndices &group : reassociation) { + int64_t groupSize = 1; + for (int64_t srcDim : group) { + if (srcType.isDynamicDim(srcDim)) { + // Source dim is dynamic, so the collapsed dim is also dynamic. + groupSize = ShapedType::kDynamicSize; + break; + } + + groupSize *= srcType.getDimSize(srcDim); + } + + resultShape.push_back(groupSize); + } + + if (srcType.getLayout().isIdentity()) { + // If the source is contiguous (i.e., no layout map specified), so is the + // result. + MemRefLayoutAttrInterface layout; + return MemRefType::get(resultShape, srcType.getElementType(), layout, + srcType.getMemorySpace()); + } + + // Source may not be fully contiguous. Compute the layout map. + // Note: Dimensions that are collapsed into a single dim are assumed to be + // contiguous. + FailureOr computedLayout = + computeCollapsedLayoutMap(srcType, resultShape, reassociation); + assert(succeeded(computedLayout) && + "invalid source layout map or collapsing non-contiguous dims"); + auto computedType = + MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, + srcType.getMemorySpaceAsInt()); + return canonicalizeStridedLayout(computedType); +} + +void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, + ArrayRef reassociation, + ArrayRef attrs) { + auto srcType = src.getType().cast(); + MemRefType resultType = computeCollapsedType(srcType, reassociation); + build(b, result, resultType, src, attrs); + result.addAttribute(getReassociationAttrName(), + getReassociationIndicesAttribute(b, reassociation)); +} + LogicalResult CollapseShapeOp::verify() { - return verifyReshapeOp(*this, getSrcType(), getResultType()); + MemRefType srcType = getSrcType(); + MemRefType resultType = getResultType(); + + // Verify result shape. + if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(), + srcType.getShape(), getReassociationIndices(), + /*allowMultipleDynamicDimsPerGroup=*/true))) + return failure(); + + // Compute expected result type (including layout map). + MemRefType expectedResultType; + if (srcType.getLayout().isIdentity()) { + // If the source is contiguous (i.e., no layout map specified), so is the + // result. + MemRefLayoutAttrInterface layout; + expectedResultType = + MemRefType::get(resultType.getShape(), srcType.getElementType(), layout, + srcType.getMemorySpace()); + } else { + // Source may not be fully contiguous. Compute the layout map. + // Note: Dimensions that are collapsed into a single dim are assumed to be + // contiguous. + FailureOr computedLayout = computeCollapsedLayoutMap( + srcType, resultType.getShape(), getReassociationIndices()); + if (failed(computedLayout)) + return emitOpError( + "invalid source layout map or collapsing non-contiguous dims"); + auto computedType = + MemRefType::get(resultType.getShape(), srcType.getElementType(), + *computedLayout, srcType.getMemorySpaceAsInt()); + expectedResultType = canonicalizeStridedLayout(computedType); + } + + auto canonicalizedResultType = canonicalizeStridedLayout(resultType); + if (expectedResultType != canonicalizedResultType) + return emitOpError("expected collapsed type to be ") + << expectedResultType << " but found " << canonicalizedResultType; + + return success(); } struct CollapseShapeOpMemRefCastFolder @@ -1943,9 +2010,9 @@ if (!CastOp::canFoldIntoConsumerOp(cast)) return failure(); - Type newResultType = computeReshapeCollapsedType( - cast.getOperand().getType().cast(), - op.getReassociationMaps()); + Type newResultType = + computeCollapsedType(cast.getOperand().getType().cast(), + op.getReassociationIndices()); if (newResultType == op.getResultType()) { rewriter.updateRootInPlace( diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -409,7 +409,7 @@ // ----- func @collapse_shape_to_higher_rank(%arg0: memref) { - // expected-error @+1 {{expected the type 'memref' to have higher rank than the type = 'memref<1xf32>'}} + // expected-error @+1 {{op reassociation index 0 is out of bounds}} %0 = memref.collapse_shape %arg0 [[0]] : memref into memref<1xf32> } @@ -432,15 +432,8 @@ // ----- -func @collapse_shape(%arg0: memref) { - // expected-error @+1 {{expected to collapse or expand dims}} - %0 = memref.collapse_shape %arg0 [[0]] : memref into memref -} - -// ----- - func @collapse_shape_mismatch_indices_num(%arg0: memref) { - // expected-error @+1 {{expected rank of the collapsed type(2) to be the number of reassociation maps(1)}} + // expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}} %0 = memref.collapse_shape %arg0 [[0, 1]] : memref into memref } @@ -448,15 +441,26 @@ // ----- func @collapse_shape_invalid_reassociation(%arg0: memref) { - // expected-error @+1 {{expected reassociation map #1 to be valid and contiguous}} + // expected-error @+1 {{reassociation indices must be contiguous}} %0 = memref.collapse_shape %arg0 [[0, 1], [1, 2]] : memref into memref } // ----- +func @collapse_shape_reshaping_non_contiguous( + %arg0: memref<3x4x5xf32, offset: 0, strides: [270, 50, 10]>) { + // expected-error @+1 {{invalid source layout map or collapsing non-contiguous dims}} + %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : + memref<3x4x5xf32, offset: 0, strides: [270, 50, 10]> + into memref<12x5xf32, offset: 0, strides: [50, 1]> + return +} + +// ----- + func @collapse_shape_wrong_collapsed_type(%arg0: memref) { - // expected-error @+1 {{expected collapsed type to be 'memref', but got 'memref (d0 * s0 + d1)>>'}} + // expected-error @+1 {{expected collapsed type to be 'memref' but found 'memref (d0 * s0 + d1)>>'}} %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref into memref (d0 * s0 + d1)>> } @@ -485,7 +489,7 @@ func @collapse_shape_illegal_static_memref (%arg0: memref<2x3x2x4x5xf32>) -> memref<2x3x20xf32> { - // expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}} + // expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}} %0 = memref.collapse_shape %arg0 [[0], [1], [2, 3, 4]] : memref<2x3x2x4x5xf32> into memref<2x3x20xf32> return %0 : memref<2x3x20xf32> @@ -537,7 +541,7 @@ func @collapse_shape_illegal_mixed_memref(%arg0 : memref) -> memref { - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}} + // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}} %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref into memref return %0 : memref @@ -547,7 +551,7 @@ func @collapse_shape_illegal_mixed_memref_2(%arg0 : memref) -> memref { - // expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}} + // expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}} %0 = memref.collapse_shape %arg0 [[0], [1, 2]] : memref into memref return %0 : memref diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -111,7 +111,9 @@ %arg2: tensor<3x?x5xf32>, %arg3: memref<30x20xf32, offset : 100, strides : [4000, 2]>, %arg4: memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>>, - %arg5: memref) { + %arg5: memref, + %arg6: memref<3x4x5xf32, offset: 0, strides: [240, 60, 10]>, + %arg7: memref<1x2049xi64, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) { // Reshapes that collapse and expand back a contiguous buffer. // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> @@ -164,6 +166,17 @@ memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>> into memref<1x1x5xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 5 + s0 + d2 + d1 * 5)>> + // Note: Only the collapsed two shapes are contiguous in the follow test case. +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] + %r6 = memref.collapse_shape %arg6 [[0, 1], [2]] : + memref<3x4x5xf32, offset: 0, strides: [240, 60, 10]> into + memref<12x5xf32, offset: 0, strides: [60, 10]> + +// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] + %r7 = memref.collapse_shape %arg7 [[0, 1]] : + memref<1x2049xi64, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into + memref<2049xi64, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>> + // Reshapes that expand and collapse back a contiguous buffer with some 1's. // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] // CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>