diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -28,27 +28,22 @@ /// Attribute name for the ArrayAttr which encodes reassociation indices. constexpr StringRef getReassociationAttrName(); -/// Collapse reassociation maps that are used in pair of reshape ops where one +/// Compose reassociation maps that are used in pair of reshape ops where one /// is a producer and other is the consumer. Only valid to use this method when /// both the producer and consumer are collapsing dimensions or both are /// expanding dimensions. /// /// For example, -/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] -/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, -/// affine_map<(d0, d1, d2) -> (d2)>] +/// producerReassociation = [[0, 1], [2], [3, 4]] +/// consumerReassociation = [[0, 1], [2]] /// /// is folded into /// -/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] -/// TODO: Use reassociation indices instead of affine maps here. -Optional> -collapseReassociationIndices(ArrayRef mapsProducer, - ArrayRef mapsConsumer, - MLIRContext *context); +/// result = [[0, 1, 2], [3, 4]]. +Optional> composeReassociationIndices( + ArrayRef producerReassociations, + ArrayRef consumerReassociations, + MLIRContext *context); /// Return the reassociations maps to use to reshape given the source type and /// the target type when possible. Return llvm::None when this computation @@ -210,8 +205,8 @@ ShapedType resultType = reshapeOp.getResultType(); Optional> reassociationIndices = - collapseReassociationIndices(srcReshapeOp.getReassociationMaps(), - reshapeOp.getReassociationMaps(), + composeReassociationIndices(srcReshapeOp.getReassociationIndices(), + reshapeOp.getReassociationIndices(), rewriter.getContext()); if (!reassociationIndices) return failure(); diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -11,6 +11,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include + using namespace mlir; constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; } @@ -145,37 +147,40 @@ return success(); } -Optional> -mlir::collapseReassociationIndices(ArrayRef mapsProducer, - ArrayRef mapsConsumer, - MLIRContext *context) { +Optional> mlir::composeReassociationIndices( + ArrayRef producerReassociations, + ArrayRef consumerReassociations, + MLIRContext *context) { + SmallVector composedIndices; // Make the producer the larger sized vector. If they are of same size, the // resulting reshape is not a supported reshape op. - if (mapsProducer.size() == mapsConsumer.size()) + if (producerReassociations.size() == consumerReassociations.size()) return llvm::None; - if (mapsProducer.size() < mapsConsumer.size()) - std::swap(mapsProducer, mapsConsumer); + if (producerReassociations.size() < consumerReassociations.size()) + std::swap(producerReassociations, consumerReassociations); // Handle the corner case of the result being a rank 0 shaped type. Return an // empty reassociation. - if (mapsConsumer.empty()) - return SmallVector{}; - if (mapsProducer.size() != mapsConsumer[0].getNumDims()) + if (consumerReassociations.empty()) + return composedIndices; + + size_t consumerDims = std::accumulate( + consumerReassociations.begin(), consumerReassociations.end(), 0, + [](size_t all, ReassociationIndicesRef indices) { + return all + indices.size(); + }); + if (producerReassociations.size() != consumerDims) return llvm::None; - unsigned currDim = 0; - SmallVector reassociationMaps; - for (AffineMap rhs : mapsConsumer) { + for (ReassociationIndicesRef consumerIndices : consumerReassociations) { ReassociationIndices reassociations; - for (AffineExpr rhsExpr : rhs.getResults()) { - AffineDimExpr dimExpr = rhsExpr.cast(); - for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); - i < e; ++i) - reassociations.push_back(currDim++); + for (int64_t consumerIndex : consumerIndices) { + for (int64_t producerIndex : producerReassociations[consumerIndex]) + reassociations.push_back(producerIndex); } - reassociationMaps.push_back(std::move(reassociations)); + composedIndices.push_back(std::move(reassociations)); } - return reassociationMaps; + return composedIndices; } bool mlir::isReassociationValid(ArrayRef reassociation,