diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -170,8 +170,8 @@ private: struct LoopLevelInfo { LoopLevelInfo(ArrayRef tids, ArrayRef dims, Operation *loop, - Value iv, StringAttr loopTag) - : tids(tids), dims(dims), loop(loop), iv(iv) { + Block *userBlock, Value iv, StringAttr loopTag) + : tids(tids), dims(dims), loop(loop), userCodeBlock(userBlock), iv(iv) { // Attached a special tag to loop emitter generated loop. if (loopTag) loop->setAttr(LoopEmitter::getLoopEmitterLoopAttrName(), loopTag); @@ -181,8 +181,9 @@ const llvm::SmallVector tids; // The corresponding dims for the tensors const llvm::SmallVector dims; - const Operation *loop; // the loop operation - const Value iv; // the induction variable for the loop + const Operation *loop; // the loop operation + Block *const userCodeBlock; // the block holding user' generated code. + const Value iv; // the induction variable for the loop }; /// Linearizes address for dense dimension (i.e., p = (i * d0) + j). diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -82,6 +82,30 @@ return std::make_pair(v, rem); } +static std::pair +genSliceLegitPredicate(OpBuilder &builder, Location loc, Value coord, + SparseTensorEncodingAttr enc, unsigned lvl) { + std::pair trans = fromSliceCoord(builder, loc, coord, enc, lvl); + // 1st, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip + // the check if the offset is zero). + auto geOffset = + builder.create(loc, arith::CmpIPredicate::uge, coord, + getSliceOffset(builder, loc, enc, lvl)); + // 2nd, coord_in_slice < length + auto ltLength = + builder.create(loc, arith::CmpIPredicate::ult, trans.first, + getSliceLength(builder, loc, enc, lvl)); + + // 3rd, rem == 0; confirmed that (a % 1) will be folded to 0 + auto fitStride = + builder.create(loc, arith::CmpIPredicate::eq, trans.second, + constantIndex(builder, loc, 0)); + + auto pred = builder.create(loc, geOffset, ltLength); + pred = builder.create(loc, pred, fitStride); + return {trans.first, pred}; +} + //===----------------------------------------------------------------------===// // Sparse tensor loop emitter class implementations //===----------------------------------------------------------------------===// @@ -296,8 +320,6 @@ } auto enc = getSparseTensorEncoding(tensors[tid].getType()); - // Use the stride as the step for a dense level of a sparse tensor slices. - // TODO: support dynamic slices. Value step = constantIndex(builder, loc, 1); Value lo = isSparseInput ? pidxs[tid][dim] // current offset : loopSeqStack.back(); // universal index @@ -349,31 +371,14 @@ if (isSparseSlices[tid] && isSparseInput) { // For sparse level slices, we need to filter out invalid coordinates that // are not included in the slice. - std::pair trans = fromSliceCoord(builder, loc, c, enc, dim); SmallVector types; for (Value red : reduc) types.push_back(red.getType()); - // 1st, coord >= offset (TODO: seems unsigned >= 0 won't be folded, skip - // the check if the offset is zero). - auto geOff = - builder.create(loc, arith::CmpIPredicate::uge, c, - getSliceOffset(builder, loc, enc, dim)); - // 2nd, coords < length - auto ltLen = builder.create( - loc, arith::CmpIPredicate::ult, trans.first, - getSliceLength(builder, loc, enc, dim)); - - // 3rd, rem == 0; confirmed that (a % 1) will be folded to 0 - auto fitStride = builder.create( - loc, arith::CmpIPredicate::eq, trans.second, - constantIndex(builder, loc, 0)); - - auto pred = builder.create(loc, geOff, ltLen); - pred = builder.create(loc, pred, fitStride); + auto [trans, pred] = genSliceLegitPredicate(builder, loc, c, enc, dim); bool hasReduc = !types.empty(); - scf::IfOp ifOp = - builder.create(loc, types, pred, /*else*/ hasReduc); + scf::IfOp ifOp = builder.create(loc, types, pred, + /*else*/ hasReduc); if (hasReduc) { // scf.for (a) -> v // %s = scf.if (a) -> v @@ -388,7 +393,7 @@ } // Set the insert point to matched branch. builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); - c = trans.first; + c = trans; } assert(c); @@ -396,7 +401,7 @@ // NOTE: we can also prepare for next dim here in advance // Push the loop into stack loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), loop, - coord[tid][dim], loopTag); + builder.getInsertionBlock(), coord[tid][dim], loopTag); // Emit extra locals. emitExtraLocalsForTensorsAtDenseDims(builder, loc, tids, dims); @@ -466,7 +471,7 @@ // NOTE: we can also prepare for next dim here in advance // Push the loop into stack loopStack.emplace_back(ArrayRef(tid), ArrayRef(dim), forOp, - coord[tid][dim], nullptr); + builder.getInsertionBlock(), coord[tid][dim], nullptr); return forOp; } @@ -532,7 +537,9 @@ // Generates while body. builder.setInsertionPointToStart(&whileOp.getAfter().front()); - Value min; + + SmallVector> slicesPreds; + unsigned i = 0; for (auto [tid, dim] : llvm::zip(tids, dims)) { // Prepares for next level. if (isCompressedDLT(dimTypes[tid][dim]) || @@ -541,26 +548,73 @@ Value s = pidxs[tid][dim]; Value load = genIndexLoad(builder, loc, ptr, s); coord[tid][dim] = load; - if (!needsUniv) { + if (isSparseSlices[tid]) { + auto enc = getSparseTensorEncoding(tensors[tid].getType()); + auto [trans, pred] = + genSliceLegitPredicate(builder, loc, load, enc, dim); + slicesPreds.emplace_back(pred, i); + // Updates to the relative coordinate to the slice. + coord[tid][dim] = trans; + } + i++; + } + } + + if (!slicesPreds.empty()) { + // Skip invalid loop iteration when slice coordinate is inapplicable. + SmallVector yields(after->getArguments()); + // Generates a list of if statments + // pidx = in_slice ? pidx : pidx + 1 + // TODO: instead of always picking pidx + 1, we should set pidx = high to + // break to loop the coordinates is larger than the slice size. + for (auto [pred, idx] : slicesPreds) { + Value nextPidx = builder.create( + loc, yields[idx], constantIndex(builder, loc, 1)); + yields[idx] = + builder.create(loc, pred, yields[idx], nextPidx); + } + + Value pred = slicesPreds.front().first; + for (int i = 1, e = slicesPreds.size(); i < e; i++) { + pred = builder.create(loc, pred, slicesPreds[i].first); + } + auto ifOp = builder.create(loc, types, pred, /*else*/ true); + ifOp->setAttr(getLoopEmitterLoopAttrName(), + StringAttr::get(builder.getContext(), "slice")); + builder.create(loc, ifOp->getResults()); + assert(types.size() == yields.size()); + // If not all slices are legit + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + builder.create(loc, yields); + + // If all slices are legit, start the user generated code. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } + + Value min; + // Finds the minimum coordinate + if (!needsUniv) { + for (auto [tid, dim] : llvm::zip(tids, dims)) { + if (isCompressedDLT(dimTypes[tid][dim]) || + isSingletonDLT(dimTypes[tid][dim])) { if (min) { Value cmp = builder.create( - loc, arith::CmpIPredicate::ult, load, min); - min = builder.create(loc, cmp, load, min); + loc, arith::CmpIPredicate::ult, coord[tid][dim], min); + min = builder.create(loc, cmp, coord[tid][dim], min); } else { - min = load; + min = coord[tid][dim]; } } } - } - - if (needsUniv) { + } else { assert(!min); // Otherwise, universal index is the minimal pidx. min = after->getArguments().back(); } // Sets up the loop stack. - loopStack.emplace_back(tids, dims, whileOp, min, loopTag); + loopStack.emplace_back(tids, dims, whileOp, builder.getInsertionBlock(), min, + loopTag); assert(loopStack.size() == loopSeqStack.size()); // Emits extra locals @@ -638,6 +692,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc, MutableArrayRef reduc) { LoopLevelInfo &loopInfo = loopStack.back(); + rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock); auto &dims = loopStack.back().dims; auto &tids = loopStack.back().tids; auto forOp = llvm::dyn_cast(loopInfo.loop); @@ -718,12 +773,12 @@ void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc, MutableArrayRef reduc) { - auto whileOp = llvm::cast(loopStack.back().loop); - auto &dims = loopStack.back().dims; - auto &tids = loopStack.back().tids; - Value iv = loopStack.back().iv; - // Generation while loop induction at the end. - builder.setInsertionPointToEnd(&whileOp.getAfter().front()); + const LoopLevelInfo &loopInfo = loopStack.back(); + auto whileOp = llvm::cast(loopInfo.loop); + builder.setInsertionPointToEnd(loopInfo.userCodeBlock); + auto &dims = loopInfo.dims; + auto &tids = loopInfo.tids; + Value iv = loopInfo.iv; // Finalize the induction. Note that the induction could be performed // in the individual if-branches to avoid re-evaluating the conditions. // However, that would result in a rather elaborate forest of yield diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1100,6 +1100,11 @@ if (env.isReduc() || env.isExpand() || env.getInsertionChain()) { while (auto ifOp = dyn_cast_or_null( builder.getInsertionBlock()->getParentOp())) { + if (auto attr = ifOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()); + attr && attr == StringAttr::get(ifOp->getContext(), "slice")) { + // break on ifop for slicing filtering + break; + } unsigned y = 0; SmallVector yields; if (env.isReduc()) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir @@ -1302,9 +1302,9 @@ // CHECK: ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index, %[[VAL_37:.*]]: f64): // CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_34]]] : memref // CHECK: %[[VAL_39:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref // CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_39]], %[[VAL_38]] : index // CHECK: %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_38]] : index -// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref // CHECK: %[[VAL_43:.*]] = arith.cmpi ult, %[[VAL_42]], %[[VAL_41]] : index // CHECK: %[[VAL_44:.*]] = arith.select %[[VAL_43]], %[[VAL_42]], %[[VAL_41]] : index // CHECK: %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_38]], %[[VAL_44]] : index diff --git a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_2d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_2d.mlir @@ -1130,9 +1130,9 @@ // CHECK: ^bb0(%[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index, %[[VAL_60:.*]]: f32): // CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_57]]] : memref // CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref // CHECK: %[[VAL_63:.*]] = arith.cmpi ult, %[[VAL_62]], %[[VAL_61]] : index // CHECK: %[[VAL_64:.*]] = arith.select %[[VAL_63]], %[[VAL_62]], %[[VAL_61]] : index -// CHECK: %[[VAL_65:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_59]]] : memref // CHECK: %[[VAL_66:.*]] = arith.cmpi ult, %[[VAL_65]], %[[VAL_64]] : index // CHECK: %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_65]], %[[VAL_64]] : index // CHECK: %[[VAL_68:.*]] = arith.cmpi eq, %[[VAL_61]], %[[VAL_67]] : index diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul_slice.mlir @@ -0,0 +1,181 @@ +// DEFINE: %{option} = enable-runtime-library=false +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext,%mlir_lib_dir/libmlir_runner_utils%shlibext | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} +// + +// TODO: support lib path. + +#DCSR = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ] +}> + +#DCSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ [0, 4, 1], [0, 8, 1] ] +}> + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ] +}> + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ [0, 4, 1], [0, 8, 1] ] +}> + +#CSR_SLICE_1 = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ [0, 4, 2], [0, 4, 1] ] +}> + +#DCSR_SLICE_1 = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + slice = [ [0, 4, 2], [1, 4, 1] ] +}> + +module { + func.func private @printMemrefF64(%ptr : tensor<*xf64>) + func.func private @printMemref1dF64(%ptr : memref) attributes { llvm.emit_c_interface } + + + // + // Computes C = A x B with all matrices sparse slice (SpMSpM) in CSR and DCSR + // + func.func @matmul1(%A: tensor<4x4xf64, #CSR_SLICE_1>, + %B: tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR> { + %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> + %D = linalg.matmul + ins(%A, %B: tensor<4x4xf64, #CSR_SLICE_1>, tensor<4x4xf64, #DCSR_SLICE_1>) + outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> + return %D: tensor<4x4xf64, #CSR> + } + + // + // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR. + // + func.func @matmul2(%A: tensor<4x8xf64, #CSR_SLICE>, + %B: tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> { + %C = bufferization.alloc_tensor() : tensor<4x4xf64, #CSR> + %D = linalg.matmul + ins(%A, %B: tensor<4x8xf64, #CSR_SLICE>, tensor<8x4xf64, #CSR>) + outs(%C: tensor<4x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> + return %D: tensor<4x4xf64, #CSR> + } + + func.func @matmul3(%A: tensor<4x8xf64, #DCSR_SLICE>, + %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { + %C = bufferization.alloc_tensor() : tensor<4x4xf64, #DCSR> + %D = linalg.matmul + ins(%A, %B: tensor<4x8xf64, #DCSR_SLICE>, tensor<8x4xf64, #DCSR>) + outs(%C: tensor<4x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> + return %D: tensor<4x4xf64, #DCSR> + } + + // + // Main driver. + // + func.func @entry() { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f64 + + %sa = arith.constant dense<[ + [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ], + [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ], + [ 0.0, 2.1, 0.0, 0.0, 0.0, 6.1, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ], + [ 0.0, 2.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0 ] + ]> : tensor<8x8xf64> + %sb = arith.constant dense<[ + [ 0.0, 0.0, 0.0, 1.0 ], + [ 0.0, 0.0, 2.0, 0.0 ], + [ 0.0, 3.0, 0.0, 0.0 ], + [ 4.0, 0.0, 0.0, 0.0 ], + [ 0.0, 0.0, 0.0, 0.0 ], + [ 0.0, 5.0, 0.0, 0.0 ], + [ 0.0, 0.0, 6.0, 0.0 ], + [ 0.0, 0.0, 7.0, 8.0 ] + ]> : tensor<8x4xf64> + %zero = arith.constant dense<0.0> : tensor<4x4xf64> + + // Convert all these matrices to sparse format. + %tmp = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #DCSR> + %a = tensor.extract_slice %tmp[0, 0][4, 8][1, 1] : tensor<8x8xf64, #DCSR> to tensor<4x8xf64, #DCSR_SLICE> + %b = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR> + + %2 = call @matmul3(%a, %b) + : (tensor<4x8xf64, #DCSR_SLICE>, + tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> + + // DCSR test + // + // CHECK: [0, 30.5, 4.2, 0], + // CHECK-NEXT: [0, 0, 0, 0], + // CHECK-NEXT: [0, 0, 4.6, 0], + // CHECK-NEXT: [0, 0, 7, 8] + // + %c2 = sparse_tensor.convert %2 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64> + %c2u = tensor.cast %c2 : tensor<4x4xf64> to tensor<*xf64> + call @printMemrefF64(%c2u) : (tensor<*xf64>) -> () + + %t1 = sparse_tensor.convert %sa : tensor<8x8xf64> to tensor<8x8xf64, #CSR> + %a1 = tensor.extract_slice %t1[0, 0][4, 8][1, 1] : tensor<8x8xf64, #CSR> to tensor<4x8xf64, #CSR_SLICE> + %b1 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR> + %3 = call @matmul2(%a1, %b1) + : (tensor<4x8xf64, #CSR_SLICE>, + tensor<8x4xf64, #CSR>) -> tensor<4x4xf64, #CSR> + + // CSR test + // + // CHECK: [0, 30.5, 4.2, 0], + // CHECK-NEXT: [0, 0, 0, 0], + // CHECK-NEXT: [0, 0, 4.6, 0], + // CHECK-NEXT: [0, 0, 7, 8] + // + %c3 = sparse_tensor.convert %3 : tensor<4x4xf64, #CSR> to tensor<4x4xf64> + %c3u = tensor.cast %c3 : tensor<4x4xf64> to tensor<*xf64> + call @printMemrefF64(%c3u) : (tensor<*xf64>) -> () + + // slice x slice + // + // CHECK: [2.3, 0, 0, 0], + // CHECK-NEXT: [6.9, 0, 0, 0], + // CHECK-NEXT: [0, 0, 0, 0], + // CHECK-NEXT: [12.6, 0, 0, 0]] + // + %s1 = tensor.extract_slice %tmp[0, 1][4, 4][2, 1] : tensor<8x8xf64, #DCSR> to tensor<4x4xf64, #DCSR_SLICE_1> + %s2 = tensor.extract_slice %b1[0, 0][4, 4][2, 1] : tensor<8x4xf64, #CSR> to tensor<4x4xf64, #CSR_SLICE_1> + %4 = call @matmul1(%s2, %s1) + : (tensor<4x4xf64, #CSR_SLICE_1>, + tensor<4x4xf64, #DCSR_SLICE_1>) -> tensor<4x4xf64, #CSR> + + %c4 = sparse_tensor.convert %4 : tensor<4x4xf64, #CSR> to tensor<4x4xf64> + %c4u = tensor.cast %c4 : tensor<4x4xf64> to tensor<*xf64> + call @printMemrefF64(%c4u) : (tensor<*xf64>) -> () + + // sparse slices should generate the same result as dense slices + // + // CHECK: [2.3, 0, 0, 0], + // CHECK-NEXT: [6.9, 0, 0, 0], + // CHECK-NEXT: [0, 0, 0, 0], + // CHECK-NEXT: [12.6, 0, 0, 0]] + // + %ds1 = tensor.extract_slice %sa[0, 1][4, 4][2, 1] : tensor<8x8xf64> to tensor<4x4xf64> + %ds2 = tensor.extract_slice %sb[0, 0][4, 4][2, 1] : tensor<8x4xf64> to tensor<4x4xf64> + + %d = bufferization.alloc_tensor() copy(%zero) : tensor<4x4xf64> + %r = linalg.matmul ins(%ds2, %ds1: tensor<4x4xf64>, tensor<4x4xf64>) + outs(%d: tensor<4x4xf64>) -> tensor<4x4xf64> + %du = tensor.cast %r : tensor<4x4xf64> to tensor<*xf64> + call @printMemrefF64(%du) : (tensor<*xf64>) -> () + + return + } +}