diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -15,6 +15,8 @@ #define MLIR_TRANSFORMS_PASSES_H #include "mlir/Support/LLVM.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" + #include #include @@ -59,7 +61,10 @@ /// Creates a pass that transforms a single ParallelLoop over N induction /// variables into another ParallelLoop over less than N induction variables. -std::unique_ptr createParallelLoopCollapsingPass(); +std::unique_ptr +createParallelLoopCollapsingPass(ArrayRef collapsedIndices0 = {}, + ArrayRef collapsedIndices1 = {}, + ArrayRef collapsedIndices2 = {}); /// Creates a pass to perform optimizations relying on memref dataflow such as /// store to load forwarding, elimination of dead stores, and dead allocs. diff --git a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp --- a/mlir/lib/Transforms/ParallelLoopCollapsing.cpp +++ b/mlir/lib/Transforms/ParallelLoopCollapsing.cpp @@ -19,8 +19,19 @@ using namespace mlir; namespace { + struct ParallelLoopCollapsing : public ParallelLoopCollapsingBase { + ParallelLoopCollapsing() = default; + ParallelLoopCollapsing(const ParallelLoopCollapsing &){}; + ParallelLoopCollapsing(ArrayRef collapsedIndices0, + ArrayRef collapsedIndices1, + ArrayRef collapsedIndices2) { + clCollapsedIndices0 = collapsedIndices0; + clCollapsedIndices1 = collapsedIndices1; + clCollapsedIndices2 = collapsedIndices2; + } + void runOnOperation() override { Operation *module = getOperation(); @@ -28,18 +39,23 @@ // The common case for GPU dialect will be simplifying the ParallelOp to 3 // arguments, so we do that here to simplify things. llvm::SmallVector, 3> combinedLoops; - if (clCollapsedIndices0.size()) + if (!clCollapsedIndices0.empty()) combinedLoops.push_back(clCollapsedIndices0); - if (clCollapsedIndices1.size()) + if (!clCollapsedIndices1.empty()) combinedLoops.push_back(clCollapsedIndices1); - if (clCollapsedIndices2.size()) + if (!clCollapsedIndices2.empty()) combinedLoops.push_back(clCollapsedIndices2); collapseParallelLoops(op, combinedLoops); }); } }; + } // namespace -std::unique_ptr mlir::createParallelLoopCollapsingPass() { - return std::make_unique(); +std::unique_ptr +mlir::createParallelLoopCollapsingPass(ArrayRef collapsedIndices0, + ArrayRef collapsedIndices1, + ArrayRef collapsedIndices2) { + return std::make_unique( + collapsedIndices0, collapsedIndices1, collapsedIndices2); }