diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -58,9 +58,9 @@ let description = [{ Returns the pointers array of the sparse storage scheme at the given dimension for the given sparse tensor. This is similar to the - `buffer_cast` operation in the sense that it provides a bridge + `memref.buffer_cast` operation in the sense that it provides a bridge between a tensor world view and a bufferized world view. Unlike the - `buffer_cast` operation, however, this sparse operation actually + `memref.buffer_cast` operation, however, this sparse operation actually lowers into a call into a support library to obtain access to the pointers array. @@ -82,9 +82,9 @@ let description = [{ Returns the indices array of the sparse storage scheme at the given dimension for the given sparse tensor. This is similar to the - `buffer_cast` operation in the sense that it provides a bridge + `memref.buffer_cast` operation in the sense that it provides a bridge between a tensor world view and a bufferized world view. Unlike the - `buffer_cast` operation, however, this sparse operation actually + `memref.buffer_cast` operation, however, this sparse operation actually lowers into a call into a support library to obtain access to the indices array. @@ -106,9 +106,9 @@ let description = [{ Returns the values array of the sparse storage scheme for the given sparse tensor, independent of the actual dimension. This is similar to - the `buffer_cast` operation in the sense that it provides a bridge + the `memref.buffer_cast` operation in the sense that it provides a bridge between a tensor world view and a bufferized world view. Unlike the - `buffer_cast` operation, however, this sparse operation actually + `memref.buffer_cast` operation, however, this sparse operation actually lowers into a call into a support library to obtain access to the values array. @@ -121,4 +121,33 @@ let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)"; } +def SparseTensor_ToTensorOp : SparseTensor_Op<"tensor", [NoSideEffect]>, + Arguments<(ins Variadic>:$memrefs)>, + Results<(outs AnyTensor:$result)> { + let summary = "Reconstructs tensor from arrays(s)"; + let description = [{ + Reconstructs the sparse tensor from the sparse storage scheme array(s). + This is similar to the `memref.load` operation in the sense that it + provides a bridge between a bufferized world view and a tensor world + view. Unlike the `memref.load` operation, however, this sparse operation + is used only temporarily to maintain a correctly typed intermediate + representation during progressive bufferization. Eventually the operation + is folded away. + + The input arrays are defined unambigously by the sparsity annotations + (pointers and indices for overhead storage in every compressed dimension, + followed by one final values array). + + Examples: + + ```mlir + %1 = sparse_tensor.tensor %0 : memref to tensor<64x64xf64, #Dense> + + %3 = sparse_tensor.tensor %0, %1, %2 : + memref, memref, memref to tensor<10x10xf32, #CSR> + ``` + }]; + let assemblyFormat = "$memrefs attr-dict `:` type($memrefs) `to` type($result)"; +} + #endif // SPARSETENSOR_OPS diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -243,6 +243,21 @@ return success(); } +// TODO: generalize this beyond all-dense linearized "sparse" tensors +static LogicalResult verify(ToTensorOp op) { + if (op.getNumOperands() != 1) + return op.emitError("expected single values array"); + if (auto e = getSparseTensorEncoding(op.result().getType())) { + auto dlt = e.getDimLevelType(); + for (unsigned i = 0, sz = dlt.size(); i < sz; i++) { + if (dlt[i] != SparseTensorEncodingAttr::DimLevelType::Dense) + return op.emitError("unexpected non-dense dimension"); + } + return success(); + } + return op.emitError("expected a sparse tensor as result"); +} + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -270,6 +270,26 @@ } }; +/// Sparse conversion rule for tensor reconstruction. +class SparseTensorToTensorConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + // Simply fold the operator into the pointer to the sparse storage scheme. + // TODO: generalize this beyond all-dense linearized "sparse" tensors + matchAndRewrite(ToTensorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (auto call = operands[0].getDefiningOp()) { + Value arg = call.getOperand(0); + if (arg.getType().isa()) { + rewriter.replaceOp(op, arg); + return success(); + } + } + return failure(); + } +}; + } // namespace /// Populates the given patterns list with conversion rules required for @@ -278,6 +298,7 @@ RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + SparseTensorToIndicesConverter, SparseTensorToValuesConverter, + SparseTensorToTensorConverter>(typeConverter, + patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -258,6 +258,11 @@ /// Returns true if bit corresponds to queried dim. bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } + /// Returns true if bit corresponds to index of output tensor. + bool isOutTensor(unsigned b, unsigned i) const { + return tensor(b) == outTensor && index(b) == i; + } + /// Returns true if tensor access at given index has queried dim. bool isDim(unsigned t, unsigned i, Dim d) const { assert(t < numTensors && i < numLoops); @@ -367,15 +372,17 @@ if (!map.isProjectedPermutation()) return false; auto enc = getSparseTensorEncoding(t->get().getType()); - if (enc) { + if (enc) annotated = true; - if (t == lhs) - return false; // TODO: handle sparse outputs - } assert(map.getNumResults() == op.getRank(t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); + Dim dim = toDim(enc, d); merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); + // Accept only all-dense annotated "sparse" output. + // TODO: support truly sparse outputs too + if (t == lhs && dim != Dim::kDense) + return false; } } return annotated; @@ -556,7 +563,7 @@ /// Local bufferization of all dense and sparse data structures. /// This code enables testing the first prototype sparse compiler. // TODO: replace this with a proliferated bufferization strategy -static void genBuffers(Merger &merger, CodeGen &codegen, +static bool genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op) { Location loc = op.getLoc(); assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); @@ -564,6 +571,7 @@ // same bounds on loop indices, and obtain dense or sparse buffer(s). SmallVector args; for (OpOperand *t : op.getInputAndOutputOperands()) { + unsigned tensor = t->getOperandNumber(); auto shape = op.getShape(t); auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); @@ -572,7 +580,7 @@ for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); // Handle sparse storage schemes. - if (merger.isDim(t->getOperandNumber(), idx, Dim::kSparse)) { + if (merger.isDim(tensor, idx, Dim::kSparse)) { auto dynShape = {ShapedType::kDynamicSize}; auto ptrTp = MemRefType::get( dynShape, genIntType(rewriter, enc.getPointerBitWidth())); @@ -580,9 +588,9 @@ dynShape, genIntType(rewriter, enc.getIndexBitWidth())); Value dim = rewriter.create(loc, d); // Generate sparse primitives to obtains pointer and indices. - codegen.pointers[t->getOperandNumber()][idx] = + codegen.pointers[tensor][idx] = rewriter.create(loc, ptrTp, t->get(), dim); - codegen.indices[t->getOperandNumber()][idx] = + codegen.indices[tensor][idx] = rewriter.create(loc, indTp, t->get(), dim); } // Find lower and upper bound in current dimension. @@ -593,27 +601,33 @@ } else { up = rewriter.create(loc, shape[d]); } - codegen.sizes[idx] = codegen.highs[t->getOperandNumber()][idx] = up; + codegen.sizes[idx] = codegen.highs[tensor][idx] = up; } - // Perform the required bufferization. All dense inputs materialize - // from the input tensor. The dense output tensor needs special - // handling. Sparse inputs use a sparse primitive to obtain the values. + // Perform the required bufferization. Dense inputs materialize + // from the input tensors. Dense outputs need special handling. + // Sparse inputs use sparse primitives to obtain the values. + // We also accept in-place all-dense annotated "sparse" outputs. Type elementType = getElementTypeOrSelf(t->get().getType()); if (!enc) { + // Non-annotated dense tensors. auto denseTp = MemRefType::get(shape, elementType); - if (t->getOperandNumber() < op.getNumInputs()) - codegen.buffers[t->getOperandNumber()] = + if (tensor < op.getNumInputs()) + codegen.buffers[tensor] = rewriter.create(loc, denseTp, t->get()); else - codegen.buffers[t->getOperandNumber()] = + codegen.buffers[tensor] = genOutputBuffer(codegen, rewriter, op, denseTp, args); } else { + // Annotated sparse tensors. + if (tensor == op.getNumInputs() && !getInPlace(t->get())) + return false; // reject output if not in-place auto dynShape = {ShapedType::kDynamicSize}; auto sparseTp = MemRefType::get(dynShape, elementType); - codegen.buffers[t->getOperandNumber()] = + codegen.buffers[tensor] = rewriter.create(loc, sparseTp, t->get()); } } + return true; } /// Constructs vector type. @@ -705,35 +719,38 @@ } // Actual load. SmallVector args; - OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs() - ? op.getInputOperand(merger.exp(exp).e0) - : op.getOutputOperand(0); - auto map = op.getTiedIndexingMap(tensor); - auto enc = getSparseTensorEncoding(tensor->get().getType()); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - unsigned idx = map.getDimPosition(perm(enc, d)); - args.push_back(codegen.loops[idx]); // universal dense index - if (enc) { - args.clear(); - args.push_back( - codegen.pidxs[tensor->getOperandNumber()][idx]); // position index + OpOperand *t = merger.exp(exp).e0 < op.getNumInputs() + ? op.getInputOperand(merger.exp(exp).e0) + : op.getOutputOperand(0); + unsigned tensor = t->getOperandNumber(); + auto map = op.getTiedIndexingMap(t); + auto enc = getSparseTensorEncoding(t->get().getType()); + unsigned rank = map.getNumResults(); + if (enc) { + unsigned idx = map.getDimPosition(perm(enc, rank - 1)); + assert(codegen.pidxs[tensor][idx] != nullptr); + args.push_back(codegen.pidxs[tensor][idx]); // position index + } else { + for (unsigned d = 0; d < rank; d++) { + unsigned idx = map.getDimPosition(d); + args.push_back(codegen.loops[idx]); // universal dense index } } Location loc = op.getLoc(); - Value ptr = codegen.buffers[tensor->getOperandNumber()]; + Value ptr = codegen.buffers[tensor]; if (codegen.curVecLength > 1) return genVectorLoad(codegen, rewriter, ptr, args); return rewriter.create(loc, ptr, args); } -/// Generates a store on a dense tensor. +/// Generates a store on a dense or sparse tensor. static void genTensorStore(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - OpOperand *tensor, Value rhs) { + OpOperand *t, Value rhs) { Location loc = op.getLoc(); // Test if this is a scalarized reduction. OpOperand *lhs = op.getOutputOperand(0); - if (lhs == tensor && codegen.redVal) { + if (lhs == t && codegen.redVal) { if (codegen.curVecLength > 1) rhs = rewriter.create(loc, codegen.curVecMask, rhs, codegen.redVal); @@ -742,13 +759,21 @@ } // Actual store. SmallVector args; - auto map = op.getTiedIndexingMap(tensor); - assert(!getSparseTensorEncoding(tensor->get().getType())); - for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - unsigned idx = map.getDimPosition(d); - args.push_back(codegen.loops[idx]); // universal dense index + unsigned tensor = t->getOperandNumber(); + auto map = op.getTiedIndexingMap(t); + auto enc = getSparseTensorEncoding(t->get().getType()); + unsigned rank = map.getNumResults(); + if (enc) { + unsigned idx = map.getDimPosition(perm(enc, rank - 1)); + assert(codegen.pidxs[tensor][idx] != nullptr); + args.push_back(codegen.pidxs[tensor][idx]); // position index + } else { + for (unsigned d = 0; d < rank; d++) { + unsigned idx = map.getDimPosition(d); + args.push_back(codegen.loops[idx]); // universal dense index + } } - Value ptr = codegen.buffers[tensor->getOperandNumber()]; + Value ptr = codegen.buffers[tensor]; if (codegen.curVecLength > 1) genVectorStore(codegen, rewriter, rhs, ptr, args); else @@ -770,7 +795,7 @@ // elements into 64-bit loses some performance since the 32-bit indexed // gather/scatter is more efficient than the 64-bit index variant (if the // negative 32-bit index space is unused, the enableSIMDIndex32 flag can - // preserve this performance)). For 64-bit values, there is no good way + // preserve this performance). For 64-bit values, there is no good way // to state that the indices are unsigned, with creates the potential of // incorrect address calculations in the unlikely case we need such // extremely large offsets. @@ -1191,9 +1216,12 @@ codegen.loops[idx] = min; } - // Initialize dense positions. + // Initialize dense positions. Note that we generate dense indices of the + // output tensor unconditionally, since they may not appear in the lattice, + // but may be needed for linearized codegen. for (unsigned b = 0, be = locals.size(); b < be; b++) { - if (locals[b] && merger.isDim(b, Dim::kDense)) { + if ((locals[b] || merger.isOutTensor(b, idx)) && + merger.isDim(b, Dim::kDense)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); unsigned pat = at; @@ -1359,6 +1387,19 @@ codegen.loops[idx] = Value(); } +/// Converts the result computed by the sparse kernel into the required form. +static void genResult(CodeGen &codegen, PatternRewriter &rewriter, + linalg::GenericOp op) { + RankedTensorType resType = op.getOutputTensorTypes()[0]; + Value result = codegen.buffers.back(); + if (getSparseTensorEncoding(resType)) + result = rewriter.create(op.getLoc(), resType, result); + else + result = + rewriter.create(op.getLoc(), resType, result); + rewriter.replaceOp(op, result); +} + namespace { /// Sparse rewriting rule for generic Lingalg operation. @@ -1398,11 +1439,10 @@ // Recursively generates code. CodeGen codegen(options, numTensors, numLoops); - genBuffers(merger, codegen, rewriter, op); + if (!genBuffers(merger, codegen, rewriter, op)) + return failure(); // could not bufferize genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); - Value result = rewriter.create( - op.getLoc(), codegen.buffers.back()); - rewriter.replaceOp(op, result); + genResult(codegen, rewriter, op); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s --sparse-tensor-conversion | FileCheck %s +#DenseVector = #sparse_tensor.encoding<{ + dimLevelType = ["dense"] +}> + #SparseVector = #sparse_tensor.encoding<{ dimLevelType = ["compressed"] }> @@ -185,3 +189,12 @@ %0 = sparse_tensor.values %arg0: tensor<128xi8, #SparseVector> to memref return %0 : memref } + +// CHECK-LABEL: func @sparse_reconstruct( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr +// CHECK: return %[[A]] : !llvm.ptr +func @sparse_reconstruct(%arg0: tensor<128xf32, #DenseVector> {linalg.inplaceable = true}) -> tensor<128xf32, #DenseVector> { + %0 = sparse_tensor.values %arg0 : tensor<128xf32, #DenseVector> to memref + %1 = sparse_tensor.tensor %0 : memref to tensor<128xf32, #DenseVector> + return %1 : tensor<128xf32, #DenseVector> +} diff --git a/mlir/test/Dialect/SparseTensor/dense.mlir b/mlir/test/Dialect/SparseTensor/dense.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/dense.mlir @@ -0,0 +1,200 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -sparsification | FileCheck %s + +// Test to demonstrate the difference between non-annotated dense tensors +// and all-dense-annotated "sparse" tensors. The former class remains as +// two-dimensional tensors that are bufferized by subsequent passes. The +// latter class is linearized into one-dimensional buffers that are backed +// by the runtime support library. + +#DenseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense" ] }> + +#trait_2d = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) + 1" +} + +#trait_3d = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "X(i,j) += A(i,j,k)" +} + +// +// Test with an all-dense-annotated "sparse" matrix as input and +// a non-annotated dense matrix as output that is not inplacable. +// This results in an explicit allocation to facilitate output. +// +// CHECK-LABEL: func @dense1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32> {linalg.inplaceable = false}) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32> +// CHECK: %[[VAL_9:.*]] = memref.alloc() : memref<32x16xf32> +// CHECK: linalg.copy(%[[VAL_8]], %[[VAL_9]]) : memref<32x16xf32>, memref<32x16xf32> +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_12:.*]] = muli %[[VAL_10]], %[[VAL_4]] : index +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_12]], %[[VAL_11]] : index +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[VAL_2]] : f32 +// CHECK: memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_11]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_16:.*]] = memref.tensor_load %[[VAL_9]] : memref<32x16xf32> +// CHECK: return %[[VAL_16]] : tensor<32x16xf32> +// CHECK: } +func @dense1(%arga: tensor<32x16xf32, #DenseMatrix>, + %argx: tensor<32x16xf32> {linalg.inplaceable = false}) + -> tensor<32x16xf32> { + %c = constant 1.0 : f32 + %0 = linalg.generic #trait_2d + ins(%arga: tensor<32x16xf32, #DenseMatrix>) + outs(%argx: tensor<32x16xf32>) { + ^bb(%a: f32, %x: f32): + %1 = addf %a, %c : f32 + linalg.yield %1 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// +// Test with an all-dense-annotated "sparse" matrix as input and +// a non-annotated dense matrix as output that is inplacable. +// This allows updating the dense output in place. +// +// CHECK-LABEL: func @dense2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32x16xf32> +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_11:.*]] = muli %[[VAL_9]], %[[VAL_4]] : index +// CHECK: %[[VAL_12:.*]] = addi %[[VAL_11]], %[[VAL_10]] : index +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[VAL_2]] : f32 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32x16xf32> +// CHECK: return %[[VAL_15]] : tensor<32x16xf32> +// CHECK: } +func @dense2(%arga: tensor<32x16xf32, #DenseMatrix>, + %argx: tensor<32x16xf32> {linalg.inplaceable = true}) + -> tensor<32x16xf32> { + %c = constant 1.0 : f32 + %0 = linalg.generic #trait_2d + ins(%arga: tensor<32x16xf32, #DenseMatrix>) + outs(%argx: tensor<32x16xf32>) { + ^bb(%a: f32, %x: f32): + %1 = addf %a, %c : f32 + linalg.yield %1 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// +// Test with a non-annotated dense matrix as input and +// an all-dense annotated "sparse" matrix as output. +// The rewriting would fail if argx was not in-placeable. +// +// CHECK-LABEL: func @dense3( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> { +// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16xf32> +// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_11:.*]] = muli %[[VAL_9]], %[[VAL_4]] : index +// CHECK: %[[VAL_12:.*]] = addi %[[VAL_11]], %[[VAL_10]] : index +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[VAL_2]] : f32 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_15:.*]] = sparse_tensor.tensor %[[VAL_8]] : memref to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: return %[[VAL_15]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +func @dense3(%arga: tensor<32x16xf32>, + %argx: tensor<32x16xf32, #DenseMatrix> {linalg.inplaceable = true}) + -> tensor<32x16xf32, #DenseMatrix> { + %c = constant 1.0 : f32 + %0 = linalg.generic #trait_2d + ins(%arga: tensor<32x16xf32>) + outs(%argx: tensor<32x16xf32, #DenseMatrix>) { + ^bb(%a: f32, %x: f32): + %1 = addf %a, %c : f32 + linalg.yield %1 : f32 + } -> tensor<32x16xf32, #DenseMatrix> + return %0 : tensor<32x16xf32, #DenseMatrix> +} + + +// +// Test with a non-annotated dense matrix as input and +// an all-dense annotated "sparse" matrix as output. +// The rewriting would fail if argx was not in-placeable. +// The missing innermost "k" index (due to a reduction) is accounted +// for by scalarizing the reduction operation for the output tensor. +// +// CHECK-LABEL: func @dense4( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> {linalg.inplaceable = true}) -> tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> { +// CHECK: %[[VAL_2:.*]] = constant 8 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_0]] : memref<32x16x8xf32> +// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}}>> to memref +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_11:.*]] = muli %[[VAL_9]], %[[VAL_4]] : index +// CHECK: %[[VAL_12:.*]] = addi %[[VAL_11]], %[[VAL_10]] : index +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) { +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]], %[[VAL_15]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_18:.*]] = addf %[[VAL_16]], %[[VAL_17]] : f32 +// CHECK: scf.yield %[[VAL_18]] : f32 +// CHECK: } +// CHECK: memref.store %[[VAL_19:.*]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_20:.*]] = sparse_tensor.tensor %[[VAL_8]] : memref to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: return %[[VAL_20]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +func @dense4(%arga: tensor<32x16x8xf32>, + %argx: tensor<32x16xf32, #DenseMatrix> {linalg.inplaceable = true}) + -> tensor<32x16xf32, #DenseMatrix> { + %0 = linalg.generic #trait_3d + ins(%arga: tensor<32x16x8xf32>) + outs(%argx: tensor<32x16xf32, #DenseMatrix>) { + ^bb(%a: f32, %x: f32): + %1 = addf %x, %a : f32 + linalg.yield %1 : f32 + } -> tensor<32x16xf32, #DenseMatrix> + return %0 : tensor<32x16xf32, #DenseMatrix> +} diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -85,3 +85,34 @@ %0 = sparse_tensor.values %arg0 : tensor to memref return %0 : memref } + +// ----- + +func @sparse_to_unannotated_tensor(%arg0: memref) -> tensor<16x32xf64> { + // expected-error@+1 {{expected a sparse tensor as result}} + %0 = sparse_tensor.tensor %arg0 : memref to tensor<16x32xf64> + return %0 : tensor<16x32xf64> +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}> + +func @sparse_to_sparse_tensor(%arg0: memref) -> tensor<16x32xf64, #SparseMatrix> { + // expected-error@+1 {{unexpected non-dense dimension}} + %0 = sparse_tensor.tensor %arg0 : memref to tensor<16x32xf64, #SparseMatrix> + return %0 : tensor<16x32xf64, #SparseMatrix> +} + +// ----- + +#DenseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","dense"]}> + +func @sparse_to_tensor(%arg0: memref, + %arg1: memref, + %arg2: memref) -> tensor<16x32xf64, #DenseMatrix> { + // expected-error@+1 {{expected single values array}} + %0 = sparse_tensor.tensor %arg0, %arg1, %arg2 + : memref, memref, memref to tensor<16x32xf64, #DenseMatrix> + return %0 : tensor<16x32xf64, #DenseMatrix> +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -53,3 +53,16 @@ %0 = sparse_tensor.values %arg0 : tensor<128xf64, #SparseVector> to memref return %0 : memref } + +// ----- + +#DenseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","dense"]}> + +// CHECK-LABEL: func @sparse_to_tensor( +// CHECK-SAME: %[[A:.*]]: memref) +// CHECK: %[[T:.*]] = sparse_tensor.tensor %[[A]] : memref to tensor<16x32xf64, #{{.*}}> +// CHECK: return %[[T]] : tensor<16x32xf64, #{{.*}}> +func @sparse_to_tensor(%arg0: memref) -> tensor<16x32xf64, #DenseMatrix> { + %0 = sparse_tensor.tensor %arg0 : memref to tensor<16x32xf64, #DenseMatrix> + return %0 : tensor<16x32xf64, #DenseMatrix> +}