diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -243,19 +243,10 @@ return success(); } -// TODO: generalize this beyond all-dense linearized "sparse" tensors static LogicalResult verify(ToTensorOp op) { - if (op.getNumOperands() != 1) - return op.emitError("expected single values array"); - if (auto e = getSparseTensorEncoding(op.result().getType())) { - auto dlt = e.getDimLevelType(); - for (unsigned i = 0, sz = dlt.size(); i < sz; i++) { - if (dlt[i] != SparseTensorEncodingAttr::DimLevelType::Dense) - return op.emitError("unexpected non-dense dimension"); - } - return success(); - } - return op.emitError("expected a sparse tensor as result"); + if (!getSparseTensorEncoding(op.result().getType())) + return op.emitError("expected a sparse tensor as result"); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -276,17 +276,27 @@ using OpConversionPattern::OpConversionPattern; LogicalResult // Simply fold the operator into the pointer to the sparse storage scheme. - // TODO: generalize this beyond all-dense linearized "sparse" tensors matchAndRewrite(ToTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (auto call = operands[0].getDefiningOp()) { - Value arg = call.getOperand(0); - if (arg.getType().isa()) { - rewriter.replaceOp(op, arg); - return success(); + // Check that all arguments of the tensor reconstruction operators are calls + // into the support library that query exactly the same opague pointer. + Value ptr; + for (Value op : operands) { + if (auto call = op.getDefiningOp()) { + Value arg = call.getOperand(0); + if (!arg.getType().isa()) + return failure(); + if (!ptr) + ptr = arg; + else if (arg != ptr) + return failure(); } } - return failure(); + // If a single opague pointer is found, perform the folding. + if (!ptr) + return failure(); + rewriter.replaceOp(op, ptr); + return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -366,7 +366,6 @@ /// Fills the per-dimension sparsity information for all tensors. static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { bool annotated = false; - OpOperand *lhs = op.getOutputOperand(0); for (OpOperand *t : op.getInputAndOutputOperands()) { auto map = op.getTiedIndexingMap(t); if (!map.isProjectedPermutation()) @@ -377,12 +376,7 @@ assert(map.getNumResults() == op.getRank(t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); - Dim dim = toDim(enc, d); merger.setDim(t->getOperandNumber(), idx, toDim(enc, d)); - // Accept only all-dense annotated "sparse" output. - // TODO: support truly sparse outputs too - if (t == lhs && dim != Dim::kDense) - return false; } } return annotated; @@ -496,6 +490,55 @@ return None; } +/// Returns true if given tensor co-iterates with conjunction only. +/// For the output tensor, this defines a "simply dynamic" operation. +/// For instance: A(I) = A(I) * B(I) * C(I) +static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { + switch (merger.exp(exp).kind) { + case Kind::kTensor: + return merger.exp(exp).e0 == tensor; + case Kind::kMulF: + case Kind::kMulI: + return isConjunction(merger, tensor, merger.exp(exp).e0) || + isConjunction(merger, tensor, merger.exp(exp).e1); + default: + return false; + } +} + +/// Returns true when the tensor expression is admissable for codegen. +/// Since all sparse input tensors are admissable, we just need to check +/// whether the output tensor in the tensor expression codegen is admissable. +static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, + unsigned exp) { + OpOperand *lhs = op.getOutputOperand(0); + unsigned tensor = lhs->getOperandNumber(); + auto enc = getSparseTensorEncoding(lhs->get().getType()); + // An non-annotated output tensor is assumed dense, and becomes a random + // access n-dim memref. Admissable since "creation" cannot occur. + if (!enc) + return true; + // An all-dense annotated "sparse" output tensor becomes a linearized random + // access 1-dim memref. Also admissable since "creation" cannot occur. + bool allDense = true; + unsigned numLoops = op.iterator_types().getValue().size(); + for (unsigned i = 0; i < numLoops; i++) + if (merger.isDim(tensor, i, Dim::kSparse)) { + allDense = false; + break; + } + if (allDense) + return true; + // A tensor expression with a sparse output tensor that changes its values + // but not its nonzero structure, an operation called "simply dynamic" in + // [Bik96,Ch9], is also admissable without special codegen. + if (isConjunction(merger, tensor, exp)) + return true; + // Reject for now since this requires changes to the nonzero structure. + // TODO: implement "workspaces" [Kjolstad2019] + return false; +} + /// Builds the iteration lattices in a bottom-up traversal given the remaining /// tensor (sub)expression and the next loop index in the iteration graph. static unsigned buildLattices(Merger &merger, linalg::GenericOp op, @@ -1390,15 +1433,34 @@ } /// Converts the result computed by the sparse kernel into the required form. -static void genResult(CodeGen &codegen, PatternRewriter &rewriter, - linalg::GenericOp op) { - RankedTensorType resType = op.getOutputTensorTypes()[0]; - Value result = codegen.buffers.back(); - if (getSparseTensorEncoding(resType)) - result = rewriter.create(op.getLoc(), resType, result); - else - result = - rewriter.create(op.getLoc(), resType, result); +static void genResult(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op) { + Location loc = op.getLoc(); + OpOperand *lhs = op.getOutputOperand(0); + Type resType = lhs->get().getType(); + unsigned tensor = lhs->getOperandNumber(); + auto map = op.getTiedIndexingMap(lhs); + auto enc = getSparseTensorEncoding(resType); + Value result = codegen.buffers.back(); // value array + if (enc) { + // The sparse annotation unambigiously defines the arrays needed + // to "reconstruct" the sparse tensor from the storage scheme + // (even though lowering should never need this eventually). + SmallVector args; + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + unsigned idx = map.getDimPosition(perm(enc, d)); + if (merger.isDim(tensor, idx, Dim::kSparse)) { + args.push_back(codegen.pointers[tensor][idx]); + args.push_back(codegen.indices[tensor][idx]); + } + } + args.push_back(result); + result = rewriter.create(loc, resType, args); + } else { + // To "reconstruct" an non-annotated tensor, sipmly load it + // from the bufferized value. + result = rewriter.create(loc, resType, result); + } rewriter.replaceOp(op, result); } @@ -1437,12 +1499,16 @@ if (!exp.hasValue()) return failure(); // build failure + // Reject an inadmissable tensor expression. + if (!isAdmissableTensorExp(merger, op, exp.getValue())) + return failure(); + // Recursively generates code. CodeGen codegen(options, numTensors, numLoops); if (!genBuffers(merger, codegen, rewriter, op)) return failure(); // could not bufferize genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); - genResult(codegen, rewriter, op); + genResult(merger, codegen, rewriter, op); return success(); } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -93,26 +93,3 @@ %0 = sparse_tensor.tensor %arg0 : memref to tensor<16x32xf64> return %0 : tensor<16x32xf64> } - -// ----- - -#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}> - -func @sparse_to_sparse_tensor(%arg0: memref) -> tensor<16x32xf64, #SparseMatrix> { - // expected-error@+1 {{unexpected non-dense dimension}} - %0 = sparse_tensor.tensor %arg0 : memref to tensor<16x32xf64, #SparseMatrix> - return %0 : tensor<16x32xf64, #SparseMatrix> -} - -// ----- - -#DenseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","dense"]}> - -func @sparse_to_tensor(%arg0: memref, - %arg1: memref, - %arg2: memref) -> tensor<16x32xf64, #DenseMatrix> { - // expected-error@+1 {{expected single values array}} - %0 = sparse_tensor.tensor %arg0, %arg1, %arg2 - : memref, memref, memref to tensor<16x32xf64, #DenseMatrix> - return %0 : tensor<16x32xf64, #DenseMatrix> -} diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -0,0 +1,133 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (i,j)> +}> + +#DCSR = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + dimOrdering = affine_map<(i,j) -> (i,j)> +}> + +#trait_scale = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = X(i,j) * 2" +} + +// CHECK-LABEL: func @sparse_simply_dynamic1( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> { +// CHECK: %[[VAL_1:.*]] = constant 2.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_3]] : index +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_3]] { +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: memref.store %[[VAL_17]], %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_18:.*]] = sparse_tensor.tensor %[[VAL_4]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]] : memref, memref, memref, memref, memref to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: return %[[VAL_18]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +func @sparse_simply_dynamic1(%argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> { + %c = constant 2.0 : f32 + %0 = linalg.generic #trait_scale + outs(%argx: tensor<32x16xf32, #DCSR>) { + ^bb(%x: f32): + %1 = mulf %x, %c : f32 + linalg.yield %1 : f32 + } -> tensor<32x16xf32, #DCSR> + return %0 : tensor<32x16xf32, #DCSR> +} + +#trait_elt_wise_mult = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * X(i,j)" +} + +// CHECK-LABEL: func @sparse_simply_dynamic2( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_3]] { +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = addi %[[VAL_15]], %[[VAL_3]] : index +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_20:.*]] = addi %[[VAL_14]], %[[VAL_3]] : index +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_22:.*]]:2 = scf.while (%[[VAL_23:.*]] = %[[VAL_16]], %[[VAL_24:.*]] = %[[VAL_19]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_25:.*]] = cmpi ult, %[[VAL_23]], %[[VAL_18]] : index +// CHECK: %[[VAL_26:.*]] = cmpi ult, %[[VAL_24]], %[[VAL_21]] : index +// CHECK: %[[VAL_27:.*]] = and %[[VAL_25]], %[[VAL_26]] : i1 +// CHECK: scf.condition(%[[VAL_27]]) %[[VAL_23]], %[[VAL_24]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index): +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_29]]] : memref +// CHECK: %[[VAL_32:.*]] = cmpi ult, %[[VAL_31]], %[[VAL_30]] : index +// CHECK: %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : index +// CHECK: %[[VAL_34:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_33]] : index +// CHECK: %[[VAL_35:.*]] = cmpi eq, %[[VAL_31]], %[[VAL_33]] : index +// CHECK: %[[VAL_36:.*]] = and %[[VAL_34]], %[[VAL_35]] : i1 +// CHECK: scf.if %[[VAL_36]] { +// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref +// CHECK: %[[VAL_38:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_37]], %[[VAL_38]] : f32 +// CHECK: memref.store %[[VAL_39]], %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_40:.*]] = cmpi eq, %[[VAL_30]], %[[VAL_33]] : index +// CHECK: %[[VAL_41:.*]] = addi %[[VAL_28]], %[[VAL_3]] : index +// CHECK: %[[VAL_42:.*]] = select %[[VAL_40]], %[[VAL_41]], %[[VAL_28]] : index +// CHECK: %[[VAL_43:.*]] = cmpi eq, %[[VAL_31]], %[[VAL_33]] : index +// CHECK: %[[VAL_44:.*]] = addi %[[VAL_29]], %[[VAL_3]] : index +// CHECK: %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_29]] : index +// CHECK: scf.yield %[[VAL_42]], %[[VAL_45]] : index, index +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_46:.*]] = sparse_tensor.tensor %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] : memref, memref, memref, memref, memref to tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: return %[[VAL_46]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +func @sparse_simply_dynamic2(%arga: tensor<32x16xf32, #CSR>, + %argx: tensor<32x16xf32, #DCSR> {linalg.inplaceable = true}) -> tensor<32x16xf32, #DCSR> { + %0 = linalg.generic #trait_elt_wise_mult + ins(%arga: tensor<32x16xf32, #CSR>) + outs(%argx: tensor<32x16xf32, #DCSR>) { + ^bb(%a: f32, %x: f32): + %1 = mulf %x, %a : f32 + linalg.yield %1 : f32 + } -> tensor<32x16xf32, #DCSR> + return %0 : tensor<32x16xf32, #DCSR> +}