diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -57,6 +57,8 @@ // d0 and d1 (for affine expression reduction). // If the list is empty, it means that there is no affine expression on the // input [tid, dim]. + // NOTE: the order of the returned list should be consistent with the + // topological order of the iteration graph. using DependentDimGetter = function_ref>(unsigned, unsigned)>; @@ -108,10 +110,7 @@ ArrayRef dims); // exit the current loop sequence, this will reset universal index to 0. - void exitCurrentLoopSeq() { - assert(loopSeqStack.size() == loopStack.size() + 1); - loopSeqStack.pop_back(); - } + void exitCurrentLoopSeq(OpBuilder &builder, Location loc); // TODO: Gets rid of `dim` in the argument list? Track the dimension we // are currently at internally. Then it would be enterNextDimForTensor. @@ -180,9 +179,13 @@ private: struct LoopLevelInfo { - LoopLevelInfo(ArrayRef tids, ArrayRef dims, Operation *loop, + LoopLevelInfo(ArrayRef tids, ArrayRef dims, + ArrayRef slicedTids, ArrayRef slicedDims, + ArrayRef sliceResolved, Operation *loop, Block *userBlock, Value iv, StringAttr loopTag) - : tids(tids), dims(dims), loop(loop), userCodeBlock(userBlock), iv(iv) { + : tids(tids), dims(dims), slicedTids(slicedTids), + slicedDims(slicedDims), sliceResolved(sliceResolved), loop(loop), + userCodeBlock(userBlock), iv(iv) { // Attached a special tag to loop emitter generated loop. if (loopTag) loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); @@ -192,11 +195,36 @@ const llvm::SmallVector tids; // The corresponding dims for the tensors const llvm::SmallVector dims; + // The set of tensors that the loop is operating on + const llvm::SmallVector slicedTids; + // The corresponding dims for the tensors + const llvm::SmallVector slicedDims; + // The corresponding dims for the tensors + const llvm::SmallVector sliceResolved; const Operation *loop; // the loop operation Block *const userCodeBlock; // the block holding user' generated code. const Value iv; // the induction variable for the loop }; + struct SliceInfo { + SliceInfo(Value baseSlice, Value minCoord, Value offset, Value isNonEmpty, + std::optional slicedOnLvl, unsigned depth) + : baseSlice(baseSlice), minCoord(minCoord), offset(offset), + isNonEmpty(isNonEmpty), slicedOnLvl(slicedOnLvl), depth(depth) { + assert(!slicedOnLvl || minCoord); + } + + // Whether this is the first slice + bool isInitialTensor() const { return !slicedOnLvl.has_value(); } + + Value baseSlice; // the current slices being reduced + Value minCoord; // the minimal coordinates of the slice on lvl. + Value offset; // the offset of the current slice. + Value isNonEmpty; // whether the slice is empty. + std::optional slicedOnLvl; // the level on which the slice is done + unsigned depth; // the depth (relative to dependentDimMap[tid][lvl]). + }; + /// Linearizes address for dense dimension (i.e., p = (i * d0) + j). Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim, Value iv); @@ -225,6 +253,11 @@ ArrayRef tids, ArrayRef dims); + Operation *emitForLoopOverTensorAtDim(OpBuilder &builder, Location loc, + size_t tid, size_t dim, + MutableArrayRef reduc, + bool isParallel); + /// Exits a for loop, returns the reduction results, e.g., /// For sequential for loops: /// %ret = for () { @@ -256,6 +289,68 @@ void exitWhileLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc); + // + // Slice-driven loop related methods. + // + + /// Retrieves the most recent slices on lvl. To reduce affine expression like + /// d0 + d1 + d2, we need two slices (one of size d1 + d2, and the other of + /// size d2). This methods returns the latter slice (of size d2), which is + /// also the final slice on the level. + SliceInfo &getFinalSliceOnLvl(size_t tid, size_t lvl); + + /// Get the total number of constraints that needed to fully resolve the + /// dependent dimension on tensor[tid]. + size_t sliceTotalConstraints(size_t tid); + + /// Whether the tid is fully resolved, i.e., all the dependent dimension are + /// reduced by slices offsets. + bool sliceFullyResolved(size_t tid); + + /// Generates a whileOp to iterate over a subset of coordinates on tid on lvl + /// using the pHi and pLo provided, the loop break on the first coordinate + /// that exceeds the slice boundary (i.e., coord >= slice.offset + + /// slice.size). + std::pair + genSliceLvlTraverseLoop(OpBuilder &builder, Location loc, Value pLo, + Value pHi, Value offset, size_t tid, size_t lvl, + size_t depth, ValueRange userReduc, bool genYield, + /*bodyBody=*/ + llvm::function_ref)>); + + /// Generates a nested loop that iterates over tid on all the coordinates on + /// lvl. + ValueRange genSliceAllLvlTraverseLoop( + OpBuilder &builder, Location loc, Value offset, size_t tid, size_t lvl, + size_t depth, ValueRange userReduc, + /*bodyBody=*/ + llvm::function_ref)>); + + /// Generates code to get the first non-empty slice of tid on lvl. + /// return true if has already been resolved. + bool genSliceBegin(OpBuilder &builder, Location loc, size_t tid, size_t lvl); + + /// Generates code to get the next non-empty slices of tid on lvl. + void genSliceNextInduction(OpBuilder &builder, Location loc, + const Operation *whileOp, size_t tid, size_t lvl, + SmallVectorImpl &operands, + unsigned &retIdx); + + /// Generates a slice-driven while loop like follows. + /// + /// curSlice = getFirstNonEmptySlice(tensor). + /// + /// while(isNonEmpty) { + /// ..user code.. + /// isNonEmpty, curSlice = getNextNonEmptySlice(curSlice) + /// } + Operation *emitSliceDrivenLoopOverTensorAtDim(OpBuilder &builder, + Location loc, size_t tid, + size_t lvl, + MutableArrayRef reduc); + /// A optional string attribute that should be attached to the loop /// generated by loop emitter, it might help following passes to identify /// loops that operates on sparse tensors more easily. @@ -281,25 +376,46 @@ std::vector> ptrBuffer; // to_pointers std::vector> idxBuffer; // to_indices std::vector valBuffer; // to_value - - // Map from [tid, dim] to a list of dependent [tid, dim]. - // See comments for `DependentDimGetter`. - std::vector>>> - dependentDimMap; - // Loop Stack, stores the information of all the nested loops that are // alive. std::vector loopStack; // Loop Sequence Stack, stores the unversial index for the current loop - // sequence. - std::vector loopSeqStack; + // sequence. and a list of tids which was taken sliced. + // TODO: maybe we should have a LoopSeqInfo + std::vector< + std::pair>>> + loopSeqStack; // Maps AffineDimExpr to the index of the loop in loopStack. // TODO: We should probably use a callback function here to make it more // general. std::vector sparsiferLoopLvlMap; + // + // Slice-driven loops related fields. + // + + // Map from [tid, dim] to a list of dependent [tid, dim]. + // See comments for `DependentDimGetter`. + std::vector>>> + dependentDimMap; + + // The cached pointer buffer for the slices, they serve the same purpose as + // ptrBuffer for compressed dimensions. But they always starts with the first + // pidx pointing to coord > slice.offset to avoid iteration from the + // beginning. + std::vector>> slicePtrBuffer; + + // The cached size for each slices. + std::vector>> sliceSizes; + + // The number of resolved constraints so far. + std::vector sliceResolvedConstraints; + + // sliceStack[tid] holds the generated slice stack on tid. + std::vector> sliceStack; + // TODO: not yet used, it should track the current level for each tensor // to help eliminate `dim` paramters from above APIs. // std::vector curLv; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" using namespace mlir; using namespace mlir::sparse_tensor; @@ -23,6 +24,14 @@ // File local helper functions. //===----------------------------------------------------------------------===// +/// Extracts a corresponding vector of type from a ValueRange. +static SmallVector getTypesFromValues(ValueRange vs) { + SmallVector ret; + for (auto v : vs) + ret.push_back(v.getType()); + return ret; +} + /// Generates a pointer/index load from the sparse storage scheme. Narrower /// data types need to be zero extended before casting the value into the /// index type used for looping and indexing. @@ -74,6 +83,60 @@ return std::make_pair(iv, rem); } +/// Helper method the generate a tensor.extract_slice operation with the given +/// offset and size on dim. +static Value genExtractSliceWithOffsetOnDim(OpBuilder &builder, Location loc, + Value src, Value dynOffset, + Value sz, unsigned dim) { + + RankedTensorType srcTp = src.getType().cast(); + int rank = srcTp.getRank(); + + SmallVector offsets(rank, 0); + SmallVector strides(rank, 1); + SmallVector sizes(srcTp.getShape()); + SmallVector dynSizes; + + offsets[dim] = ShapedType::kDynamic; + sizes[dim] = ShapedType::kDynamic; + + auto srcEncoding = getSparseTensorEncoding(srcTp); + SmallVector sliceAttrs; + + for (unsigned i = 0, e = sizes.size(); i < e; i++) { + // Infers the slice attribute array (sets offset/size to dynamic on the + // slicing dimension). + int offset = i == dim ? SparseTensorDimSliceAttr::kDynamic : 0; + int size = (i == dim || ShapedType::isDynamic(sizes[i])) + ? SparseTensorDimSliceAttr::kDynamic + : sizes[i]; + sliceAttrs.push_back(SparseTensorDimSliceAttr::get(srcTp.getContext(), + /*offset=*/offset, + /*size=*/size, + /*stride=*/1)); + if (ShapedType::isDynamic(sizes[i])) { + if (dim == i) + dynSizes.push_back(sz); + else + dynSizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, dim)); + } + } + + // Keeps original encodings but attaches slice attribute. + auto encoding = SparseTensorEncodingAttr::get( + srcTp.getContext(), srcEncoding.getDimLevelType(), + srcEncoding.getDimOrdering(), srcEncoding.getHigherOrdering(), + srcEncoding.getPointerBitWidth(), srcEncoding.getIndexBitWidth(), + sliceAttrs); + + auto retTp = RankedTensorType::get(sizes, srcTp.getElementType(), encoding); + return builder + .create(loc, retTp, src, ValueRange{dynOffset}, + dynSizes, ValueRange{}, offsets, sizes, + strides) + .getResult(); +} + std::pair LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value coord, unsigned tid, unsigned lvl) { @@ -126,9 +189,10 @@ size_t dim, Value iv) { Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1]; Value mul = builder.create(loc, highs[tid][dim], p); - if (isSparseSlices[tid]) + if (isSparseSlices[tid]) { iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim], sliceStrides[tid][dim], tensors[tid], dim); + } Value add = builder.create(loc, mul, iv); return add; } @@ -164,6 +228,11 @@ this->dependentDimMap.assign( tensors.size(), std::vector>>()); + this->slicePtrBuffer.assign(tensors.size(), + std::vector>()); + this->sliceSizes.assign(tensors.size(), std::vector>()); + this->sliceStack.assign(tensors.size(), std::vector()); + this->sliceResolvedConstraints.assign(tensors.size(), 0); for (size_t tid = 0, e = tensors.size(); tid < e; tid++) { @@ -193,11 +262,25 @@ ptrBuffer[tid].assign(rank, Value()); idxBuffer[tid].assign(rank, Value()); + // Slice-driven loops related initialization. dependentDimMap[tid].assign(rank, std::vector>()); - if (dimGetter) - for (unsigned i = 0; i < rank; i++) + slicePtrBuffer[tid].assign(rank, std::vector()); + sliceSizes[tid].assign(rank, std::vector()); + sliceStack[tid].emplace_back(tensors[tid], /*minCoord=*/Value(), + /*offset=*/Value(), /*isNonEmpty*/ Value(), + std::nullopt, 0); + if (dimGetter) { + for (unsigned i = 0; i < rank; i++) { dependentDimMap[tid][i] = dimGetter(tid, i); + unsigned depends = dependentDimMap[tid][i].size(); + if (depends != 0) { + // We need depends - 1 slices to fully resolve the affine expression. + slicePtrBuffer[tid][i].assign(depends - 1, nullptr); + sliceSizes[tid][i].assign(depends - 1, nullptr); + } + } + } } // FIXME: This map should be maintained outside loop emitter. @@ -275,17 +358,121 @@ // some loop preparation from tensor iteration, but will also // (undesirably) hosit the code ouside if conditions. } + + Type indexType = builder.getIndexType(); + Value c0 = constantZero(builder, loc, indexType); + Value c2 = constantIndex(builder, loc, 2); + // TODO: We should probably use integer with pointer bitwidth for the cache. + MemRefType cacheTp = MemRefType::get({ShapedType::kDynamic}, indexType); + // Generate caches required to fast compute next-non-empty slices with + // increasing offset for slice-base loop. + // We need to start a separate loop here because the cache size depends on the + // dimension size computed in the aboves loops. + for (size_t t = 0, e = tensors.size(); t < e; t++) { + auto rtp = tensors[t].getType().dyn_cast(); + if (!rtp) + continue; + + // for a pair of [pLo, pHi]. Note that we can not compress pHi because slice + // creates segments in the index buffer so that the pHi for the current dim + // is no longer the pLo for the next dim. + Value pIdxSize = c2; + auto rank = rtp.getRank(); + for (unsigned lvl = 0; lvl < rank; lvl++) { + if (!dependentDimMap[t][lvl].empty()) { + // Needs at least two operands to form a non-trivial affine expression. + ArrayRef> dependedDim = + dependentDimMap[t][lvl]; + assert(dependedDim.size() > 1); + + Value size = c0; + for (unsigned e = dependedDim.size() - 1; e >= 1; e--) { + auto [dt, dd] = dependedDim[e]; + size = builder.create(loc, size, dims[dt][dd]); + sliceSizes[t][lvl][e - 1] = size; + } + + // No cache for dense level, they can be simply increased by one. + auto dlt = dimTypes[t][lvl]; + + if (!isDenseDLT(dlt)) { + llvm::for_each(slicePtrBuffer[t][lvl], [cacheTp, pIdxSize, c2, loc, + &builder](Value &cache) { + cache = builder.create( + loc, cacheTp, + // Additional two metadata {memSize, idx} at head. + builder.create(loc, pIdxSize, c2).getResult()); + }); + } + + // Accumlates the size required to cache the pLo for the slice. + // E.g., if we want to cache the pIdx for slice on the second + // level. We at most need to a memref. + // NOTE: this is apperantly an over-approximation when the previous + // level is compressed, and we can compute a precise memory size + // inside the loops. But that would also requires us to allocate/free + // memorys in loops. + // TODO: Maybe using allocaScopOp inside the loop to resolve the issue? + if (!dependentDimMap[t][lvl].empty()) { + auto [dt, dd] = dependentDimMap[t][lvl].back(); + pIdxSize = builder.create(loc, pIdxSize, dims[dt][dd]); + } else { + // This level does not need to be sliced, the final size of the slice + // on the level will be the same as the current size. + pIdxSize = builder.create(loc, pIdxSize, dims[t][lvl]); + } + } + } + } } void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc, ArrayRef tids, ArrayRef dims) { assert(loopSeqStack.size() == loopStack.size()); - // Universal Index starts from 0. - loopSeqStack.emplace_back(constantIndex(builder, loc, 0)); // Prepares for all the tensors used in the current loop sequence. - for (auto [tid, dim] : llvm::zip(tids, dims)) - prepareLoopOverTensorAtDim(builder, loc, tid, dim); + std::vector> slicedTids; + for (auto [tid, dim] : llvm::zip(tids, dims)) { + if (!dependentDimMap[tid][dim].empty()) { + bool fullyRes = genSliceBegin(builder, loc, tid, dim); + slicedTids.emplace_back(tid, dim, fullyRes); + } else { + prepareLoopOverTensorAtDim(builder, loc, tid, dim); + } + } + + // Universal Index starts from 0. + loopSeqStack.emplace_back(constantIndex(builder, loc, 0), + std::move(slicedTids)); +} + +void LoopEmitter::exitCurrentLoopSeq(OpBuilder &builder, Location loc) { + assert(loopSeqStack.size() == loopStack.size() + 1); + + const std::vector> &slicedTids = + loopSeqStack.back().second; + + // Pop out outdated slices. + for (auto [tid, lvl, res] : slicedTids) { + if (!res) { + assert(sliceStack[tid].back().slicedOnLvl == lvl); + sliceStack[tid].pop_back(); + // There is an additional item in sliceStack for the input tensor. + assert(sliceResolvedConstraints[tid] + 1 == sliceStack[tid].size()); + } else { + Value c1 = constantIndex(builder, loc, 1); + Value c2 = constantIndex(builder, loc, 2); + + // pIdx += 2, we finished the current lvl, advance the pointer index of + // the previous level by two to skip the [pLo, pHi] for current level. + // TODO: we could probably use an SSA value for it. + Value sPtrBuf = slicePtrBuffer[tid][lvl].back(); + Value curP = genIndexLoad(builder, loc, sPtrBuf, c1); + Value nexP = builder.create(loc, curP, c2); + builder.create(loc, nexP, sPtrBuf, c1); + } + } + loopSeqStack.pop_back(); } Value LoopEmitter::genAffine(OpBuilder &builder, AffineExpr a, Location loc) { @@ -315,35 +502,18 @@ } } -Operation *LoopEmitter::enterLoopOverTensorAtDim( - OpBuilder &builder, Location loc, ArrayRef tids, - ArrayRef dims, MutableArrayRef reduc, bool isParallel) { - // TODO: support multiple return on parallel for? - assert(!isParallel || reduc.size() <= 1); - bool isSparseInput = false; - size_t tid = tids.front(), dim = dims.front(); - for (auto [t, d] : llvm::zip(tids, dims)) { - assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair - assert(!coord[t][d]); // We cannot re-enter the same level - auto dimType = dimTypes[t][d]; - // Must be a recognizable DLT. - assert(isDenseDLT(dimType) || isCompressedDLT(dimType) || - isSingletonDLT(dimType)); - bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType); - // We can at most have one sparse input, otherwise, a while loop is - // required to co-iterate multiple sparse tensors. - assert(!isSparseInput || !isSparse); - if (isSparse) { - tid = t; - dim = d; - } - isSparseInput = isSparseInput || isSparse; - } +Operation *LoopEmitter::emitForLoopOverTensorAtDim(OpBuilder &builder, + Location loc, size_t tid, + size_t dim, + MutableArrayRef reduc, + bool isParallel) { + bool isSparseCond = + isCompressedDLT(dimTypes[tid][dim]) || isSingletonDLT(dimTypes[tid][dim]); // TODO: support dynamic slices. Value step = constantIndex(builder, loc, 1); - Value lo = isSparseInput ? pidxs[tid][dim] // current offset - : loopSeqStack.back(); // universal index + Value lo = isSparseCond ? pidxs[tid][dim] // current offset + : loopSeqStack.back().first; // universal index Value hi = highs[tid][dim]; Operation *loop = nullptr; @@ -379,7 +549,7 @@ assert(loop && iv); Value c; - if (isSparseInput) { + if (isSparseCond) { pidxs[tid][dim] = iv; // Generating a load on the indices array yields the coordinate. Value ptr = idxBuffer[tid][dim]; @@ -389,7 +559,7 @@ c = iv; } - if (isSparseSlices[tid] && isSparseInput) { + if (isSparseSlices[tid] && isSparseCond) { // For sparse level slices, we need to filter out invalid coordinates that // are not included in the slice. SmallVector types; @@ -419,14 +589,106 @@ assert(c); coord[tid][dim] = c; - // NOTE: we can also prepare for next dim here in advance - // Push the loop into stack - loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), loop, - builder.getInsertionBlock(), coord[tid][dim], loopTag); + return loop; +} + +Operation *LoopEmitter::enterLoopOverTensorAtDim( + OpBuilder &builder, Location loc, ArrayRef tids, + ArrayRef dims, MutableArrayRef reduc, bool isParallel) { + // TODO: support multiple return on parallel for? + assert(!isParallel || reduc.size() <= 1); + bool isSparseCond = false, isSliceCond = false; + size_t tid = tids.front(), dim = dims.front(); + + for (auto [t, d] : llvm::zip(tids, dims)) { + assert(dimTypes[t].size() > d); // Must be a valid tid, dim pair + assert(!coord[t][d] || // We cannot re-enter the same level + !dependentDimMap[t][d].empty()); // unless it is a slice-driver loop + auto dimType = dimTypes[t][d]; + // Must be a recognizable DLT. + assert(isDenseDLT(dimType) || isCompressedDLT(dimType) || + isSingletonDLT(dimType)); + + // This is a slice-driven loop. + if (!dependentDimMap[t][d].empty()) { + assert(!isSliceCond && !isSparseCond); + isSliceCond = true; + tid = t; + dim = d; + continue; + } + + bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType); + // We can at most have one sparse input, otherwise, a while loop is + // required to co-iterate multiple sparse tensors. + assert(!isSparseCond || !isSparse); + assert(!isSliceCond || !isSparseCond); + if (isSparse) { + tid = t; + dim = d; + } + isSparseCond = isSparseCond || isSparse; + } + + // if the slice is fully reduced, we can now use TACO-based algorithm to + // iterate it. + Operation *l = nullptr; + if (isSliceCond) { + bool fullyResolved = sliceFullyResolved(tid); + if (!fullyResolved) { + l = emitSliceDrivenLoopOverTensorAtDim(builder, loc, tid, dim, reduc); + } else { + const SliceInfo &info = getFinalSliceOnLvl(tid, dim); + Value offset = info.offset; + unsigned depth = info.depth - 1; + Operation *insertPoint = nullptr; + // TODO: we should generalize the method to support iteration over for + // normal slices as well to allow early break. + l = genSliceLvlTraverseLoop( + builder, loc, pidxs[tid][dim], highs[tid][dim], offset, tid, dim, + depth, reduc, + /*genYield=*/false, // unaware of the yield values from user yet + [this, tid, dim, reduc, offset, + &insertPoint](OpBuilder &builder, Location loc, Value iv, + MutableArrayRef innerReduc) { + assert(innerReduc.size() == reduc.size()); + // Updates users' reduction variable inplace + for (unsigned i = 0, e = reduc.size(); i < e; i++) + reduc[i] = innerReduc[i]; + // Loads the coordinates. + Value absC = + genIndexLoad(builder, loc, idxBuffer[tid][dim], iv); + + // We need to substract the offset to get relative coordinates. + // TODO: how to assert relC >=0 during runtime? + insertPoint = builder.create(loc, absC, offset); + pidxs[tid][dim] = iv; + coord[tid][dim] = insertPoint->getResult(0); + }) + .first; + // We did not finish the loop body, reset the insertion point and delegate + // to user. + builder.setInsertionPointAfter(insertPoint); + } + // NOTE: we can also prepare for next dim here in advance + // Pushes the loop into stack. + loopStack.emplace_back( + ArrayRef(), ArrayRef(), ArrayRef(tid), + ArrayRef(dim), ArrayRef(fullyResolved), l, + builder.getInsertionBlock(), coord[tid][dim], loopTag); + } else { + l = emitForLoopOverTensorAtDim(builder, loc, tid, dim, reduc, isParallel); + // NOTE: we can also prepare for next dim here in advance + // Pushes the loop into stack. + loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), + ArrayRef(), ArrayRef(), + ArrayRef(), l, builder.getInsertionBlock(), + coord[tid][dim], loopTag); + } + // Emit extra locals. emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims); - - return loop; + return l; } Operation *LoopEmitter::enterFilterLoopOverTensorAtDim( @@ -492,8 +754,10 @@ // NOTE: we can also prepare for next dim here in advance // Push the loop into stack - loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), forOp, - builder.getInsertionBlock(), coord[tid][dim], nullptr); + loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), + ArrayRef(), ArrayRef(), + ArrayRef(), forOp, builder.getInsertionBlock(), + coord[tid][dim], nullptr); return forOp; } @@ -508,12 +772,17 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims( OpBuilder &builder, Location loc, ArrayRef tids, ArrayRef dims, bool needsUniv, MutableArrayRef reduc) { + + // NOTE: make sure that the slice driven tensor-related reduction variable + // appears first than normal tensors. assert(tids.size() == dims.size()); SmallVector types; SmallVector operands; // Construct the while-loop with a parameter for each index. Type indexType = builder.getIndexType(); for (auto [tid, dim] : llvm::zip(tids, dims)) { + // TODO: support coiteration with slice driven tensors. + assert(dependentDimMap[tid][dim].empty() && "TODO: not yet implemented"); if (isCompressedDLT(dimTypes[tid][dim]) || isSingletonDLT(dimTypes[tid][dim])) { assert(pidxs[tid][dim]); @@ -529,7 +798,7 @@ if (needsUniv) { types.push_back(indexType); // Update universal index. - operands.push_back(loopSeqStack.back()); + operands.push_back(loopSeqStack.back().first); } assert(types.size() == operands.size()); scf::WhileOp whileOp = builder.create(loc, types, operands); @@ -634,8 +903,9 @@ } // Sets up the loop stack. - loopStack.emplace_back(tids, dims, whileOp, builder.getInsertionBlock(), min, - loopTag); + loopStack.emplace_back(tids, dims, ArrayRef(), ArrayRef(), + ArrayRef(), whileOp, builder.getInsertionBlock(), + min, loopTag); assert(loopStack.size() == loopSeqStack.size()); // Emits extra locals @@ -800,6 +1070,7 @@ auto &dims = loopInfo.dims; auto &tids = loopInfo.tids; Value iv = loopInfo.iv; + // Finalize the induction. Note that the induction could be performed // in the individual if-branches to avoid re-evaluating the conditions. // However, that would result in a rather elaborate forest of yield @@ -807,6 +1078,34 @@ // after the if-statements more closely resembles code generated by TACO. unsigned o = 0; SmallVector operands; + unsigned delta = 0; + for (auto [tid, dim, resolved] : llvm::zip( + loopInfo.slicedTids, loopInfo.slicedDims, loopInfo.sliceResolved)) { + if (!resolved) { + genSliceNextInduction(builder, loc, whileOp, tid, dim, operands, o); + sliceResolvedConstraints[tid]--; + } else { + // TODO: We need to distinguish coiterate loop with slice-driven loop and + // fully reduced while op for iterating one slices. + // since we didn't implement coiteration, this must be iteration just + // on fully resolved slice. + assert(loopInfo.slicedTids.size() == 1 && loopInfo.tids.empty()); + // The if guard to filter out out-range coordinates. + assert(llvm::isa(builder.getInsertionBlock()->getParentOp())); + pidxs[tid][dim] = whileOp->getResult(o++); + // FIXME: we are not using continue here since we do not support + // coiteration on slices. But it need to be treated similarly as the + // universal index. + o++; // skip continue flag. + // Since we did not push two results from whileOp. The size of the + // operands vector is smaller than the actual number of return values from + // the whileOp. + // It is because we are actually generate yield in the IfOp inside the + // whileOp to only iterates over inbound coordinates within the slices. + delta += 2; + } + }; + Value one = constantIndex(builder, loc, 1); for (auto [tid, dim] : llvm::zip(tids, dims)) { if (isCompressedDLT(dimTypes[tid][dim]) || @@ -834,15 +1133,15 @@ } // An (optional) universal index. - if (operands.size() < whileOp.getNumResults()) { - assert(operands.size() + 1 == whileOp.getNumResults()); + if (operands.size() + delta < whileOp.getNumResults()) { + assert(operands.size() + delta + 1 == whileOp.getNumResults()); // The last one is the universial index. operands.push_back(builder.create(loc, iv, one)); // update the loop starting point of current loop sequence - loopSeqStack.back() = whileOp->getResult(o++); + loopSeqStack.back().first = whileOp->getResult(o++); } - assert(o == operands.size()); + assert(o == operands.size() + delta); builder.create(loc, operands); builder.setInsertionPointAfter(whileOp); } @@ -863,3 +1162,568 @@ assert(loopStack.size() == loopSeqStack.size()); loopStack.pop_back(); } + +//===----------------------------------------------------------------------===// +// Slice-driven loop related methods. +//===----------------------------------------------------------------------===// + +LoopEmitter::SliceInfo &LoopEmitter::getFinalSliceOnLvl(size_t tid, + size_t lvl) { + for (auto it = sliceStack[tid].rbegin(), ie = sliceStack[tid].rend(); it < ie; + it++) { + if (it->slicedOnLvl == lvl) { + assert(it->depth == dependentDimMap[tid][lvl].size() - 1); + return *it; + } + } + + llvm_unreachable("Failed to find sliceInfo"); +} + +size_t LoopEmitter::sliceTotalConstraints(size_t tid) { + size_t numConstraints = 0; + for (const auto &lvlDeps : dependentDimMap[tid]) { + if (!lvlDeps.empty()) { + assert(lvlDeps.size() >= 2); + numConstraints += lvlDeps.size() - 1; + } + } + return numConstraints; +} + +bool LoopEmitter::sliceFullyResolved(size_t tid) { + return sliceTotalConstraints(tid) == sliceResolvedConstraints[tid]; +} + +std::pair LoopEmitter::genSliceLvlTraverseLoop( + OpBuilder &builder, Location loc, Value loopLo, Value loopHi, Value offset, + size_t tid, size_t lvl, size_t depth, ValueRange userReduc, bool genYield, + llvm::function_ref)> + bodyBuilder) { + Value c1 = constantIndex(builder, loc, 1); + Value sliceHi = + builder.create(loc, offset, sliceSizes[tid][lvl].back()); + + SmallVector reduc = { + loopLo, // loop lower bounds + constantI1(builder, loc, true), // continue + }; + // Append user required reduction value. + reduc.append(userReduc.begin(), userReduc.end()); + SmallVector types = getTypesFromValues(reduc); + + scf::WhileOp whileOp = builder.create( + loc, types, reduc, + /*beforeBuilder=*/ + [loopHi](OpBuilder &builder, Location loc, ValueRange args) { + Value lo = args[0]; + Value cont = args[1]; + Value inBound = builder.create( + loc, arith::CmpIPredicate::ult, lo, loopHi); + Value cond = builder.create(loc, cont, inBound); + // continue if not yet break nor out of bound. + builder.create(loc, cond, args); + }, + /*afterBuilder=*/ + [this, c1, tid, lvl, sliceHi, genYield, + bodyBuilder](OpBuilder &builder, Location loc, ValueRange args) { + Value iv = args[0]; + Value coord = genIndexLoad(builder, loc, idxBuffer[tid][lvl], iv); + // If coord < sliceHi + Value cont = builder.create( + loc, arith::CmpIPredicate::ult, coord, sliceHi); + + SmallVector types = getTypesFromValues(args.drop_front(2)); + auto ifOp = builder.create(loc, types, cont, true); + { + // 2 reduction variable maintained by us. + SmallVector ifRet = args.drop_front(2); + assert(ifRet.size() == args.size() - 2); + + OpBuilder::InsertionGuard guard(builder); + // If not in slice. + // Break the while loop (by setting continue to false) + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, ifRet); + + // If this is a legit coordinates in slice + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + bodyBuilder(builder, loc, iv, ifRet); + if (genYield) { + builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); + builder.create(loc, ifRet); + } + } + // Marks this speical ifOp to avoid sparisification finalizing it. + ifOp->setAttr(getLoopEmitterLoopAttrName(), + StringAttr::get(builder.getContext(), "slice")); + // Insertion point restored to after ifOp. + SmallVector yields; + // Increase induction variable. + yields.push_back(builder.create(loc, iv, c1)); + yields.push_back(cont); + yields.append(ifOp.getResults().begin(), ifOp.getResults().end()); + builder.create(loc, yields); + }); + + builder.setInsertionPointAfter(whileOp); + return std::make_pair(whileOp, whileOp.getResults().drop_front(2)); +} + +ValueRange LoopEmitter::genSliceAllLvlTraverseLoop( + OpBuilder &builder, Location loc, Value offset, size_t tid, size_t lvl, + size_t depth, ValueRange userReduc, + llvm::function_ref)> + bodyBuilder) { + + Value c0 = constantIndex(builder, loc, 0); + Value c1 = constantIndex(builder, loc, 1); + Value c2 = constantIndex(builder, loc, 2); + + // TODO: it only works on all compressed tensor. + Value sPtrBuf = slicePtrBuffer[tid][lvl][depth]; + Value pSt = c2; // pointer starting index + Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize + + auto forOp = + scf::buildLoopNest( + builder, loc, pSt, mSz, c2, userReduc, + [this, c1, depth, tid, lvl, offset, sPtrBuf, + bodyBuilder](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + // generate traversal for each level. + Value loopLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front()); + Value loopHi = genIndexLoad( + builder, loc, sPtrBuf, + builder.create(loc, ivs.front(), c1)); + return genSliceLvlTraverseLoop(builder, loc, loopLo, loopHi, offset, + tid, lvl, depth, iterArgs, true, + bodyBuilder) + .second; + }) + .loops.front(); + + // Insert after current while operation. + builder.setInsertionPointAfter(forOp); + return forOp.getResults(); +} + +bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, size_t tid, + size_t lvl) { + + Value c0 = constantIndex(builder, loc, 0); + Value c1 = constantIndex(builder, loc, 1); + Value c2 = constantIndex(builder, loc, 2); + Value c3 = constantIndex(builder, loc, 3); + Value c4 = constantIndex(builder, loc, 4); + + if (sliceFullyResolved(tid)) { + // If constraints on the tensor is fully resolved. We do not need to + // generates slice begin any more, instead we fall back to TACO-based + // algorithm to (co)iterates over the slice. + Value pLoPtr = + genIndexLoad(builder, loc, slicePtrBuffer[tid][lvl].back(), c1); + pLoPtr = builder.create(loc, pLoPtr, c2); + Value pHiPtr = builder.create(loc, pLoPtr, c1); + pidxs[tid][lvl] = + genIndexLoad(builder, loc, slicePtrBuffer[tid][lvl].back(), pLoPtr); + highs[tid][lvl] = + genIndexLoad(builder, loc, slicePtrBuffer[tid][lvl].back(), pHiPtr); + return true; + } + + // Only when the level is sorted, the next-non-empty slice can be computed + // efficiently. + assert(isOrderedDLT(dimTypes[tid][lvl])); + if (isDenseDLT(dimTypes[tid][lvl]) || isSingletonDLT(dimTypes[tid][lvl])) + llvm_unreachable("TODO: dense level should be easy to support, while " + "singleton level requres more efforts"); + + assert(!dependentDimMap[tid][lvl].empty()); + assert(!sliceStack[tid].empty()); + + const SliceInfo &sliceInfo = sliceStack[tid].back(); + auto baseEnc = getSparseTensorEncoding(sliceInfo.baseSlice.getType()); + + Value size, minCoord, isNonEmpty; + unsigned depth = 0; + if (sliceInfo.isInitialTensor()) { + // The input tensor is slices, not yet handled. + if (baseEnc.isSlice()) + llvm_unreachable("TODO: not yet implemented"); + + assert(lvl == 0); // must be reduing the affine expression on the first lvl. + // Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, 0, pHi] + Value sPtrBuf = slicePtrBuffer[tid][0][0]; + Value pHi = genIndexLoad(builder, loc, ptrBuffer[tid][0], c1); + builder.create(loc, c4, sPtrBuf, c0); // memSize = 4 + builder.create(loc, c0, sPtrBuf, c1); // index = 0 + builder.create(loc, c0, sPtrBuf, c2); // pLo = 0; + builder.create(loc, pHi, sPtrBuf, c3); // loaded pHi. + + size = sliceSizes[tid][0][0]; + // This is an non empty tensor if 0 < pHi. + isNonEmpty = + builder.create(loc, arith::CmpIPredicate::ult, c0, pHi); + // The minimal coord must be at the first on ordered level. + // FIXME: Technically we should load the coord only when the slice is + // nonempty. though we assume that even on empty sparse tensors, a non-empty + // ptr/idx buffer is allocated for each level so it would not cause OOB to + // avoid generating a ifOp here. + minCoord = genIndexLoad(builder, loc, idxBuffer[tid][0], c0); + depth = 1; + } else { + unsigned prevLvl = *sliceInfo.slicedOnLvl; + assert(lvl >= prevLvl); + if (lvl != prevLvl + 1) { + // Either lvl = prevSlicedLvl, i.e., t[d0 + d1 + d2,...] (more than one + // variable need to be reduced on the same level). + // Or lvl > prevSliceLvl + 1, i.e., t[..., d2, d3 + d4] (having a + // simple dim expression in between). + llvm_unreachable("TODO: not yet implemented"); + } else { + assert(slicePtrBuffer[tid][prevLvl].size() == sliceInfo.depth); + Value sPtrBuf = slicePtrBuffer[tid][lvl][0]; + + SmallVector reduc = { + constantI1(builder, loc, false), // isNonEmpty + dims[tid][lvl], // minCoord + c2, // memSize + }; + ValueRange result = genSliceAllLvlTraverseLoop( + builder, loc, sliceInfo.offset, tid, prevLvl, sliceInfo.depth - 1, + reduc, + [this, c1, c2, tid, lvl, sPtrBuf](OpBuilder &builder, Location loc, + Value iv, + MutableArrayRef reduc) { + Value &isNonEmpty = reduc[0]; + Value &minCoord = reduc[1]; + Value &curMemSize = reduc[2]; + + Value pHi = builder.create(loc, iv, c1); + Value sPLo = genIndexLoad(builder, loc, ptrBuffer[tid][lvl], iv); + Value sPHi = genIndexLoad(builder, loc, ptrBuffer[tid][lvl], pHi); + + // isNonEmpty = isNonEmpty || lvlNonEmpty + Value lvlNonEmpty = builder.create( + loc, arith::CmpIPredicate::ult, sPLo, sPHi); + isNonEmpty = + builder.create(loc, lvlNonEmpty, isNonEmpty); + + // Update minimal coordinate. + auto ifNonEmpty = builder.create( + loc, builder.getIndexType(), lvlNonEmpty, true); + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(ifNonEmpty.thenBlock()); + Value curC = + genIndexLoad(builder, loc, idxBuffer[tid][lvl], sPLo); + Value isCurSmaller = builder.create( + loc, arith::CmpIPredicate::ult, curC, minCoord); + Value newMin = builder.create(loc, isCurSmaller, + curC, minCoord); + builder.create(loc, newMin); + builder.setInsertionPointToStart(ifNonEmpty.elseBlock()); + builder.create(loc, minCoord); + } + minCoord = ifNonEmpty.getResult(0); + + // filles in + builder.create(loc, sPLo, sPtrBuf, curMemSize); + Value nxtMemSize = + builder.create(loc, curMemSize, c1); + builder.create(loc, sPHi, sPtrBuf, nxtMemSize); + + // curMemSize += 2 + curMemSize = builder.create(loc, curMemSize, c2); + }); + + size = sliceSizes[tid][lvl][0]; + isNonEmpty = result[0]; + minCoord = result[1]; + depth = 1; + + // Two metadata [memSize, idx]. + // TODO: we might be able to use an SSA value for memSize here to avoid + // memory operation. + builder.create(loc, result[2], sPtrBuf, c0); + builder.create(loc, c0, sPtrBuf, c1); + } + } + + assert(depth > 0 && size && isNonEmpty && minCoord && depth); + // Compute the minimal offsets viable for a non empty tensor. + // offset = isNonEmpty && minCoord >= size ? minCoord - size + 1 : 0; + // NOTE: that minCoord is invalid when isNonEmpty = false, in which case + // the computed slices are meaningless. + // FIXME: support relative offset compute. + Value geSize = builder.create(loc, arith::CmpIPredicate::uge, + minCoord, size); + Value pred = builder.create(loc, isNonEmpty, geSize); + + Value mp1 = builder.create(loc, minCoord, c1); + Value mms = builder.create(loc, mp1, size); + // This is the absolute offset related to the underly tensor. + Value absOffset = builder.create(loc, pred, mms, c0); + // This is the relative offset related to the base slice. + Value relOffset = absOffset; + uint64_t dim = toOrigDim(baseEnc, lvl); + Value newSlice = genExtractSliceWithOffsetOnDim( + builder, loc, sliceInfo.baseSlice, relOffset, size, dim); + sliceStack[tid].emplace_back(newSlice, minCoord, absOffset, isNonEmpty, lvl, + depth); + return false; +} + +void LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc, + const Operation *op, size_t tid, + size_t lvl, + SmallVectorImpl &operands, + unsigned &retIdx) { + if (!isCompressedDLT(dimTypes[tid][lvl])) + llvm_unreachable("TODO"); + + // else generate code to compute next non empty slice. + Value c0 = constantIndex(builder, loc, 0); + Value c1 = constantIndex(builder, loc, 1); + Value c2 = constantIndex(builder, loc, 2); + + auto whileOp = llvm::cast(op); + SliceInfo &info = sliceStack[tid].back(); + assert(info.slicedOnLvl == lvl); + + // + // We forward to the next non empty slice by + // if (minCoord > offset) { + // offset += 1 + // } else { + // minCoord = nextMinInSlice(); + // offset = minCoord - size + 1; + // } + // + // if (offset + size > parents.size) + // isNonEmpty = false; + // + Value absOffset = info.offset; + // Resets slices pointers as the resolved slices are invalidated after we + // moves forward to the next slice. + for (unsigned i = 0; i <= lvl; i++) + builder.create(loc, c0, slicePtrBuffer[tid][i].back(), c1); + + SmallVector reduc = {info.minCoord, info.isNonEmpty, absOffset}; + SmallVector types = getTypesFromValues(reduc); + Value sPtrBuf = slicePtrBuffer[tid][lvl][info.depth - 1]; + Value fastPathP = builder.create( + loc, arith::CmpIPredicate::ugt, info.minCoord, absOffset); + auto ifOp = builder.create(loc, types, fastPathP, true); + { + OpBuilder::InsertionGuard guard(builder); + // Take the fast path if minCoord > offset + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + reduc[2] = builder.create(loc, absOffset, c1); + // Yield offset + 1. + builder.create(loc, reduc); + + // Else, take the slow path. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + reduc[2] = absOffset; // restore value. + Value pSt = c2; // pointer starting index + Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize + reduc[0] = dims[tid][lvl]; // next min coord + reduc[1] = constantI1(builder, loc, false); // isNonEmpty + auto loopArgs = static_cast(reduc).drop_back(); + auto forOp = scf::buildLoopNest( + builder, loc, pSt, mSz, c2, loopArgs, + [this, tid, lvl, c1, sPtrBuf, + &info](OpBuilder &builder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + Value curMinCoord = iterArgs[0]; + Value isNonEmpty = iterArgs[1]; + + Type idxTp = builder.getIndexType(); + Value pLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front()); + Value pHi = + genIndexLoad(builder, loc, sPtrBuf, + builder.create(loc, ivs.front(), c1)); + // + // if pLo < pHi + // coord = load[pLo] + // if coord == minCoord + // pLo += 1 + // + // if pLo < pHi + // curMinCoord = min(curMinCoord, load[pLo]) + // + Value pred = builder.create( + loc, arith::CmpIPredicate::ult, pLo, pHi); + auto advPLo = builder.create(loc, idxTp, pred, true); + /* if pLo < pHi */ { + builder.setInsertionPointToStart(&advPLo.getThenRegion().front()); + // coord = load[pLo] + Value coord = genIndexLoad(builder, loc, idxBuffer[tid][lvl], pLo); + Value pred = builder.create( + loc, arith::CmpIPredicate::eq, coord, info.minCoord); + auto ifEqual = builder.create(loc, idxTp, pred, true); + /* if coord == minCoord */ { + builder.setInsertionPointToStart( + &ifEqual.getThenRegion().front()); + Value newPlo = builder.create(loc, pLo, c1); + // Updates the cache. + builder.create(loc, newPlo, sPtrBuf, + ivs.front()); + builder.create(loc, newPlo); + } + /* else coord != minCoord */ { + builder.setInsertionPointToStart( + &ifEqual.getElseRegion().front()); + builder.create(loc, pLo); + } + builder.setInsertionPointAfter(ifEqual); + builder.create(loc, ifEqual.getResults()); + } + /* else pLo >= pHi */ { + builder.setInsertionPointToStart(&advPLo.getElseRegion().front()); + builder.create(loc, pLo); + } + + builder.setInsertionPointAfter(advPLo); + pLo = advPLo.getResult(0); + Value lvlNonEmpty = builder.create( + loc, arith::CmpIPredicate::ult, pLo, pHi); + // Update minCoords + auto newMin = + builder.create(loc, idxTp, lvlNonEmpty, true); + builder.setInsertionPointToStart(&newMin.getThenRegion().front()); + builder.create( + loc, genIndexLoad(builder, loc, idxBuffer[tid][lvl], pLo)); + + builder.setInsertionPointToStart(&newMin.getElseRegion().front()); + builder.create(loc, curMinCoord); + builder.setInsertionPointAfter(newMin); + + // isNonEmpty = isNonEmpty || lvlNonEmpty + isNonEmpty = + builder.create(loc, lvlNonEmpty, isNonEmpty); + curMinCoord = builder.create( + loc, + builder.create(loc, arith::CmpIPredicate::ult, + newMin.getResult(0), curMinCoord), + newMin.getResult(0), curMinCoord); + return {curMinCoord, isNonEmpty}; + }); + + builder.setInsertionPointAfter(forOp.loops.front()); + // minOffset = minCoord + 1 >= size ? minCoord + 1 - size : c0 + Value tmp = builder.create(loc, forOp.results.front(), c1); + Value minOffset = builder.create( + loc, tmp, sliceSizes[tid][lvl][info.depth - 1]); + Value p = + builder.create(loc, arith::CmpIPredicate::uge, tmp, + sliceSizes[tid][lvl][info.depth - 1]); + minOffset = builder.create(loc, p, minOffset, c0); + SmallVector yields; + yields.assign(forOp.results.begin(), forOp.results.end()); + yields.push_back(minOffset); + builder.create(loc, yields); + } + + Value nextMinCoord = ifOp.getResults()[0]; + //// builder.create(loc, nextMinCoord); + Value nextNonEmpty = ifOp.getResults()[1]; + + // the next offset should at least be offset + 1; + Value minOffset = ifOp.getResults()[2]; + Value nxOffset = builder.create(loc, info.offset, c1); + Value maxPred = builder.create(loc, arith::CmpIPredicate::ugt, + minOffset, nxOffset); + Value nextAbsOffset = + builder.create(loc, maxPred, minOffset, nxOffset); + + Value sliceUB = builder.create( + loc, nextAbsOffset, sliceSizes[tid][lvl][info.depth - 1]); + + // FIXME: this only works if the parsent is the tensor, we should use the + // parents slice size + parent offset. + assert(info.depth - 1 == 0); + // nextNonEmpty = nextNonEmpty && slice upper bound <= parent upperbound. + nextNonEmpty = builder.create( + loc, nextNonEmpty, + builder.create(loc, arith::CmpIPredicate::ule, sliceUB, + dims[tid][lvl])); + // FIXME: compute relative offset. + assert(info.depth - 1 == 0); + Value nextRelOffset = nextAbsOffset; + nextRelOffset = + builder.create(loc, nextNonEmpty, nextRelOffset, c0); + + uint64_t dim = + toOrigDim(getSparseTensorEncoding(tensors[tid].getType()), lvl); + + Value nextSlice = genExtractSliceWithOffsetOnDim( + builder, loc, sliceStack[tid][sliceStack.size() - 2].baseSlice, + nextRelOffset, sliceSizes[tid][lvl][info.depth - 1], dim); + + operands.push_back(nextNonEmpty); + operands.push_back(nextSlice); + operands.push_back(nextMinCoord); + operands.push_back(nextAbsOffset); // we push the absolute offset. + + // Update the slice stack. + info.isNonEmpty = whileOp.getResult(retIdx++); + info.baseSlice = whileOp.getResult(retIdx++); + info.minCoord = whileOp.getResult(retIdx++); + info.offset = whileOp.getResult(retIdx++); +} + +Operation *LoopEmitter::emitSliceDrivenLoopOverTensorAtDim( + OpBuilder &builder, Location loc, size_t tid, size_t lvl, + MutableArrayRef reduc) { + assert(!sliceFullyResolved(tid)); + SliceInfo &sliceInfo = sliceStack[tid].back(); + assert(sliceInfo.slicedOnLvl == lvl); + + // NOTE: The order matters! + constexpr size_t numMetaReduc = 4; // number of reduction maintained by us. + SmallVector operands{sliceInfo.isNonEmpty, sliceInfo.baseSlice, + sliceInfo.minCoord, sliceInfo.offset}; + // Append user-required reduction values. + operands.append(reduc.begin(), reduc.end()); + assert(operands.size() == numMetaReduc + reduc.size()); + + SmallVector types = getTypesFromValues(operands); + + auto whileOp = builder.create( + loc, types, operands, + /*beforeBuilder=*/ + [](OpBuilder &builder, Location loc, ValueRange args) { + builder.create(loc, /*isNonEmpty*/ args[0], args); + }, + /*afterBuilder=*/ + [this, tid, lvl, reduc, &sliceInfo](OpBuilder &builder, Location loc, + ValueRange args) { + assert(args.size() == reduc.size() + numMetaReduc); + sliceInfo.isNonEmpty = args[0]; + sliceInfo.baseSlice = args[1]; + sliceInfo.minCoord = args[2]; + sliceInfo.offset = args[3]; + // The slice offset is the coordinate. + Value c = sliceInfo.offset; + if (sliceInfo.depth > 1) { + // Coord is the relative offset related to its parents. + // Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1] + llvm_unreachable("TODO: not yet implement"); + } + coord[tid][lvl] = c; + + for (unsigned i = 0, e = reduc.size(); i < e; i++) + reduc[i] = args[i + numMetaReduc]; + }); + + // Increments the number of resolved constraints on tid. + sliceResolvedConstraints[tid]++; + // Set the insertion point to while loop body. + builder.setInsertionPointToEnd(&whileOp.getAfter().front()); + return whileOp; +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -997,7 +997,7 @@ // Link the reduction chain. Note that loop emitter update the reducValue // in place. loopEmitter.exitCurrentLoop(rewriter, loc, reducValue); - loopEmitter.exitCurrentLoopSeq(); + loopEmitter.exitCurrentLoopSeq(rewriter, loc); } // Replace the foreach operator with the value returned by the outtermost diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -58,9 +58,12 @@ class AffineDimFinder : public AffineExprVisitor { public: explicit AffineDimFinder(linalg::GenericOp op) - : iterTypes(op.getIteratorTypesArray()) {} + : iterTypes(op.getIteratorTypes()) {} void visitDimExpr(AffineDimExpr expr) { - if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()]) { + if (pickedDim == nullptr || + pickIterType == iterTypes[expr.getPosition()] + .cast() + .getValue()) { pickedDim = expr; } } @@ -79,13 +82,13 @@ /// The iterator type that we want. utils::IteratorType pickIterType; /// The mapping between dim=>iterator type. - SmallVector iterTypes; + ArrayAttr iterTypes; }; /// A helper class that visits an affine expression and tries to find an /// AffineDimExpr to which the corresponding iterator from a GenericOp matches /// the desired iterator type. -struct AffineDimCollector : public AffineExprVisitor { +struct AffineDimCollector : public AffineExprVisitor { void visitDimExpr(AffineDimExpr expr) { dims.push_back(expr); } SmallVector dims; }; @@ -1544,7 +1547,7 @@ static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at, unsigned idx, unsigned ldx) { assert(env.getLoopIdxValue(idx) == nullptr); - env.emitter().exitCurrentLoopSeq(); + env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc()); // Unmark bookkeeping of invariants and loop index. genInvariants(env, builder, exp, ldx, /*atStart=*/false); // Finalize access pattern expansion for sparse tensor output. diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir @@ -0,0 +1,300 @@ +// RUN: mlir-opt %s --sparsification="enable-slice-affine=true" --cse | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> + +#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> +// CHECK-LABEL: func.func @conv2d_all_sparse_CSR( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<3x3xi32>) -> tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> { +// CHECK: %[[VAL_2:.*]] = arith.constant 8 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant true +// CHECK: %[[VAL_10:.*]] = arith.constant false +// CHECK: %[[VAL_11:.*]] = bufferization.alloc_tensor() : tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} +// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} +// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} +// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} +// CHECK: %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] +// CHECK: %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_1]] : memref<3x3xi32> +// CHECK: %[[VAL_18:.*]] = memref.alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_19:.*]] = memref.alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref +// CHECK: memref.store %[[VAL_4]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref +// CHECK: memref.store %[[VAL_5]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref +// CHECK: memref.store %[[VAL_5]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref +// CHECK: memref.store %[[VAL_20]], %[[VAL_18]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_21:.*]] = arith.cmpi ugt, %[[VAL_20]], %[[VAL_5]] : index +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_23:.*]] = arith.cmpi uge, %[[VAL_22]], %[[VAL_3]] : index +// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_23]] : i1 +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_6]] : index +// CHECK: %[[VAL_26:.*]] = arith.subi %[[VAL_25]], %[[VAL_3]] : index +// CHECK: %[[VAL_27:.*]] = arith.select %[[VAL_24]], %[[VAL_26]], %[[VAL_5]] : index +// CHECK: %[[VAL_28:.*]] = tensor.extract_slice %[[VAL_0]]{{\[}}%[[VAL_27]], 0] {{\[}}%[[VAL_3]], 8] [1, 1] +// CHECK: %[[VAL_29:.*]]:5 = scf.while (%[[VAL_30:.*]] = %[[VAL_21]], +// CHECK-SAME: %[[VAL_31:.*]] = %[[VAL_28]], +// CHECK-SAME: %[[VAL_32:.*]] = %[[VAL_22]], +// CHECK-SAME: %[[VAL_33:.*]] = %[[VAL_27]], +// CHECK-SAME: %[[VAL_34:.*]] = %[[VAL_11]]) +// CHECK: scf.condition(%[[VAL_30]]) %[[VAL_30]], %[[VAL_31]], %[[VAL_32]], %[[VAL_33]], %[[VAL_34]] +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_35:.*]]: i1, +// CHECK-SAME: %[[VAL_36:.*]]: tensor>, +// CHECK-SAME: %[[VAL_37:.*]]: index, %[[VAL_38:.*]]: index, +// CHECK-SAME: %[[VAL_39:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>): +// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref +// +// !!!!! Code below is for slice non empty begin. +// +// CHECK: %[[VAL_41:.*]]:3 = scf.for %[[VAL_42:.*]] = %[[VAL_7]] to %[[VAL_40]] step %[[VAL_7]] iter_args(%[[VAL_43:.*]] = %[[VAL_10]], %[[VAL_44:.*]] = %[[VAL_2]], %[[VAL_45:.*]] = %[[VAL_7]]) -> (i1, index, index) { +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_42]]] : memref +// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_42]], %[[VAL_6]] : index +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_49:.*]] = arith.addi %[[VAL_38]], %[[VAL_3]] : index +// CHECK: %[[VAL_50:.*]]:5 = scf.while (%[[VAL_51:.*]] = %[[VAL_46]], %[[VAL_52:.*]] = %[[VAL_9]], %[[VAL_53:.*]] = %[[VAL_43]], %[[VAL_54:.*]] = %[[VAL_44]], %[[VAL_55:.*]] = %[[VAL_45]]) : (index, i1, i1, index, index) -> (index, i1, i1, index, index) { +// CHECK: %[[VAL_56:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_48]] : index +// CHECK: %[[VAL_57:.*]] = arith.andi %[[VAL_52]], %[[VAL_56]] : i1 +// CHECK: scf.condition(%[[VAL_57]]) %[[VAL_51]], %[[VAL_52]], %[[VAL_53]], %[[VAL_54]], %[[VAL_55]] : index, i1, i1, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_58:.*]]: index, %[[VAL_59:.*]]: i1, %[[VAL_60:.*]]: i1, %[[VAL_61:.*]]: index, %[[VAL_62:.*]]: index): +// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_64:.*]] = arith.cmpi ult, %[[VAL_63]], %[[VAL_49]] : index +// CHECK: %[[VAL_65:.*]]:3 = scf.if %[[VAL_64]] -> (i1, index, index) { +// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_58]], %[[VAL_6]] : index +// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_66]]] : memref +// CHECK: %[[VAL_69:.*]] = arith.cmpi ult, %[[VAL_67]], %[[VAL_68]] : index +// CHECK: %[[VAL_70:.*]] = arith.ori %[[VAL_69]], %[[VAL_60]] : i1 +// CHECK: %[[VAL_71:.*]] = scf.if %[[VAL_69]] -> (index) { +// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_67]]] : memref +// CHECK: %[[VAL_73:.*]] = arith.cmpi ult, %[[VAL_72]], %[[VAL_61]] : index +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_73]], %[[VAL_72]], %[[VAL_61]] : index +// CHECK: scf.yield %[[VAL_74]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_61]] : index +// CHECK: } +// CHECK: memref.store %[[VAL_67]], %[[VAL_19]]{{\[}}%[[VAL_62]]] : memref +// CHECK: %[[VAL_75:.*]] = arith.addi %[[VAL_62]], %[[VAL_6]] : index +// CHECK: memref.store %[[VAL_68]], %[[VAL_19]]{{\[}}%[[VAL_75]]] : memref +// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_62]], %[[VAL_7]] : index +// CHECK: scf.yield %[[VAL_70]], %[[VAL_77:.*]], %[[VAL_76]] : i1, index, index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_60]], %[[VAL_61]], %[[VAL_62]] : i1, index, index +// CHECK: } {"Emitted from" = "slice"} +// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_58]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_78]], %[[VAL_64]], %[[VAL_79:.*]]#0, %[[VAL_79]]#1, %[[VAL_79]]#2 : index, i1, i1, index, index +// CHECK: } +// CHECK: scf.yield %[[VAL_80:.*]]#2, %[[VAL_80]]#3, %[[VAL_80]]#4 : i1, index, index +// CHECK: } +// CHECK: memref.store %[[VAL_81:.*]]#2, %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref +// CHECK: memref.store %[[VAL_5]], %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_82:.*]] = arith.cmpi uge, %[[VAL_81]]#1, %[[VAL_3]] : index +// CHECK: %[[VAL_83:.*]] = arith.andi %[[VAL_81]]#0, %[[VAL_82]] : i1 +// CHECK: %[[VAL_84:.*]] = arith.addi %[[VAL_81]]#1, %[[VAL_6]] : index +// CHECK: %[[VAL_85:.*]] = arith.subi %[[VAL_84]], %[[VAL_3]] : index +// CHECK: %[[VAL_86:.*]] = arith.select %[[VAL_83]], %[[VAL_85]], %[[VAL_5]] : index +// CHECK: %[[VAL_87:.*]] = tensor.extract_slice %[[VAL_36]][0, %[[VAL_86]]] {{\[}}%[[VAL_2]], %[[VAL_3]]] [1, 1] : tensor> to tensor> +// CHECK: %[[VAL_88:.*]]:5 = scf.while (%[[VAL_89:.*]] = %[[VAL_81]]#0, %[[VAL_90:.*]] = %[[VAL_87]], %[[VAL_91:.*]] = %[[VAL_81]]#1, %[[VAL_92:.*]] = %[[VAL_86]], %[[VAL_93:.*]] = %[[VAL_39]]) : (i1, tensor>, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> (i1, tensor>, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) { +// CHECK: scf.condition(%[[VAL_89]]) %[[VAL_89]], %[[VAL_90]], %[[VAL_91]], %[[VAL_92]], %[[VAL_93]] : i1, tensor>, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_94:.*]]: i1, %[[VAL_95:.*]]: tensor>, %[[VAL_96:.*]]: index, %[[VAL_97:.*]]: index, %[[VAL_98:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>): +// CHECK: %[[VAL_99:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_100:.*]] = arith.addi %[[VAL_99]], %[[VAL_7]] : index +// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_100]], %[[VAL_6]] : index +// CHECK: %[[VAL_102:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_100]]] : memref +// CHECK: %[[VAL_103:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_101]]] : memref +// CHECK: %[[VAL_104:.*]] = arith.addi %[[VAL_38]], %[[VAL_3]] : index +// CHECK: %[[VAL_105:.*]]:4 = scf.while (%[[VAL_106:.*]] = %[[VAL_102]], %[[VAL_107:.*]] = %[[VAL_9]], %[[VAL_108:.*]] = %[[VAL_8]], %[[VAL_109:.*]] = %[[VAL_98]]) : (index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> (index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) { +// CHECK: %[[VAL_110:.*]] = arith.cmpi ult, %[[VAL_106]], %[[VAL_103]] : index +// CHECK: %[[VAL_111:.*]] = arith.andi %[[VAL_107]], %[[VAL_110]] : i1 +// CHECK: scf.condition(%[[VAL_111]]) %[[VAL_106]], %[[VAL_107]], %[[VAL_108]], %[[VAL_109]] : index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_112:.*]]: index, %[[VAL_113:.*]]: i1, %[[VAL_114:.*]]: i32, %[[VAL_115:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>): +// CHECK: %[[VAL_116:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_112]]] : memref +// CHECK: %[[VAL_117:.*]] = arith.cmpi ult, %[[VAL_116]], %[[VAL_104]] : index +// CHECK: %[[VAL_118:.*]]:2 = scf.if %[[VAL_117]] -> (i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) { +// CHECK: %[[VAL_119:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_112]]] : memref +// CHECK: %[[VAL_120:.*]] = arith.subi %[[VAL_119]], %[[VAL_38]] : index +// CHECK: %[[VAL_121:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_122:.*]] = arith.addi %[[VAL_121]], %[[VAL_7]] : index +// CHECK: %[[VAL_123:.*]] = arith.addi %[[VAL_122]], %[[VAL_6]] : index +// CHECK: %[[VAL_124:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_122]]] : memref +// CHECK: %[[VAL_125:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_123]]] : memref +// CHECK: %[[VAL_126:.*]] = arith.addi %[[VAL_97]], %[[VAL_3]] : index +// CHECK: %[[VAL_127:.*]]:4 = scf.while (%[[VAL_128:.*]] = %[[VAL_124]], %[[VAL_129:.*]] = %[[VAL_9]], %[[VAL_130:.*]] = %[[VAL_114]], %[[VAL_131:.*]] = %[[VAL_115]]) : (index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) -> (index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) { +// CHECK: %[[VAL_132:.*]] = arith.cmpi ult, %[[VAL_128]], %[[VAL_125]] : index +// CHECK: %[[VAL_133:.*]] = arith.andi %[[VAL_129]], %[[VAL_132]] : i1 +// CHECK: scf.condition(%[[VAL_133]]) %[[VAL_128]], %[[VAL_129]], %[[VAL_130]], %[[VAL_131]] : index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } do { +// +// !!!!! Code below is the actually convolution kernel +// +// CHECK: ^bb0(%[[VAL_134:.*]]: index, %[[VAL_135:.*]]: i1, %[[VAL_136:.*]]: i32, %[[VAL_137:.*]]: tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>): +// CHECK: %[[VAL_138:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_134]]] : memref +// CHECK: %[[VAL_139:.*]] = arith.cmpi ult, %[[VAL_138]], %[[VAL_126]] : index +// CHECK: %[[VAL_140:.*]]:2 = scf.if %[[VAL_139]] -> (i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>) { +// CHECK: %[[VAL_141:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_134]]] : memref +// CHECK: %[[VAL_142:.*]] = arith.subi %[[VAL_141]], %[[VAL_97]] : index +// CHECK: %[[VAL_143:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_134]]] : memref +// CHECK: %[[VAL_144:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_120]], %[[VAL_142]]] : memref<3x3xi32> +// CHECK: %[[VAL_145:.*]] = arith.muli %[[VAL_143]], %[[VAL_144]] : i32 +// CHECK: %[[VAL_146:.*]] = arith.addi %[[VAL_136]], %[[VAL_145]] : i32 +// CHECK: scf.yield %[[VAL_146]], %[[VAL_137]] : i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_136]], %[[VAL_137]] : i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } {"Emitted from" = "slice"} +// CHECK: %[[VAL_147:.*]] = arith.addi %[[VAL_134]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_147]], %[[VAL_139]], %[[VAL_148:.*]]#0, %[[VAL_148]]#1 : index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } attributes {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_149:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_150:.*]] = arith.addi %[[VAL_149]], %[[VAL_7]] : index +// CHECK: memref.store %[[VAL_150]], %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref +// CHECK: scf.yield %[[VAL_151:.*]]#2, %[[VAL_151]]#3 : i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_114]], %[[VAL_115]] : i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } {"Emitted from" = "slice"} +// CHECK: %[[VAL_152:.*]] = arith.addi %[[VAL_112]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_152]], %[[VAL_117]], %[[VAL_153:.*]]#0, %[[VAL_153]]#1 : index, i1, i32, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } attributes {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_154:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_155:.*]] = arith.addi %[[VAL_154]], %[[VAL_7]] : index +// CHECK: memref.store %[[VAL_155]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_156:.*]] = sparse_tensor.insert %[[VAL_157:.*]]#2 into %[[VAL_157]]#3{{\[}}%[[VAL_38]], %[[VAL_97]]] : tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: memref.store %[[VAL_5]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref +// CHECK: memref.store %[[VAL_5]], %[[VAL_19]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_158:.*]] = arith.cmpi ugt, %[[VAL_96]], %[[VAL_97]] : index +// CHECK: %[[VAL_159:.*]]:3 = scf.if %[[VAL_158]] -> (index, i1, index) { +// CHECK: %[[VAL_160:.*]] = arith.addi %[[VAL_97]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_96]], %[[VAL_94]], %[[VAL_160]] : index, i1, index +// CHECK: } else { +// CHECK: %[[VAL_161:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_162:.*]]:2 = scf.for %[[VAL_163:.*]] = %[[VAL_7]] to %[[VAL_161]] step %[[VAL_7]] iter_args(%[[VAL_164:.*]] = %[[VAL_2]], %[[VAL_165:.*]] = %[[VAL_10]]) -> (index, i1) { +// CHECK: %[[VAL_166:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_163]]] : memref +// CHECK: %[[VAL_167:.*]] = arith.addi %[[VAL_163]], %[[VAL_6]] : index +// CHECK: %[[VAL_168:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_167]]] : memref +// CHECK: %[[VAL_169:.*]] = arith.cmpi ult, %[[VAL_166]], %[[VAL_168]] : index +// CHECK: %[[VAL_170:.*]] = scf.if %[[VAL_169]] -> (index) { +// CHECK: %[[VAL_171:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_166]]] : memref +// CHECK: %[[VAL_172:.*]] = arith.cmpi eq, %[[VAL_171]], %[[VAL_96]] : index +// CHECK: %[[VAL_173:.*]] = scf.if %[[VAL_172]] -> (index) { +// CHECK: %[[VAL_174:.*]] = arith.addi %[[VAL_166]], %[[VAL_6]] : index +// CHECK: memref.store %[[VAL_174]], %[[VAL_19]]{{\[}}%[[VAL_163]]] : memref +// CHECK: scf.yield %[[VAL_174]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_166]] : index +// CHECK: } +// CHECK: scf.yield %[[VAL_175:.*]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_166]] : index +// CHECK: } +// CHECK: %[[VAL_176:.*]] = arith.cmpi ult, %[[VAL_177:.*]], %[[VAL_168]] : index +// CHECK: %[[VAL_178:.*]] = scf.if %[[VAL_176]] -> (index) { +// CHECK: %[[VAL_179:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_177]]] : memref +// CHECK: scf.yield %[[VAL_179]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_164]] : index +// CHECK: } +// CHECK: %[[VAL_180:.*]] = arith.ori %[[VAL_176]], %[[VAL_165]] : i1 +// CHECK: %[[VAL_181:.*]] = arith.cmpi ult, %[[VAL_182:.*]], %[[VAL_164]] : index +// CHECK: %[[VAL_183:.*]] = arith.select %[[VAL_181]], %[[VAL_182]], %[[VAL_164]] : index +// CHECK: scf.yield %[[VAL_183]], %[[VAL_180]] : index, i1 +// CHECK: } +// CHECK: %[[VAL_184:.*]] = arith.addi %[[VAL_185:.*]]#0, %[[VAL_6]] : index +// CHECK: %[[VAL_186:.*]] = arith.subi %[[VAL_184]], %[[VAL_3]] : index +// CHECK: %[[VAL_187:.*]] = arith.cmpi uge, %[[VAL_184]], %[[VAL_3]] : index +// CHECK: %[[VAL_188:.*]] = arith.select %[[VAL_187]], %[[VAL_186]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_185]]#0, %[[VAL_185]]#1, %[[VAL_188]] : index, i1, index +// CHECK: } +// +// !!!!! Code below is for slice non empty next. +// +// CHECK: %[[VAL_189:.*]] = arith.addi %[[VAL_97]], %[[VAL_6]] : index +// CHECK: %[[VAL_190:.*]] = arith.cmpi ugt, %[[VAL_191:.*]]#2, %[[VAL_189]] : index +// CHECK: %[[VAL_192:.*]] = arith.select %[[VAL_190]], %[[VAL_191]]#2, %[[VAL_189]] : index +// CHECK: %[[VAL_193:.*]] = arith.addi %[[VAL_192]], %[[VAL_3]] : index +// CHECK: %[[VAL_194:.*]] = arith.cmpi ule, %[[VAL_193]], %[[VAL_2]] : index +// CHECK: %[[VAL_195:.*]] = arith.andi %[[VAL_191]]#1, %[[VAL_194]] : i1 +// CHECK: %[[VAL_196:.*]] = arith.select %[[VAL_195]], %[[VAL_192]], %[[VAL_5]] : index +// CHECK: %[[VAL_197:.*]] = tensor.extract_slice %[[VAL_36]][0, %[[VAL_196]]] {{\[}}%[[VAL_2]], %[[VAL_3]]] [1, 1] : tensor> to tensor> +// CHECK: scf.yield %[[VAL_195]], %[[VAL_197]], %[[VAL_191]]#0, %[[VAL_192]], %[[VAL_156]] : i1, tensor>, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } attributes {"Emitted from" = "linalg.generic"} +// CHECK: memref.store %[[VAL_5]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_198:.*]] = arith.cmpi ugt, %[[VAL_37]], %[[VAL_38]] : index +// CHECK: %[[VAL_199:.*]]:3 = scf.if %[[VAL_198]] -> (index, i1, index) { +// CHECK: %[[VAL_200:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_37]], %[[VAL_35]], %[[VAL_200]] : index, i1, index +// CHECK: } else { +// CHECK: %[[VAL_201:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_202:.*]]:2 = scf.for %[[VAL_203:.*]] = %[[VAL_7]] to %[[VAL_201]] step %[[VAL_7]] iter_args(%[[VAL_204:.*]] = %[[VAL_2]], %[[VAL_205:.*]] = %[[VAL_10]]) -> (index, i1) { +// CHECK: %[[VAL_206:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_203]]] : memref +// CHECK: %[[VAL_207:.*]] = arith.addi %[[VAL_203]], %[[VAL_6]] : index +// CHECK: %[[VAL_208:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_207]]] : memref +// CHECK: %[[VAL_209:.*]] = arith.cmpi ult, %[[VAL_206]], %[[VAL_208]] : index +// CHECK: %[[VAL_210:.*]] = scf.if %[[VAL_209]] -> (index) { +// CHECK: %[[VAL_211:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_206]]] : memref +// CHECK: %[[VAL_212:.*]] = arith.cmpi eq, %[[VAL_211]], %[[VAL_37]] : index +// CHECK: %[[VAL_213:.*]] = scf.if %[[VAL_212]] -> (index) { +// CHECK: %[[VAL_214:.*]] = arith.addi %[[VAL_206]], %[[VAL_6]] : index +// CHECK: memref.store %[[VAL_214]], %[[VAL_18]]{{\[}}%[[VAL_203]]] : memref +// CHECK: scf.yield %[[VAL_214]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_206]] : index +// CHECK: } +// CHECK: scf.yield %[[VAL_215:.*]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_206]] : index +// CHECK: } +// CHECK: %[[VAL_216:.*]] = arith.cmpi ult, %[[VAL_217:.*]], %[[VAL_208]] : index +// CHECK: %[[VAL_218:.*]] = scf.if %[[VAL_216]] -> (index) { +// CHECK: %[[VAL_219:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_217]]] : memref +// CHECK: scf.yield %[[VAL_219]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_204]] : index +// CHECK: } +// CHECK: %[[VAL_220:.*]] = arith.ori %[[VAL_216]], %[[VAL_205]] : i1 +// CHECK: %[[VAL_221:.*]] = arith.cmpi ult, %[[VAL_222:.*]], %[[VAL_204]] : index +// CHECK: %[[VAL_223:.*]] = arith.select %[[VAL_221]], %[[VAL_222]], %[[VAL_204]] : index +// CHECK: scf.yield %[[VAL_223]], %[[VAL_220]] : index, i1 +// CHECK: } +// CHECK: %[[VAL_224:.*]] = arith.addi %[[VAL_225:.*]]#0, %[[VAL_6]] : index +// CHECK: %[[VAL_226:.*]] = arith.subi %[[VAL_224]], %[[VAL_3]] : index +// CHECK: %[[VAL_227:.*]] = arith.cmpi uge, %[[VAL_224]], %[[VAL_3]] : index +// CHECK: %[[VAL_228:.*]] = arith.select %[[VAL_227]], %[[VAL_226]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_225]]#0, %[[VAL_225]]#1, %[[VAL_228]] : index, i1, index +// CHECK: } +// CHECK: %[[VAL_229:.*]] = arith.addi %[[VAL_38]], %[[VAL_6]] : index +// CHECK: %[[VAL_230:.*]] = arith.cmpi ugt, %[[VAL_231:.*]]#2, %[[VAL_229]] : index +// CHECK: %[[VAL_232:.*]] = arith.select %[[VAL_230]], %[[VAL_231]]#2, %[[VAL_229]] : index +// CHECK: %[[VAL_233:.*]] = arith.addi %[[VAL_232]], %[[VAL_3]] : index +// CHECK: %[[VAL_234:.*]] = arith.cmpi ule, %[[VAL_233]], %[[VAL_2]] : index +// CHECK: %[[VAL_235:.*]] = arith.andi %[[VAL_231]]#1, %[[VAL_234]] : i1 +// CHECK: %[[VAL_236:.*]] = arith.select %[[VAL_235]], %[[VAL_232]], %[[VAL_5]] : index +// CHECK: %[[VAL_237:.*]] = tensor.extract_slice %[[VAL_36]]{{\[}}%[[VAL_236]], 0] {{\[}}%[[VAL_3]], 8] [1, 1] : tensor> to tensor> +// CHECK: scf.yield %[[VAL_235]], %[[VAL_237]], %[[VAL_231]]#0, %[[VAL_232]], %[[VAL_238:.*]]#4 : i1, tensor>, index, index, tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } attributes {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_239:.*]] = sparse_tensor.load %[[VAL_240:.*]]#4 hasInserts : tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: return %[[VAL_239]] : tensor<6x6xi32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>> +// CHECK: } +func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>, + %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> { + %0 = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR> + %1 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : tensor<8x8xi32, #DCSR>, tensor<3x3xi32>) + outs(%0 : tensor<6x6xi32, #DCSR>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %2 = arith.muli %in, %in_0 : i32 + %3 = arith.addi %out, %2 : i32 + linalg.yield %3 : i32 + } -> tensor<6x6xi32, #DCSR> + return %1 : tensor<6x6xi32, #DCSR> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_slice_based.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_slice_based.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_slice_based.mlir @@ -0,0 +1,81 @@ +// DEFINE: %{option} = "enable-slice-affine=true enable-runtime-library=false" +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} + +#map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> + +#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +module { + func.func @conv2d_all_sparse_CSR(%arg0: tensor<8x8xi32, #DCSR>, %arg1: tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> { + %0 = bufferization.alloc_tensor() : tensor<6x6xi32, #DCSR> + %1 = linalg.generic { + indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + ins(%arg0, %arg1 : tensor<8x8xi32, #DCSR>, tensor<3x3xi32>) + outs(%0 : tensor<6x6xi32, #DCSR>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %2 = arith.muli %in, %in_0 : i32 + %3 = arith.addi %out, %2 : i32 + linalg.yield %3 : i32 + } -> tensor<6x6xi32, #DCSR> + return %1 : tensor<6x6xi32, #DCSR> + } + + func.func @entry() { + %c0 = arith.constant 0 : index + %i0 = arith.constant 0 : i32 + + // A typical edge detection filter. + %filter = arith.constant dense<[ + [ 1, 0, -1 ], + [ 0, 0, 0 ], + [ -1, 0, 1 ] + ]> : tensor<3x3xi32> + + %input = arith.constant dense<[ + [ 1, 2, 3, 4, 0, 6, 7, 8 ], + [ 2, 2, 4, 4, 0, 0, 6, 8 ], + [ 2, 2, 4, 4, 0, 0, 6, 8 ], + [ 2, 2, 3, 4, 0, 0, 7, 8 ], + [ 1, 3, 3, 4, 0, 0, 6, 8 ], + [ 3, 2, 3, 4, 0, 0, 7, 8 ], + [ 1, 3, 3, 4, 3, 6, 6, 8 ], + [ 1, 3, 3, 4, 3, 0, 7, 8 ] + ]> : tensor<8x8xi32> + + %sparse_filter_CSR = sparse_tensor.convert %filter + : tensor<3x3xi32> to tensor<3x3xi32> + + %sparse_input_CSR = sparse_tensor.convert %input + : tensor<8x8xi32> to tensor<8x8xi32, #DCSR> + + %3 = call @conv2d_all_sparse_CSR(%sparse_input_CSR, %sparse_filter_CSR) + : (tensor<8x8xi32, #DCSR>, + tensor<3x3xi32>) -> tensor<6x6xi32, #DCSR> + + %out = sparse_tensor.convert %3 + : tensor<6x6xi32, #DCSR> to tensor<6x6xi32> + // + // CHECK: ( ( 0, 0, -1, -6, -1, 6 ), + // CHECK-SAME: ( -1, 0, 1, 0, 1, 0 ), + // CHECK-SAME: ( 0, -1, 1, 0, 0, 0 ), + // CHECK-SAME: ( -1, 0, 0, 0, 0, 0 ), + // CHECK-SAME: ( 0, 0, 3, 6, -3, -6 ), + // CHECK-SAME: ( 2, -1, 3, 0, -3, 0 ) ) + // + %v2 = vector.transfer_read %out[%c0, %c0], %i0 + : tensor<6x6xi32>, vector<6x6xi32> + vector.print %v2 : vector<6x6xi32> + + return + } + +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_slice_based.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_slice_based.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_3d_slice_based.mlir @@ -0,0 +1,97 @@ +// DEFINE: %{option} = "enable-slice-affine=true enable-runtime-library=false" +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} + +#CCC = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed" ] +}> + +func.func @alloc_3d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %f : f32) -> tensor { + %buf = bufferization.alloc_tensor(%s1, %s2, %s3) : tensor + %ret = linalg.fill ins(%f : f32) outs(%buf : tensor) -> tensor + return %ret : tensor +} + +func.func @conv_3d_CCC(%arg0: tensor, %arg1: tensor) -> tensor { + %c6 = arith.constant 6 : index + %s = bufferization.alloc_tensor(%c6, %c6, %c6) : tensor + %ret = linalg.conv_3d + ins (%arg0, %arg1: tensor, tensor) + outs (%s: tensor) -> tensor + return %ret : tensor +} + +func.func @entry() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %f10 = arith.constant 10.00000e+00 : f32 + %val = arith.constant 2.00000e+00 : f32 + %zero = arith.constant 0.00000e+00 : f32 + + %filter3D = call @alloc_3d_filled_f32(%c3, %c3, %c3, %val) : (index, index, index, f32) -> (tensor) + %in3D_tmp = call @alloc_3d_filled_f32(%c8, %c8, %c8, %val) : (index, index, index, f32) -> (tensor) + %in3D = tensor.insert %f10 into %in3D_tmp[%c0, %c3, %c0] : tensor + %out3D = call @alloc_3d_filled_f32(%c6, %c6, %c6, %zero) : (index, index, index, f32) -> (tensor) + + %in3D_CCC = sparse_tensor.convert %in3D + : tensor to tensor + %CCC_ret = call @conv_3d_CCC(%in3D_CCC, %filter3D) : (tensor, tensor) -> (tensor) + // CHECK: ( ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 124, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ), + // CHECK-SAME: ( ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ), + // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ) + %1 = sparse_tensor.convert %CCC_ret + : tensor to tensor + %v1 = vector.transfer_read %1[%c0, %c0, %c0], %zero + : tensor, vector<6x6x6xf32> + vector.print %v1 : vector<6x6x6xf32> + + // Free the resources + bufferization.dealloc_tensor %in3D : tensor + bufferization.dealloc_tensor %filter3D : tensor + + bufferization.dealloc_tensor %in3D_CCC : tensor + bufferization.dealloc_tensor %CCC_ret : tensor + + return +}