diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1165,40 +1165,6 @@ [NoSideEffect, ViewLikeOpInterface])>, Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>, Results<(outs AnyStridedMemRef:$result)>{ - let builders = [ - // Builders for a contracting reshape whose result type is computed from - // `src` and `reassociation`. - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs)>, - OpBuilder<(ins "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, src, reassociationMaps, attrs); - }]>, - - // Builders for a reshape whose result type is passed explicitly. This may - // be either a contracting or expanding reshape. - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - build($_builder, $_state, resultType, src, attrs); - $_state.addAttribute("reassociation", - getReassociationIndicesAttribute($_builder, reassociation)); - }]>, - OpBuilder<(ins "Type":$resultType, "Value":$src, - "ArrayRef":$reassociation, - CArg<"ArrayRef", "{}">:$attrs), - [{ - auto reassociationMaps = - convertReassociationMapsToIndices($_builder, reassociation); - build($_builder, $_state, resultType, src, reassociationMaps, attrs); - }]> - ]; code commonExtraClassDeclaration = [{ SmallVector getReassociationMaps(); @@ -1257,6 +1223,42 @@ memref into memref ``` }]; + + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `resultShape`, `src` and `reassociation`. + OpBuilder<(ins "ArrayRef":$resultShape, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "ArrayRef":$resultShape, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultShape, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. This may + // be either a contracting or expanding reshape. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + let extraClassDeclaration = commonExtraClassDeclaration; } @@ -1294,6 +1296,42 @@ memref into memref ``` }]; + + let builders = [ + // Builders for a contracting reshape whose result type is computed from + // `src` and `reassociation`. + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, src, reassociationMaps, attrs); + }]>, + + // Builders for a reshape whose result type is passed explicitly. This may + // be either a contracting or expanding reshape. + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + build($_builder, $_state, resultType, src, attrs); + $_state.addAttribute("reassociation", + getReassociationIndicesAttribute($_builder, reassociation)); + }]>, + OpBuilder<(ins "Type":$resultType, "Value":$src, + "ArrayRef":$reassociation, + CArg<"ArrayRef", "{}">:$attrs), + [{ + auto reassociationMaps = + convertReassociationMapsToIndices($_builder, reassociation); + build($_builder, $_state, resultType, src, reassociationMaps, attrs); + }]> + ]; + let extraClassDeclaration = commonExtraClassDeclaration; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1488,7 +1488,78 @@ /// Compute the MemRefType obtained by applying the `reassociation` (which is /// expected to be valid) to `type`. -/// If `type` is Contiguous MemRefType, this always produce a contiguous +/// If `type` is a contiguous MemRefType, this always produce a contiguous +/// MemRefType. +static MemRefType computeReshapeExpandedType(MemRefType type, + ArrayRef reassociation, + ArrayRef resultShape) { + // Early-exit: if `type` is contiguous, the result must be contiguous. + if (canonicalizeStridedLayout(type).getLayout().isIdentity()) + return MemRefType::Builder(type).setShape(resultShape).setLayout({}); + + AffineExpr offset; + SmallVector strides; + auto status = getStridesAndOffset(type, strides, offset); + (void)status; + assert(succeeded(status) && "expected strided memref"); + + SmallVector newStrides(resultShape.size(), AffineExpr()); + newStrides.reserve(reassociation.size()); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + int64_t resultDim = resultShape.size() - 1; + int64_t sourceDim = strides.size() - 1; + for (AffineMap m : llvm::reverse(reassociation)) { + newStrides[resultDim] = strides[sourceDim]; + int64_t numDynamicDims = static_cast(resultShape[resultDim] == + ShapedType::kDynamicSize); + --resultDim; + + // If this loop has iterations, this source dim is expanded. + for (int64_t r = static_cast(m.getNumResults()) - 2; r >= 0; --r) { + numDynamicDims += static_cast(resultShape[resultDim] == + ShapedType::kDynamicSize); + assert(numDynamicDims < 2 && + "at most one result dim of an expansion may be dynamic"); + if (resultShape[resultDim + 1] != ShapedType::kDynamicSize && + newStrides[resultDim + 1]) + newStrides[resultDim] = + newStrides[resultDim + 1] * resultShape[resultDim + 1]; + --resultDim; + } + + --sourceDim; + } + + assert(resultDim == -1 && "did not process all result dims"); + assert(sourceDim == -1 && "did not process all source dims"); + + // Convert back to int64_t because we don't have enough information to create + // new strided layouts from AffineExpr only. This corresponds to a case where + // copies may be necessary. + int64_t intOffset = ShapedType::kDynamicStrideOrOffset; + if (auto o = offset.dyn_cast()) + intOffset = o.getValue(); + SmallVector intStrides; + intStrides.reserve(newStrides.size()); + for (auto stride : newStrides) { + if (auto cst = stride.dyn_cast_or_null()) + intStrides.push_back(cst.getValue()); + else + intStrides.push_back(ShapedType::kDynamicStrideOrOffset); + } + auto layout = + makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); + return canonicalizeStridedLayout(MemRefType::Builder(type) + .setShape(resultShape) + .setLayout(AffineMapAttr::get(layout))); +} + +/// Compute the MemRefType obtained by applying the `reassociation` (which is +/// expected to be valid) to `type`. +/// If `type` is a contiguous MemRefType, this always produce a contiguous /// MemRefType. static MemRefType computeReshapeCollapsedType(MemRefType type, @@ -1550,13 +1621,16 @@ AffineMapAttr::get(layout))); } -void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src, +void ExpandShapeOp::build(OpBuilder &b, OperationState &result, + ArrayRef resultShape, Value src, ArrayRef reassociation, ArrayRef attrs) { auto memRefType = src.getType().cast(); - auto resultType = computeReshapeCollapsedType( - memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs( - b.getContext(), reassociation))); + auto resultType = computeReshapeExpandedType( + memRefType, + getSymbolLessAffineMaps( + convertReassociationIndicesToExprs(b.getContext(), reassociation)), + resultShape); build(b, result, resultType, src, attrs); result.addAttribute(getReassociationAttrName(), getReassociationIndicesAttribute(b, reassociation)); @@ -1574,25 +1648,19 @@ getReassociationIndicesAttribute(b, reassociation)); } -template ::value> -static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, - MemRefType collapsedType) { - if (failed( - verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) +static LogicalResult verify(ExpandShapeOp op) { + if (failed(verifyReshapeLikeTypes(op, op.getResultType(), op.getSrcType(), + /*isExpansion=*/true))) return failure(); auto maps = op.getReassociationMaps(); - MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); - if (collapsedType != expectedType) - return op.emitOpError("expected collapsed type to be ") - << expectedType << ", but got " << collapsedType; + MemRefType expectedType = computeReshapeExpandedType( + op.getSrcType(), maps, op.getResultType().getShape()); + if (op.getResultType() != expectedType) + return op.emitOpError("expected result type to be ") + << expectedType << ", but got " << op.getResultType(); return success(); } -static LogicalResult verify(ExpandShapeOp op) { - return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); -} - void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, @@ -1600,7 +1668,15 @@ } static LogicalResult verify(CollapseShapeOp op) { - return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); + if (failed(verifyReshapeLikeTypes(op, op.getSrcType(), op.getResultType(), + /*isExpansion=*/false))) + return failure(); + auto maps = op.getReassociationMaps(); + MemRefType expectedType = computeReshapeCollapsedType(op.getSrcType(), maps); + if (op.getResultType() != expectedType) + return op.emitOpError("expected result type to be ") + << expectedType << ", but got " << op.getResultType(); + return success(); } struct CollapseShapeOpMemRefCastFolder diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -411,7 +411,7 @@ // ----- func @collapse_shape_wrong_collapsed_type(%arg0: memref) { - // expected-error @+1 {{expected collapsed type to be 'memref', but got 'memref (d0 * s0 + d1)>>'}} + // expected-error @+1 {{expected result type to be 'memref', but got 'memref (d0 * s0 + d1)>>'}} %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref into memref (d0 * s0 + d1)>> }