diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -186,7 +186,7 @@ static void createCompareFuncImplementation( OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny, bool isCoo, - function_ref + function_ref compareBuilder) { OpBuilder::InsertionGuard insertionGuard(builder); @@ -195,13 +195,17 @@ Location loc = func.getLoc(); ValueRange args = entryBlock->getArguments(); - scf::IfOp topIfOp; + Value result; auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) { - scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1)); - if (k == 0) { - topIfOp = ifOp; - } else { + bool isFirstDim = (k == 0); + bool isLastDim = (k == nx - 1); + Value val = + compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim); + if (isFirstDim) { + result = val; + } else if (!isLastDim) { OpBuilder::InsertionGuard insertionGuard(builder); + auto ifOp = cast(val.getDefiningOp()); builder.setInsertionPointAfter(ifOp); builder.create(loc, ifOp.getResult(0)); } @@ -209,35 +213,41 @@ forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder); - builder.setInsertionPointAfter(topIfOp); - builder.create(loc, topIfOp.getResult(0)); + builder.setInsertionPointAfterValue(result); + builder.create(loc, result); } -/// Generates an if-statement to compare whether x[i] is equal to x[j]. -static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i, - Value j, Value x, bool isLastDim) { - Value f = constantI1(builder, loc, false); - Value t = constantI1(builder, loc, true); +/// Generates code to compare whether x[i] is equal to x[j] and returns the +/// result of the comparison. +static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j, + Value x, bool isFirstDim, bool isLastDim) { Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); - Value cond = - builder.create(loc, arith::CmpIPredicate::eq, vi, vj); - scf::IfOp ifOp = - builder.create(loc, f.getType(), cond, /*else=*/true); - - // x[1] != x[j]: - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - builder.create(loc, f); + Value res; + if (isLastDim) { + res = builder.create(loc, arith::CmpIPredicate::eq, vi, vj); + // For 1D, we create a compare without any control flow. Otherwise, we + // create YieldOp to return the result in the nested if-stmt. + if (!isFirstDim) + builder.create(loc, res); + } else { + Value ne = + builder.create(loc, arith::CmpIPredicate::ne, vi, vj); + scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), + ne, /*else=*/true); + // If (x[i] != x[j]). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value f = constantI1(builder, loc, false); + builder.create(loc, f); - // x[i] == x[j]: - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - if (isLastDim == 1) { - // Finish checking all dimensions. - builder.create(loc, t); + // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that + // checks the remaining dimensions. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + res = ifOp.getResult(0); } - return ifOp; + return res; } /// Creates a function to compare whether xs[i] is equal to xs[j]. @@ -260,58 +270,49 @@ createEqCompare); } -/// Generates an if-statement to compare whether x[i] is less than x[j]. -static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc, - Value i, Value j, Value x, - bool isLastDim) { - Value f = constantI1(builder, loc, false); - Value t = constantI1(builder, loc, true); +/// Generates code to compare whether x[i] is less than x[j] and returns the +/// result of the comparison. +static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i, + Value j, Value x, bool isFirstDim, + bool isLastDim) { Value vi = builder.create(loc, x, i); Value vj = builder.create(loc, x, j); - Value cond = - builder.create(loc, arith::CmpIPredicate::ult, vi, vj); - scf::IfOp ifOp = - builder.create(loc, f.getType(), cond, /*else=*/true); - // If (x[i] < x[j]). - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - builder.create(loc, t); - - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); - if (isLastDim == 1) { - // Finish checking all dimensions. - builder.create(loc, f); + Value res; + if (isLastDim) { + res = builder.create(loc, arith::CmpIPredicate::ult, vi, vj); + // For 1D, we create a compare without any control flow. Otherwise, we + // create YieldOp to return the result in the nested if-stmt. + if (!isFirstDim) + builder.create(loc, res); } else { - cond = - builder.create(loc, arith::CmpIPredicate::ult, vj, vi); - scf::IfOp ifOp2 = - builder.create(loc, f.getType(), cond, /*else=*/true); - // Otherwise if (x[j] < x[i]). - builder.setInsertionPointToStart(&ifOp2.getThenRegion().front()); - builder.create(loc, f); - - // Otherwise check the remaining dimensions. - builder.setInsertionPointAfter(ifOp2); - builder.create(loc, ifOp2.getResult(0)); - // Set up the insertion point for the nested if-stmt that checks the - // remaining dimensions. - builder.setInsertionPointToStart(&ifOp2.getElseRegion().front()); + Value ne = + builder.create(loc, arith::CmpIPredicate::ne, vi, vj); + scf::IfOp ifOp = builder.create(loc, builder.getIntegerType(1), + ne, /*else=*/true); + // If (x[i] != x[j]). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value lt = + builder.create(loc, arith::CmpIPredicate::ult, vi, vj); + builder.create(loc, lt); + + // If (x[i] == x[j]). Set up the insertion point for the nested if-stmt that + // checks the remaining dimensions. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + res = ifOp.getResult(0); } - return ifOp; + return res; } /// Creates a function to compare whether xs[i] is less than xs[j]. // // The generate IR corresponds to this C like algorithm: -// if (x0[i] < x0[j]) -// return true; -// else if (x0[j] < x0[i]) -// return false; +// if (x0[i] != x0[j]) +// return x0[i] < x0[j]; +// else if (x1[j] != x1[i]) +// return x1[i] < x1[j]; // else -// if (x1[i] < x1[j]) -// return true; -// else if (x1[j] < x1[i])) // and so on ... static void createLessThanFunc(OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx, uint64_t ny,