diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -573,8 +573,22 @@ rewriter.create(loc, cond, before->getArguments()); Block *after = rewriter.createBlock(&whileOp.getAfter(), {}, noTypes); rewriter.setInsertionPointToStart(after); + + bool hasDenseDim = llvm::any_of( + enc.getDimLevelType(), [](DimLevelType dlt) { return isDenseDLT(dlt); }); + if (hasDenseDim) { + Value elemV = rewriter.create(loc, elemPtr); + Value isZero = genIsNonzero(rewriter, loc, elemV); + scf::IfOp ifOp = rewriter.create(loc, isZero, /*else*/ false); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + } // Callback here to build loop body. bodyBuilder(rewriter, loc, srcIdx, elemPtr); + + // Exit the scope from the IfOp. + if (hasDenseDim) + rewriter.setInsertionPointToEnd(after); + rewriter.create(loc); // Finish generating loop. rewriter.setInsertionPointAfter(whileOp); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -149,9 +150,11 @@ // TODO: Maybe pick the bitwidth based on input/output tensors (probably the // largest one among them) in the original operation instead of using the // default value. + unsigned pointerBitWidth = encSrc ? encSrc.getPointerBitWidth() : 0; + unsigned indexBitWidth = encSrc ? encSrc.getIndexBitWidth() : 0; auto enc = SparseTensorEncodingAttr::get( ctx, dims, AffineMap::getMultiDimIdentityMap(rank, ctx), AffineMap(), - encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth()); + pointerBitWidth, indexBitWidth); return RankedTensorType::get(src.getShape(), src.getElementType(), enc); } @@ -428,10 +431,24 @@ PatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto rtp = op.getType().cast(); - // TODO: Build the output shape if needed. - assert(rtp.hasStaticShape()); - auto rank = rtp.getRank(); size_t conDim = op.getDimension().getZExtValue(); + SmallVector dynSizes; + if (!rtp.hasStaticShape()) { + ArrayRef rShape = rtp.getShape(); + for (const auto &d : llvm::enumerate(rShape)) { + if (d.value() == ShapedType::kDynamicSize) { + Value v = + createOrFoldDimOp(rewriter, loc, op.getOperand(0), d.index()); + rewriter.create(loc, op.getOperand(0), d.index()); + for (const auto &opnd : op.getOperands().drop_front()) { + Value t = createOrFoldDimOp(rewriter, loc, opnd, d.index()); + v = rewriter.create(loc, v, t); + } + dynSizes.push_back(v); + } + } + } + // %t = concatenate %s1, %s2, %s3 {dim = 1} // ==> // %tmp = bufferization.alloc_tensor : unordered COO @@ -441,13 +458,11 @@ // %t = sparse_tensor.cast %tmp auto cooTp = getUnorderedCOOFromType(rtp); auto cooBuffer = - rewriter.create(loc, cooTp, ValueRange()).getResult(); - + rewriter.create(loc, cooTp, dynSizes).getResult(); + auto rank = rtp.getRank(); Value offset = constantIndex(rewriter, loc, 0); ForeachOp foreachOp; for (Value input : op.getInputs()) { - // Builds the indexing map. - // Build a for op for each input tensor to append new values into the // output tensor. foreachOp = rewriter.create( @@ -456,16 +471,22 @@ ValueRange reduc) { SmallVector indices; for (int64_t i = 0; i < rank; i++) { - uint64_t dim = - toStoredDim(getSparseTensorEncoding(input.getType()), i); - Value idx = args[dim]; + Value idx = args[i]; if (i == static_cast(conDim)) // transform coordinates on matching dim idx = builder.create(loc, idx, offset); indices.push_back(idx); } - auto t = builder.create(loc, v, reduc.front(), indices); - builder.create(loc, t); + Value cond = genIsNonzero(rewriter, loc, v); + scf::IfOp ifOp = builder.create( + loc, TypeRange(reduc.front().getType()), cond, /*else*/ true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value t = builder.create(loc, v, reduc.front(), indices); + rewriter.create(loc, t); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, reduc.front()); + rewriter.setInsertionPointAfter(ifOp); + rewriter.create(loc, ifOp.getResult(0)); }); // Accumulates the offset. Note that only static-shaped inputs are allowed // by concatenate op verifier, which saves us from computing the offset @@ -661,7 +682,7 @@ for (uint64_t i = 0; i < rank; i++) { uint64_t orgDim = toOrigDim(encSrc, i); xs[toStoredDim(encDst, orgDim)] = rewriter.create( - loc, indTp, src, rewriter.getIndexAttr(orgDim)); + loc, indTp, src, rewriter.getIndexAttr(i)); } // Retrieve NNZ. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/concatenate.mlir @@ -1,4 +1,11 @@ -// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=true | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s +// +// Do the same run, but now with direct IR generation. +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ // RUN: mlir-cpu-runner \ // RUN: -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ @@ -169,9 +176,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64> vector.print %v : vector<9x4xf64> + %n = sparse_tensor.number_of_entries %A : tensor<9x4xf64, #MAT_C_C> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -184,9 +194,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<9x4xf64>, vector<9x4xf64> vector.print %v : vector<9x4xf64> + %n = sparse_tensor.number_of_entries %A : tensor<9x4xf64, #MAT_C_C_P> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<9x4xf64, #MAT_C_C_P> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -209,9 +222,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64> vector.print %v : vector<4x9xf64> + %n = sparse_tensor.number_of_entries %A : tensor<4x9xf64, #MAT_C_C> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -224,9 +240,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor, vector<4x9xf64> vector.print %v : vector<4x9xf64> + %n = sparse_tensor.number_of_entries %A : tensor + vector.print %n : index + %1 = sparse_tensor.values %A : tensor to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -239,9 +258,12 @@ %v = vector.transfer_read %c[%c0, %c0], %du: tensor<4x9xf64>, vector<4x9xf64> vector.print %v : vector<4x9xf64> + %n = sparse_tensor.number_of_entries %A : tensor<4x9xf64, #MAT_C_C_P> + vector.print %n : index + %1 = sparse_tensor.values %A : tensor<4x9xf64, #MAT_C_C_P> to memref - %2 = vector.transfer_read %1[%c0], %du: memref, vector<36xf64> - vector.print %2 : vector<36xf64> + %2 = vector.transfer_read %1[%c0], %du: memref, vector<18xf64> + vector.print %2 : vector<18xf64> return } @@ -297,7 +319,8 @@ %sm44dc_dyn = sparse_tensor.convert %m44 : tensor<4x4xf64> to tensor // CHECK: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3, 2, 1, 0, 1, 1, 0, 0.5, 0, 0, 1, 5, 2, 0, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3, 2, 1, 1, 1, 0.5, 1, 5, 2, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5 ) %0 = call @concat_sparse_sparse(%sm24cc, %sm34cd, %sm44dc) : (tensor<2x4xf64, #MAT_C_C>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C> call @dump_mat_9x4(%0) : (tensor<9x4xf64, #MAT_C_C>) -> () @@ -308,7 +331,8 @@ call @dump_mat_dense_9x4(%1) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3, 2, 1, 0, 1, 1, 0, 0.5, 0, 0, 1, 5, 2, 0, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3, 2, 1, 1, 1, 0.5, 1, 5, 2, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5 ) %2 = call @concat_mix_sparse(%m24, %sm34cd, %sm44dc) : (tensor<2x4xf64>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C> call @dump_mat_9x4(%2) : (tensor<9x4xf64, #MAT_C_C>) -> () @@ -319,7 +343,8 @@ call @dump_mat_dense_9x4(%3) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1, 1, 2, 0, 0.5, 5, 3.5, 5, 0.5, 3, 1, 0, 2, 1.5, 2, 1, 0, 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 2, 0.5, 5, 3.5, 5, 0.5, 3, 1, 2, 1.5, 2, 1, 1 ) %4 = call @concat_sparse_sparse_perm(%sm24ccp, %sm34cd, %sm44dc) : (tensor<2x4xf64, #MAT_C_C_P>, tensor<3x4xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C_P> call @dump_mat_perm_9x4(%4) : (tensor<9x4xf64, #MAT_C_C_P>) -> () @@ -330,7 +355,8 @@ call @dump_mat_dense_9x4(%5) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 3, 0 ), ( 0, 2, 0, 0 ), ( 1, 0, 1, 1 ), ( 0, 0.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 0, 0, 1.5, 1 ), ( 0, 3.5, 0, 0 ), ( 1, 5, 2, 0 ), ( 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3, 2, 1, 0, 1, 1, 0, 0.5, 0, 0, 1, 5, 2, 0, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3, 2, 1, 1, 1, 0.5, 1, 5, 2, 1.5, 1, 3.5, 1, 5, 2, 1, 0.5 ) %6 = call @concat_mix_sparse_perm(%m24, %sm34cdp, %sm44dc) : (tensor<2x4xf64>, tensor<3x4xf64, #MAT_C_D_P>, tensor<4x4xf64, #MAT_D_C>) -> tensor<9x4xf64, #MAT_C_C> call @dump_mat_9x4(%6) : (tensor<9x4xf64, #MAT_C_C>) -> () @@ -341,7 +367,8 @@ call @dump_mat_dense_9x4(%7) : (tensor<9x4xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %8 = call @concat_sparse_sparse_dim1(%sm42cc, %sm43cd, %sm44dc) : (tensor<4x2xf64, #MAT_C_C>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C> call @dump_mat_4x9(%8) : (tensor<4x9xf64, #MAT_C_C>) -> () @@ -352,7 +379,8 @@ call @dump_mat_dense_4x9(%9) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %10 = call @concat_mix_sparse_dim1(%m42, %sm43cd, %sm44dc) : (tensor<4x2xf64>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C> call @dump_mat_4x9(%10) : (tensor<4x9xf64, #MAT_C_C>) -> () @@ -363,7 +391,8 @@ call @dump_mat_dense_4x9(%11) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 3.1, 2, 1, 1, 0, 5, 0, 0, 0, 2, 1, 0.5, 1, 0, 1, 1, 3.5, 5, 0.5, 1.5, 2, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 3.1, 2, 1, 1, 5, 2, 1, 0.5, 1, 1, 1, 3.5, 5, 0.5, 1.5, 2, 1 ) %12 = call @concat_sparse_sparse_perm_dim1(%sm42ccp, %sm43cd, %sm44dc) : (tensor<4x2xf64, #MAT_C_C_P>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C_P> call @dump_mat_perm_4x9(%12) : (tensor<4x9xf64, #MAT_C_C_P>) -> () @@ -374,7 +403,8 @@ call @dump_mat_dense_4x9(%13) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %14 = call @concat_mix_sparse_perm_dim1(%m42, %sm43cdp, %sm44dc) : (tensor<4x2xf64>, tensor<4x3xf64, #MAT_C_D_P>, tensor<4x4xf64, #MAT_D_C>) -> tensor<4x9xf64, #MAT_C_C> call @dump_mat_4x9(%14) : (tensor<4x9xf64, #MAT_C_C>) -> () @@ -385,7 +415,8 @@ call @dump_mat_dense_4x9(%15) : (tensor<4x9xf64>) -> () // CHECK-NEXT: ( ( 1, 0, 1, 0, 1, 0, 0, 1.5, 1 ), ( 3.1, 0, 1, 0, 0.5, 0, 3.5, 0, 0 ), ( 0, 2, 0, 0, 1, 1, 5, 2, 0 ), ( 0, 0, 5, 2, 0, 1, 0.5, 0, 0 ) ) - // CHECK-NEXT: ( 1, 1, 0, 1, 1.5, 1, 3.1, 1, 0, 0.5, 3.5, 2, 0, 0, 1, 1, 5, 2, 5, 2, 0, 1, 0.5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: 18 + // CHECK-NEXT: ( 1, 1, 1, 1.5, 1, 3.1, 1, 0.5, 3.5, 2, 1, 1, 5, 2, 5, 2, 1, 0.5 ) %16 = call @concat_mix_sparse_dyn(%m42, %sm43cd, %sm44dc) : (tensor<4x2xf64>, tensor<4x3xf64, #MAT_C_D>, tensor<4x4xf64, #MAT_D_C>) -> tensor call @dump_mat_dyn(%16) : (tensor) -> ()