Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Show First 20 Lines • Show All 77 Lines • ▼ Show 20 Lines | |||||
} | } | ||||
/// A helper class that visits an affine expression and tries to find an | /// A helper class that visits an affine expression and tries to find an | ||||
/// AffineDimExpr to which the corresponding iterator from a GenericOp matches | /// AffineDimExpr to which the corresponding iterator from a GenericOp matches | ||||
/// the desired iterator type. | /// the desired iterator type. | ||||
class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> { | class AffineDimFinder : public AffineExprVisitor<AffineDimFinder> { | ||||
public: | public: | ||||
explicit AffineDimFinder(linalg::GenericOp op) | explicit AffineDimFinder(linalg::GenericOp op) | ||||
: iterTypes(op.getIteratorTypesArray()) {} | : iterTypes(op.getIteratorTypes()) {} | ||||
// Overrides method from AffineExprVisitor. | |||||
void visitDimExpr(AffineDimExpr expr) { | void visitDimExpr(AffineDimExpr expr) { | ||||
aartbik: I know this was already there, but can we use override here to make it more clear that we… | |||||
I added a comment, this is non-vritual function, so I did not use override here. Peiming: I added a comment, this is non-vritual function, so I did not use override here. | |||||
if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()]) { | if (pickedDim == nullptr || | ||||
pickIterType == iterTypes[expr.getPosition()] | |||||
.cast<linalg::IteratorTypeAttr>() | |||||
.getValue()) { | |||||
pickedDim = expr; | pickedDim = expr; | ||||
} | } | ||||
} | } | ||||
/// Set the desired iterator type that we want to pick. | /// Set the desired iterator type that we want to pick. | ||||
void setPickedIterType(utils::IteratorType iterType) { | void setPickedIterType(utils::IteratorType iterType) { | ||||
pickIterType = iterType; | pickIterType = iterType; | ||||
} | } | ||||
/// Get the desired AffineDimExpr. | /// Get the desired AffineDimExpr. | ||||
AffineDimExpr getDimExpr() const { return pickedDim.cast<AffineDimExpr>(); } | AffineDimExpr getDimExpr() const { return pickedDim.cast<AffineDimExpr>(); } | ||||
private: | private: | ||||
/// The picked AffineDimExpr after visit. This must be stored as | /// The picked AffineDimExpr after visit. This must be stored as | ||||
/// `AffineExpr` rather than `AffineDimExpr`, because the latter | /// `AffineExpr` rather than `AffineDimExpr`, because the latter | ||||
/// doesn't have a default ctor. | /// doesn't have a default ctor. | ||||
AffineExpr pickedDim; | AffineExpr pickedDim; | ||||
/// The iterator type that we want. | /// The iterator type that we want. | ||||
utils::IteratorType pickIterType; | utils::IteratorType pickIterType; | ||||
/// The mapping between dim=>iterator type. | /// The mapping between dim=>iterator type. | ||||
SmallVector<utils::IteratorType> iterTypes; | ArrayAttr iterTypes; | ||||
}; | }; | ||||
// Flattens an affine expression into a list of AffineDimExprs. | // Flattens an affine expression into a list of AffineDimExprs. | ||||
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { | struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { | ||||
// Overrides method from AffineExprVisitor. | |||||
void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); } | void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); } | ||||
SmallVector<AffineDimExpr> dims; | SmallVector<AffineDimExpr> dims; | ||||
}; | }; | ||||
} // namespace | } // namespace | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Sparse compiler analysis methods. | // Sparse compiler analysis methods. | ||||
▲ Show 20 Lines • Show All 179 Lines • ▼ Show 20 Lines | if (isSubExp) { | ||||
// E.g., | // E.g., | ||||
// `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] | // `d0 + d1` for indexing t0[lvl0] and `d0 + d2` for indexing t1[lvl0] | ||||
// d0_1 = getNextSliceOffset t0 along lvl0 | // d0_1 = getNextSliceOffset t0 along lvl0 | ||||
// d0_2 = getNextSliceOffset t1 along lvl0 | // d0_2 = getNextSliceOffset t1 along lvl0 | ||||
// if d0_1 == d0_2 then d0 = d0_1 = d0_1 | // if d0_1 == d0_2 then d0 = d0_1 = d0_1 | ||||
// else increase min(d0_1, d0_2). | // else increase min(d0_1, d0_2). | ||||
return false; | return false; | ||||
} | } | ||||
merger.setLoopDependentTensorLevel(ldx, tensor, lvl); | merger.setLoopDependentTensorLevel(ldx, tensor, lvl, dlt); | ||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
case AffineExprKind::Constant: | case AffineExprKind::Constant: | ||||
case AffineExprKind::Mul: | case AffineExprKind::Mul: | ||||
// TODO: Support Mul and Constant AffineExp for slice-based codegen | // TODO: Support Mul and Constant AffineExp for slice-based codegen | ||||
return false; | return false; | ||||
case AffineExprKind::Add: { | case AffineExprKind::Add: { | ||||
▲ Show 20 Lines • Show All 451 Lines • ▼ Show 20 Lines | static bool computeIterationGraph(CodegenEnv &env, SortMask mask, | ||||
std::vector<std::vector<bool>> adjM(numLoops, | std::vector<std::vector<bool>> adjM(numLoops, | ||||
std::vector<bool>(numLoops, false)); | std::vector<bool>(numLoops, false)); | ||||
std::vector<unsigned> inDegree(numLoops, 0); // in-degree of each node. | std::vector<unsigned> inDegree(numLoops, 0); // in-degree of each node. | ||||
const auto iteratorTypes = env.op().getIteratorTypesArray(); | const auto iteratorTypes = env.op().getIteratorTypesArray(); | ||||
// Iterate over the indexing maps of every tensor in the tensor expression. | // Iterate over the indexing maps of every tensor in the tensor expression. | ||||
for (OpOperand &t : env.op()->getOpOperands()) { | for (OpOperand &t : env.op()->getOpOperands()) { | ||||
// Get map and encoding. | // Get map and encoding. | ||||
const auto enc = getSparseTensorEncoding(t.get().getType()); | const auto enc = getSparseTensorEncoding(t.get().getType()); | ||||
assert(env.op().getMatchingIndexingMap(&t).getNumDims() + | |||||
getNumNonTrivialIdxExpOnSparseLvls(env.op()) == | |||||
numLoops); | |||||
// Skips dense inputs/outputs when not requested. | // Skips dense inputs/outputs when not requested. | ||||
const bool isDenseInput = !enc && env.op().isDpsInput(&t); | const bool isDenseInput = !enc && env.op().isDpsInput(&t); | ||||
const bool isDenseOutput = !enc && !isDenseInput; | const bool isDenseOutput = !enc && !isDenseInput; | ||||
if ((isDenseInput && !includesDenseInput(mask)) || | if ((isDenseInput && !includesDenseInput(mask)) || | ||||
(isDenseOutput && !includesDenseOutput(mask))) | (isDenseOutput && !includesDenseOutput(mask))) | ||||
continue; | continue; | ||||
// Push unrelated loops into sparse iteration space, so these | // Push unrelated loops into sparse iteration space, so these | ||||
▲ Show 20 Lines • Show All 704 Lines • ▼ Show 20 Lines | static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, | ||||
}); | }); | ||||
env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls); | env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tids, lvls); | ||||
// Maintain the universal index only if it is actually | // Maintain the universal index only if it is actually | ||||
// consumed by a subsequent lattice point. | // consumed by a subsequent lattice point. | ||||
if (needsUniv) { | if (needsUniv) { | ||||
for (const LatPointId li : env.set(lts).drop_front()) | for (const LatPointId li : env.set(lts).drop_front()) | ||||
if (!env.merger().hasAnySparse(env.lat(li).simple) && | if (!env.merger().hasAnySparse(env.lat(li).simple)) | ||||
!env.merger().hasSparseIdxReduction(env.lat(li).simple)) | |||||
return true; | return true; | ||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
static void genConstantDenseAddressFromLevel(CodegenEnv &env, | static void genConstantDenseAddressFromLevel(CodegenEnv &env, | ||||
OpBuilder &builder, TensorId tid, | OpBuilder &builder, TensorId tid, | ||||
Level startLvl) { | Level startLvl) { | ||||
▲ Show 20 Lines • Show All 164 Lines • ▼ Show 20 Lines | if (tid != env.merger().getOutTensorID()) | ||||
genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); | genConstantDenseAddressFromLevel(env, builder, tid, lvl + 1); | ||||
} | } | ||||
return std::make_pair(loop, isSingleCond); | return std::make_pair(loop, isSingleCond); | ||||
} | } | ||||
/// Ends a single loop in current sequence. Returns new values for needsUniv. | /// Ends a single loop in current sequence. Returns new values for needsUniv. | ||||
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, | static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, | ||||
LoopId idx, LatPointId li, bool needsUniv) { | LoopId idx, LatPointId li, bool needsUniv, | ||||
// End a while-loop. | bool isSingleCond) { | ||||
if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { | |||||
finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp); | if (isSingleCond) { | ||||
} else if (auto forOp = dyn_cast<scf::ForOp>(loop)) { | // Either a for-loop or a while-loop that iterates over a slice. | ||||
I would start with the same comment as in the else, and state it in the affirmative rather than the speculative // End either a for-loop or a while-loop that iterates over a slice. aartbik: I would start with the same comment as in the else, and state it in the affirmative rather than… | |||||
// Any iteration of a reduction for-loop creates a valid lex insert. | // Any iteration creates a valid lex insert. | ||||
if (env.isReduc() && env.getValidLexInsert()) | if (env.isReduc() && env.getValidLexInsert()) | ||||
env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); | env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true)); | ||||
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { | |||||
// End a while-loop. | |||||
finalizeWhileOp(env, rewriter, idx, needsUniv, whileOp); | |||||
} else { | } else { | ||||
needsUniv = false; | needsUniv = false; | ||||
} | } | ||||
env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { | env.genLoopBoundary([&](MutableArrayRef<Value> reduc) { | ||||
env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc); | env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc); | ||||
return std::nullopt; | return std::nullopt; | ||||
}); | }); | ||||
return needsUniv; | return needsUniv; | ||||
} | } | ||||
/// Ends a loop sequence at given level. | /// Ends a loop sequence at given level. | ||||
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp, | static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, | ||||
LoopOrd at, LoopId idx, LoopId ldx) { | unsigned at, unsigned idx, unsigned ldx) { | ||||
assert(!env.getLoopVar(idx)); | assert(!env.getLoopVar(idx)); | ||||
env.emitter().exitCurrentLoopSeq(); | env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc()); | ||||
// Unmark bookkeeping of invariants and loop index. | // Unmark bookkeeping of invariants and loop index. | ||||
genInvariants(env, builder, exp, ldx, /*atStart=*/false); | genInvariants(env, builder, exp, ldx, /*atStart=*/false); | ||||
// Finalize access pattern expansion for sparse tensor output. | // Finalize access pattern expansion for sparse tensor output. | ||||
genExpand(env, builder, at, /*atStart=*/false); | genExpand(env, builder, at, /*atStart=*/false); | ||||
} | } | ||||
/// Recursively generates code while computing iteration lattices in order | /// Recursively generates code while computing iteration lattices in order | ||||
/// to manage the complexity of implementing co-iteration over unions | /// to manage the complexity of implementing co-iteration over unions | ||||
▲ Show 20 Lines • Show All 48 Lines • ▼ Show 20 Lines | for (unsigned j = 0; j < lsize; j++) { | ||||
endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput); | endIf(env, rewriter, ifOp, loop, redInput, cntInput, insInput); | ||||
} else { | } else { | ||||
genStmt(env, rewriter, ej, at + 1); | genStmt(env, rewriter, ej, at + 1); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
// End a loop. | // End a loop. | ||||
needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv); | needsUniv = endLoop(env, rewriter, loop, idx, li, needsUniv, isSingleCond); | ||||
} | } | ||||
// End a loop sequence. | // End a loop sequence. | ||||
endLoopSeq(env, rewriter, exp, at, idx, ldx); | endLoopSeq(env, rewriter, exp, at, idx, ldx); | ||||
} | } | ||||
/// Converts the result computed by the sparse kernel into the required form. | /// Converts the result computed by the sparse kernel into the required form. | ||||
static void genResult(CodegenEnv &env, RewriterBase &rewriter) { | static void genResult(CodegenEnv &env, RewriterBase &rewriter) { | ||||
▲ Show 20 Lines • Show All 176 Lines • Show Last 20 Lines |
I know this was already there, but can we use override here to make it more clear that we implementing the base visitor class?
Or at the very least group all overrides into a
// Visitor method overrides.
...
section?