diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -33,6 +33,8 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include + using namespace mlir; using namespace mlir::linalg; @@ -1225,47 +1227,46 @@ /// expanding dimensions. /// /// For example, -/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] -/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, -/// affine_map<(d0, d1, d2) -> (d2)>] +/// producerReassociation = [[0, 1], [2], [3, 4]] +/// consumerReassociation = [[0, 1], [2]] /// /// is folded into /// -/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, -/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] -static Optional> -collapseReassociationIndices(ArrayRef mapsProducer, - ArrayRef mapsConsumer, - MLIRContext *context) { +/// result = [[0, 1, 2], [3, 4]]. +static Optional> composeReassociationIndices( + ArrayRef producerReassociations, + ArrayRef consumerReassociations, + MLIRContext *context) { + SmallVector composedIndices; // Make the producer the larger sized vector. If they are of same size, the // resulting reshape is not a supported reshape op. - if (mapsProducer.size() == mapsConsumer.size()) + if (producerReassociations.size() == consumerReassociations.size()) return llvm::None; - if (mapsProducer.size() < mapsConsumer.size()) - std::swap(mapsProducer, mapsConsumer); + if (producerReassociations.size() < consumerReassociations.size()) + std::swap(producerReassociations, consumerReassociations); // Handle the corner case of the result being a rank 0 shaped type. Return an // empty reassociation. - if (mapsConsumer.empty()) - return SmallVector{}; - if (mapsProducer.size() != mapsConsumer[0].getNumDims()) + if (consumerReassociations.empty()) + return composedIndices; + + size_t consumerDims = std::accumulate( + consumerReassociations.begin(), consumerReassociations.end(), 0, + [](size_t all, ReassociationIndicesRef indices) { + return all + indices.size(); + }); + if (producerReassociations.size() != consumerDims) return llvm::None; - unsigned currDim = 0; - SmallVector reassociationMaps; - for (AffineMap rhs : mapsConsumer) { + for (ReassociationIndicesRef consumerIndices : consumerReassociations) { ReassociationIndices reassociations; - for (AffineExpr rhsExpr : rhs.getResults()) { - AffineDimExpr dimExpr = rhsExpr.cast(); - for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); - i < e; ++i) - reassociations.push_back(currDim++); + for (int64_t consumerIndex : consumerIndices) { + for (int64_t producerIndex : producerReassociations[consumerIndex]) + reassociations.push_back(producerIndex); } - reassociationMaps.push_back(std::move(reassociations)); + composedIndices.push_back(std::move(reassociations)); } - return reassociationMaps; + return composedIndices; } namespace { @@ -1282,9 +1283,9 @@ ShapedType resultType = reshapeOp.getResultType(); Optional> reassociationIndices = - collapseReassociationIndices(srcReshapeOp.getReassociationMaps(), - reshapeOp.getReassociationMaps(), - rewriter.getContext()); + composeReassociationIndices(srcReshapeOp.getReassociationIndices(), + reshapeOp.getReassociationIndices(), + rewriter.getContext()); if (!reassociationIndices) return failure(); rewriter.replaceOpWithNewOp(