diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -52,16 +53,6 @@ /// provide an op-specified hook so that Linalg ops may override the behavior. LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op); -using ReassociationIndices = SmallVector; -using ReassociationIndicesRef = ArrayRef; -using ReassociationExprs = SmallVector; - -/// Return the reassociations maps to use to reshape given the source type and -/// the target type when possible. Return llvm::None when this computation -/// failed. -Optional> -getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); - /// Returns the name mangled library call name to disambiguate between different /// overloads at the C level. The name mangling scheme is basic and uses MLIR /// type names: diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -0,0 +1,266 @@ +//===- RehshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities and common canonicalization patterns for +// reshape operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H +#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H + +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { + +using ReassociationIndices = SmallVector; +using ReassociationIndicesRef = ArrayRef; +using ReassociationExprs = SmallVector; + +/// Attribute name for the ArrayAttr which encodes reassociation indices. +constexpr StringRef getReassociationAttrName(); + +/// Collapse reassociation maps that are used in pair of reshape ops where one +/// is a producer and other is the consumer. Only valid to use this method when +/// both the producer and consumer are collapsing dimensions or both are +/// expanding dimensions. +/// +/// For example, +/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] +/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, +/// affine_map<(d0, d1, d2) -> (d2)>] +/// +/// is folded into +/// +/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] +/// TODO: Use reassociation indices instead of affine maps here. +Optional> +collapseReassociationIndices(ArrayRef mapsProducer, + ArrayRef mapsConsumer, + MLIRContext *context); + +/// Return the reassociations maps to use to reshape given the source type and +/// the target type when possible. Return llvm::None when this computation +/// failed. +Optional> +getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType); + +/// Return true if the reassociation specification is valid, false otherwise. +/// When false, the `invalidIndex` integer pointer is optionally filled with the +/// index of the offending reassociation map. +bool isReassociationValid(ArrayRef reassociation, + int *invalidIndex = nullptr); + +/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp, +/// linalg::(Tensor)CollapseShapeOp. +ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result); + +/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp, +/// linalg::(Tensor)CollapseShapeOp. +template +void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) { + p << op.getOperationName() << ' ' << op.src() << " ["; + + llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) { + p << '['; + auto arrayAttr = attr.template cast(); + llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) { + p << attr.cast().getInt(); + }); + p << ']'; + }); + + p << "] "; + p.printOptionalAttrDict(op->getAttrs(), + /*elidedAttrs=*/{op.getReassociationAttrName()}); + p << ": " << op.src().getType() << " into " << op.getType(); +} + +template +static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, + ArrayRef operands) { + // Fold producer-consumer reshape ops that where the operand type of the + // producer is same as the return type of the consumer. + auto reshapeSrcOp = + reshapeOp.src().template getDefiningOp(); + if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) + return reshapeSrcOp.src(); + // Reshape of a constant can be replaced with a new constant. + if (auto elements = operands.front().dyn_cast_or_null()) { + return elements.reshape( + reshapeOp.getResult().getType().template cast()); + } + return nullptr; +} + +/// Common verifier for reshape-like types. Fills `expandedType` and +///`collapsedType` with the proper `src` or `result` type. +template +static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, + T collapsedType, bool isExpansion) { + unsigned expandedRank = expandedType.getRank(); + unsigned collapsedRank = collapsedType.getRank(); + if (expandedRank < collapsedRank) + return op.emitOpError("expected the type ") + << expandedType + << " to have higher rank than the type = " << collapsedType; + if (expandedRank == 0) + return op.emitOpError("expected non-zero memref ranks"); + if (expandedRank == collapsedRank) + return op.emitOpError("expected to collapse or expand dims"); + + if (collapsedRank == 0) { + // If collapsed rank is 0, then expanded type must be static shaped and of + // sizes 1. + if (llvm::any_of(expandedType.getShape(), + [](int64_t dim) -> bool { return dim != 1; })) + return op.emitOpError("invalid to reshape tensor/memref with non-unit " + "extent dimensions to zero-rank tensor/memref"); + return success(); + } + if (collapsedRank != op.reassociation().size()) + return op.emitOpError("expected rank of the collapsed type(") + << collapsedRank << ") to be the number of reassociation maps(" + << op.reassociation().size() << ")"; + auto maps = op.getReassociationMaps(); + for (auto it : llvm::enumerate(maps)) + if (it.value().getNumDims() != expandedRank) + return op.emitOpError("expected reassociation map #") + << it.index() << " of same rank as expanded memref(" + << expandedRank << "), but got " << it.value().getNumDims(); + int invalidIdx = 0; + if (!isReassociationValid(maps, &invalidIdx)) + return op.emitOpError("expected reassociation map #") + << invalidIdx << " to be valid and contiguous"; + return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion); +} + +/// Verify that shapes of the reshaped types using following rules +/// 1) if a dimension in the collapsed type is static, then the corresponding +/// dimensions in the expanded shape should be +/// a) static +/// b) the product should be same as the collaped shape. +/// 2) if a dimension in the collaped type is dynamic, one and only one of the +/// corresponding dimensions in the expanded type should be dynamic. This +/// rule is only needed with reshape operations that are expanding. +template +static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, + ShapedType expandedType, + bool isExpandingReshape) { + ArrayRef collapsedShape = collapsedType.getShape(); + ArrayRef expandedShape = expandedType.getShape(); + unsigned expandedDimStart = 0; + for (auto map : llvm::enumerate(op.getReassociationMaps())) { + Optional dynamicShape; + int64_t linearizedStaticShape = 1; + for (auto dim : llvm::enumerate(expandedShape.slice( + expandedDimStart, map.value().getNumResults()))) { + if (ShapedType::isDynamic(dim.value())) { + if (isExpandingReshape && dynamicShape) { + return op->emitOpError("invalid to have a single dimension (") + << map.index() << ") expanded into multiple dynamic dims (" + << expandedDimStart + dynamicShape.getValue() << "," + << expandedDimStart + dim.index() << ")"; + } + dynamicShape = dim.index(); + } else { + linearizedStaticShape *= dim.value(); + } + } + if (dynamicShape) { + if (!ShapedType::isDynamic(collapsedShape[map.index()])) { + return op->emitOpError("expected dimension ") + << map.index() + << " of collapsed type to be dynamic since one or more of the " + "corresponding dimensions in the expanded type is dynamic"; + } + } else { + if (collapsedShape[map.index()] != linearizedStaticShape) { + return op->emitOpError("expected dimension ") + << map.index() << " of collapsed type to be static value of " + << linearizedStaticShape << " "; + } + } + expandedDimStart += map.value().getNumResults(); + } + return success(); +} + +/// Pattern to collapse producer/consumer reshape ops that are both collapsing +/// dimensions or are both expanding dimensions. +template +struct CollapseReshapeOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, + PatternRewriter &rewriter) const override { + auto srcReshapeOp = reshapeOp.src().template getDefiningOp(); + if (!srcReshapeOp) + return failure(); + + ShapedType resultType = reshapeOp.getResultType(); + Optional> reassociationIndices = + collapseReassociationIndices(srcReshapeOp.getReassociationMaps(), + reshapeOp.getReassociationMaps(), + rewriter.getContext()); + if (!reassociationIndices) + return failure(); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + return success(); + } +}; + +/// Pattern to collapse producer/consumer reshape ops that are both collapsing +/// dimensions or are both expanding dimensions. +template +struct CollapseMixedReshapeOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, + PatternRewriter &rewriter) const override { + auto srcReshapeOp = + reshapeOp.src().template getDefiningOp(); + if (!srcReshapeOp) + return failure(); + + ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); + ShapedType intermediateType = reshapeOp.getSrcType(); + ShapedType resultType = reshapeOp.getResultType(); + + // If the source reshape can be collapsed/expanded into the target reshape + // they can still be folded. This can only be reasoned about statically + // for cases where + // - either all shapes are static, or + // - The number of dynamic dimensions matches in the source of source and + // result with all other dimensions being 1. + Optional> reassociationIndices = + getReassociationIndicesForReshape(srcReshapeSrcType, resultType); + if (!reassociationIndices) + return failure(); + bool originalOpExpands = + intermediateType.getRank() > srcReshapeSrcType.getRank(); + bool resultingOpExpands = + resultType.getRank() > srcReshapeSrcType.getRank(); + if (!(resultingOpExpands ^ originalOpExpands)) + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + else + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); + return success(); + } +}; + +} // namespace mlir + +#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H diff --git a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt --- a/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt +++ b/mlir/lib/Conversion/TosaToLinalg/CMakeLists.txt @@ -10,6 +10,7 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC + MLIRDialectUtils MLIRIR MLIRLinalg MLIRLinalgUtils diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -1120,8 +1121,7 @@ (operandTy.getRank() > resultTy.getRank() ? resultTy.getShape() : operandTy.getShape()); unsigned currSrcDim = 0, currDstDim = 0; - SmallVector reassociationMap( - collapsedShape.size()); + SmallVector reassociationMap(collapsedShape.size()); // First scan all dimensions in the source shapes to see whether we have a // perfect case where consecutive dimensions in source are collapsed. For @@ -1176,11 +1176,11 @@ std::accumulate(expandedShape.begin(), expandedShape.end(), 1, std::multiplies()); auto elemTy = operandTy.getElementType(); - SmallVector collapsingMap = { + SmallVector collapsingMap = { // Use operandTy here because we need to collapse all operands // dimensions. getIdentityExprs(operandTy.getShape().size())}; - SmallVector expandingMap = { + SmallVector expandingMap = { // Use resultTy here because we need to expand to all result // dimensions. getIdentityExprs(resultTy.getShape().size())}; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1069,338 +1069,20 @@ // ReshapeOp //===----------------------------------------------------------------------===// -Optional> -mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType, - ShapedType targetType) { - // Make the sourceType greater rank than the targetType. If they are same - // rank, then its an unsupported reshape op. - if (sourceType.getRank() == targetType.getRank()) - return llvm::None; - if (sourceType.getRank() < targetType.getRank()) - std::swap(sourceType, targetType); - - ArrayRef sourceShape = sourceType.getShape(); - ArrayRef targetShape = targetType.getShape(); - unsigned sourceDim = 0; - SmallVector reassociationMap; - reassociationMap.reserve(targetType.getRank()); - - ReassociationIndices currIndices; - int64_t prodOfCollapsedDims = 1; - while (sourceDim < sourceShape.size()) { - unsigned targetDim = reassociationMap.size(); - - // If all the dimensions of the targetShape are exhausted, then the - // remaining dims in the source shape must be all 1s. So for such cases, set - // 1 as the target shape. The actual reassociation indices will be handled - // later. - int64_t currTargetShape = - (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); - while (sourceShape[sourceDim] != ShapedType::kDynamicSize && - prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && - sourceDim < sourceShape.size()) { - prodOfCollapsedDims *= sourceShape[sourceDim]; - currIndices.push_back(sourceDim++); - } - - // If the current expanded dimension is dynamic, then the collapsed - // dimensions should also be dynamic and product of all previous unprocessed - // dimensions of the expanded shape should be 1. - if (sourceShape[sourceDim] == ShapedType::kDynamicSize && - (currTargetShape != ShapedType::kDynamicSize || - prodOfCollapsedDims != 1)) - return llvm::None; - - // If the collapsed dim is dynamic, the current expanded dim should also - // be dynamic. - if (currTargetShape == ShapedType::kDynamicSize && - sourceShape[sourceDim] != ShapedType::kDynamicSize) - return llvm::None; - - // For static shapes, if the product of dimensions of the expanded shape - // should match the collapsed dimension shape. - if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) - return llvm::None; - - currIndices.push_back(sourceDim++); - // If the reassociation is empty but the currIndices is not, this by - // definition is folding unit-dimensions with the result being scalar type. - // So only append the `currIndices` if reassociation map is not empty. - if (targetDim == targetShape.size()) { - if (!reassociationMap.empty() && !currIndices.empty()) - reassociationMap.back().append(currIndices.begin(), currIndices.end()); - // Break out of the loops. We should be done here. - break; - } - reassociationMap.emplace_back(ReassociationIndices{}); - std::swap(reassociationMap.back(), currIndices); - prodOfCollapsedDims = 1; - } - // All the dimensions in the two shapes must have been processed. - if (reassociationMap.size() != targetShape.size() || - sourceDim != sourceShape.size()) - return llvm::None; - return reassociationMap; -} - -template -static void print(OpAsmPrinter &p, ReshapeLikeOp op) { - p << op.getOperationName() << ' ' << op.src() << " ["; - - llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) { - p << '['; - auto arrayAttr = attr.template cast(); - llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) { - p << attr.cast().getInt(); - }); - p << ']'; - }); - - p << "] "; - p.printOptionalAttrDict(op->getAttrs(), - /*elidedAttrs=*/{op.getReassociationAttrName()}); - p << ": " << op.src().getType() << " into " << op.getType(); -} - static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) { - print(p, op); + ::mlir::printReshapeOp(p, op); } static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) { - print(p, op); + ::mlir::printReshapeOp(p, op); } static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) { - print(p, op); + ::mlir::printReshapeOp(p, op); } static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) { - print(p, op); -} - -static constexpr StringRef getReassociationAttrName() { - return "reassociation"; -} - -static ParseResult parseReshapeLikeOp(OpAsmParser &parser, - OperationState &result) { - // Parse the operand. - OpAsmParser::OperandType src; - if (parser.parseOperand(src)) - return failure(); - - // Parse reassociation indices. - Builder &b = parser.getBuilder(); - SmallVector reassociation; - if (parser.parseLSquare()) - return failure(); - - while (true) { - if (succeeded(parser.parseOptionalRSquare())) - break; - if (parser.parseLSquare()) - return failure(); - SmallVector indices; - while (true) { - int64_t index; - if (parser.parseInteger(index)) - return failure(); - indices.push_back(index); - - if (succeeded(parser.parseOptionalComma())) - continue; - if (failed(parser.parseRSquare())) - return failure(); - break; - } - reassociation.push_back(b.getI64ArrayAttr(indices)); - if (succeeded(parser.parseOptionalComma())) - continue; - if (failed(parser.parseRSquare())) - return failure(); - break; - } - - result.addAttribute(getReassociationAttrName(), - b.getArrayAttr(reassociation)); - - // Parse optional attributes. - parser.parseOptionalAttrDict(result.attributes); - - // Parse types. - Type srcType; - Type resultType; - if (parser.parseColon() || parser.parseType(srcType) || - parser.resolveOperand(src, srcType, result.operands) || - parser.parseKeyword("into") || parser.parseType(resultType)) - return failure(); - result.addTypes(resultType); - return success(); -} - -/// Collapse reassociation maps that are used in pair of reshape ops where one -/// is a producer and other is the consumer. Only valid to use this method when -/// both the producer and consumer are collapsing dimensions or both are -/// expanding dimensions. -/// -/// For example, -/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] -/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, -/// affine_map<(d0, d1, d2) -> (d2)>] -/// -/// is folded into -/// -/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] -static Optional> -collapseReassociationIndices(ArrayRef mapsProducer, - ArrayRef mapsConsumer, - MLIRContext *context) { - // Make the producer the larger sized vector. If they are of same size, the - // resulting reshape is not a supported reshape op. - if (mapsProducer.size() == mapsConsumer.size()) - return llvm::None; - if (mapsProducer.size() < mapsConsumer.size()) - std::swap(mapsProducer, mapsConsumer); - - // Handle the corner case of the result being a rank 0 shaped type. Return an - // empty reassociation. - if (mapsConsumer.empty()) - return SmallVector{}; - if (mapsProducer.size() != mapsConsumer[0].getNumDims()) - return llvm::None; - - unsigned currDim = 0; - SmallVector reassociationMaps; - for (AffineMap rhs : mapsConsumer) { - ReassociationIndices reassociations; - for (AffineExpr rhsExpr : rhs.getResults()) { - AffineDimExpr dimExpr = rhsExpr.cast(); - for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); - i < e; ++i) - reassociations.push_back(currDim++); - } - reassociationMaps.push_back(std::move(reassociations)); - } - return reassociationMaps; -} - -namespace { -/// Pattern to collapse producer/consumer reshape ops that are both collapsing -/// dimensions or are both expanding dimensions. -template -struct CollapseReshapeOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, - PatternRewriter &rewriter) const override { - auto srcReshapeOp = reshapeOp.src().template getDefiningOp(); - if (!srcReshapeOp) - return failure(); - - ShapedType resultType = reshapeOp.getResultType(); - Optional> reassociationIndices = - collapseReassociationIndices(srcReshapeOp.getReassociationMaps(), - reshapeOp.getReassociationMaps(), - rewriter.getContext()); - if (!reassociationIndices) - return failure(); - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); - return success(); - } -}; - -/// Pattern to collapse producer/consumer reshape ops that are both collapsing -/// dimensions or are both expanding dimensions. -template -struct CollapseMixedReshapeOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, - PatternRewriter &rewriter) const override { - auto srcReshapeOp = - reshapeOp.src().template getDefiningOp(); - if (!srcReshapeOp) - return failure(); - - ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType(); - ShapedType intermediateType = reshapeOp.getSrcType(); - ShapedType resultType = reshapeOp.getResultType(); - - // If the source reshape can be collapsed/expanded into the target reshape - // they can still be folded. This can only be reasoned about statically - // for cases where - // - either all shapes are static, or - // - The number of dynamic dimensions matches in the source of source and - // result with all other dimensions being 1. - Optional> reassociationIndices = - getReassociationIndicesForReshape(srcReshapeSrcType, resultType); - if (!reassociationIndices) - return failure(); - bool originalOpExpands = - intermediateType.getRank() > srcReshapeSrcType.getRank(); - bool resultingOpExpands = - resultType.getRank() > srcReshapeSrcType.getRank(); - if (!(resultingOpExpands ^ originalOpExpands)) - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); - else - rewriter.replaceOpWithNewOp( - reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices); - return success(); - } -}; -} // namespace - -template -static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, - ArrayRef operands) { - // Fold producer-consumer reshape ops that where the operand type of the - // producer is same as the return type of the consumer. - auto reshapeSrcOp = - reshapeOp.src().template getDefiningOp(); - if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) - return reshapeSrcOp.src(); - // Reshape of a constant can be replaced with a new constant. - if (auto elements = operands.front().dyn_cast_or_null()) { - return elements.reshape( - reshapeOp.getResult().getType().template cast()); - } - return nullptr; -} - -/// Return true if the reassociation specification is valid, false otherwise. -/// When false, the `invalidIndex` integer pointer is optionally filled with the -/// index of the offending reassociation map. -static bool isReassociationValid(ArrayRef reassociation, - int *invalidIndex = nullptr) { - if (reassociation.empty()) - return true; - unsigned nDims = reassociation[0].getNumDims(); - unsigned nextExpectedDim = 0; - for (auto it : llvm::enumerate(reassociation)) { - auto m = it.value(); - if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { - if (invalidIndex) - *invalidIndex = it.index(); - return false; - } - for (auto e : m.getResults()) { - auto d = e.dyn_cast(); - if (!d || d.getPosition() != nextExpectedDim++) { - if (invalidIndex) - *invalidIndex = it.index(); - return false; - } - } - } - if (nextExpectedDim != nDims) { - if (invalidIndex) - *invalidIndex = reassociation.size() - 1; - return false; - } - return true; + ::mlir::printReshapeOp(p, op); } /// Detect whether memref dims [dim, dim + extent) can be reshaped without @@ -1736,106 +1418,12 @@ Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); } -/// Verify that shapes of the reshaped types using following rules -/// 1) if a dimension in the collapsed type is static, then the corresponding -/// dimensions in the expanded shape should be -/// a) static -/// b) the product should be same as the collaped shape. -/// 2) if a dimension in the collaped type is dynamic, one and only one of the -/// corresponding dimensions in the expanded type should be dynamic. This -/// rule is only needed with reshape operations that are expanding. -template -static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, - ShapedType expandedType, - bool isExpandingReshape) { - ArrayRef collapsedShape = collapsedType.getShape(); - ArrayRef expandedShape = expandedType.getShape(); - unsigned expandedDimStart = 0; - for (auto map : llvm::enumerate(op.getReassociationMaps())) { - Optional dynamicShape; - int64_t linearizedStaticShape = 1; - for (auto dim : llvm::enumerate(expandedShape.slice( - expandedDimStart, map.value().getNumResults()))) { - if (ShapedType::isDynamic(dim.value())) { - if (isExpandingReshape && dynamicShape) { - return op->emitOpError("invalid to have a single dimension (") - << map.index() << ") expanded into multiple dynamic dims (" - << expandedDimStart + dynamicShape.getValue() << "," - << expandedDimStart + dim.index() << ")"; - } - dynamicShape = dim.index(); - } else { - linearizedStaticShape *= dim.value(); - } - } - if (dynamicShape) { - if (!ShapedType::isDynamic(collapsedShape[map.index()])) { - return op->emitOpError("expected dimension ") - << map.index() - << " of collapsed type to be dynamic since one or more of the " - "corresponding dimensions in the expanded type is dynamic"; - } - } else { - if (collapsedShape[map.index()] != linearizedStaticShape) { - return op->emitOpError("expected dimension ") - << map.index() << " of collapsed type to be static value of " - << linearizedStaticShape << " "; - } - } - expandedDimStart += map.value().getNumResults(); - } - return success(); -} - -// Common verifier for reshape-like types. Fills `expandedType` and -// `collapsedType` with the proper `src` or `result` type. -template ::value || - std::is_same::value> -static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, - T collapsedType) { - unsigned expandedRank = expandedType.getRank(); - unsigned collapsedRank = collapsedType.getRank(); - if (expandedRank < collapsedRank) - return op.emitOpError("expected the type ") - << expandedType - << " to have higher rank than the type = " << collapsedType; - if (expandedRank == 0) - return op.emitOpError("expected non-zero memref ranks"); - if (expandedRank == collapsedRank) - return op.emitOpError("expected to collapse or expand dims"); - - if (collapsedRank == 0) { - // If collapsed rank is 0, then expanded type must be static shaped and of - // sizes 1. - if (llvm::any_of(expandedType.getShape(), - [](int64_t dim) -> bool { return dim != 1; })) - return op.emitOpError("invalid to reshape tensor/memref with non-unit " - "extent dimensions to zero-rank tensor/memref"); - return success(); - } - if (collapsedRank != op.reassociation().size()) - return op.emitOpError("expected rank of the collapsed type(") - << collapsedRank << ") to be the number of reassociation maps(" - << op.reassociation().size() << ")"; - auto maps = op.getReassociationMaps(); - for (auto it : llvm::enumerate(maps)) - if (it.value().getNumDims() != expandedRank) - return op.emitOpError("expected reassociation map #") - << it.index() << " of same rank as expanded memref(" - << expandedRank << "), but got " << it.value().getNumDims(); - int invalidIdx = 0; - if (!isReassociationValid(maps, &invalidIdx)) - return op.emitOpError("expected reassociation map #") - << invalidIdx << " to be valid and contiguous"; - return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion); -} - -template -static LogicalResult verifyReshapeOp(TensorReshapeOp op, - MemRefType expandedType, +template ::value> +static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, MemRefType collapsedType) { - if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) + if (failed( + verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) return failure(); auto maps = op.getReassociationMaps(); MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); @@ -1923,11 +1511,14 @@ getReassociationIndicesAttribute(b, reassociation)); } -template +template ::value> static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType) { - if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) + if (failed( + verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion))) return failure(); auto maps = op.getReassociationMaps(); diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRDialectUtils + ReshapeOpsUtils.cpp StructuredOpsUtils.cpp StaticValueUtils.cpp diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -0,0 +1,209 @@ +//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" + +using namespace mlir; + +constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; } + +Optional> +mlir::getReassociationIndicesForReshape(ShapedType sourceType, + ShapedType targetType) { + // Make the sourceType greater rank than the targetType. If they are same + // rank, then its an unsupported reshape op. + if (sourceType.getRank() == targetType.getRank()) + return llvm::None; + if (sourceType.getRank() < targetType.getRank()) + std::swap(sourceType, targetType); + + ArrayRef sourceShape = sourceType.getShape(); + ArrayRef targetShape = targetType.getShape(); + unsigned sourceDim = 0; + SmallVector reassociationMap; + reassociationMap.reserve(targetType.getRank()); + + ReassociationIndices currIndices; + int64_t prodOfCollapsedDims = 1; + while (sourceDim < sourceShape.size()) { + unsigned targetDim = reassociationMap.size(); + + // If all the dimensions of the targetShape are exhausted, then the + // remaining dims in the source shape must be all 1s. So for such cases, set + // 1 as the target shape. The actual reassociation indices will be handled + // later. + int64_t currTargetShape = + (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); + while (sourceShape[sourceDim] != ShapedType::kDynamicSize && + prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && + sourceDim < sourceShape.size()) { + prodOfCollapsedDims *= sourceShape[sourceDim]; + currIndices.push_back(sourceDim++); + } + + // If the current expanded dimension is dynamic, then the collapsed + // dimensions should also be dynamic and product of all previous unprocessed + // dimensions of the expanded shape should be 1. + if (sourceShape[sourceDim] == ShapedType::kDynamicSize && + (currTargetShape != ShapedType::kDynamicSize || + prodOfCollapsedDims != 1)) + return llvm::None; + + // If the collapsed dim is dynamic, the current expanded dim should also + // be dynamic. + if (currTargetShape == ShapedType::kDynamicSize && + sourceShape[sourceDim] != ShapedType::kDynamicSize) + return llvm::None; + + // For static shapes, if the product of dimensions of the expanded shape + // should match the collapsed dimension shape. + if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) + return llvm::None; + + currIndices.push_back(sourceDim++); + // If the reassociation is empty but the currIndices is not, this by + // definition is folding unit-dimensions with the result being scalar type. + // So only append the `currIndices` if reassociation map is not empty. + if (targetDim == targetShape.size()) { + if (!reassociationMap.empty() && !currIndices.empty()) + reassociationMap.back().append(currIndices.begin(), currIndices.end()); + // Break out of the loops. We should be done here. + break; + } + reassociationMap.emplace_back(ReassociationIndices{}); + std::swap(reassociationMap.back(), currIndices); + prodOfCollapsedDims = 1; + } + // All the dimensions in the two shapes must have been processed. + if (reassociationMap.size() != targetShape.size() || + sourceDim != sourceShape.size()) + return llvm::None; + return reassociationMap; +} + +ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser, + OperationState &result) { + // Parse the operand. + OpAsmParser::OperandType src; + if (parser.parseOperand(src)) + return failure(); + + // Parse reassociation indices. + Builder &b = parser.getBuilder(); + SmallVector reassociation; + if (parser.parseLSquare()) + return failure(); + + while (true) { + if (succeeded(parser.parseOptionalRSquare())) + break; + if (parser.parseLSquare()) + return failure(); + SmallVector indices; + while (true) { + int64_t index; + if (parser.parseInteger(index)) + return failure(); + indices.push_back(index); + + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + break; + } + reassociation.push_back(b.getI64ArrayAttr(indices)); + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + break; + } + + result.addAttribute(getReassociationAttrName(), + b.getArrayAttr(reassociation)); + + // Parse optional attributes. + parser.parseOptionalAttrDict(result.attributes); + + // Parse types. + Type srcType; + Type resultType; + if (parser.parseColon() || parser.parseType(srcType) || + parser.resolveOperand(src, srcType, result.operands) || + parser.parseKeyword("into") || parser.parseType(resultType)) + return failure(); + result.addTypes(resultType); + return success(); +} + +Optional> +mlir::collapseReassociationIndices(ArrayRef mapsProducer, + ArrayRef mapsConsumer, + MLIRContext *context) { + // Make the producer the larger sized vector. If they are of same size, the + // resulting reshape is not a supported reshape op. + if (mapsProducer.size() == mapsConsumer.size()) + return llvm::None; + if (mapsProducer.size() < mapsConsumer.size()) + std::swap(mapsProducer, mapsConsumer); + + // Handle the corner case of the result being a rank 0 shaped type. Return an + // empty reassociation. + if (mapsConsumer.empty()) + return SmallVector{}; + if (mapsProducer.size() != mapsConsumer[0].getNumDims()) + return llvm::None; + + unsigned currDim = 0; + SmallVector reassociationMaps; + for (AffineMap rhs : mapsConsumer) { + ReassociationIndices reassociations; + for (AffineExpr rhsExpr : rhs.getResults()) { + AffineDimExpr dimExpr = rhsExpr.cast(); + for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); + i < e; ++i) + reassociations.push_back(currDim++); + } + reassociationMaps.push_back(std::move(reassociations)); + } + return reassociationMaps; +} + +bool mlir::isReassociationValid(ArrayRef reassociation, + int *invalidIndex) { + if (reassociation.empty()) + return true; + unsigned nDims = reassociation[0].getNumDims(); + unsigned nextExpectedDim = 0; + for (auto it : llvm::enumerate(reassociation)) { + auto m = it.value(); + if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { + if (invalidIndex) + *invalidIndex = it.index(); + return false; + } + for (auto e : m.getResults()) { + auto d = e.dyn_cast(); + if (!d || d.getPosition() != nextExpectedDim++) { + if (invalidIndex) + *invalidIndex = it.index(); + return false; + } + } + } + if (nextExpectedDim != nDims) { + if (invalidIndex) + *invalidIndex = reassociation.size() - 1; + return false; + } + return true; +}