diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -388,7 +388,7 @@ !highs[t][l]); const auto lvlTp = lvlTypes[t][l]; // Handle sparse storage schemes. - if (isCompressedDLT(lvlTp)) { + if (isCompressedDLT(lvlTp) || isCompressedWithHiDLT(lvlTp)) { // Generate sparse primitives to obtain positions and coordinates. positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l); coordinatesBuffers[t][l] = @@ -557,6 +557,7 @@ OpBuilder &builder, Location loc, TensorId tid, Level dstLvl, Value lo, Value hi, MutableArrayRef reduc, bool isParallel) { bool isSparseCond = isCompressedDLT(lvlTypes[tid][dstLvl]) || + isCompressedWithHiDLT(lvlTypes[tid][dstLvl]) || isSingletonDLT(lvlTypes[tid][dstLvl]); const auto reassoc = getCollapseReassociation(tid, dstLvl); @@ -695,7 +696,7 @@ auto lvlType = lvlTypes[t][l]; // Must be a recognizable DLT. assert(isDenseDLT(lvlType) || isCompressedDLT(lvlType) || - isSingletonDLT(lvlType)); + isCompressedWithHiDLT(lvlType) || isSingletonDLT(lvlType)); // This is a slice-driven loop on sparse level. if (!dependentLvlMap[t][l].empty() && !isDenseDLT(lvlType)) { @@ -901,7 +902,8 @@ // TODO: support coiteration with slice driven tensors. const auto lvlTp = lvlTypes[tid][lvl]; assert(dependentLvlMap[tid][lvl].empty() && "TODO: not yet implemented"); - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { + if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || + isCompressedWithHiDLT(lvlTp)) { const auto reassoc = getCollapseReassociation(tid, lvl); for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { if (!isUniqueDLT(lvlTypes[tid][reassoc[i]])) { @@ -941,7 +943,8 @@ for (auto [t, lvl] : llvm::zip(tids, lvls)) { const TensorId tid = t; // Why `t` can not be captured by lambda? const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { + if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || + isCompressedWithHiDLT(lvlTp)) { const auto reassoc = getCollapseReassociation(tid, lvl); assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { @@ -974,7 +977,8 @@ for (auto [tid, lvl] : llvm::zip(tids, lvls)) { // Prepares for next level. const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { + if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || + isCompressedWithHiDLT(lvlTp)) { coords[tid][lvl] = genSparseCrd(builder, loc, tid, lvl); if (isSparseSlices[tid]) { auto [trans, pred] = @@ -1023,7 +1027,8 @@ if (!needsUniv) { for (auto [tid, lvl] : llvm::zip(tids, lvls)) { const auto lvlTp = lvlTypes[tid][lvl]; - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { + if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || + isCompressedWithHiDLT(lvlTp)) { const auto crd = coords[tid][lvl]; if (min) { Value cmp = CMPI(ult, coords[tid][lvl], min); @@ -1117,12 +1122,14 @@ // Either the first level, or the previous level has been set. /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header. assert(srcLvl == 0 || posits[tid][srcLvl - 1]); - if (!isCompressedDLT(lvlTp) && !isSingletonDLT(lvlTp)) + if (isDenseDLT(lvlTp)) continue; - if (isCompressedDLT(lvlTp)) { + if (isCompressedDLT(lvlTp) || isCompressedWithHiDLT(lvlTp)) { const Value mem = positionsBuffers[tid][srcLvl]; - const Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; + Value pLo = srcLvl == 0 ? c0 : posits[tid][srcLvl - 1]; + if (isCompressedWithHiDLT(lvlTp)) + pLo = builder.create(loc, pLo, C_IDX(2)); posits[tid][srcLvl] = genIndexLoad(builder, loc, mem, pLo); const Value pHi = ADDI(pLo, c1); @@ -1321,7 +1328,8 @@ Value one = C_IDX(1); for (auto [tid, dstLvl] : llvm::zip(loopInfo.tids, loopInfo.lvls)) { const auto lvlTp = lvlTypes[tid][dstLvl]; - if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp)) { + if (isCompressedDLT(lvlTp) || isSingletonDLT(lvlTp) || + isCompressedWithHiDLT(lvlTp)) { const auto reassoc = getCollapseReassociation(tid, dstLvl); assert(reassoc.size() == 1 || isUniqueCOOType(tensors[tid].getType())); for (unsigned i = 0, e = reassoc.size() - 1; i < e; i++) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -532,6 +532,8 @@ const Level lvlRank = stt.getLvlRank(); for (Level l = 0; l < lvlRank; l++) { const auto dlt = stt.getLvlType(l); + if (isCompressedWithHiDLT(dlt)) + llvm_unreachable("TODO: Not yet implemented"); if (isCompressedDLT(dlt)) { // Compressed dimensions need a position cleanup for all entries // that were not visited during the insertion pass. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageLayout.cpp @@ -145,7 +145,7 @@ // As a result, the compound type can be constructed directly in the given // order. const auto dlt = lvlTypes[l]; - if (isCompressedDLT(dlt)) { + if (isCompressedDLT(dlt) || isCompressedWithHiDLT(dlt)) { RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt); RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt); } else if (isSingletonDLT(dlt)) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -794,7 +794,8 @@ const TensorId tid = env.makeTensorId(t.getOperandNumber()); for (LoopId i = 0; i < numLoops; i++) { const auto dltI = env.dlt(tid, i); - if (isCompressedDLT(dltI) || isSingletonDLT(dltI)) { + if (isCompressedDLT(dltI) || isCompressedWithHiDLT(dltI) || + isSingletonDLT(dltI)) { for (LoopId j = 0; j < numLoops; j++) if (isUndefDLT(env.dlt(tid, j))) { adjM[i][j] = true; @@ -1410,7 +1411,7 @@ DimLevelType dlt, bool /*unused*/) { assert(ldx == env.merger().loop(b)); Value clause; - if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) { + if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt)) { assert(lvl.has_value()); const Value crd = env.emitter().getCoords()[tid][*lvl]; const Value lvar = env.getLoopVar(ldx); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -418,7 +418,7 @@ // Slice on dense level has `locate` property as well, and can be optimized. if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) { const auto dlt = getDimLevelType(b); - if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt)) { + if (!isCompressedDLT(dlt) && !isSingletonDLT(dlt) && !isCompressedWithHiDLT(dlt)) { if (reset) simple.reset(b); reset = true; @@ -585,7 +585,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const { for (TensorLoopId b : bits.set_bits()) { const auto dlt = getDimLevelType(b); - if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) + if (isCompressedDLT(dlt) || isSingletonDLT(dlt) || isCompressedWithHiDLT(dlt)) return true; } return hasSparseIdxReduction(bits); diff --git a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_foreach.mlir @@ -138,4 +138,34 @@ "test.use" (%v) : (f64) -> () } return -} \ No newline at end of file +} + +#BCOO = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed-hi-nu", "singleton" ], +}> + +// CHECK-LABEL: func.func @foreach_bcoo( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4x4xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_7]], %[[VAL_4]] : index +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_2]] to %[[VAL_10]] step %[[VAL_3]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: "test.use"(%[[VAL_12]]) : (f64) -> () +// CHECK: } +// CHECK: } +// CHECK: return +func.func @foreach_bcoo(%A: tensor<4x4x4xf64, #BCOO>) { + sparse_tensor.foreach in %A : tensor<4x4x4xf64, #BCOO> do { + ^bb0(%1: index, %2: index, %3: index, %v: f64) : + "test.use" (%v) : (f64) -> () + } + return +}