diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -354,13 +354,17 @@ explicit SparseTensorLoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr, bool hasOutput = false, - bool isSparseOut = false); + bool isSparseOut = false, + ArrayRef topSort = {}); /// Starts a loop emitting session by generating all the buffers needed to /// iterate tensors. void initializeLoopEmit(OpBuilder &builder, Location loc, OutputUpdater updater = nullptr); + /// Generates a list of operations to compute the affine expression. + Value genAffine(OpBuilder &builder, AffineExpr a, Location loc); + /// Enters a new loop sequence, the loops within the same sequence starts from /// the break points of previous loop instead of starting over from 0. /// e.g., @@ -544,6 +548,11 @@ // sequence. std::vector loopSeqStack; + // Maps AffineDimExpr to the index of the loop in loopStack. + // TODO: We should probably use a callback function here to make it more + // general. + std::vector sparsiferLoopLvlMap; + // TODO: not yet used, it should track the current level for each tensor // to help eliminate `dim` paramters from above APIs. // std::vector curLv; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -97,12 +97,14 @@ SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput, - bool isSparseOut) + bool isSparseOut, + ArrayRef topSort) : loopTag(loopTag), hasOutput(hasOutput), isSparseOut(isSparseOut), tensors(tensors.begin(), tensors.end()), dimTypes(tensors.size()), pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()), ptrBuffer(tensors.size()), idxBuffer(tensors.size()), - valBuffer(tensors.size()), loopStack() { + valBuffer(tensors.size()), loopStack(), + sparsiferLoopLvlMap(topSort.size(), 0) { for (size_t tid = 0, e = tensors.size(); tid < e; tid++) { auto t = tensors[tid]; // a scalar or 0-dimension tensors @@ -126,6 +128,13 @@ ptrBuffer[tid].assign(rank, Value()); idxBuffer[tid].assign(rank, Value()); } + + for (unsigned i = 0, e = topSort.size(); i < e; i++) { + // This is an inverse map of the topologically sorted loop index from + // sparsifier. This is needed to map the AffineDimExpr back to the loopStack + // index used in loop emitter. + sparsiferLoopLvlMap[topSort[i]] = i; + } } void SparseTensorLoopEmitter::initializeLoopEmit( @@ -216,6 +225,34 @@ prepareLoopOverTensorAtDim(builder, loc, tid, dim); } +Value SparseTensorLoopEmitter::genAffine(OpBuilder &builder, AffineExpr a, + Location loc) { + switch (a.getKind()) { + case AffineExprKind::DimId: { + unsigned idx = a.cast().getPosition(); + return loopStack[sparsiferLoopLvlMap[idx]].iv; + } + case AffineExprKind::Add: { + auto binOp = a.cast(); + return builder.create( + loc, genAffine(builder, binOp.getLHS(), loc), + genAffine(builder, binOp.getRHS(), loc)); + } + case AffineExprKind::Mul: { + auto binOp = a.cast(); + return builder.create( + loc, genAffine(builder, binOp.getLHS(), loc), + genAffine(builder, binOp.getRHS(), loc)); + } + case AffineExprKind::Constant: { + int64_t c = a.cast().getValue(); + return constantIndex(builder, loc, c); + } + default: + llvm_unreachable("unexpected affine subscript"); + } +} + Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim( OpBuilder &builder, Location loc, size_t tid, size_t dim, MutableArrayRef reduc, bool isParallel, ArrayRef extraTids, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -61,7 +61,7 @@ tensors, StringAttr::get(context, linalg::GenericOp::getOperationName()), /*hasOutput=*/true, - /*isSparseOut=*/op != nullptr), + /*isSparseOut=*/op != nullptr, ts), sparseOut(op), outerParNest(nest), topSort(ts) { if (op) insChain = op->get(); @@ -485,38 +485,6 @@ }); } -/// Generates an affine expression. -// -// TODO: generalize for sparse tensor subscripts -// -static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a, - Location loc) { - switch (a.getKind()) { - case AffineExprKind::DimId: { - unsigned idx = a.cast().getPosition(); - return codegen.getLoopIdxValue(idx); // universal dense index - } - case AffineExprKind::Add: { - auto binOp = a.cast(); - return builder.create( - loc, genAffine(codegen, builder, binOp.getLHS(), loc), - genAffine(codegen, builder, binOp.getRHS(), loc)); - } - case AffineExprKind::Mul: { - auto binOp = a.cast(); - return builder.create( - loc, genAffine(codegen, builder, binOp.getLHS(), loc), - genAffine(codegen, builder, binOp.getRHS(), loc)); - } - case AffineExprKind::Constant: { - int64_t c = a.cast().getValue(); - return constantIndex(builder, loc, c); - } - default: - llvm_unreachable("unexpected affine subscript"); - } -} - /// Generates index for load/store on sparse tensor. static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { auto map = op.getMatchingIndexingMap(t); @@ -546,7 +514,7 @@ } else { for (unsigned d = 0; d < rank; d++) { AffineExpr a = map.getResult(d); - args.push_back(genAffine(codegen, builder, a, op.getLoc())); + args.push_back(codegen.loopEmitter.genAffine(builder, a, op.getLoc())); } } return codegen.loopEmitter.getValBuffer()[tensor];