diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -293,6 +293,8 @@ I32EnumAttrCase<"PtrMemSize", 1, "ptr_mem_sz">, I32EnumAttrCase<"IdxMemSize", 2, "idx_mem_sz">, I32EnumAttrCase<"ValMemSize", 3, "val_mem_sz">, + I32EnumAttrCase<"DimOffset", 4, "dim_offset">, + I32EnumAttrCase<"DimStride", 5, "dim_stride">, ]> { let genSpecializedAttr = 0; let cppNamespace = SparseTensor_Dialect.cppNamespace; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -236,20 +236,34 @@ } def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>, + Arguments<(ins Optional:$source)>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; let description = [{ Returns an initial storage specifier value. A storage specifier value holds the sizes for tensor dimensions, pointer arrays, index arrays, and the value array. + If this is a specifier for slices, it also holds the extra strides/offsets for tensor + dimensions. Example: ```mlir %0 = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier<#CSR> + %1 = sparse_tensor.storage_specifier.init with %src + : !sparse_tensor.storage_specifier<#CSR> to + !sparse_tensor.storage_specifier<#CSR_SLICE> ``` }]; - let assemblyFormat = "attr-dict `:` qualified(type($result))"; + let builders = [ + OpBuilder<(ins "Type":$result), + [{ + build($_builder, $_state, result, Value()); + }]> + ]; + + let assemblyFormat = "attr-dict (`with` $source^)? `:` qualified(type($result))" + "(`to` qualified(type($source))^)?"; } def SparseTensor_GetStorageSpecifierOp : SparseTensor_Op<"storage_specifier.get", [Pure]>, diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -512,10 +512,7 @@ enc.getContext(), dlts, AffineMap(), // dimOrdering (irrelavant to storage speicifer) AffineMap(), // highLvlOrdering (irrelavant to storage specifer) - enc.getPointerBitWidth(), enc.getIndexBitWidth(), - // FIXME: we should keep the slice information, for now it is okay as only - // constant can be used for slice - ArrayRef{} /*enc.getDimSlices()*/); + enc.getPointerBitWidth(), enc.getIndexBitWidth(), enc.getDimSlices()); } StorageSpecifierType diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -339,6 +339,15 @@ /// Generates code to retrieve the values size for the sparse tensor. Value genValMemSize(OpBuilder &builder, Location loc, Value tensor); +/// Generate scode to retrieve the slice offset for the sparse tensor slice, +/// return a constant if the offset is statically known. +Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, + unsigned dim); + +/// Generate scode to retrieve the slice slice for the sparse tensor slice, +/// return a constant if the offset is statically known. +Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, + unsigned dim); } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -556,7 +556,26 @@ Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc, Value tensor) { - SmallVector fields; - auto desc = getMutDescriptorFromTensorTuple(tensor, fields); + auto desc = getDescriptorFromTensorTuple(tensor); return desc.getValMemSize(builder, loc); -} \ No newline at end of file +} + +Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, + Value tensor, unsigned dim) { + auto enc = getSparseTensorEncoding(tensor.getType()); + assert(enc && enc.isSlice()); + std::optional offset = enc.getStaticDimSliceOffset(dim); + if (enc.getStaticDimSliceOffset(dim)) + return constantIndex(builder, loc, *offset); + return builder.create(loc, tensor, APInt(64, dim)); +} + +Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, + Value tensor, unsigned dim) { + auto enc = getSparseTensorEncoding(tensor.getType()); + assert(enc && enc.isSlice()); + std::optional offset = enc.getStaticDimSliceStride(dim); + if (enc.getStaticDimSliceStride(dim)) + return constantIndex(builder, loc, *offset); + return builder.create(loc, tensor, APInt(64, dim)); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -190,6 +190,12 @@ Value genAddress(OpBuilder &builder, Location loc, size_t tid, size_t dim, Value iv); + /// Generate a predicate to determine whether the tranformed coordinates is + /// on the given slice. + /// Returns std::pair + std::pair genSliceLegitPredicate(OpBuilder &builder, + Location loc, Value coord, + unsigned tid, unsigned lvl); bool isOutputTensor(size_t tid) { return hasOutput && tid == tensors.size() - 1; } @@ -249,8 +255,12 @@ bool isSparseOut; /// Input and (optional) output tensors. std::vector tensors; + /// Values realted to slices. std::vector isSparseSlices; - /// The dim type array for each tensor. + std::vector> sliceOffsets; + std::vector> sliceStrides; + /// The dim and dim type array for each tensor. + std::vector> dims; std::vector> dimTypes; /// Sparse iteration information (by tensor and dim). These arrays /// are updated to remain current within the current loop. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -42,27 +42,21 @@ return load; } -// TODO: Support dynamic sized slice. -static Value getSliceOffset(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { - return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl)); +static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, + unsigned lvl) { + auto enc = getSparseTensorEncoding(tensor.getType()); + return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl)); } -static Value getSliceSize(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { - return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl)); -} - -static Value getSliceStride(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { - return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl)); +static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, + unsigned lvl) { + auto enc = getSparseTensorEncoding(tensor.getType()); + return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl)); } static Value toSliceCoord(OpBuilder &builder, Location loc, Value v, - SparseTensorEncodingAttr enc, unsigned lvl) { - - Value stride = getSliceStride(builder, loc, enc, lvl); - Value offset = getSliceOffset(builder, loc, enc, lvl); + Value offset, Value stride, Value tensor, + unsigned lvl) { // iv = iv * stride + offset v = builder.create(loc, v, stride); v = builder.create(loc, v, offset); @@ -70,39 +64,57 @@ } static std::pair fromSliceCoord(OpBuilder &builder, Location loc, - Value v, - SparseTensorEncodingAttr enc, + Value iv, Value offset, + Value stride, Value tensor, unsigned lvl) { - Value stride = getSliceStride(builder, loc, enc, lvl); - Value offset = getSliceOffset(builder, loc, enc, lvl); // iv = (iv - offset) / stride - v = builder.create(loc, v, offset); - Value rem = builder.create(loc, v, stride); - v = builder.create(loc, v, stride); - return std::make_pair(v, rem); + iv = builder.create(loc, iv, offset); + Value rem = builder.create(loc, iv, stride); + iv = builder.create(loc, iv, stride); + return std::make_pair(iv, rem); } -static std::pair -genSliceLegitPredicate(OpBuilder &builder, Location loc, Value coord, - SparseTensorEncodingAttr enc, unsigned lvl) { - std::pair trans = fromSliceCoord(builder, loc, coord, enc, lvl); - // 1st, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip - // the check if the offset is zero). - auto geOffset = - builder.create(loc, arith::CmpIPredicate::uge, coord, - getSliceOffset(builder, loc, enc, lvl)); +std::pair +LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, + Value coord, unsigned tid, unsigned lvl) { + assert(isSparseSlices[tid]); + Value slice = tensors[tid]; + Value offset = sliceOffsets[tid][lvl]; + Value stride = sliceStrides[tid][lvl]; + auto enc = getSparseTensorEncoding(slice.getType()); + + std::pair trans = + fromSliceCoord(builder, loc, coord, offset, stride, slice, lvl); + + SmallVector conds; // at most 3 conditions + + if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl); + !(staticOffset.has_value() && *staticOffset == 0)) { + // 1st, coord >= offset (skip the check if offset is known to be 0). + auto geOffset = builder.create( + loc, arith::CmpIPredicate::uge, coord, offset); + conds.push_back(geOffset); + } + // 2nd, coord_in_slice < length - auto ltLength = - builder.create(loc, arith::CmpIPredicate::ult, trans.first, - getSliceSize(builder, loc, enc, lvl)); + auto ltLength = builder.create(loc, arith::CmpIPredicate::ult, + trans.first, dims[tid][lvl]); + conds.push_back(ltLength); + + if (auto staticStride = enc.getStaticLvlSliceStride(lvl); + !(staticStride.has_value() && *staticStride == 1)) { + // 3rd, rem == 0 (skip the check if stride is known to be 1). + auto fitStride = builder.create( + loc, arith::CmpIPredicate::eq, trans.second, + constantIndex(builder, loc, 0)); + conds.push_back(fitStride); + } - // 3rd, rem == 0; confirmed that (a % 1) will be folded to 0 - auto fitStride = - builder.create(loc, arith::CmpIPredicate::eq, trans.second, - constantIndex(builder, loc, 0)); + // Must meet all condition to be a valid coordinate in slice. + auto pred = conds.front(); + for (unsigned i = 1, e = conds.size(); i < e; i++) + pred = builder.create(loc, pred, conds[i]); - auto pred = builder.create(loc, geOffset, ltLength); - pred = builder.create(loc, pred, fitStride); return {trans.first, pred}; } @@ -114,10 +126,9 @@ size_t dim, Value iv) { Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1]; Value mul = builder.create(loc, highs[tid][dim], p); - if (isSparseSlices[tid]) { - auto enc = getSparseTensorEncoding(tensors[tid].getType()); - iv = toSliceCoord(builder, loc, iv, enc, dim); - } + if (isSparseSlices[tid]) + iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim], + sliceStrides[tid][dim], tensors[tid], dim); Value add = builder.create(loc, mul, iv); return add; } @@ -136,6 +147,9 @@ this->isSparseOut = isSparseOut; this->tensors.assign(tensors.begin(), tensors.end()); this->isSparseSlices.assign(tensors.size(), false); + this->sliceOffsets.assign(tensors.size(), std::vector()); + this->sliceStrides.assign(tensors.size(), std::vector()); + this->dims.assign(tensors.size(), std::vector()); this->dimTypes.assign(tensors.size(), std::vector()); this->pidxs.assign(tensors.size(), std::vector()); this->coord.assign(tensors.size(), std::vector()); @@ -164,6 +178,9 @@ dimTypes[tid].assign(rank, DimLevelType::Dense); // Initialize using empty value. + sliceOffsets[tid].assign(rank, Value()); + sliceStrides[tid].assign(rank, Value()); + dims[tid].assign(rank, Value()); pidxs[tid].assign(rank, Value()); coord[tid].assign(rank, Value()); highs[tid].assign(rank, Value()); @@ -174,8 +191,8 @@ // FIXME: This map should be maintained outside loop emitter. for (unsigned i = 0, e = topSort.size(); i < e; i++) { // This is an inverse map of the topologically sorted loop index from - // sparsifier. This is needed to map the AffineDimExpr back to the loopStack - // index used in loop emitter. + // sparsifier. This is needed to map the AffineDimExpr back to the + // loopStack index used in loop emitter. sparsiferLoopLvlMap[topSort[i]] = i; } } @@ -215,13 +232,16 @@ // Find upper bound in current dimension. unsigned p = toOrigDim(enc, d); Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, p); - highs[t][d] = up; + highs[t][d] = dims[t][d] = up; + if (isSparseSlices[t]) { + sliceOffsets[t][d] = genSliceOffset(builder, loc, tensors[t], d); + sliceStrides[t][d] = genSliceStride(builder, loc, tensors[t], d); + } } // Perform the required bufferization. Dense inputs materialize - // from the input tensors. Sparse inputs use sparse primitives to obtain the - // values. - // Delegates extra output initialization to clients. + // from the input tensors. Sparse inputs use sparse primitives to obtain + // the values. Delegates extra output initialization to clients. bool isOutput = isOutputTensor(t); Type elementType = rtp.getElementType(); if (!enc) { @@ -236,14 +256,15 @@ valBuffer[t] = denseVal; } else { // Annotated sparse tensors. - // We also need the value buffer for annotated all dense `sparse` tensor. + // We also need the value buffer for annotated all dense `sparse` + // tensor. auto dynShape = {ShapedType::kDynamic}; auto sparseTp = MemRefType::get(dynShape, elementType); valBuffer[t] = builder.create(loc, sparseTp, tensor); } // NOTE: we can also prepare for 0 dim here in advance, this will hosit - // some loop preparation from tensor iteration, but will also (undesirably) - // hosit the code ouside if conditions. + // some loop preparation from tensor iteration, but will also + // (undesirably) hosit the code ouside if conditions. } } @@ -301,8 +322,8 @@ assert(isDenseDLT(dimType) || isCompressedDLT(dimType) || isSingletonDLT(dimType)); bool isSparse = isCompressedDLT(dimType) || isSingletonDLT(dimType); - // We can at most have one sparse input, otherwise, a while loop is required - // to co-iterate multiple sparse tensors. + // We can at most have one sparse input, otherwise, a while loop is + // required to co-iterate multiple sparse tensors. assert(!isSparseInput || !isSparse); if (isSparse) { tid = t; @@ -311,7 +332,6 @@ isSparseInput = isSparseInput || isSparse; } - auto enc = getSparseTensorEncoding(tensors[tid].getType()); // TODO: support dynamic slices. Value step = constantIndex(builder, loc, 1); Value lo = isSparseInput ? pidxs[tid][dim] // current offset @@ -328,12 +348,12 @@ iv = parOp.getInductionVars()[0]; // In-place update on the reduction variable vector. - // Note that the init vals is not the actual reduction variables but instead - // used as a `special handle` to (temporarily) represent them. The + // Note that the init vals is not the actual reduction variables but + // instead used as a `special handle` to (temporarily) represent them. The // expression on init vals will be moved into scf.reduce and replaced with // the block arguments when exiting the loop (see exitForLoop). This is - // needed as we can not build the actual reduction block and get the actual - // reduction varaible before users fill parallel loop body. + // needed as we can not build the actual reduction block and get the + // actual reduction varaible before users fill parallel loop body. for (int i = 0, e = reduc.size(); i < e; i++) reduc[i] = parOp.getInitVals()[i]; loop = parOp; @@ -368,7 +388,7 @@ for (Value red : reduc) types.push_back(red.getType()); - auto [trans, pred] = genSliceLegitPredicate(builder, loc, c, enc, dim); + auto [trans, pred] = genSliceLegitPredicate(builder, loc, c, tid, dim); bool hasReduc = !types.empty(); scf::IfOp ifOp = builder.create(loc, types, pred, /*else*/ hasReduc); @@ -416,8 +436,9 @@ // TODO: We should instead use a whileOp for filter loop to allow early // break when exceeding (for ordered dimensions). - // TODO: There are many other potiential opportunities that we might apply in - // the future. E.g., we could use binary search to located the pointer index. + // TODO: There are many other potiential opportunities that we might apply + // in the future. E.g., we could use binary search to located the pointer + // index. scf::ForOp forOp = builder.create(loc, lo, hi, step, reduc); // In-place update on the reduction variable vector. @@ -542,9 +563,8 @@ Value load = genIndexLoad(builder, loc, ptr, s); coord[tid][dim] = load; if (isSparseSlices[tid]) { - auto enc = getSparseTensorEncoding(tensors[tid].getType()); auto [trans, pred] = - genSliceLegitPredicate(builder, loc, load, enc, dim); + genSliceLegitPredicate(builder, loc, load, tid, dim); slicesPreds.emplace_back(pred, i); // Updates to the relative coordinate to the slice. coord[tid][dim] = trans; @@ -559,7 +579,7 @@ // Generates a list of if statments // pidx = in_slice ? pidx : pidx + 1 // TODO: instead of always picking pidx + 1, we should set pidx = high to - // break to loop the coordinates is larger than the slice size. + // break to loop if the coordinates is larger than the slice size. for (auto [pred, idx] : slicesPreds) { Value nextPidx = builder.create( loc, yields[idx], constantIndex(builder, loc, 1)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -20,18 +20,28 @@ // Helper methods. //===----------------------------------------------------------------------===// -static SmallVector getSpecifierFields(StorageSpecifierType tp) { +static SmallVector getSpecifierFields(StorageSpecifierType tp) { MLIRContext *ctx = tp.getContext(); auto enc = tp.getEncoding(); unsigned rank = enc.getDimLevelType().size(); - SmallVector result; + SmallVector result; auto indexType = tp.getSizesType(); auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank); auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType, getNumDataFieldsFromEncoding(enc)); result.push_back(dimSizes); result.push_back(memSizes); + + if (enc.isSlice()) { + // Extra fields are required for the slice information + auto dimOffset = LLVM::LLVMArrayType::get(ctx, indexType, rank); + auto dimStride = LLVM::LLVMArrayType::get(ctx, indexType, rank); + + result.push_back(dimOffset); + result.push_back(dimStride); + } + return result; } @@ -46,6 +56,8 @@ constexpr uint64_t kDimSizePosInSpecifier = 0; constexpr uint64_t kMemSizePosInSpecifier = 1; +constexpr uint64_t kDimOffsetPosInSpecifier = 2; +constexpr uint64_t kDimStridePosInSpecifier = 3; class SpecifierStructBuilder : public StructBuilder { public: @@ -54,34 +66,63 @@ } // Undef value for dimension sizes, all zero value for memory sizes. - static Value getInitValue(OpBuilder &builder, Location loc, Type structType); + static Value getInitValue(OpBuilder &builder, Location loc, Type structType, + Value source); + + Value dimOffset(OpBuilder &builder, Location loc, unsigned dim) const; + void setDimOffset(OpBuilder &builder, Location loc, unsigned dim, Value size); - Value dimSize(OpBuilder &builder, Location loc, unsigned dim); + Value dimSize(OpBuilder &builder, Location loc, unsigned dim) const; void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size); - Value memSize(OpBuilder &builder, Location loc, unsigned pos); + Value dimStride(OpBuilder &builder, Location loc, unsigned dim) const; + void setDimStride(OpBuilder &builder, Location loc, unsigned dim, Value size); + + Value memSize(OpBuilder &builder, Location loc, unsigned pos) const; void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size); + + Value memSizeArray(OpBuilder &builder, Location loc) const; + void setMemSizeArray(OpBuilder &builder, Location loc, Value array); }; Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc, - Type structType) { + Type structType, Value source) { Value metaData = builder.create(loc, structType); SpecifierStructBuilder md(metaData); - auto memSizeArrayType = structType.cast() - .getBody()[kMemSizePosInSpecifier] - .cast(); + if (!source) { + auto memSizeArrayType = structType.cast() + .getBody()[kMemSizePosInSpecifier] + .cast(); + + Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); + // Fill memSizes array with zero. + for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) + md.setMemSize(builder, loc, i, zero); + } else { + // We copies non-slice information (memory sizes array) from source + SpecifierStructBuilder sourceMd(source); + md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc)); + } + return md; +} - Value zero = constantZero(builder, loc, memSizeArrayType.getElementType()); - // Fill memSizes array with zero. - for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++) - md.setMemSize(builder, loc, i, zero); +/// Builds IR extracting the pos-th offset from the descriptor. +Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, + unsigned dim) const { + return builder.create( + loc, value, ArrayRef({kDimOffsetPosInSpecifier, dim})); +} - return md; +/// Builds IR inserting the pos-th offset into the descriptor. +void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimOffsetPosInSpecifier, dim})); } /// Builds IR inserting the pos-th size into the descriptor. Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc, - unsigned dim) { + unsigned dim) const { return builder.create( loc, value, ArrayRef({kDimSizePosInSpecifier, dim})); } @@ -93,9 +134,23 @@ loc, value, size, ArrayRef({kDimSizePosInSpecifier, dim})); } +/// Builds IR extracting the pos-th stride from the descriptor. +Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc, + unsigned dim) const { + return builder.create( + loc, value, ArrayRef({kDimStridePosInSpecifier, dim})); +} + +/// Builds IR inserting the pos-th stride into the descriptor. +void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc, + unsigned dim, Value size) { + value = builder.create( + loc, value, size, ArrayRef({kDimStridePosInSpecifier, dim})); +} + /// Builds IR extracting the pos-th memory size into the descriptor. Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc, - unsigned pos) { + unsigned pos) const { return builder.create( loc, value, ArrayRef({kMemSizePosInSpecifier, pos})); } @@ -107,6 +162,20 @@ loc, value, size, ArrayRef({kMemSizePosInSpecifier, pos})); } +/// Builds IR extracting the memory size array from the descriptor. +Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder, + Location loc) const { + return builder.create(loc, value, + kMemSizePosInSpecifier); +} + +/// Builds IR inserting the memory size array into the descriptor. +void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc, + Value array) { + value = builder.create(loc, value, array, + kMemSizePosInSpecifier); +} + } // namespace //===----------------------------------------------------------------------===// @@ -132,22 +201,40 @@ matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SpecifierStructBuilder spec(adaptor.getSpecifier()); - Value v; - if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) { - v = Base::onDimSize(rewriter, op, spec, - op.getDim().value().getZExtValue()); - } else { + switch (op.getSpecifierKind()) { + case StorageSpecifierKind::DimOffset: { + Value v = + Base::onDimOffset(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimSize: { + Value v = + Base::onDimSize(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::DimStride: { + Value v = + Base::onDimStride(rewriter, op, spec, (*op.getDim()).getZExtValue()); + rewriter.replaceOp(op, v); + return success(); + } + case StorageSpecifierKind::IdxMemSize: + case StorageSpecifierKind::PtrMemSize: + case StorageSpecifierKind::ValMemSize: { auto enc = op.getSpecifier().getType().getEncoding(); StorageLayout layout(enc); Optional dim = std::nullopt; if (op.getDim()) - dim = op.getDim().value().getZExtValue(); + dim = (*op.getDim()).getZExtValue(); unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim); - v = Base::onMemSize(rewriter, op, spec, idx); + Value v = Base::onMemSize(rewriter, op, spec, idx); + rewriter.replaceOp(op, v); + return success(); } - - rewriter.replaceOp(op, v); - return success(); + } + llvm_unreachable("unrecognized specifer kind"); } }; @@ -155,12 +242,25 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + + static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimOffset(builder, op.getLoc(), d, op.getValue()); + return spec; + } + static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, unsigned d) { spec.setDimSize(builder, op.getLoc(), d, op.getValue()); return spec; } + static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op, + SpecifierStructBuilder &spec, unsigned d) { + spec.setDimStride(builder, op.getLoc(), d, op.getValue()); + return spec; + } + static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op, SpecifierStructBuilder &spec, unsigned i) { spec.setMemSize(builder, op.getLoc(), i, op.getValue()); @@ -172,12 +272,24 @@ : public SpecifierGetterSetterOpConverter { using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter; + + static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, unsigned d) { + return spec.dimOffset(builder, op.getLoc(), d); + } + static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op, - SpecifierStructBuilder &spec, unsigned d) { + const SpecifierStructBuilder &spec, unsigned d) { return spec.dimSize(builder, op.getLoc(), d); } + + static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op, + const SpecifierStructBuilder &spec, unsigned d) { + return spec.dimStride(builder, op.getLoc(), d); + } + static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op, - SpecifierStructBuilder &spec, unsigned i) { + const SpecifierStructBuilder &spec, unsigned i) { return spec.memSize(builder, op.getLoc(), i); } }; @@ -190,8 +302,9 @@ matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type llvmType = getTypeConverter()->convertType(op.getResult().getType()); - rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue( - rewriter, op.getLoc(), llvmType)); + rewriter.replaceOp( + op, SpecifierStructBuilder::getInitValue( + rewriter, op.getLoc(), llvmType, adaptor.getSource())); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -100,6 +100,18 @@ return forOp; } +/// Returns a value from the mixed slice info (either a attribute for statically +/// known value or a dynamic value). Creates a constantOp if needed. +static Value valueFromMixedSliceInfo(OpBuilder &builder, Location loc, + OpFoldResult ofr) { + Value v = ofr.dyn_cast(); + if (!v) { + auto attr = ofr.get().cast(); + v = builder.create(loc, attr.getInt()); + } + return v; +} + /// Gets the dimension size for the given sparse tensor at the given /// original dimension 'dim'. Returns std::nullopt if no sparse encoding is /// attached to the given tensor type. @@ -648,6 +660,23 @@ } }; +template +class SparseSliceGetterOpConver : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Simply lowers to specifer.get xxx. + auto desc = getDescriptorFromTensorTuple(adaptor.getSlice()); + auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, + op.getDim().getZExtValue()); + + rewriter.replaceOp(op, v); + return success(); + } +}; + /// Sparse codegen rule for trivial tensor casts. class SparseCastConverter : public OpConversionPattern { public: @@ -919,9 +948,9 @@ uint64_t dim = op.getDimension().getZExtValue(); Value field = desc.getIdxMemRefOrView(rewriter, loc, dim); - // Insert a cast to bridge the actual type to the user expected type. If the - // actual type and the user expected type aren't compatible, the compiler or - // the runtime will issue an error. + // Insert a cast to bridge the actual type to the user expected type. If + // the actual type and the user expected type aren't compatible, the + // compiler or the runtime will issue an error. Type resType = op.getResult().getType(); if (resType != field.getType()) field = rewriter.create(loc, resType, field); @@ -994,12 +1023,14 @@ LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = op.getContext(); auto srcEnc = getSparseTensorEncoding(op.getSourceType()); auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); if (!srcEnc && !dstEnc) return failure(); - // We should probably check these in ExtractSliceOp::verify. + // FIXME: We should probably check these in ExtractSliceOp::verify. assert(srcEnc && dstEnc && dstEnc.isSlice()); assert(srcEnc.getDimLevelType() == dstEnc.getDimLevelType()); assert(srcEnc.getDimOrdering() == dstEnc.getDimOrdering()); @@ -1007,16 +1038,37 @@ assert(srcEnc.getPointerBitWidth() == dstEnc.getPointerBitWidth()); assert(srcEnc.getIndexBitWidth() == dstEnc.getIndexBitWidth()); - // TODO: support dynamic slices. - for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) { - assert(op.getStaticStrides()[i] == dstEnc.getStaticDimSliceStride(i)); - assert(op.getStaticOffsets()[i] == dstEnc.getStaticDimSliceOffset(i)); - assert(op.getStaticSizes()[i] == dstEnc.getStaticDimSliceSize(i)); + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields); + + auto newSpec = rewriter.create( + loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); + desc.setSpecifier(newSpec); + + // Fills in slice information. + unsigned dim = 0; + for (auto [offset, size, stride] : llvm::zip( + op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { + Value offsetV = valueFromMixedSliceInfo(rewriter, loc, offset); + Value sizeV = valueFromMixedSliceInfo(rewriter, loc, size); + Value strideV = valueFromMixedSliceInfo(rewriter, loc, stride); + // TODO: We could probably only set dynamic value here. But it would + // requires us to fill the hole when casting a static slice to dynamic + // slice. + desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset, + dim, offsetV); + desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimSize, dim, + sizeV); + desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride, + dim, strideV); + dim++; } - // TODO: create a new specifer for slices (need to encode slice metadata). - // It does not matter now because only constant offset/stride are allowed. - rewriter.replaceOp(op, adaptor.getSource()); + // NOTE: we can not generates tuples directly from descriptor here, as the + // descriptor is holding the original type, yet we want the slice type + // here (they shared every memref but with an updated specifier). + rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(), + desc.getFields())); return success(); } }; @@ -1051,11 +1103,15 @@ SparseCastConverter, SparseTensorDeallocConverter, SparseExtractSliceCoverter, SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, - SparseInsertConverter, SparseToPointersConverter, - SparseToIndicesConverter, SparseToIndicesBufferConverter, - SparseToValuesConverter, SparseConvertConverter, - SparseNumberOfEntriesConverter>(typeConverter, - patterns.getContext()); + SparseInsertConverter, + SparseSliceGetterOpConver, + SparseSliceGetterOpConver, + SparseToPointersConverter, SparseToIndicesConverter, + SparseToIndicesBufferConverter, SparseToValuesConverter, + SparseConvertConverter, SparseNumberOfEntriesConverter>( + typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext(), enableBufferInitialization); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -241,6 +241,8 @@ /// Getters: get the value for required field. /// + Value getSpecifier() const { return fields.back(); } + Value getSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim) const { @@ -337,6 +339,8 @@ fields[fidx] = v; } + void setSpecifier(Value newSpec) { fields.back() = newSpec; } + void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional dim, Value v) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir b/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s --sparse-tensor-codegen --cse | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ] +}> + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (0, 4, 1), (0, 8, 1) ] +}> + +// +// CHECK-LABEL: func.func @sparse_slice( +// CHECK-SAME: %[[VAL_0:.*0]]: memref, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier +// CHECK-DAG: %[[VAL_4:.*]] = sparse_tensor.storage_specifier.init with %[[VAL_3]] : !sparse_tensor.storage_specifier +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_8:.*]] = arith.index_cast %[[VAL_5]] : index to i64 +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.set %[[VAL_4]] dim_offset at 0 with %[[VAL_8]] : i64, !sparse_tensor.storage_specifier +// CHECK-DAG: %[[VAL_10:.*]] = arith.index_cast %[[VAL_6]] : index to i64 +// CHECK-DAG: %[[VAL_11:.*]] = sparse_tensor.storage_specifier.set %[[VAL_9]] dim_sz at 0 with %[[VAL_10]] : i64, !sparse_tensor.storage_specifier +// CHECK-DAG: %[[VAL_12:.*]] = arith.index_cast %[[VAL_7]] : index to i64 +// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.set %[[VAL_11]] dim_stride at 0 with %[[VAL_12]] : i64, !sparse_tensor.storage_specifier +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] dim_offset at 1 with %[[VAL_8]] : i64, !sparse_tensor.storage_specifier +// CHECK-DAG: %[[VAL_16:.*]] = arith.index_cast %[[VAL_14]] : index to i64 +// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_15]] dim_sz at 1 with %[[VAL_16]] : i64, !sparse_tensor.storage_specifier +// CHECK-DAG: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] dim_stride at 1 with %[[VAL_12]] : i64, !sparse_tensor.storage_specifier +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_18]] : memref, memref, memref, !sparse_tensor.storage_specifier +// +func.func @sparse_slice(%t1 : tensor<8x8xf64, #CSR>) -> tensor<4x8xf64, #CSR_SLICE> { + %a1 = tensor.extract_slice %t1[0, 0][4, 8][1, 1] : tensor<8x8xf64, #CSR> to + tensor<4x8xf64, #CSR_SLICE> + return %a1 : tensor<4x8xf64, #CSR_SLICE> +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_slice_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_slice_foreach.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_slice_foreach.mlir @@ -0,0 +1,113 @@ +// RUN: mlir-opt %s --post-sparsification-rewrite -allow-unregistered-dialect --canonicalize --cse | FileCheck %s + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ (0, 4, 1), (2, 4, 1) ] +}> + +#CSR_SLICE_DYN = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ (?, ?, ?), (?, ?, ?) ] +}> + + +// CHECK-LABEL: func.func @foreach_print_slice_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] { +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index +// CHECK: %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index +// CHECK: %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index +// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index +// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index +// CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index +// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 +// CHECK: %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1 +// CHECK: scf.if %[[VAL_25]] { +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref +// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] { +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref +// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index +// CHECK: %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index +// CHECK: %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index +// CHECK: %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index +// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index +// CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index +// CHECK: %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1 +// CHECK: %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1 +// CHECK: scf.if %[[VAL_38]] { +// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref +// CHECK: "use"(%[[VAL_39]]) : (f64) -> () +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return +// +func.func @foreach_print_slice_dyn(%A: tensor) { + sparse_tensor.foreach in %A : tensor do { + ^bb0(%1: index, %2: index, %v: f64) : + "use" (%v) : (f64) -> () + } + return +} + +// CHECK-LABEL: func.func @foreach_print_slice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf64, +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 0 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 0 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] { +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index +// CHECK: scf.if %[[VAL_14]] { +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref +// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] { +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index +// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index +// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index +// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 +// CHECK: scf.if %[[VAL_23]] { +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: "use"(%[[VAL_24]]) : (f64) -> () +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return +// +func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) { + sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do { + ^bb0(%1: index, %2: index, %v: f64) : + "use" (%v) : (f64) -> () + } + return +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir @@ -18,6 +18,11 @@ slice = [ (1, 4, 1), (1, 4, 2) ] }> +#CSR_SLICE_DYN = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (?, ?, ?), (?, ?, ?) ] +}> + module { func.func @foreach_print_non_slice(%A: tensor<4x4xf64, #CSR>) { sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR> do { @@ -39,8 +44,22 @@ return } + func.func @foreach_print_slice_dyn(%A: tensor) { + sparse_tensor.foreach in %A : tensor do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + func.func @entry() { %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %sa = arith.constant dense<[ [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ], [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ], @@ -73,6 +92,25 @@ // call @foreach_print_slice(%a) : (tensor<4x4xf64, #CSR_SLICE>) -> () + // The same slice, but with dynamic encoding. + %a_dyn = tensor.extract_slice %tmp[%c1, %c1][%c4, %c4][%c1, %c2] : + tensor<8x8xf64, #CSR> to tensor + // + // CHECK: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 2.3 + // CHECK-NEXT: 2 + // CHECK-NEXT: 3 + // CHECK-NEXT: 1 + // CHECK-NEXT: 3 + // CHECK-NEXT: 0 + // CHECK-NEXT: 2.1 + // CHECK-NEXT: 3 + // CHECK-NEXT: 2 + // CHECK-NEXT: 6.1 + // + call @foreach_print_slice_dyn(%a_dyn) : (tensor) -> () + %dense = tensor.extract_slice %sa[1, 1][4, 4][1, 2] : tensor<8x8xf64> to tensor<4x4xf64> %b = sparse_tensor.convert %dense : tensor<4x4xf64> to tensor<4x4xf64, #CSR> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir @@ -38,10 +38,32 @@ slice = [ (0, 4, 2), (1, 4, 1) ] }> +#CSR_SLICE_dyn = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (?, 4, ?), (?, 4, ?) ] +}> + +#DCSR_SLICE_dyn = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ (?, 4, ?), (?, 4, ?) ] +}> + + module { func.func private @printMemrefF64(%ptr : tensor<*xf64>) func.func private @printMemref1dF64(%ptr : memref) attributes { llvm.emit_c_interface } + // + // Computes C = A x B with all matrices dynamic sparse slice (SpMSpM) in CSR and DCSR + // + func.func @matmul_dyn(%A: tensor<4x4xf64, #CSR_SLICE_dyn>, + %B: tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR> { + %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> + %D = linalg.matmul + ins(%A, %B: tensor<4x4xf64, #CSR_SLICE_dyn>, tensor<4x4xf64, #DCSR_SLICE_dyn>) + outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> + return %D: tensor<4x4xf64, #CSR> + } // // Computes C = A x B with all matrices sparse slice (SpMSpM) in CSR and DCSR @@ -80,7 +102,9 @@ // Main driver. // func.func @entry() { - %c0 = arith.constant 0 : index + %c_0 = arith.constant 0 : index + %c_1 = arith.constant 1 : index + %c_2 = arith.constant 2 : index %f0 = arith.constant 0.0 : f64 %sa = arith.constant dense<[ @@ -155,11 +179,27 @@ %4 = call @matmul1(%s2, %s1) : (tensor<4x4xf64, #CSR_SLICE_1>, tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR> - %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64> %c4u = tensor.cast %c4 : tensor<4x4xf64> to tensor<*xf64> call @printMemrefF64(%c4u) : (tensor<*xf64>) -> () + // slice x slice (same as above, but with dynamic stride information) + // + // CHECK: [2.3, 0, 0, 0], + // CHECK-NEXT: [6.9, 0, 0, 0], + // CHECK-NEXT: [0, 0, 0, 0], + // CHECK-NEXT: [12.6, 0, 0, 0]] + // + %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn> + %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn> + %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn) + : (tensor<4x4xf64, #CSR_SLICE_dyn>, + tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR> + + %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64> + %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64> + call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> () + // sparse slices should generate the same result as dense slices // // CHECK: [2.3, 0, 0, 0],