diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -570,7 +570,7 @@ /// We normalized sparse tensor encoding attribute by always using /// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well /// as other variants) lead to the same storage specifier type, and stripping -/// irrelevant fields that does not alter the sparse tensor memory layout. +/// irrelevant fields that do not alter the sparse tensor memory layout. static SparseTensorEncodingAttr getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { SmallVector dlts; @@ -582,13 +582,10 @@ AffineMap(), // dimOrdering (irrelavant to storage speicifer) AffineMap(), // highLvlOrdering (irrelavant to storage specifer) // Always use `index` for memSize and lvlSize instead of reusing - // `getPosWidth`/`getCrdWidth`. - // It allows us to reuse the same SSA value for different bitwidth, - // It also avoids casting between index/integer (returned by DimOp) - 0, 0, - // FIXME: we should keep the slice information, for now it is okay as only - // constant can be used for slice - ArrayRef{} /*enc.getDimSlices()*/); + // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA + // value for different bitwidth, it also avoids casting between index and + // integer (returned by DimOp) + 0, 0, enc.getDimSlices()); } StorageSpecifierType @@ -620,11 +617,10 @@ const auto enc = md.getType().getEncoding(); const Level lvlRank = enc.getLvlRank(); - // TODO: - // if (mdKind == StorageSpecifierKind::DimOffset || - // mdKind == StorageSpecifierKind::DimStride) - // if (!enc.isSlice()) - // return op->emitError("requested slice data on non-slice tensor"); + if (mdKind == StorageSpecifierKind::DimOffset || + mdKind == StorageSpecifierKind::DimStride) + if (!enc.isSlice()) + return op->emitError("requested slice data on non-slice tensor"); if (mdKind != StorageSpecifierKind::ValMemSize) { if (!lvl) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -364,6 +364,15 @@ /// Generates code to retrieve the values size for the sparse tensor. Value genValMemSize(OpBuilder &builder, Location loc, Value tensor); +/// Generates code to retrieve the slice offset for the sparse tensor slice, +/// return a constant if the offset is statically known. +Value createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, Value tensor, + unsigned dim); + +/// Generates code to retrieve the slice slice for the sparse tensor slice, +/// return a constant if the offset is statically known. +Value createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, Value tensor, + unsigned dim); } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -694,3 +694,23 @@ Value tensor) { return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc); } + +Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc, + Value tensor, unsigned dim) { + auto enc = getSparseTensorEncoding(tensor.getType()); + assert(enc && enc.isSlice()); + std::optional offset = enc.getStaticDimSliceOffset(dim); + if (enc.getStaticDimSliceOffset(dim)) + return constantIndex(builder, loc, *offset); + return builder.create(loc, tensor, APInt(64, dim)); +} + +Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc, + Value tensor, unsigned dim) { + auto enc = getSparseTensorEncoding(tensor.getType()); + assert(enc && enc.isSlice()); + std::optional offset = enc.getStaticDimSliceStride(dim); + if (enc.getStaticDimSliceStride(dim)) + return constantIndex(builder, loc, *offset); + return builder.create(loc, tensor, APInt(64, dim)); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -202,6 +202,13 @@ Value genSparseCrd(OpBuilder &builder, Location loc, size_t tid, size_t dstLvl); + /// Generates a predicate to determine whether the tranformed coordinates is + /// on the given slice. + /// Returns std::pair + std::pair genSliceLegitPredicate(OpBuilder &builder, + Location loc, Value coord, + unsigned tid, unsigned lvl); + bool isOutputTensor(size_t tid) { return hasOutput && tid == tensors.size() - 1; } @@ -278,6 +285,9 @@ /// Whether the sparse input is a slice. std::vector isSparseSlices; + /// Values related to slices. + std::vector> sliceOffsets; + std::vector> sliceStrides; /// Loop Stack, stores the information of all the nested loops that are /// alive. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -43,29 +43,23 @@ return load; } -// TODO: Support dynamic sized slice. -static Value getSliceOffset(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { - return constantIndex(builder, loc, *enc.getStaticLvlSliceOffset(lvl)); +static Value genSliceOffset(OpBuilder &builder, Location loc, Value tensor, + unsigned lvl) { + auto enc = getSparseTensorEncoding(tensor.getType()); + return createOrFoldSliceOffsetOp(builder, loc, tensor, toOrigDim(enc, lvl)); } -static Value getSliceSize(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { - return constantIndex(builder, loc, *enc.getStaticLvlSliceSize(lvl)); -} - -static Value getSliceStride(OpBuilder &builder, Location loc, - SparseTensorEncodingAttr enc, unsigned lvl) { - return constantIndex(builder, loc, *enc.getStaticLvlSliceStride(lvl)); +static Value genSliceStride(OpBuilder &builder, Location loc, Value tensor, + unsigned lvl) { + auto enc = getSparseTensorEncoding(tensor.getType()); + return createOrFoldSliceStrideOp(builder, loc, tensor, toOrigDim(enc, lvl)); } // Converts a coordinate relative to the slice to the coordinate relative // to the underlying tensor. static Value toSliceCoord(OpBuilder &builder, Location loc, Value v, - SparseTensorEncodingAttr enc, unsigned lvl) { - - Value stride = getSliceStride(builder, loc, enc, lvl); - Value offset = getSliceOffset(builder, loc, enc, lvl); + Value offset, Value stride, Value tensor, + unsigned lvl) { // iv = iv * stride + offset v = builder.create(loc, v, stride); v = builder.create(loc, v, offset); @@ -75,39 +69,57 @@ // Converts a coordinate relative to the underlying tensor to the coordinate // relative to the slice, returns a extra reminder value static std::pair fromSliceCrd(OpBuilder &builder, Location loc, - Value v, - SparseTensorEncodingAttr enc, + Value iv, Value offset, + Value stride, Value tensor, unsigned lvl) { - Value stride = getSliceStride(builder, loc, enc, lvl); - Value offset = getSliceOffset(builder, loc, enc, lvl); // iv = (iv - offset) / stride - v = builder.create(loc, v, offset); - Value rem = builder.create(loc, v, stride); - v = builder.create(loc, v, stride); - return std::make_pair(v, rem); + iv = builder.create(loc, iv, offset); + Value rem = builder.create(loc, iv, stride); + iv = builder.create(loc, iv, stride); + return std::make_pair(iv, rem); } -static std::pair -genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd, - SparseTensorEncodingAttr enc, unsigned lvl) { - std::pair trans = fromSliceCrd(builder, loc, crd, enc, lvl); - // First, crd >= offset (TODO: seems unsigned >= 0 won't be folded, skip - // the check if the offset is zero). - auto geOffset = - builder.create(loc, arith::CmpIPredicate::uge, crd, - getSliceOffset(builder, loc, enc, lvl)); +std::pair +LoopEmitter::genSliceLegitPredicate(OpBuilder &builder, Location loc, Value crd, + unsigned tid, unsigned lvl) { + assert(isSparseSlices[tid]); + Value slice = tensors[tid]; + Value offset = sliceOffsets[tid][lvl]; + Value stride = sliceStrides[tid][lvl]; + auto enc = getSparseTensorEncoding(slice.getType()); + + std::pair trans = + fromSliceCrd(builder, loc, crd, offset, stride, slice, lvl); + + SmallVector conds; // at most 3 conditions + + // First, coord >= offset (skip the check if offset is known to be 0). + if (auto staticOffset = enc.getStaticLvlSliceOffset(lvl); + !(staticOffset.has_value() && *staticOffset == 0)) { + auto geOffset = builder.create( + loc, arith::CmpIPredicate::uge, crd, offset); + conds.push_back(geOffset); + } + // Second, coord_in_slice < length - auto ltLength = - builder.create(loc, arith::CmpIPredicate::ult, trans.first, - getSliceSize(builder, loc, enc, lvl)); + auto ltLength = builder.create( + loc, arith::CmpIPredicate::ult, trans.first, lvlSizes[tid][lvl]); + conds.push_back(ltLength); + + // Third, rem == 0 (skip the check if stride is known to be 1). + if (auto staticStride = enc.getStaticLvlSliceStride(lvl); + !(staticStride.has_value() && *staticStride == 1)) { + auto fitStride = builder.create( + loc, arith::CmpIPredicate::eq, trans.second, + constantIndex(builder, loc, 0)); + conds.push_back(fitStride); + } - // Third, rem == 0; confirmed that (a % 1) will be folded to 0 - auto fitStride = - builder.create(loc, arith::CmpIPredicate::eq, trans.second, - constantIndex(builder, loc, 0)); + // Must meet all condition to be a valid coordinate in slice. + auto pred = conds.front(); + for (unsigned i = 1, e = conds.size(); i < e; i++) + pred = builder.create(loc, pred, conds[i]); - auto pred = builder.create(loc, geOffset, ltLength); - pred = builder.create(loc, pred, fitStride); return {trans.first, pred}; } @@ -119,10 +131,9 @@ size_t dim, Value iv) { Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1]; Value mul = builder.create(loc, highs[tid][dim], p); - if (isSparseSlices[tid]) { - auto enc = getSparseTensorEncoding(tensors[tid].getType()); - iv = toSliceCoord(builder, loc, iv, enc, dim); - } + if (isSparseSlices[tid]) + iv = toSliceCoord(builder, loc, iv, sliceOffsets[tid][dim], + sliceStrides[tid][dim], tensors[tid], dim); Value add = builder.create(loc, mul, iv); return add; } @@ -204,6 +215,8 @@ this->isSparseOut = isSparseOut; this->tensors.assign(ts.begin(), ts.end()); this->isSparseSlices.assign(tensors.size(), false); + this->sliceOffsets.assign(tensors.size(), std::vector()); + this->sliceStrides.assign(tensors.size(), std::vector()); this->dimTypes.assign(tensors.size(), std::vector()); this->pidxs.assign(tensors.size(), std::vector()); this->segHi.assign(tensors.size(), std::vector()); @@ -246,6 +259,8 @@ dimTypes[tid].assign(rank, DimLevelType::Dense); // Initialize using empty value. + sliceOffsets[tid].assign(rank, Value()); + sliceStrides[tid].assign(rank, Value()); pidxs[tid].assign(rank, Value()); segHi[tid].assign(rank, Value()); coord[tid].assign(rank, Value()); @@ -300,11 +315,15 @@ assert(isDenseDLT(dlt)); } - // Find upper bound in current dimension. // FIXME: `toOrigDim` is deprecated const Dimension d = toOrigDim(enc, l); - lvlSizes[t][l] = highs[t][l] = - mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d); + Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, d); + // Find upper bound in current dimension. + highs[t][l] = lvlSizes[t][l] = up; + if (isSparseSlices[t]) { + sliceOffsets[t][l] = genSliceOffset(builder, loc, tensors[t], l); + sliceStrides[t][l] = genSliceStride(builder, loc, tensors[t], l); + } } // Perform the required bufferization. Dense inputs materialize @@ -405,7 +424,6 @@ isSparseInput = isSparseInput || isSparse; } - auto enc = getSparseTensorEncoding(tensors[tid].getType()); const auto reassoc = getCollapseReassociation(tid, dim); // TODO: support dynamic slices. // Uses the first dimension here to build the loop bound (which is also the @@ -468,7 +486,7 @@ for (Value red : reduc) types.push_back(red.getType()); - auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, enc, dim); + auto [trans, pred] = genSliceLegitPredicate(builder, loc, crd, tid, dim); bool hasReduc = !types.empty(); scf::IfOp ifOp = builder.create(loc, types, pred, /*else*/ hasReduc); @@ -660,11 +678,8 @@ isSingletonDLT(dimTypes[tid][dim])) { coord[tid][dim] = genSparseCrd(builder, loc, tid, dim); if (isSparseSlices[tid]) { - Value load = - genIndexLoad(builder, loc, crdBuffer[tid][dim], pidxs[tid][dim]); - auto enc = getSparseTensorEncoding(tensors[tid].getType()); auto [trans, pred] = - genSliceLegitPredicate(builder, loc, load, enc, dim); + genSliceLegitPredicate(builder, loc, coord[tid][dim], tid, dim); slicesPreds.emplace_back(pred, i); // Updates to the relative coordinate to the slice. coord[tid][dim] = trans; @@ -679,7 +694,7 @@ // Generates a list of if statments // pidx = in_slice ? pidx : pidx + 1 // TODO: instead of always picking pidx + 1, we should set pidx = high to - // break to loop the coordinates is larger than the slice size. + // break to loop if the coordinates is larger than the slice size. for (auto [pred, idx] : slicesPreds) { Value nextPidx = builder.create( loc, yields[idx], constantIndex(builder, loc, 1)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseStorageSpecifierToLLVM.cpp @@ -130,17 +130,18 @@ /// Builds IR extracting the pos-th offset from the descriptor. Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc, Dimension dim) const { - return builder.create( - loc, value, - ArrayRef({kDimOffsetPosInSpecifier, static_cast(dim)})); + return extractField( + builder, loc, + ArrayRef{kDimOffsetPosInSpecifier, static_cast(dim)}); } /// Builds IR inserting the pos-th offset into the descriptor. void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc, Dimension dim, Value size) { - value = builder.create( - loc, value, size, - ArrayRef({kDimOffsetPosInSpecifier, static_cast(dim)})); + insertField( + builder, loc, + ArrayRef{kDimOffsetPosInSpecifier, static_cast(dim)}, + size); } /// Builds IR extracting the `lvl`-th level-size from the descriptor. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -97,6 +97,17 @@ return forOp; } +/// Returns a value from the mixed slice info (either an attribute for +/// statically known value or a dynamic value). Creates a constantOp if needed. +static Value valueFromMixedSliceInfo(OpBuilder &builder, Location loc, + OpFoldResult ofr) { + if (Value v = ofr.dyn_cast()) + return v; + + auto attr = ofr.get().cast(); + return builder.create(loc, attr.getInt()); +} + /// Gets the dimension size for the given sparse tensor at the given /// original dimension 'dim'. static Value sizeFromTensorAtDim(OpBuilder &builder, Location loc, @@ -697,6 +708,23 @@ } }; +template +class SparseSliceGetterOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Simply lowers to specifer.get operation. + auto desc = getDescriptorFromTensorTuple(adaptor.getSlice()); + auto v = desc.getSpecifierField(rewriter, op.getLoc(), kind, + op.getDim().getZExtValue()); + + rewriter.replaceOp(op, v); + return success(); + } +}; + /// Sparse codegen rule for trivial tensor casts. class SparseCastConverter : public OpConversionPattern { public: @@ -1099,13 +1127,15 @@ } }; -class SparseExtractSliceCoverter +class SparseExtractSliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = op.getContext(); auto srcEnc = getSparseTensorEncoding(op.getSourceType()); auto dstEnc = getSparseTensorEncoding(op.getResult().getType()); if (!srcEnc && !dstEnc) @@ -1119,16 +1149,42 @@ assert(srcEnc.getPosWidth() == dstEnc.getPosWidth()); assert(srcEnc.getCrdWidth() == dstEnc.getCrdWidth()); - // TODO: support dynamic slices. - for (int i = 0, e = op.getSourceType().getRank(); i < e; i++) { - assert(op.getStaticStrides()[i] == dstEnc.getStaticDimSliceStride(i)); - assert(op.getStaticOffsets()[i] == dstEnc.getStaticDimSliceOffset(i)); - assert(op.getStaticSizes()[i] == dstEnc.getStaticDimSliceSize(i)); + SmallVector fields; + auto desc = getMutDescriptorFromTensorTuple(adaptor.getSource(), fields); + + auto newSpec = rewriter.create( + loc, StorageSpecifierType::get(ctx, dstEnc), desc.getSpecifier()); + desc.setSpecifier(newSpec); + + // Fills in slice information. + unsigned dim = 0; + for (auto [offset, size, stride] : llvm::zip( + op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) { + Value offsetV = valueFromMixedSliceInfo(rewriter, loc, offset); + Value sizeV = valueFromMixedSliceInfo(rewriter, loc, size); + Value strideV = valueFromMixedSliceInfo(rewriter, loc, stride); + // TODO: We could probably only set dynamic value here. But it would + // requires us to fill the hole when casting a static slice to dynamic + // slice. + desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimOffset, + dim, offsetV); + + // FIXME: we need to distinguish level sizes and dimension size for slices + // here. Maybe we should store slice level sizes in a different array + // instead of reusing it. + assert(srcEnc.hasIdDimOrdering()); + desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::LvlSize, dim, + sizeV); + desc.setSpecifierField(rewriter, loc, StorageSpecifierKind::DimStride, + dim, strideV); + dim++; } - // TODO: create a new specifer for slices (need to encode slice metadata). - // It does not matter now because only constant offset/stride are allowed. - rewriter.replaceOp(op, adaptor.getSource()); + // NOTE: we can not generate tuples directly from descriptor here, as the + // descriptor is holding the original type, yet we want the slice type + // here (they shared every memref but with an updated specifier). + rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(), + desc.getFields())); return success(); } }; @@ -1449,13 +1505,18 @@ patterns.add( - typeConverter, patterns.getContext()); + SparseInsertConverter, + SparseSliceGetterOpConverter, + SparseSliceGetterOpConverter, + SparseToPositionsConverter, SparseToCoordinatesConverter, + SparseToCoordinatesBufferConverter, SparseToValuesConverter, + SparseConvertConverter, SparseNewOpConverter, + SparseNumberOfEntriesConverter>(typeConverter, + patterns.getContext()); patterns.add(typeConverter, patterns.getContext(), enableBufferInitialization); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.h @@ -403,6 +403,9 @@ fields[fidx] = v; } + void setSpecifier(Value newSpec) { fields.back() = newSpec; } + + // FIXME: see note [CLARIFY_DIM_LVL]. void setSpecifierField(OpBuilder &builder, Location loc, StorageSpecifierKind kind, std::optional lvl, Value v) { diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -259,16 +259,16 @@ return %0 : index } -//// ----- -// -//#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> -// -//func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { -// // _e_xpected-error@+1 {{requested slice data on non-slice tensor}} -// %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 -// : !sparse_tensor.storage_specifier<#SparseVector> to i64 -// return %0 : i64 -//} +// ----- + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { + // expected-error@+1 {{requested slice data on non-slice tensor}} + %0 = sparse_tensor.storage_specifier.get %arg0 dim_offset at 0 + : !sparse_tensor.storage_specifier<#SparseVector> + return %0 : index +} // ----- diff --git a/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir b/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_extract_slice.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s --sparse-tensor-codegen --cse | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ] +}> + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (0, 4, 1), (0, 8, 1) ] +}> + +// CHECK-LABEL: func.func @sparse_slice( +// CHECK-SAME: %[[VAL_0:.*0]]: memref, +// CHECK-SAME: %[[VAL_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_3:.*3]]: !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) +// CHECK: %[[VAL_4:.*]] = sparse_tensor.storage_specifier.init with %[[VAL_3]] +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = sparse_tensor.storage_specifier.set %[[VAL_4]] dim_offset at 0 with %[[VAL_5]] +// CHECK: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.set %[[VAL_8]] lvl_sz at 0 with %[[VAL_6]] +// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.set %[[VAL_9]] dim_stride at 0 with %[[VAL_7]] +// CHECK: %[[VAL_11:.*]] = arith.constant 8 : index +// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] dim_offset at 1 with %[[VAL_5]] +// CHECK: %[[VAL_13:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] lvl_sz at 1 with %[[VAL_11]] +// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_13]] dim_stride at 1 with %[[VAL_7]] +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_14]] +func.func @sparse_slice(%t1 : tensor<8x8xf64, #CSR>) -> tensor<4x8xf64, #CSR_SLICE> { + %a1 = tensor.extract_slice %t1[0, 0][4, 8][1, 1] : tensor<8x8xf64, #CSR> to + tensor<4x8xf64, #CSR_SLICE> + return %a1 : tensor<4x8xf64, #CSR_SLICE> +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" | FileCheck %s +// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s // CHECK-LABEL: func.func @sparse_foreach_constant // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -27,3 +27,115 @@ } return } + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ (0, 4, 1), (2, 4, 1) ] +}> + +#CSR_SLICE_DYN = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ (?, ?, ?), (?, ?, ?) ] +}> + + +// CHECK-LABEL: func.func @foreach_print_slice_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_3]]{{\[}}%[[VAL_2]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_2]] { +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_6]] : index +// CHECK: %[[VAL_19:.*]] = arith.remui %[[VAL_18]], %[[VAL_7]] : index +// CHECK: %[[VAL_20:.*]] = arith.divui %[[VAL_18]], %[[VAL_7]] : index +// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_17]], %[[VAL_6]] : index +// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_5]] : index +// CHECK: %[[VAL_23:.*]] = arith.cmpi eq, %[[VAL_19]], %[[VAL_1]] : index +// CHECK: %[[VAL_24:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 +// CHECK: %[[VAL_25:.*]] = arith.andi %[[VAL_24]], %[[VAL_23]] : i1 +// CHECK: scf.if %[[VAL_25]] { +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_16]], %[[VAL_2]] : index +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref +// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_2]] { +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_29]]] : memref +// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : index +// CHECK: %[[VAL_32:.*]] = arith.remui %[[VAL_31]], %[[VAL_12]] : index +// CHECK: %[[VAL_33:.*]] = arith.divui %[[VAL_31]], %[[VAL_12]] : index +// CHECK: %[[VAL_34:.*]] = arith.cmpi uge, %[[VAL_30]], %[[VAL_11]] : index +// CHECK: %[[VAL_35:.*]] = arith.cmpi ult, %[[VAL_33]], %[[VAL_10]] : index +// CHECK: %[[VAL_36:.*]] = arith.cmpi eq, %[[VAL_32]], %[[VAL_1]] : index +// CHECK: %[[VAL_37:.*]] = arith.andi %[[VAL_34]], %[[VAL_35]] : i1 +// CHECK: %[[VAL_38:.*]] = arith.andi %[[VAL_37]], %[[VAL_36]] : i1 +// CHECK: scf.if %[[VAL_38]] { +// CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_29]]] : memref +// CHECK: "test.use"(%[[VAL_39]]) : (f64) -> () +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return +// +func.func @foreach_print_slice_dyn(%A: tensor) { + sparse_tensor.foreach in %A : tensor do { + ^bb0(%1: index, %2: index, %v: f64) : + "test.use" (%v) : (f64) -> () + } + return +} + +// CHECK-LABEL: func.func @foreach_print_slice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf64, +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf64, +// CHECK-DAG: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] { +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = arith.cmpi ult, %[[VAL_13]], %[[VAL_1]] : index +// CHECK: scf.if %[[VAL_14]] { +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_12]], %[[VAL_4]] : index +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref +// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_4]] { +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_20:.*]] = arith.subi %[[VAL_19]], %[[VAL_2]] : index +// CHECK: %[[VAL_21:.*]] = arith.cmpi uge, %[[VAL_19]], %[[VAL_2]] : index +// CHECK: %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_1]] : index +// CHECK: %[[VAL_23:.*]] = arith.andi %[[VAL_21]], %[[VAL_22]] : i1 +// CHECK: scf.if %[[VAL_23]] { +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: "test.use"(%[[VAL_24]]) : (f64) -> () +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return +// +func.func @foreach_print_slice(%A: tensor<4x4xf64, #CSR_SLICE>) { + sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR_SLICE> do { + ^bb0(%1: index, %2: index, %v: f64) : + "test.use" (%v) : (f64) -> () + } + return +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_foreach_slices.mlir @@ -2,7 +2,7 @@ // DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ // DEFINE: mlir-cpu-runner \ // DEFINE: -e entry -entry-point-result=void \ -// DEFINE: -shared-libs=%mlir_c_runner_utils | \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ // DEFINE: FileCheck %s // // RUN: %{command} @@ -18,6 +18,12 @@ slice = [ (1, 4, 1), (1, 4, 2) ] }> +#CSR_SLICE_DYN = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (?, ?, ?), (?, ?, ?) ] +}> + + module { func.func @foreach_print_non_slice(%A: tensor<4x4xf64, #CSR>) { sparse_tensor.foreach in %A : tensor<4x4xf64, #CSR> do { @@ -39,8 +45,22 @@ return } + func.func @foreach_print_slice_dyn(%A: tensor) { + sparse_tensor.foreach in %A : tensor do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + func.func @entry() { %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %sa = arith.constant dense<[ [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ], [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ], @@ -52,6 +72,7 @@ [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ] ]> : tensor<8x8xf64> + %tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR> %a = tensor.extract_slice %tmp[1, 1][4, 4][1, 2] : tensor<8x8xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE> @@ -72,7 +93,7 @@ %dense = tensor.extract_slice %sa[1, 1][4, 4][1, 2] : tensor<8x8xf64> to tensor<4x4xf64> %b = sparse_tensor.convert %dense : tensor<4x4xf64> to tensor<4x4xf64, #CSR> - // Foreach on sparse tensor instead of slice should yield the same result. + // Foreach on sparse tensor instead of slice they should yield the same result. // // CHECK-NEXT: 1 // CHECK-NEXT: 0 @@ -86,8 +107,28 @@ // call @foreach_print_non_slice(%b) : (tensor<4x4xf64, #CSR>) -> () - bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR> + // The same slice, but with dynamic encoding. + // TODO: Investigates why reusing the same %tmp above would cause bufferization + // errors. + %tmp1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR> + %a_dyn = tensor.extract_slice %tmp1[%c1, %c1][%c4, %c4][%c1, %c2] : + tensor<8x8xf64, #CSR> to tensor + // + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 2.3 + // CHECK-NEXT: 2 + // CHECK-NEXT: 3 + // CHECK-NEXT: 1 + // CHECK-NEXT: 3 + // CHECK-NEXT: 2 + // CHECK-NEXT: 2.1 + // + call @foreach_print_slice_dyn(%a_dyn) : (tensor) -> () + bufferization.dealloc_tensor %tmp : tensor<8x8xf64, #CSR> + bufferization.dealloc_tensor %tmp1 : tensor<8x8xf64, #CSR> + bufferization.dealloc_tensor %b : tensor<4x4xf64, #CSR> return } } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir @@ -38,10 +38,32 @@ slice = [ (0, 4, 2), (1, 4, 1) ] }> +#CSR_SLICE_dyn = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (?, 4, ?), (?, 4, ?) ] +}> + +#DCSR_SLICE_dyn = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ (?, 4, ?), (?, 4, ?) ] +}> + + module { func.func private @printMemrefF64(%ptr : tensor<*xf64>) func.func private @printMemref1dF64(%ptr : memref) attributes { llvm.emit_c_interface } + // + // Computes C = A x B with all matrices dynamic sparse slice (SpMSpM) in CSR and DCSR + // + func.func @matmul_dyn(%A: tensor<4x4xf64, #CSR_SLICE_dyn>, + %B: tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR> { + %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> + %D = linalg.matmul + ins(%A, %B: tensor<4x4xf64, #CSR_SLICE_dyn>, tensor<4x4xf64, #DCSR_SLICE_dyn>) + outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> + return %D: tensor<4x4xf64, #CSR> + } // // Computes C = A x B with one matrix CSR sparse slices and the other DSCR sparse slice. @@ -83,7 +105,9 @@ // Main driver. // func.func @entry() { - %c0 = arith.constant 0 : index + %c_0 = arith.constant 0 : index + %c_1 = arith.constant 1 : index + %c_2 = arith.constant 2 : index %f0 = arith.constant 0.0 : f64 %sa = arith.constant dense<[ @@ -158,11 +182,27 @@ %4 = call @matmul1(%s2, %s1) : (tensor<4x4xf64, #CSR_SLICE_1>, tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR> - %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64> %c4u = tensor.cast %c4 : tensor<4x4xf64> to tensor<*xf64> call @printMemrefF64(%c4u) : (tensor<*xf64>) -> () + // slice x slice (same as above, but with dynamic stride information) + // + // CHECK: [2.3, 0, 0, 0], + // CHECK-NEXT: [6.9, 0, 0, 0], + // CHECK-NEXT: [0, 0, 0, 0], + // CHECK-NEXT: [12.6, 0, 0, 0]] + // + %s1_dyn = tensor.extract_slice %tmp[%c_0, %c_1][4, 4][%c_2, %c_1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_dyn> + %s2_dyn = tensor.extract_slice %b1[%c_0, %c_0][4, 4][%c_2, %c_1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_dyn> + %dyn_4 = call @matmul_dyn(%s2_dyn, %s1_dyn) + : (tensor<4x4xf64, #CSR_SLICE_dyn>, + tensor<4x4xf64, #DCSR_SLICE_dyn>) -> tensor<4x4xf64, #CSR> + + %c4_dyn = sparse_tensor.convert %dyn_4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64> + %c4u_dyn = tensor.cast %c4_dyn : tensor<4x4xf64> to tensor<*xf64> + call @printMemrefF64(%c4u_dyn) : (tensor<*xf64>) -> () + // sparse slices should generate the same result as dense slices // // CHECK: [2.3, 0, 0, 0], @@ -179,7 +219,7 @@ %du = tensor.cast %r : tensor<4x4xf64> to tensor<*xf64> call @printMemrefF64(%du) : (tensor<*xf64>) -> () - // Releases resources. + // Releases resources (we do not need to deallocate slices). bufferization.dealloc_tensor %b1 : tensor<8x4xf64, #CSR> bufferization.dealloc_tensor %t1 : tensor<8x8xf64, #CSR> bufferization.dealloc_tensor %b : tensor<8x4xf64, #DCSR> @@ -187,6 +227,7 @@ bufferization.dealloc_tensor %4 : tensor<4x4xf64, #CSR> bufferization.dealloc_tensor %3 : tensor<4x4xf64, #CSR> bufferization.dealloc_tensor %2 : tensor<4x4xf64, #DCSR> + bufferization.dealloc_tensor %dyn_4 : tensor<4x4xf64, #CSR> return }