diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -195,16 +195,39 @@ llvm_unreachable("Unknown element type"); } +/// Generates the code to read the value from tensor[ivs], and conditionally +/// stores the indices ivs to the memory in ind. The generated code looks like +/// the following and the insertion point after this routine is inside the +/// if-then branch behind the assignment to ind. This is to ensure that the +/// addEltX call generated after is inside the if-then branch. +/// if (tensor[ivs]!=0) { +/// ind = ivs +static Value genIndexAndValueForDense(ConversionPatternRewriter &rewriter, + Operation *op, Type eltType, Value tensor, + Value ind, ValueRange ivs) { + Location loc = op->getLoc(); + Value val = rewriter.create(loc, tensor, ivs); + Value cond = genIsNonzero(rewriter, loc, eltType, val); + scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); + rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); + unsigned i = 0; + for (auto iv : ivs) { + Value idx = rewriter.create(loc, rewriter.getIndexAttr(i++)); + rewriter.create(loc, iv, ind, idx); + } + return val; +} + /// Generates a call that adds one element to a coordinate scheme. /// In particular, this generates code like the following: /// val = a[i1,..,ik]; /// if val != 0 /// t->add(val, [i1,..,ik], [p1,..,pk]); static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, - Value ptr, Value tensor, Value ind, Value perm, - ValueRange ivs) { + Type eltType, Value ptr, Value val, Value ind, + Value perm) { + Location loc = op->getLoc(); StringRef name; - Type eltType = tensor.getType().cast().getElementType(); if (eltType.isF64()) name = "addEltF64"; else if (eltType.isF32()) @@ -219,16 +242,6 @@ name = "addEltI8"; else llvm_unreachable("Unknown element type"); - Location loc = op->getLoc(); - Value val = rewriter.create(loc, tensor, ivs); - Value cond = genIsNonzero(rewriter, loc, eltType, val); - scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); - rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); - unsigned i = 0; - for (auto iv : ivs) { - Value idx = rewriter.create(loc, rewriter.getIndexAttr(i++)); - rewriter.create(loc, iv, ind, idx); - } SmallVector params; params.push_back(ptr); params.push_back(val); @@ -240,6 +253,41 @@ params); } +/// If the tensor is a sparse constant, generates and returns the pair of +/// the constants for the indices and the values. +static Optional> +genSplitSparseConstant(ConversionPatternRewriter &rewriter, ConvertOp op, + Value tensor) { + if (auto constOp = tensor.getDefiningOp()) { + if (auto attr = constOp.value().dyn_cast()) { + Location loc = op->getLoc(); + DenseElementsAttr indicesAttr = attr.getIndices(); + Value indices = rewriter.create(loc, indicesAttr); + DenseElementsAttr valuesAttr = attr.getValues(); + Value values = rewriter.create(loc, valuesAttr); + return std::make_pair(indices, values); + } + } + return {}; +} + +/// Generates the code to copy the index at indices[ivs] to ind, and return +/// the value at value[ivs]. +static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter, + Operation *op, Value indices, + Value values, Value ind, ValueRange ivs, + unsigned rank) { + Location loc = op->getLoc(); + for (unsigned i = 0; i < rank; i++) { + Value idx = rewriter.create(loc, rewriter.getIndexAttr(i)); + Value val = rewriter.create(loc, indices, + ValueRange{ivs[0], idx}); + val = rewriter.create(loc, val, rewriter.getIndexType()); + rewriter.create(loc, val, ind, idx); + } + return rewriter.create(loc, values, ivs[0]); +} + //===----------------------------------------------------------------------===// // Conversion rules. //===----------------------------------------------------------------------===// @@ -328,15 +376,26 @@ // TODO: sparse => dense return failure(); } - // This is a dense => sparse conversion, which is handled as follows: + // This is a dense => sparse conversion or a sparse constant in COO => + // sparse conversion, which is handled as follows: // t = newSparseCOO() + // ...code to fill the COO tensor t... + // s = newSparseTensor(t) + // + // To fill the COO tensor from a dense tensor: // for i1 in dim1 // .. // for ik in dimk // val = a[i1,..,ik] // if val != 0 // t->add(val, [i1,..,ik], [p1,..,pk]) - // s = newSparseTensor(t) + // + // To fill the COO tensor from a sparse constant in COO format: + // for i in range(NNZ) + // val = values[i] + // [i1,..,ik] = indices[i] + // t->add(val, [i1,..,ik], [p1,..,pk]) + // // Note that the dense tensor traversal code is actually implemented // using MLIR IR to avoid having to expose too much low-level // memref traversal details to the runtime support library. @@ -349,7 +408,6 @@ MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); Value perm; Value ptr = genNewCall(rewriter, op, encDst, 2, perm); - Value tensor = operands[0]; Value arg = rewriter.create( loc, rewriter.getIndexAttr(shape.getRank())); Value ind = rewriter.create(loc, memTp, ValueRange{arg}); @@ -358,16 +416,37 @@ SmallVector st; Value zero = rewriter.create(loc, rewriter.getIndexAttr(0)); Value one = rewriter.create(loc, rewriter.getIndexAttr(1)); - for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { + Value tensor = operands[0]; + auto indicesValues = genSplitSparseConstant(rewriter, op, tensor); + Value indices; + Value values; + if (indicesValues.hasValue()) { + indices = indicesValues->first; + values = indicesValues->second; lo.push_back(zero); - hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i)); + hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, values, 0)); st.push_back(one); + } else { + for (unsigned i = 0, rank = shape.getRank(); i < rank; i++) { + lo.push_back(zero); + hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, tensor, i)); + st.push_back(one); + } } + Type eltType = shape.getElementType(); + unsigned rank = shape.getRank(); scf::buildLoopNest(rewriter, op.getLoc(), lo, hi, st, {}, [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) -> scf::ValueVector { - genAddEltCall(rewriter, op, ptr, tensor, ind, perm, - ivs); + Value val; + if (indicesValues.hasValue()) + val = genIndexAndValueForSparse( + rewriter, op, indices, values, ind, ivs, rank); + else + val = genIndexAndValueForDense(rewriter, op, eltType, + tensor, ind, ivs); + genAddEltCall(rewriter, op, eltType, ptr, val, ind, + perm); return {}; }); rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, ptr)); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -114,8 +114,8 @@ }); // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. - target.addLegalOp(); + target.addLegalOp(); target.addLegalDialect(); // Populate with rules and apply rewriting rules. diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -182,6 +182,45 @@ return %0 : tensor<2x4xf64, #SparseMatrix> } +#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }> + +// CHECK-LABEL: func @entry() -> !llvm.ptr { +// CHECK: %[[C1:.*]] = constant 1 : i32 +// CHECK: %[[Offset:.*]] = constant dense<[0, 1]> : tensor<2xi64> +// CHECK: %[[Dims:.*]] = constant dense<[8, 7]> : tensor<2xi64> +// CHECK: %[[Base:.*]] = constant dense<[0, 1]> : tensor<2xi8> +// CHECK: %[[I2:.*]] = constant 2 : index +// CHECK: %[[SparseV:.*]] = constant dense<[1.000000e+00, 5.000000e+00]> : tensor<2xf32> +// CHECK: %[[SparseI:.*]] = constant dense<{{\[\[}}0, 0], [1, 6]]> : tensor<2x2xi64> +// CHECK: %[[I1:.*]] = constant 1 : index +// CHECK: %[[I0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : i32 +// CHECK: %[[BaseD:.*]] = tensor.cast %[[Base]] : tensor<2xi8> to tensor +// CHECK: %[[DimsD:.*]] = tensor.cast %[[Dims]] : tensor<2xi64> to tensor +// CHECK: %[[OffsetD:.*]] = tensor.cast %[[Offset]] : tensor<2xi64> to tensor +// CHECK: %[[TCOO:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %{{.*}}, %[[C2]], %{{.}}) +// CHECK: %[[Index:.*]] = memref.alloca() : memref<2xindex> +// CHECK: %[[IndexD:.*]] = memref.cast %[[Index]] : memref<2xindex> to memref +// CHECK: scf.for %[[IV:.*]] = %[[I0]] to %[[I2]] step %[[I1]] { +// CHECK: %[[VAL0:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I0]]] : tensor<2x2xi64> +// CHECK: %[[VAL1:.*]] = index_cast %[[VAL0]] : i64 to index +// CHECK: memref.store %[[VAL1]], %[[Index]]{{\[}}%[[I0]]] : memref<2xindex> +// CHECK: %[[VAL2:.*]] = tensor.extract %[[SparseI]]{{\[}}%[[IV]], %[[I1]]] : tensor<2x2xi64> +// CHECK: %[[VAL3:.*]] = index_cast %[[VAL2]] : i64 to index +// CHECK: memref.store %[[VAL3]], %[[Index]]{{\[}}%[[I1]]] : memref<2xindex> +// CHECK: %[[VAL4:.*]] = tensor.extract %[[SparseV]]{{\[}}%[[IV]]] : tensor<2xf32> +// CHECK: call @addEltF32(%[[TCOO]], %[[VAL4]], %[[IndexD]], %[[OffsetD]]) +// CHECK: } +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[BaseD]], %[[DimsD]], %[[OffsetD]], %{{.*}}, %{{.*}}, %[[C1]], %{{.*}}) +// CHECK: return %[[T]] : !llvm.ptr +func @entry() -> tensor<8x7xf32, #CSR>{ + // Initialize a tensor. + %0 = constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32> + // Convert the tensor to a sparse tensor. + %1 = sparse_tensor.convert %0 : tensor<8x7xf32> to tensor<8x7xf32, #CSR> + return %1 : tensor<8x7xf32, #CSR> +} + // CHECK-LABEL: func @sparse_convert_3d( // CHECK-SAME: %[[A:.*]]: tensor) -> !llvm.ptr // CHECK-DAG: %[[C0:.*]] = constant 0 : index diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse-constant_to_sparse_tensor.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse-constant_to_sparse_tensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse-constant_to_sparse_tensor.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s \ +// RUN: --sparsification --sparse-tensor-conversion \ +// RUN: --convert-vector-to-scf --convert-scf-to-std \ +// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ +// RUN: --std-bufferize --finalizing-bufferize \ +// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm --reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#Tensor1 = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed"] +}> + +// +// Integration tests for conversions from sparse constants to sparse tensors. +// +module { + func @entry() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %d0 = constant 0.0 : f64 + + // A tensor in COO format. + %ti = constant sparse<[[0, 0], [0, 7], [1, 2], [4, 2], [5, 3], [6, 4], [6, 6], [9, 7]], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]> : tensor<10x8xf64> + + // Convert the tensor in COO format to a sparse tensor with annotation #Tensor1. + %ts = sparse_tensor.convert %ti : tensor<10x8xf64> to tensor<10x8xf64, #Tensor1> + + // CHECK: ( 0, 1, 4, 5, 6, 9 ) + %i0 = sparse_tensor.indices %ts, %c0 : tensor<10x8xf64, #Tensor1> to memref + %i0r = vector.transfer_read %i0[%c0], %c0: memref, vector<6xindex> + vector.print %i0r : vector<6xindex> + + // CHECK: ( 0, 7, 2, 2, 3, 4, 6, 7 ) + %i1 = sparse_tensor.indices %ts, %c1 : tensor<10x8xf64, #Tensor1> to memref + %i1r = vector.transfer_read %i1[%c0], %c0: memref, vector<8xindex> + vector.print %i1r : vector<8xindex> + + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 ) + %v = sparse_tensor.values %ts : tensor<10x8xf64, #Tensor1> to memref + %vr = vector.transfer_read %v[%c0], %d0: memref, vector<8xf64> + vector.print %vr : vector<8xf64> + + return + } +} +