Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Show First 20 Lines • Show All 207 Lines • ▼ Show 20 Lines | : outTensor(numInputOutputTensors - 1), | ||||
numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), | numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops), | ||||
numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), | numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false), | ||||
lvlTypes(numTensors, | lvlTypes(numTensors, | ||||
std::vector<DimLevelType>(numLoops, DimLevelType::Undef)), | std::vector<DimLevelType>(numLoops, DimLevelType::Undef)), | ||||
loopToLvl(numTensors, | loopToLvl(numTensors, | ||||
std::vector<std::optional<Level>>(numLoops, std::nullopt)), | std::vector<std::optional<Level>>(numLoops, std::nullopt)), | ||||
lvlToLoop(numTensors, | lvlToLoop(numTensors, | ||||
std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), | std::vector<std::optional<LoopId>>(maxLvlRank, std::nullopt)), | ||||
loopToDependencies(numLoops, std::vector<std::optional<Level>>( | loopToDependencies( | ||||
numLoops, std::vector<std::optional<std::pair<Level, DimLevelType>>>( | |||||
numTensors, std::nullopt)), | numTensors, std::nullopt)), | ||||
levelToDependentIdx(numTensors, std::vector<std::vector<LoopId>>( | levelToDependentLoop(numTensors, std::vector<std::vector<LoopId>>( | ||||
maxLvlRank, std::vector<LoopId>())), | maxLvlRank, std::vector<LoopId>())), | ||||
loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} | loopBounds(numLoops, std::make_pair(numTensors, numLoops)) {} | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Lattice methods. | // Lattice methods. | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
ExprId Merger::addTensorExp(TensorId t) { | ExprId Merger::addTensorExp(TensorId t) { | ||||
assert(isValidTensorId(t)); | assert(isValidTensorId(t)); | ||||
▲ Show 20 Lines • Show All 163 Lines • ▼ Show 20 Lines | BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) { | ||||
bool isSingleton = true; | bool isSingleton = true; | ||||
for (const LatPointId p1 : set(s0)) { | for (const LatPointId p1 : set(s0)) { | ||||
if (p0 != p1 && latGT(p0, p1)) { | if (p0 != p1 && latGT(p0, p1)) { | ||||
isSingleton = false; | isSingleton = false; | ||||
break; | break; | ||||
} | } | ||||
} | } | ||||
BitVector simple(lat(p0).bits); | BitVector simple(latPoints[p0].bits); | ||||
bool reset = | bool reset = isSingleton && hasAnySparse(simple); | ||||
isSingleton && (hasAnySparse(simple) || hasSparseIdxReduction(simple)); | const TensorLoopId be = simple.size(); | ||||
// `be`, `b`, and `offset` are `TensorLoopId` in spirit; but we avoid | TensorLoopId offset = 0; // relative to the end | ||||
// using that class in this function because we need to do a bunch of | |||||
// arithmetic on them, so using the newtype would introduce too much | |||||
// boilerplate. | |||||
const unsigned be = simple.size(); | |||||
unsigned offset = 0; // relative to the end | |||||
if (!reset) | if (!reset) | ||||
// Starts resetting from a dense level, so that the first bit (if kept) | // Starts resetting from a dense level, so that the first bit (if kept) | ||||
// is not undefined level-type. | // is not undefined level-type. | ||||
for (unsigned b = 0; b < be; b++) { | for (unsigned b = 0; b < be; b++) { | ||||
if (simple[b] && isDenseDLT(getDimLevelType(TensorLoopId{b}))) { | if (simple[b] && isDenseDLT(getDimLevelType(TensorLoopId{b}))) { | ||||
offset = be - b - 1; // relative to the end | offset = be - b - 1; // relative to the end | ||||
break; | break; | ||||
} | } | ||||
} | } | ||||
// Now apply the two basic rules. We also iterate the bits reversely to always | // Now apply the two basic rules. We also iterate the bits reversely to always | ||||
// keep the rightmost bit (which could possibly be a synthetic tensor). | // keep the rightmost bit (which could possibly be a synthetic tensor). | ||||
for (unsigned b = be - 1 - offset, i = 0; i < be; | for (unsigned b = be - 1 - offset, i = 0; i < be; | ||||
b = b == 0 ? be - 1 : b - 1, i++) { | b = b == 0 ? be - 1 : b - 1, i++) { | ||||
// FIXME: better name? also slice on dense level has locate property as | // Slice on dense level has `locate` property as well, and can be optimized. | ||||
// well. Handle it correctly! | if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { | ||||
if (simple[b] && !isLvlWithNonTrivialIdxExp(TensorLoopId{b})) { | const auto dlt = getDimLevelType(b); | ||||
aartbik: has the `locate` property as well | |||||
const auto dlt = getDimLevelType(TensorLoopId{b}); | |||||
if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { | if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { | ||||
if (reset) | if (reset) | ||||
simple.reset(b); | simple.reset(b); | ||||
reset = true; | reset = true; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
return simple; | return simple; | ||||
} | } | ||||
bool Merger::latGT(LatPointId i, LatPointId j) const { | bool Merger::latGT(LatPointId i, LatPointId j) const { | ||||
const BitVector &bitsi = lat(i).bits; | const BitVector &bitsi = lat(i).bits; | ||||
const BitVector &bitsj = lat(j).bits; | const BitVector &bitsj = lat(j).bits; | ||||
assert(bitsi.size() == bitsj.size()); | assert(bitsi.size() == bitsj.size()); | ||||
if (bitsi.count() > bitsj.count()) { | if (bitsi.count() > bitsj.count()) { | ||||
for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) | for (TensorLoopId b = 0, be = bitsj.size(); b < be; b++) | ||||
if (bitsj[b] && !bitsi[b]) | if (bitsj[b] && !bitsi[b]) | ||||
return false; | return false; | ||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { | bool Merger::onlyDenseDiff(LatPointId i, LatPointId j) const { | ||||
BitVector tmp(lat(j).bits); | BitVector tmp(latPoints[j].bits); | ||||
tmp ^= lat(i).bits; | tmp ^= latPoints[i].bits; | ||||
return !hasAnySparse(tmp) && !hasSparseIdxReduction(tmp); | return !hasAnySparse(tmp); | ||||
} | } | ||||
bool Merger::expContainsTensor(ExprId e, TensorId t) const { | bool Merger::expContainsTensor(ExprId e, TensorId t) const { | ||||
const auto &expr = exp(e); | const auto &expr = exp(e); | ||||
// First we check `expIsTensor`. | // First we check `expIsTensor`. | ||||
if (expr.kind == TensorExp::Kind::kTensor) | if (expr.kind == TensorExp::Kind::kTensor) | ||||
return expr.tensor == t; | return expr.tensor == t; | ||||
▲ Show 20 Lines • Show All 122 Lines • ▼ Show 20 Lines | bool Merger::isSingleCondition(TensorId t, ExprId e) const { | ||||
case TensorExp::Kind::kBinary: | case TensorExp::Kind::kBinary: | ||||
case TensorExp::Kind::kReduce: | case TensorExp::Kind::kReduce: | ||||
return false; | return false; | ||||
} | } | ||||
llvm_unreachable("unexpected kind"); | llvm_unreachable("unexpected kind"); | ||||
} | } | ||||
bool Merger::hasAnySparse(const BitVector &bits) const { | bool Merger::hasAnySparse(const BitVector &bits) const { | ||||
for (TensorLoopId b = 0, be = bits.size(); b < be; b++) | for (TensorLoopId b : bits.set_bits()) { | ||||
if (bits[b]) { | |||||
const auto dlt = getDimLevelType(b); | const auto dlt = getDimLevelType(b); | ||||
if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) | if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) | ||||
return true; | return true; | ||||
} | } | ||||
return false; | return hasSparseIdxReduction(bits); | ||||
} | } | ||||
bool Merger::hasSparseIdxReduction(const BitVector &bits) const { | bool Merger::hasSparseIdxReduction(const BitVector &bits) const { | ||||
// TODO: return false on dense levels. | for (TensorLoopId b : bits.set_bits()) | ||||
for (unsigned b = 0, be = bits.size(); b < be; b++) | if (isSparseLvlWithNonTrivialIdxExp(b)) | ||||
if (bits[b] && isLvlWithNonTrivialIdxExp(b)) | |||||
return true; | return true; | ||||
return false; | return false; | ||||
} | } | ||||
#ifndef NDEBUG | #ifndef NDEBUG | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Print methods (for debugging). | // Print methods (for debugging). | ||||
▲ Show 20 Lines • Show All 875 Lines • Show Last 20 Lines |
has the locate property as well