diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -101,6 +101,10 @@ /// Returns null-attribute for any type without an encoding. SparseTensorEncodingAttr getSparseTensorEncoding(Type type); +/// Returns true iff the given sparse tensor encoding attribute has a trailing +/// COO region starting at the given level. +bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique); + /// Returns true iff the given type is a COO type where the last level /// is unique. bool isUniqueCOOType(Type tp); diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -54,8 +54,9 @@ } def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>, - Arguments<(ins 1DTensorOf<[AnyType]>:$values, - 2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates)>, + Arguments<(ins TensorOf<[AnyType]>:$values, + TensorOf<[AnySignlessIntegerOrIndex]>:$coordinates, + OptionalAttr:$batched_lvls)>, Results<(outs AnySparseTensor: $result)> { let summary = "Returns a sparse tensor from the given (values, coordinates) pair"; @@ -77,6 +78,8 @@ supplies the level-coords for each element in `values`. - `values : tensor` supplies the corresponding values for each entry in `coordinates`. + - `batched_lvls : optional` + supplies the number of leading levels that are batched. This operation can be used to materialize a sparse tensor from external sources; e.g., when passing two numpy arrays from Python. @@ -92,10 +95,29 @@ // of 3x4 matrix |0.0, 0.0, 2.2, 3.3| // |0.0, 0.0, 0.0, 0.0| ``` + + If `batched_lvls` is provided, the operations materializes a batched sparse tensor. + Example: + + ```mlir + %values = arith.constant dense<[[ 1.1, 2.2, 3.3 ], + [ 1.2, 2.3, 0.0 ]]> : tensor<2x3xf64> + %coordinates = arith.constant dense<[[ [0], [1], [2] ], + [ [1], [2], [3] ]> : tensor<2x3x1xindex> + %st = sparse_tensor.pack %values, %coordinates batched_lvls=1 + : tensor<2x3xf64>, tensor<2x3x1xindex> to tensor<2x4xf64, #BCOO> + // yields BCOO format |1.1, 2.2, 3.3, 0.0| + // of 2x4 matrix |0.0, 1.2, 2.3, 0.0| + ``` + }]; + + let extraClassDeclaration = [{ + /// Returns the number of leading levels that are batched. + unsigned getNumBatchedLvls(); }]; let assemblyFormat = - "$values `,` $coordinates attr-dict" + "$values `,` $coordinates (`batched_lvls` `=` $batched_lvls^)? attr-dict" "`:` type($values) `,` type($coordinates) `to` type($result)"; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -451,9 +451,10 @@ /// Returns true iff the given sparse tensor encoding attribute has a trailing /// COO region starting at the given level. -static bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, - bool isUnique) { - if (!enc || !enc.isCompressedLvl(startLvl)) +bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc, + Level startLvl, bool isUnique) { + if (!enc || + !(enc.isCompressedLvl(startLvl) || enc.isCompressedWithHiLvl(startLvl))) return false; const Level lvlRank = enc.getLvlRank(); for (Level l = startLvl + 1; l < lvlRank; ++l) @@ -647,43 +648,55 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, SparseTensorType tensorTp, RankedTensorType valuesTp, - RankedTensorType coordinatesTp) { + RankedTensorType coordinatesTp, + IntegerAttr batchedLvls) { + unsigned nBatched = batchedLvls ? batchedLvls.getValue().getZExtValue() : 0; if (requiresStaticShape && !tensorTp.hasStaticDimShape()) return op->emitError("the sparse-tensor must have static shape"); if (!tensorTp.hasEncoding()) return op->emitError("the sparse-tensor must have an encoding attribute"); if (!tensorTp.isIdentity()) return op->emitError("the sparse-tensor must have the identity mapping"); - if (!isUniqueCOOType(tensorTp)) + if (!isCOOType(tensorTp.getEncoding(), nBatched, true)) return op->emitError("the sparse-tensor must have a COO type"); - if (coordinatesTp.getRank() != 2) - return op->emitError("coordinates must have rank 2"); + if (coordinatesTp.getRank() != 2 + nBatched) + return op->emitError("coordinates must have rank 2 + batched_lvls"); if (requiresStaticShape && !coordinatesTp.hasStaticShape()) return op->emitError("coordinates must have static shape"); if (coordinatesTp.getElementType() != tensorTp.getCrdType()) return op->emitError("input/output coordinate-types don't match"); - if (valuesTp.getRank() != 1) - return op->emitError("values must have rank 1"); + if (valuesTp.getRank() != 1 + nBatched) + return op->emitError("values must have rank 1 + batched_lvls"); if (requiresStaticShape && !valuesTp.hasStaticShape()) return op->emitError("values must have static shape"); if (valuesTp.getElementType() != tensorTp.getElementType()) return op->emitError("input/output element-types don't match"); - const auto valuesNSE = valuesTp.getShape()[0]; - const auto coordsNSE = coordinatesTp.getShape()[0]; + for (unsigned i = 0; i < nBatched; i++) { + const auto valBatch = valuesTp.getShape()[i]; + const auto crdBatch = coordinatesTp.getShape()[i]; + if (ShapedType::isDynamic(valBatch) || ShapedType::isDynamic(crdBatch) || + crdBatch != valBatch) { + return op->emitError( + "values/coordinates batched level sizes don't match statically"); + } + } + + const auto valuesNSE = valuesTp.getShape()[nBatched]; + const auto coordsNSE = coordinatesTp.getShape()[nBatched]; if (!ShapedType::isDynamic(valuesNSE) && !ShapedType::isDynamic(coordsNSE) && valuesNSE != coordsNSE) return op->emitError("values/coordinates number-of-elements don't match"); // NOTE: We use `getLvlRank` because the `coordinatesTp` is for // level-coordinates (cf., the op documentation). - const DynSize coordsRank = coordinatesTp.getShape()[1]; + const DynSize coordsRank = coordinatesTp.getShape()[1 + nBatched]; const Level tensorRank = tensorTp.getLvlRank(); // FIXME: replace the `operator!=` with our backported `safelyNE`. if (!ShapedType::isDynamic(coordsRank) && - coordsRank != static_cast(tensorRank)) + coordsRank != static_cast(tensorRank) - nBatched) return op->emitError("input/output level-ranks don't match"); return success(); @@ -693,14 +706,20 @@ const auto valuesTp = getRankedTensorType(getValues()); const auto coordinatesTp = getRankedTensorType(getCoordinates()); const auto resTp = getSparseTensorType(getResult()); - return verifyPackUnPack(*this, true, resTp, valuesTp, coordinatesTp); + return verifyPackUnPack(*this, true, resTp, valuesTp, coordinatesTp, + getBatchedLvlsAttr()); +} + +unsigned PackOp::getNumBatchedLvls() { + return getBatchedLvls().has_value() ? getBatchedLvls()->getZExtValue() : 0; } LogicalResult UnpackOp::verify() { const auto valuesTp = getRankedTensorType(getValues()); const auto coordinatesTp = getRankedTensorType(getCoordinates()); const auto srcTp = getSparseTensorType(getTensor()); - return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp); + return verifyPackUnPack(*this, false, srcTp, valuesTp, coordinatesTp, + nullptr); } LogicalResult ConvertOp::verify() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -138,7 +138,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(op->getNumResults() == 1); - assert(isUniqueCOOType(op->getResultTypes()[0].cast())); + // assert(isUniqueCOOType(op->getResultTypes()[0].cast())); // PackOp reuses the input tensors as values/coordinates instead of // creating new ones when packing into a COO format. return {{op->getOpResult(0), BufferRelation::Equivalent}}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1231,23 +1231,110 @@ } }; +static void populateCompressedWithHiPosArray(OpBuilder &builder, Location loc, + ArrayRef batchDimSzs, + Value posMemRef, unsigned nse, + PackOp op) { + SmallVector lbs, ubs, steps; + Value c0 = constantIndex(builder, loc, 0); + Value c1 = constantIndex(builder, loc, 1); + Value c2 = constantIndex(builder, loc, 2); + for (unsigned dimSz : batchDimSzs) { + lbs.push_back(c0); + ubs.push_back(constantIndex(builder, loc, dimSz)); + steps.push_back(c1); + } + auto tensorType = op.getValues().getType(); + auto memrefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + Value batV = builder.create(loc, memrefType, + op.getValues()); + scf::buildLoopNest( + builder, loc, lbs, ubs, steps, + [&ubs, c0, c1, c2, nse, batV, posMemRef](OpBuilder &builder, Location loc, + ValueRange ivs) { + // Linearize index variables + Value crd = constantIndex(builder, loc, 0); + for (int i = 0; i < ivs.size(); i++) { + crd = builder.create(loc, crd, ivs[i]); + if (i != ivs.size() - 1) + crd = builder.create(loc, crd, ubs[i + 1]); + } + Value len = constantIndex(builder, loc, nse); + Value pLo = builder.create(loc, crd, len); + SmallVector indices(ivs.begin(), ivs.end()); + auto whileOp = builder.create( + loc, TypeRange{builder.getIndexType()}, ValueRange{len}, + [&indices, c0, c1, batV](OpBuilder &builder, Location loc, + ValueRange vs) { + Value curLen = vs.front(); + Value pred = builder.create( + loc, arith::CmpIPredicate::eq, curLen, c0); + auto ifOp = builder.create( + loc, TypeRange{builder.getI1Type()}, pred, true); + { + OpBuilder::InsertionGuard guard(builder); + // if len == 0. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + builder.create(loc, + constantI1(builder, loc, false)); + // Else branch. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + indices.push_back( + builder.create(loc, curLen, c1)); + Value val = builder.create(loc, batV, indices); + indices.pop_back(); + Value cont = builder.create( + loc, arith::CmpFPredicate::OEQ, val, + constantZero(builder, loc, val.getType())); + builder.create(loc, cont); + } + builder.create(loc, ifOp.getResults()[0], vs); + }, + [c1](OpBuilder &builder, Location loc, ValueRange vs) { + // len --; + Value nxLen = builder.create(loc, vs.front(), c1); + builder.create(loc, nxLen); + }); + len = whileOp.getResults()[0]; + Value pHi = builder.create(loc, pLo, len); + // Stores position lower bound. + Value idx = builder.create(loc, crd, c2); + genStore(builder, loc, pLo, posMemRef, idx); + // Stores position upper bound. + idx = builder.create(loc, idx, c1); + genStore(builder, loc, pHi, posMemRef, idx); + }); +} + struct SparsePackOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - + const unsigned batchedLvls = op.getNumBatchedLvls(); + unsigned nse = op.getValues().getType().getDimSize(batchedLvls); const auto stt = getSparseTensorType(op.getResult()); - assert(isUniqueCOOType(stt)); + assert(isCOOType(stt.getEncoding(), batchedLvls, true)); + + unsigned batchedCount = 1; + SmallVector batchDimSzs; + batchDimSzs.reserve(batchedLvls); + for (unsigned i = 0; i < batchedLvls; i++) { + // Should already be guaranteed by verifier. + assert(!ShapedType::isDynamic(stt.getDimShape()[i])); + batchedCount *= stt.getDimShape()[i]; + batchDimSzs.push_back(stt.getDimShape()[i]); + } SmallVector fields; Location loc = op.getLoc(); foreachFieldAndTypeInSparseTensor( stt, - [&rewriter, &fields, &op, stt, + [&rewriter, &fields, &op, &batchDimSzs, nse, batchedCount, stt, loc](Type fType, FieldIndex fIdx, SparseTensorFieldKind fKind, - Level /*lvl*/, DimLevelType /*dlt*/) -> bool { + Level /*lvl*/, DimLevelType dlt) -> bool { assert(fields.size() == fIdx); Value field; switch (fKind) { @@ -1259,34 +1346,38 @@ // By creating a constant value for it, we avoid the complexity of // memory management. const auto posTp = stt.getPosType(); - auto tensorType = RankedTensorType::get({2}, posTp); - auto memrefType = MemRefType::get(tensorType.getShape(), - tensorType.getElementType()); - auto cstPtr = rewriter.create( - loc, tensorType, - DenseElementsAttr::get( - tensorType, - ArrayRef{ - IntegerAttr::get(posTp, 0), - IntegerAttr::get( - posTp, op.getValues().getType().getShape()[0])})); - field = rewriter.create(loc, memrefType, - cstPtr); + if (isCompressedDLT(dlt)) { + RankedTensorType tensorType; + SmallVector posAttr; + tensorType = RankedTensorType::get({batchedCount + 1}, posTp); + posAttr.push_back(IntegerAttr::get(posTp, 0)); + for (unsigned i = 0; i < batchedCount; i++) { + // The postion memref will have values as + // [0, nse, 2 * nse, ..., batchedCount * nse] + posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1))); + } + MemRefType memrefType = MemRefType::get( + tensorType.getShape(), tensorType.getElementType()); + auto cstPtr = rewriter.create( + loc, tensorType, DenseElementsAttr::get(tensorType, posAttr)); + field = rewriter.create( + loc, memrefType, cstPtr); + } else { + assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty()); + MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp); + field = rewriter.create(loc, posMemTp); + populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs, + field, nse, op); + } break; } case SparseTensorFieldKind::CrdMemRef: { auto tensorType = op.getCoordinates().getType(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - auto crdMemRef = rewriter.create( + field = rewriter.create( op->getLoc(), memrefType, op.getCoordinates()); - ReassociationIndices reassociation; - for (int i = 0, e = tensorType.getRank(); i < e; i++) - reassociation.push_back(i); - // Flattened the indices buffer to rank 1. - field = rewriter.create( - loc, crdMemRef, ArrayRef(reassociation)); break; } case SparseTensorFieldKind::ValMemRef: { @@ -1300,6 +1391,17 @@ } assert(field); + if (auto memrefTp = field.getType().dyn_cast(); + memrefTp && memrefTp.getRank() > 1) { + ReassociationIndices reassociation; + for (int i = 0, e = memrefTp.getRank(); i < e; i++) + reassociation.push_back(i); + // Flattens the buffer to rank 1. The value buffer might need be + // collapsed as well due to batching. + field = rewriter.create( + loc, field, ArrayRef(reassociation)); + } + if (fType != field.getType()) field = rewriter.create(loc, fType, field); fields.push_back(field); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -36,7 +36,7 @@ func.func @invalid_pack_data(%values: tensor<6x1xf64>, %coordinates: tensor<6x1xi32>) -> tensor<100xf64, #SparseVector> { - // expected-error@+1 {{'sparse_tensor.pack' op operand #0 must be 1D tensor of any type values}} + // expected-error@+1 {{values must have rank 1 + batched_lvls}} %0 = sparse_tensor.pack %values, %coordinates : tensor<6x1xf64>, tensor<6x1xi32> to tensor<100xf64, #SparseVector> return %0 : tensor<100xf64, #SparseVector>