diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -49,6 +49,12 @@ void startEmit(OpOperand *so, unsigned lv, SparseTensorLoopEmitter *le); + /// Generates loop boundary statements (entering/exiting loops). The function + /// passes and updates the passed-in parameters. + Optional genLoopBoundary( + function_ref(MutableArrayRef parameters)> + callback); + // // Merger delegates. // diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -37,6 +37,27 @@ } } +Optional CodegenEnv::genLoopBoundary( + function_ref(MutableArrayRef parameters)> + callback) { + SmallVector params; + if (isReduc()) + params.push_back(redVal); + if (isExpand()) + params.push_back(expCount); + if (insChain != nullptr) + params.push_back(insChain); + auto r = callback(params); // may update parameters + unsigned i = 0; + if (isReduc()) + updateReduc(params[i++]); + if (isExpand()) + updateExpandCount(params[i++]); + if (insChain != nullptr) + updateInsertionChain(params[i]); + return r; +} + //===----------------------------------------------------------------------===// // Code generation environment topological sort methods //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -604,34 +604,6 @@ // Sparse compiler synthesis methods (statements and expressions). //===----------------------------------------------------------------------===// -/// Generates loop boundary statements (entering/exiting loops). The function -/// passes and updates the reduction value. -static Optional genLoopBoundary( - CodegenEnv &env, - function_ref(MutableArrayRef reduc)> - callback) { - SmallVector reduc; - if (env.isReduc()) - reduc.push_back(env.getReduc()); - if (env.isExpand()) - reduc.push_back(env.getExpandCount()); - if (env.getInsertionChain()) - reduc.push_back(env.getInsertionChain()); - - auto r = callback(reduc); - - // Callback should do in-place update on reduction value vector. - unsigned i = 0; - if (env.isReduc()) - env.updateReduc(reduc[i++]); - if (env.isExpand()) - env.updateExpandCount(reduc[i++]); - if (env.getInsertionChain()) - env.updateInsertionChain(reduc[i]); - - return r; -} - /// Local bufferization of all dense and sparse data structures. static void genBuffers(CodegenEnv &env, OpBuilder &builder) { linalg::GenericOp op = env.op(); @@ -1066,7 +1038,7 @@ isCompressedDLT(env.dlt(tid, idx)) || isSingletonDLT(env.dlt(tid, idx)); bool isParallel = isParallelFor(env, isOuter, isSparse); - Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef reduc) { + Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { if (env.merger().isFilterLoop(idx)) { // extraTids/extraDims must be empty because filter loops only // corresponding to the one and only sparse tensor level. @@ -1092,7 +1064,7 @@ ArrayRef condDims, ArrayRef extraTids, ArrayRef extraDims) { - Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef reduc) { + Operation *loop = *env.genLoopBoundary([&](MutableArrayRef reduc) { // Construct the while-loop with a parameter for each // index. return env.emitter()->enterCoIterationOverTensorsAtDims( @@ -1425,7 +1397,7 @@ needsUniv = false; } - genLoopBoundary(env, [&](MutableArrayRef reduc) { + env.genLoopBoundary([&](MutableArrayRef reduc) { env.emitter()->exitCurrentLoop(rewriter, env.op().getLoc(), reduc); return std::nullopt; });