diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -52,7 +52,7 @@ explicit operator bool() { return matchedOperation != nullptr; } - Operation *getMatchedOperation() { return matchedOperation; } + Operation *getMatchedOperation() const { return matchedOperation; } ArrayRef getMatchedChildren() { return matchedChildren; } private: diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -538,25 +538,25 @@ /// Up to 3-D patterns are supported. /// If the command line argument requests a pattern of higher order, returns an /// empty pattern list which will conservatively result in no vectorization. -static std::vector -makePatterns(const DenseSet ¶llelLoops, int vectorRank, - ArrayRef fastestVaryingPattern) { +static Optional +makePattern(const DenseSet ¶llelLoops, int vectorRank, + ArrayRef fastestVaryingPattern) { using matcher::For; int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0]; int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1]; int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2]; switch (vectorRank) { case 1: - return {For(isVectorizableLoopPtrFactory(parallelLoops, d0))}; + return For(isVectorizableLoopPtrFactory(parallelLoops, d0)); case 2: - return {For(isVectorizableLoopPtrFactory(parallelLoops, d0), - For(isVectorizableLoopPtrFactory(parallelLoops, d1)))}; + return For(isVectorizableLoopPtrFactory(parallelLoops, d0), + For(isVectorizableLoopPtrFactory(parallelLoops, d1))); case 3: - return {For(isVectorizableLoopPtrFactory(parallelLoops, d0), - For(isVectorizableLoopPtrFactory(parallelLoops, d1), - For(isVectorizableLoopPtrFactory(parallelLoops, d2))))}; + return For(isVectorizableLoopPtrFactory(parallelLoops, d0), + For(isVectorizableLoopPtrFactory(parallelLoops, d1), + For(isVectorizableLoopPtrFactory(parallelLoops, d2)))); default: { - return std::vector(); + return llvm::None; } } } @@ -1259,6 +1259,48 @@ return vectorizeLoopNest(loopsToVectorize, strategy); } +/// Traverses all the loop matches and classifies them into intersection +/// buckets. Two matches intersect if any of them encloses the other one. A +/// match intersects with a bucket if the match intersects with the root +/// (outermost) loop in that bucket. +static void computeIntersectionBuckets( + ArrayRef matches, + std::vector> &intersectionBuckets) { + assert(intersectionBuckets.empty() && "Expected empty output"); + // Keeps track of the root (outermost) loop of each bucket. + SmallVector bucketRoots; + + for (const NestedMatch &match : matches) { + AffineForOp matchRoot = cast(match.getMatchedOperation()); + bool intersects = false; + for (int i = 0, end = intersectionBuckets.size(); i < end; ++i) { + AffineForOp bucketRoot = bucketRoots[i]; + // Add match to the bucket if the bucket root encloses the match root. + if (bucketRoot->isAncestor(matchRoot)) { + intersectionBuckets[i].push_back(match); + intersects = true; + break; + } + // Add match to the bucket if the match root encloses the bucket root. The + // match root becomes the new bucket root. + if (matchRoot->isAncestor(bucketRoot)) { + bucketRoots[i] = matchRoot; + intersectionBuckets[i].push_back(match); + intersects = true; + break; + } + } + + // Match doesn't intersect with any existing bucket. Create a new bucket for + // it. + if (!intersects) { + bucketRoots.push_back(matchRoot); + intersectionBuckets.push_back(SmallVector()); + intersectionBuckets.back().push_back(match); + } + } +} + /// Internal implementation to vectorize affine loops in 'loops' using the n-D /// vectorization factors in 'vectorSizes'. By default, each vectorization /// factor is applied inner-to-outer to the loops of each loop nest. @@ -1267,35 +1309,51 @@ static void vectorizeLoops(Operation *parentOp, DenseSet &loops, ArrayRef vectorSizes, ArrayRef fastestVaryingPattern) { - for (auto &pat : - makePatterns(loops, vectorSizes.size(), fastestVaryingPattern)) { - LLVM_DEBUG(dbgs() << "\n******************************************"); - LLVM_DEBUG(dbgs() << "\n******************************************"); - LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n"); - LLVM_DEBUG(parentOp->print(dbgs())); - - unsigned patternDepth = pat.getDepth(); - - SmallVector matches; - pat.match(parentOp, &matches); - // Iterate over all the top-level matches and vectorize eagerly. - // This automatically prunes intersecting matches. - for (auto m : matches) { + // Compute 1-D, 2-D or 3-D loop pattern to be matched on the target loops. + Optional pattern = + makePattern(loops, vectorSizes.size(), fastestVaryingPattern); + if (!pattern.hasValue()) { + LLVM_DEBUG(dbgs() << "\n[early-vect] pattern couldn't be computed\n"); + return; + } + + LLVM_DEBUG(dbgs() << "\n******************************************"); + LLVM_DEBUG(dbgs() << "\n******************************************"); + LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n"); + LLVM_DEBUG(dbgs() << *parentOp << "\n"); + + unsigned patternDepth = pattern->getDepth(); + + // Compute all the pattern matches and classify them into buckets of + // intersecting matches. + SmallVector allMatches; + pattern->match(parentOp, &allMatches); + std::vector> intersectionBuckets; + computeIntersectionBuckets(allMatches, intersectionBuckets); + + // Iterate over all buckets and vectorize the matches eagerly. We can only + // vectorize one match from each bucket since all the matches within a bucket + // intersect. + for (auto &intersectingMatches : intersectionBuckets) { + for (NestedMatch &match : intersectingMatches) { VectorizationStrategy strategy; // TODO: depending on profitability, elect to reduce the vector size. strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end()); - if (failed(analyzeProfitability(m.getMatchedChildren(), 1, patternDepth, - &strategy))) { + if (failed(analyzeProfitability(match.getMatchedChildren(), 1, + patternDepth, &strategy))) { continue; } - vectorizeLoopIfProfitable(m.getMatchedOperation(), 0, patternDepth, + vectorizeLoopIfProfitable(match.getMatchedOperation(), 0, patternDepth, &strategy); - // TODO: if pattern does not apply, report it; alter the - // cost/benefit. - (void)vectorizeRootMatch(m, strategy); + // Vectorize match. Skip the rest of intersecting matches in the bucket if + // vectorization succeeded. + // TODO: if pattern does not apply, report it; alter the cost/benefit. // TODO: some diagnostics if failure to vectorize occurs. + if (succeeded(vectorizeRootMatch(match, strategy))) + break; } } + LLVM_DEBUG(dbgs() << "\n"); }