diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -22,16 +22,20 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// -// File local helper functions. +// File local shorthand macros //===----------------------------------------------------------------------===// #define CMPI(p, l, r) \ - (builder.create(loc, arith::CmpIPredicate::p, l, r) \ + (builder.create(loc, arith::CmpIPredicate::p, (l), (r)) \ .getResult()) -#define ADDI(lhs, rhs) (builder.create(loc, lhs, rhs)) +#define C_IDX(v) (constantIndex(builder, loc, (v))) +#define ADDI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) +#define MULI(lhs, rhs) (builder.create(loc, (lhs), (rhs))) -#define C_IDX(v) (constantIndex(builder, loc, v)) +//===----------------------------------------------------------------------===// +// File local helper functions. +//===----------------------------------------------------------------------===// /// Generates a pointer/index load from the sparse storage scheme. Narrower /// data types need to be zero extended before casting the value into the @@ -73,9 +77,7 @@ static Value toSliceCrd(OpBuilder &builder, Location loc, Value crd, Value offset, Value stride, Value tensor, Level lvl) { // tensorCrd = sliceCrd * stride + offset - crd = builder.create(loc, crd, stride); - crd = builder.create(loc, crd, offset); - return crd; + return ADDI(MULI(crd, stride), offset); } /// Generates code to compute the (absolute) offset of the slice based on the @@ -88,7 +90,7 @@ Value geSize = CMPI(uge, minCrd, size); Value pred = builder.create(loc, isNonEmpty, geSize); // offset - Value mp1 = builder.create(loc, minCrd, C_IDX(1)); + Value mp1 = ADDI(minCrd, C_IDX(1)); Value mms = builder.create(loc, mp1, size); // This is the absolute offset related to the underly tensor. return builder.create(loc, pred, mms, C_IDX(0)); @@ -156,11 +158,11 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid, Level lvl, Value crd) { Value pos = lvl == 0 ? C_IDX(0) : posits[tid][lvl - 1]; - Value mul = builder.create(loc, highs[tid][lvl], pos); + Value mul = MULI(highs[tid][lvl], pos); if (isSparseSlices[tid]) crd = toSliceCrd(builder, loc, crd, sliceOffsets[tid][lvl], sliceStrides[tid][lvl], tensors[tid], lvl); - Value add = builder.create(loc, mul, crd); + Value add = ADDI(mul, crd); return add; } @@ -198,7 +200,7 @@ /*afterBuilder=*/ [](OpBuilder &builder, Location loc, ValueRange ivs) { // pos ++ - Value nextPos = builder.create(loc, ivs[0], C_IDX(1)); + Value nextPos = ADDI(ivs[0], C_IDX(1)); builder.create(loc, nextPos); }); // Return the segment high. @@ -218,10 +220,9 @@ const Value pos = posits[tid][dstLvl]; const Value off = genIndexLoad(builder, loc, mem, pos); // Linearized the coordinates within the same collapse reassociation. - crd = builder.create(loc, crd, off); + crd = ADDI(crd, off); if (i != reassocSize - 1) { - crd = builder.create(loc, crd, - this->lvlSizes[tid][reassoc[i + 1]]); + crd = MULI(crd, this->lvlSizes[tid][reassoc[i + 1]]); } } return crd; @@ -455,7 +456,7 @@ Value size = c0; for (unsigned e = depLvls.size() - 1; e >= 1; e--) { auto [dt, dd] = depLvls[e]; - size = builder.create(loc, size, lvlSizes[dt][dd]); + size = ADDI(size, lvlSizes[dt][dd]); sliceSizes[t][lvl][e - 1] = size; } } @@ -504,7 +505,7 @@ // TODO: we could probably use an SSA value for it. Value sPtrBuf = slicePosBuffer[tid][lvl].back(); Value curP = genIndexLoad(builder, loc, sPtrBuf, c1); - Value nexP = builder.create(loc, curP, c2); + Value nexP = ADDI(curP, c2); builder.create(loc, nexP, sPtrBuf, c1); } } @@ -525,15 +526,13 @@ } case AffineExprKind::Add: { auto binOp = a.cast(); - return builder.create( - loc, genAffine(builder, loc, binOp.getLHS()), - genAffine(builder, loc, binOp.getRHS())); + return ADDI(genAffine(builder, loc, binOp.getLHS()), + genAffine(builder, loc, binOp.getRHS())); } case AffineExprKind::Mul: { auto binOp = a.cast(); - return builder.create( - loc, genAffine(builder, loc, binOp.getLHS()), - genAffine(builder, loc, binOp.getRHS())); + return MULI(genAffine(builder, loc, binOp.getLHS()), + genAffine(builder, loc, binOp.getRHS())); } case AffineExprKind::Constant: { int64_t c = a.cast().getValue(); @@ -744,7 +743,7 @@ condLvl.push_back(lvl); } else { hi = builder.create(loc, lvlSizes[tid][lvl], sliceSz); - hi = builder.create(loc, hi, C_IDX(1)); + hi = ADDI(hi, C_IDX(1)); } } else { condTid.push_back(tid); @@ -761,9 +760,7 @@ bool fullyReduc = depFullyReduced(t, l); SliceInfo &info = sliceStack[t].back(); if (fullyReduc) { - posits[t][l] = - genAddress(builder, loc, t, l, - builder.create(loc, info.offset, iv)); + posits[t][l] = genAddress(builder, loc, t, l, ADDI(info.offset, iv)); } else { // Puts sliced dense loop into LoopInfo so that LoopEmitter knows how to // exit it. @@ -973,7 +970,7 @@ // // This "idx" is the index into `llvm::zip(tids, lvls)` for (auto [pred, idx] : slicesPreds) { - Value nextPos = builder.create(loc, yields[idx], C_IDX(1)); + Value nextPos = ADDI(yields[idx], C_IDX(1)); yields[idx] = builder.create(loc, pred, yields[idx], nextPos); } @@ -1104,7 +1101,7 @@ const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; posits[tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo); - const Value pHi = builder.create(loc, pLo, c1); + const Value pHi = ADDI(pLo, c1); highs[tid][srcLvl] = genIndexLoad(builder, loc, mem, pHi); return; } @@ -1122,7 +1119,7 @@ highs[tid][srcLvl] = (!isUniqueDLT(lvlTypes[tid][srcLvl - 1]) && parentSegHi) ? parentSegHi - : builder.create(loc, pLo, c1); + : ADDI(pLo, c1); return; } } @@ -1313,7 +1310,7 @@ // segment high. Value add = !isUniqueDLT(lvlTypes[tid][reassoc.back()]) ? segHi[tid][reassoc.back()] - : builder.create(loc, pos, one); + : ADDI(pos, one); operands.push_back(builder.create(loc, cmp, add, pos)); // Following loops continue iteration from the break point of the @@ -1345,7 +1342,7 @@ if (operands.size() + delta < whileOp.getNumResults()) { assert(operands.size() + delta + 1 == whileOp.getNumResults()); // The last one is the universial index. - operands.push_back(builder.create(loc, iv, one)); + operands.push_back(ADDI(iv, one)); // update the loop starting point of current loop sequence loopSeqStack.back().first = whileOp->getResult(o++); } @@ -1402,8 +1399,7 @@ TensorId tid, Level lvl, size_t depth, ValueRange userReduc, bool genYield, LoopBodyBuilder bodyBuilder) { Value c1 = C_IDX(1); - Value sliceHi = - builder.create(loc, offset, sliceSizes[tid][lvl].back()); + Value sliceHi = ADDI(offset, sliceSizes[tid][lvl].back()); SmallVector reduc = { loopLo, // loop lower bounds @@ -1458,7 +1454,7 @@ // Insertion point restored to after ifOp. SmallVector yields; // Increase induction variable. - yields.push_back(builder.create(loc, iv, c1)); + yields.push_back(ADDI(iv, c1)); yields.push_back(cont); yields.append(ifOp.getResults().begin(), ifOp.getResults().end()); builder.create(loc, yields); @@ -1540,24 +1536,25 @@ Value offset = slice->offset; Value sliceSz = sliceSizes[tid][sliceLvl][slice->depth - 1]; lbs.push_back(offset); - ubs.push_back(builder.create(loc, offset, sliceSz)); + ubs.push_back(ADDI(offset, sliceSz)); steps.push_back(c1); lvlSzs.push_back(lvlSizes[tid][sliceLvl]); } - auto denseNest = scf::buildLoopNest( - builder, loc, lbs, ubs, steps, innerArgs, - [&innerArgs, &lvlSzs, &ip, - &pos](OpBuilder &builder, Location loc, ValueRange ivs, - ValueRange iterArgs) -> scf::ValueVector { - for (auto em : llvm::enumerate(ivs)) { - // Linearizes postion: pos = (pos * lvlsize) + iv; - pos = builder.create(loc, pos, lvlSzs[em.index()]); - pos = builder.create(loc, pos, em.value()); - } - ip = builder.saveInsertionPoint(); - innerArgs.assign(iterArgs.begin(), iterArgs.end()); - return innerArgs; - }); + auto denseNest = + scf::buildLoopNest(builder, loc, lbs, ubs, steps, innerArgs, + [&innerArgs, &lvlSzs, &ip, &pos]( + OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + for (auto em : llvm::enumerate(ivs)) { + // Linearizes postion: pos = (pos * lvlsize) + + // iv; + pos = MULI(pos, lvlSzs[em.index()]); + pos = ADDI(pos, em.value()); + } + ip = builder.saveInsertionPoint(); + innerArgs.assign(iterArgs.begin(), iterArgs.end()); + return innerArgs; + }); } // Generates user request loop body. builder.restoreInsertionPoint(ip); @@ -1648,7 +1645,7 @@ Value &minCrd = reduc[1]; Value &curMemSz = reduc[2]; - Value pHi = builder.create(loc, iv, c1); + Value pHi = ADDI(iv, c1); Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv); Value sPHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi); @@ -1675,10 +1672,10 @@ minCrd = ifNonEmpty.getResult(0); // filles in builder.create(loc, sPLo, sPtrBuf, curMemSz); - Value nxtMemSize = builder.create(loc, curMemSz, c1); + Value nxtMemSize = ADDI(curMemSz, c1); builder.create(loc, sPHi, sPtrBuf, nxtMemSize); // curMemSize += 2 - curMemSz = builder.create(loc, curMemSz, c2); + curMemSz = ADDI(curMemSz, c2); }); unsigned depth = levelReducedDep[tid][lvl]; @@ -1708,8 +1705,8 @@ // algorithm to (co)iterates over the slice. Value pLoPtr = genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), c1); - pLoPtr = builder.create(loc, pLoPtr, c2); - Value pHiPtr = builder.create(loc, pLoPtr, c1); + pLoPtr = ADDI(pLoPtr, c2); + Value pHiPtr = ADDI(pLoPtr, c1); posits[tid][lvl] = genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pLoPtr); highs[tid][lvl] = @@ -1755,14 +1752,14 @@ auto depth = remDepOnLevel(tid, curLevel - 1); assert(sliceSizes[tid][lvl].size() >= depth); Value sz = *(sliceSizes[tid][lvl].rbegin() + depth - 1); - bufSize = builder.create(loc, bufSize, sz); + bufSize = MULI(bufSize, sz); } // For a pair of [pLo, pHi]. Note that we can not compress pHi because slice // creates segments in the index buffer so that the pHi for the current // level is no longer the pLo for the next level. - bufSize = builder.create(loc, bufSize, c2); + bufSize = MULI(bufSize, c2); // Additional two metadata {memSize, idx} at head. - bufSize = builder.create(loc, bufSize, c2); + bufSize = ADDI(bufSize, c2); llvm::for_each( slicePosBuffer[tid][lvl], [bufSize, loc, &builder](Value &cache) { cache = genAlloca(builder, loc, bufSize, builder.getIndexType()); @@ -1828,7 +1825,7 @@ OpBuilder::InsertionGuard guard(builder); // Take the fast path if minCrd > offset builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - reduc[2] = builder.create(loc, absOffset, c1); + reduc[2] = ADDI(absOffset, c1); // Yield offset + 1. builder.create(loc, reduc); @@ -1851,8 +1848,7 @@ Type idxTp = builder.getIndexType(); Value pLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front()); Value pHi = - genIndexLoad(builder, loc, sPtrBuf, - builder.create(loc, ivs.front(), c1)); + genIndexLoad(builder, loc, sPtrBuf, ADDI(ivs.front(), c1)); // // if pLo < pHi // coord = load[pLo] @@ -1874,7 +1870,7 @@ /* if coord == minCrd */ { builder.setInsertionPointToStart( &ifEqual.getThenRegion().front()); - Value newPlo = builder.create(loc, pLo, c1); + Value newPlo = ADDI(pLo, c1); // Updates the cache. builder.create(loc, newPlo, sPtrBuf, ivs.front()); @@ -1919,7 +1915,7 @@ builder.setInsertionPointAfter(forOp.loops.front()); // minOffset = minCrd + 1 >= size ? minCrd + 1 - size : c0 - Value tmp = builder.create(loc, forOp.results.front(), c1); + Value tmp = ADDI(forOp.results.front(), c1); Value minOffset = builder.create( loc, tmp, sliceSizes[tid][lvl][info.depth - 1]); Value p = CMPI(uge, tmp, sliceSizes[tid][lvl][info.depth - 1]); @@ -1935,13 +1931,12 @@ // the next offset should at least be offset + 1; Value minOffset = ifOp.getResults()[2]; - Value nxOffset = builder.create(loc, info.offset, c1); + Value nxOffset = ADDI(info.offset, c1); Value maxPred = CMPI(ugt, minOffset, nxOffset); Value nextAbsOffset = builder.create(loc, maxPred, minOffset, nxOffset); - Value sliceUB = builder.create( - loc, nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]); + Value sliceUB = ADDI(nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]); // FIXME: this only works if the parsent is the tensor, we should use the // parents slice size + parent offset. @@ -2018,4 +2013,5 @@ #undef CMPI #undef ADDI +#undef MULI #undef C_IDX