diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -22,16 +22,23 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// -// File local helper functions. +// File local shorthand macros //===----------------------------------------------------------------------===// #define CMPI(p, l, r) \ - (builder.create(loc, arith::CmpIPredicate::p, l, r) \ + (builder.create(loc, arith::CmpIPredicate::p, (l), (r)) \ .getResult()) -#define ADDI(lhs, rhs) (builder.create(loc, lhs, rhs)) +#define C_IDX(v) (constantIndex(builder, loc, (v))) +#define ADDI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) +#define ANDI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) +#define SUBI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) +#define MULI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) +#define SELECT(c, l, r) (builder.create(loc, (c), (l), (r))) -#define C_IDX(v) (constantIndex(builder, loc, v)) +//===----------------------------------------------------------------------===// +// File local helper functions. +//===----------------------------------------------------------------------===// /// Generates a pointer/index load from the sparse storage scheme. Narrower /// data types need to be zero extended before casting the value into the @@ -73,9 +80,7 @@ static Value toSliceCrd(OpBuilder &builder, Location loc, Value crd, Value offset, Value stride, Value tensor, Level lvl) { // tensorCrd = sliceCrd * stride + offset - crd = builder.create(loc, crd, stride); - crd = builder.create(loc, crd, offset); - return crd; + return ADDI(MULI(crd, stride), offset); } /// Generates code to compute the (absolute) offset of the slice based on the @@ -86,12 +91,11 @@ static Value offsetFromMinCoord(OpBuilder &builder, Location loc, Value minCrd, Value size, Value isNonEmpty) { Value geSize = CMPI(uge, minCrd, size); - Value pred = builder.create(loc, isNonEmpty, geSize); - // offset - Value mp1 = builder.create(loc, minCrd, C_IDX(1)); - Value mms = builder.create(loc, mp1, size); + Value pred = ANDI(isNonEmpty, geSize); + // Computes minCrd - size + 1 + Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size); // This is the absolute offset related to the underly tensor. - return builder.create(loc, pred, mms, C_IDX(0)); + return SELECT(pred, mms, C_IDX(0)); } /// Converts a coordinate relative to the underlying tensor to the coordinate @@ -103,7 +107,7 @@ Value stride, Value tensor, Level lvl) { // sliceCrd = (tensorCrd - offset) / stride - crd = builder.create(loc, crd, offset); + crd = SUBI(crd, offset); Value rem = builder.create(loc, crd, stride); crd = builder.create(loc, crd, stride); return std::make_pair(crd, rem); @@ -144,7 +148,7 @@ // Must meet all condition to be a valid coordinate in slice. auto pred = conds.front(); for (auto cond : ValueRange(conds).drop_front()) - pred = builder.create(loc, pred, cond); + pred = ANDI(pred, cond); return {newCrd, pred}; } @@ -156,11 +160,11 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value crd) { Value pos = lvl == 0 ? C_IDX(0) : posits[tid][lvl - 1]; - Value mul = builder.create(loc, highs[tid][lvl], pos); + Value mul = MULI(highs[tid][lvl], pos); if (isSparseSlices[tid]) crd = toSliceCrd(builder, loc, crd, sliceOffsets[tid][lvl], sliceStrides[tid][lvl], tensors[tid], lvl); - Value add = builder.create(loc, mul, crd); + Value add = ADDI(mul, crd); return add; } @@ -198,7 +202,7 @@ /*afterBuilder=*/ [](OpBuilder &builder, Location loc, ValueRange ivs) { // pos ++ - Value nextPos = builder.create(loc, ivs[0], C_IDX(1)); + Value nextPos = ADDI(ivs[0], C_IDX(1)); builder.create(loc, nextPos); }); // Return the segment high. @@ -218,10 +222,9 @@ const Value pos = posits[tid][dstLvl]; const Value off = genIndexLoad(builder, loc, mem, pos); // Linearized the coordinates within the same collapse reassociation. - crd = builder.create(loc, crd, off); + crd = ADDI(crd, off); if (i != reassocSize - 1) { - crd = builder.create(loc, crd, - this->lvlSizes[tid][reassoc[i + 1]]); + crd = MULI(crd, this->lvlSizes[tid][reassoc[i + 1]]); } } return crd; @@ -455,7 +458,7 @@ Value size = c0; for (unsigned e = depLvls.size() - 1; e >= 1; e--) { auto [dt, dd] = depLvls[e]; - size = builder.create(loc, size, lvlSizes[dt][dd]); + size = ADDI(size, lvlSizes[dt][dd]); sliceSizes[t][lvl][e - 1] = size; } } @@ -504,7 +507,7 @@ // TODO: we could probably use an SSA value for it. Value sPtrBuf = slicePosBuffer[tid][lvl].back(); Value curP = genIndexLoad(builder, loc, sPtrBuf, c1); - Value nexP = builder.create(loc, curP, c2); + Value nexP = ADDI(curP, c2); builder.create(loc, nexP, sPtrBuf, c1); } } @@ -525,15 +528,13 @@ } case AffineExprKind::Add: { auto binOp = a.cast(); - return builder.create( - loc, genAffine(builder, loc, binOp.getLHS()), - genAffine(builder, loc, binOp.getRHS())); + return ADDI(genAffine(builder, loc, binOp.getLHS()), + genAffine(builder, loc, binOp.getRHS())); } case AffineExprKind::Mul: { auto binOp = a.cast(); - return builder.create( - loc, genAffine(builder, loc, binOp.getLHS()), - genAffine(builder, loc, binOp.getRHS())); + return MULI(genAffine(builder, loc, binOp.getLHS()), + genAffine(builder, loc, binOp.getRHS())); } case AffineExprKind::Constant: { int64_t c = a.cast().getValue(); @@ -717,7 +718,7 @@ // We need to substract the offset to get relative coordinates. // TODO: how to assert relC >=0 during runtime? - insertPoint = builder.create(loc, absC, offset); + insertPoint = SUBI(absC, offset); posits[tid][lvl] = iv; coords[tid][lvl] = insertPoint->getResult(0); }) @@ -743,8 +744,8 @@ condTid.push_back(tid); condLvl.push_back(lvl); } else { - hi = builder.create(loc, lvlSizes[tid][lvl], sliceSz); - hi = builder.create(loc, hi, C_IDX(1)); + hi = SUBI(lvlSizes[tid][lvl], sliceSz); + hi = ADDI(hi, C_IDX(1)); } } else { condTid.push_back(tid); @@ -761,9 +762,7 @@ bool fullyReduc = depFullyReduced(t, l); SliceInfo &info = sliceStack[t].back(); if (fullyReduc) { - posits[t][l] = - genAddress(builder, loc, t, l, - builder.create(loc, info.offset, iv)); + posits[t][l] = genAddress(builder, loc, t, l, ADDI(info.offset, iv)); } else { // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to // exit it. @@ -930,7 +929,7 @@ // We used the first level bound as the bound the collapsed set of levels. Value op2 = highs[tid][reassoc.front()]; Value opc = CMPI(ult, op1, op2); - cond = cond ? builder.create(loc, cond, opc) : opc; + cond = cond ? ANDI(cond, opc) : opc; // Update positions Value pos = after->getArgument(o++); // For COO, the position is the same across consecutive levels. @@ -973,14 +972,13 @@ // // This "idx" is the index into `llvm::zip(tids, lvls)` for (auto [pred, idx] : slicesPreds) { - Value nextPos = builder.create(loc, yields[idx], C_IDX(1)); - yields[idx] = - builder.create(loc, pred, yields[idx], nextPos); + Value nextPos = ADDI(yields[idx], C_IDX(1)); + yields[idx] = SELECT(pred, yields[idx], nextPos); } Value pred = slicesPreds.front().first; for (int i = 1, e = slicesPreds.size(); i < e; i++) { - pred = builder.create(loc, pred, slicesPreds[i].first); + pred = ANDI(pred, slicesPreds[i].first); } auto ifOp = builder.create(loc, types, pred, /*else*/ true); ifOp->setAttr(getLoopEmitterLoopAttrName(), @@ -1004,8 +1002,7 @@ const auto crd = coords[tid][lvl]; if (min) { Value cmp = CMPI(ult, coords[tid][lvl], min); - min = - builder.create(loc, cmp, coords[tid][lvl], min); + min = SELECT(cmp, coords[tid][lvl], min); } else { min = crd; } @@ -1104,7 +1101,7 @@ const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; posits[tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo); - const Value pHi = builder.create(loc, pLo, c1); + const Value pHi = ADDI(pLo, c1); highs[tid][srcLvl] = genIndexLoad(builder, loc, mem, pHi); return; } @@ -1122,7 +1119,7 @@ highs[tid][srcLvl] = (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && parentSegHi) ? parentSegHi - : builder.create(loc, pLo, c1); + : ADDI(pLo, c1); return; } } @@ -1313,9 +1310,9 @@ // segment high. Value add = !isUniqueDLT(lvlTypes[tid][reassoc.back()]) ? segHi[tid][reassoc.back()] - : builder.create(loc, pos, one); + : ADDI(pos, one); - operands.push_back(builder.create(loc, cmp, add, pos)); + operands.push_back(SELECT(cmp, add, pos)); // Following loops continue iteration from the break point of the // current while loop. const Value newPos = whileOp->getResult(o++); @@ -1345,7 +1342,7 @@ if (operands.size() + delta < whileOp.getNumResults()) { assert(operands.size() + delta + 1 == whileOp.getNumResults()); // The last one is the universial index. - operands.push_back(builder.create(loc, iv, one)); + operands.push_back(ADDI(iv, one)); // update the loop starting point of current loop sequence loopSeqStack.back().first = whileOp->getResult(o++); } @@ -1402,8 +1399,7 @@ TensorId tid, Level lvl, size_t depth, ValueRange userReduc, bool genYield, LoopBodyBuilder bodyBuilder) { Value c1 = C_IDX(1); - Value sliceHi = - builder.create(loc, offset, sliceSizes[tid][lvl].back()); + Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back()); SmallVector reduc = { loopLo, // loop lower bounds @@ -1418,7 +1414,7 @@ Value lo = args[0]; Value cont = args[1]; Value inBound = CMPI(ult, lo, loopHi); - Value cond = builder.create(loc, cont, inBound); + Value cond = ANDI(cont, inBound); // continue if not yet break nor out of bound. builder.create(loc, cond, args); }, @@ -1458,7 +1454,7 @@ // Insertion point restored to after ifOp. SmallVector yields; // Increase induction variable. - yields.push_back(builder.create(loc, iv, c1)); + yields.push_back(ADDI(iv, c1)); yields.push_back(cont); yields.append(ifOp.getResults().begin(), ifOp.getResults().end()); builder.create(loc, yields); @@ -1540,24 +1536,25 @@ Value offset = slice->offset; Value sliceSz = sliceSizes[tid][sliceLvl][slice->depth - 1]; lbs.push_back(offset); - ubs.push_back(builder.create(loc, offset, sliceSz)); + ubs.push_back(ADDI(offset, sliceSz)); steps.push_back(c1); lvlSzs.push_back(lvlSizes[tid][sliceLvl]); } - auto denseNest = scf::buildLoopNest( - builder, loc, lbs, ubs, steps, innerArgs, - [&innerArgs, &lvlSzs, &ip, - &pos](OpBuilder &builder, Location loc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - for (auto em : llvm::enumerate(ivs)) { - // Linearizes postion: pos = (pos * lvlsize) + iv; - pos = builder.create(loc, pos, lvlSzs[em.index()]); - pos = builder.create(loc, pos, em.value()); - } - ip = builder.saveInsertionPoint(); - innerArgs.assign(iterArgs.begin(), iterArgs.end()); - return innerArgs; - }); + auto denseNest = + scf::buildLoopNest(builder, loc, lbs, ubs, steps, innerArgs, + [&innerArgs, &lvlSzs, &ip, &pos]( + OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + for (auto em : llvm::enumerate(ivs)) { + // Linearizes postion: pos = (pos * lvlsize) + + // iv; + pos = MULI(pos, lvlSzs[em.index()]); + pos = ADDI(pos, em.value()); + } + ip = builder.saveInsertionPoint(); + innerArgs.assign(iterArgs.begin(), iterArgs.end()); + return innerArgs; + }); } // Generates user request loop body. builder.restoreInsertionPoint(ip); @@ -1648,7 +1645,7 @@ Value &minCrd = reduc[1]; Value &curMemSz = reduc[2]; - Value pHi = builder.create(loc, iv, c1); + Value pHi = ADDI(iv, c1); Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv); Value sPHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi); @@ -1666,8 +1663,7 @@ Value curC = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], sPLo); Value isSmaller = CMPI(ult, curC, minCrd); - Value newMin = - builder.create(loc, isSmaller, curC, minCrd); + Value newMin = SELECT(isSmaller, curC, minCrd); builder.create(loc, newMin); builder.setInsertionPointToStart(ifNonEmpty.elseBlock()); builder.create(loc, minCrd); @@ -1675,10 +1671,10 @@ minCrd = ifNonEmpty.getResult(0); // filles in builder.create(loc, sPLo, sPtrBuf, curMemSz); - Value nxtMemSize = builder.create(loc, curMemSz, c1); + Value nxtMemSize = ADDI(curMemSz, c1); builder.create(loc, sPHi, sPtrBuf, nxtMemSize); // curMemSize += 2 - curMemSz = builder.create(loc, curMemSz, c2); + curMemSz = ADDI(curMemSz, c2); }); unsigned depth = levelReducedDep[tid][lvl]; @@ -1708,8 +1704,8 @@ // algorithm to (co)iterates over the slice. Value pLoPtr = genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), c1); - pLoPtr = builder.create(loc, pLoPtr, c2); - Value pHiPtr = builder.create(loc, pLoPtr, c1); + pLoPtr = ADDI(pLoPtr, c2); + Value pHiPtr = ADDI(pLoPtr, c1); posits[tid][lvl] = genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pLoPtr); highs[tid][lvl] = @@ -1755,14 +1751,14 @@ auto depth = remDepOnLevel(tid, curLevel - 1); assert(sliceSizes[tid][lvl].size() >= depth); Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1); - bufSize = builder.create(loc, bufSize, sz); + bufSize = MULI(bufSize, sz); } // For a pair of [pLo, pHi]. Note that we can not compress pHi because slice // creates segments in the index buffer so that the pHi for the current // level is no longer the pLo for the next level. - bufSize = builder.create(loc, bufSize, c2); + bufSize = MULI(bufSize, c2); // Additional two metadata {memSize, idx} at head. - bufSize = builder.create(loc, bufSize, c2); + bufSize = ADDI(bufSize, c2); llvm::for_each( slicePosBuffer[tid][lvl], [bufSize, loc, &builder](Value &cache) { cache = genAlloca(builder, loc, bufSize, builder.getIndexType()); @@ -1828,7 +1824,7 @@ OpBuilder::InsertionGuard guard(builder); // Take the fast path if minCrd > offset builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - reduc[2] = builder.create(loc, absOffset, c1); + reduc[2] = ADDI(absOffset, c1); // Yield offset + 1. builder.create(loc, reduc); @@ -1851,8 +1847,7 @@ Type idxTp = builder.getIndexType(); Value pLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front()); Value pHi = - genIndexLoad(builder, loc, sPtrBuf, - builder.create(loc, ivs.front(), c1)); + genIndexLoad(builder, loc, sPtrBuf, ADDI(ivs.front(), c1)); // // if pLo < pHi // coord = load[pLo] @@ -1874,7 +1869,7 @@ /* if coord == minCrd */ { builder.setInsertionPointToStart( &ifEqual.getThenRegion().front()); - Value newPlo = builder.create(loc, pLo, c1); + Value newPlo = ADDI(pLo, c1); // Updates the cache. builder.create(loc, newPlo, sPtrBuf, ivs.front()); @@ -1919,11 +1914,10 @@ builder.setInsertionPointAfter(forOp.loops.front()); // minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0 - Value tmp = builder.create(loc, forOp.results.front(), c1); - Value minOffset = builder.create( - loc, tmp, sliceSizes[tid][lvl][info.depth - 1]); + Value tmp = ADDI(forOp.results.front(), c1); + Value minOffset = SUBI(tmp, sliceSizes[tid][lvl][info.depth - 1]); Value p = CMPI(uge, tmp, sliceSizes[tid][lvl][info.depth - 1]); - minOffset = builder.create(loc, p, minOffset, c0); + minOffset = SELECT(p, minOffset, c0); SmallVector yields; yields.assign(forOp.results.begin(), forOp.results.end()); yields.push_back(minOffset); @@ -1935,26 +1929,22 @@ // the next offset should at least be offset + 1; Value minOffset = ifOp.getResults()[2]; - Value nxOffset = builder.create(loc, info.offset, c1); + Value nxOffset = ADDI(info.offset, c1); Value maxPred = CMPI(ugt, minOffset, nxOffset); - Value nextAbsOffset = - builder.create(loc, maxPred, minOffset, nxOffset); + Value nextAbsOffset = SELECT(maxPred, minOffset, nxOffset); - Value sliceUB = builder.create( - loc, nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]); + Value sliceUB = ADDI(nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]); // FIXME: this only works if the parsent is the tensor, we should use the // parents slice size + parent offset. assert(info.depth - 1 == 0); // nextNonEmpty = nextNonEmpty && slice upper bound <= parent upperbound. - nextNonEmpty = builder.create( - loc, nextNonEmpty, CMPI(ule, sliceUB, lvlSizes[tid][lvl])); + nextNonEmpty = ANDI(nextNonEmpty, CMPI(ule, sliceUB, lvlSizes[tid][lvl])); // FIXME: compute relative offset. assert(info.depth - 1 == 0); Value nextRelOffset = nextAbsOffset; - nextRelOffset = - builder.create(loc, nextNonEmpty, nextRelOffset, c0); + nextRelOffset = SELECT(nextNonEmpty, nextRelOffset, c0); operands.push_back(nextNonEmpty); operands.push_back(nextMinCrd); @@ -2018,4 +2008,8 @@ #undef CMPI #undef ADDI +#undef ANDI +#undef MULI +#undef SELECT +#undef SUBI #undef C_IDX