diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -122,7 +122,7 @@ /// invariant expressions in the kernel. Merger(unsigned t, unsigned l) : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), - dims(t + 1, std::vector(l, Dim::kUndef)) {} + hasSparseOut(false), dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); @@ -200,6 +200,9 @@ /// Dimension setter. void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } + // Has sparse output tensor setter. + void setHasSparseOut(bool s) { hasSparseOut = s; } + /// Convenience getters to immediately access the stored nodes. /// Typically it is inadvisible to keep the reference around, as in /// "TensorExpr &te = merger.exp(e))", since insertions into the merger @@ -230,6 +233,7 @@ Value v1); private: + /// Private helpers. bool maybeZero(unsigned e) const; bool isInvariant(unsigned e) const; Type inferType(unsigned e, Value src); @@ -237,11 +241,12 @@ /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. Optional buildTensorExp(linalg::GenericOp op, Value v); + /// Merger data structures. const unsigned outTensor; const unsigned syntheticTensor; const unsigned numTensors; const unsigned numLoops; - + bool hasSparseOut; std::vector> dims; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -46,15 +46,15 @@ // Code generation. struct CodeGen { CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, - OpOperand *op) + OpOperand *op, unsigned nest) : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), pointers(numTensors, std::vector(numLoops)), indices(numTensors, std::vector(numLoops)), highs(numTensors, std::vector(numLoops)), pidxs(numTensors, std::vector(numLoops)), idxs(numTensors, std::vector(numLoops)), redExp(-1u), redVal(), - redKind(kNoReduc), sparseOut(op), lexIdx(), curVecLength(1), - curVecMask() {} + redKind(kNoReduc), sparseOut(op), outerParNest(nest), lexIdx(), + curVecLength(1), curVecMask() {} /// Sparsification options. SparsificationOptions options; /// Universal dense indices and upper bounds (by index). The loops array @@ -79,8 +79,11 @@ unsigned redExp; Value redVal; Reduction redKind; - // Sparse tensor as output. + // Sparse tensor as output. Implemented either through direct injective + // insertion in lexicographic index order (where indices are updated + // in the temporary array `lexIdx`) or TODO: access pattern expansion OpOperand *sparseOut; + unsigned outerParNest; Value lexIdx; // Current vector length and mask. unsigned curVecLength; @@ -288,10 +291,13 @@ /// Returns true when the tensor expression is admissable for codegen. /// Since all sparse input tensors are admissable, we just need to check -/// whether the output tensor in the tensor expression codegen is admissable. -/// Sets `sparseOut` when a "truly dynamic" sparse tensor output occurs. +/// whether the out tensor in the tensor expression codegen is admissable. +/// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective +/// nesting depth when a "truly dynamic" sparse tensor output occurs. static bool isAdmissableTensorExp(Merger &merger, linalg::GenericOp op, - unsigned exp, OpOperand **sparseOut) { + std::vector &topSort, unsigned exp, + OpOperand **sparseOut, + unsigned &outerParNest) { OpOperand *lhs = op.getOutputOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); @@ -302,7 +308,8 @@ // An all-dense annotated "sparse" output tensor becomes a linearized random // access 1-dim memref. Also admissable since insertions cannot occur. bool allDense = true; - unsigned numLoops = op.iterator_types().getValue().size(); + auto iteratorTypes = op.iterator_types().getValue(); + unsigned numLoops = iteratorTypes.size(); for (unsigned i = 0; i < numLoops; i++) if (merger.isDim(tensor, i, Dim::kSparse)) { allDense = false; @@ -319,15 +326,20 @@ // Accept "truly dynamic" if the output tensor materializes uninitialized // into the computation and insertions occur in lexicographic index order. if (isMaterializing(lhs->get())) { - // In this first sparse tensor output implementation, this is enforced by - // rejecting any reduction loops (since the sparse parallel loops give a - // lexicographically sorted and injective view into that tensor). - // TODO: generalize to include reductions - for (auto attr : op.iterator_types()) - if (isReductionIterator(attr)) - return false; - *sparseOut = lhs; - return true; + unsigned nest = 0; + for (unsigned i = 0; i < numLoops; i++) { + if (isReductionIterator(iteratorTypes[topSort[i]])) + break; // terminate at first reduction + nest++; + } + // Determine admissable dynamic insertion situations: + // (1) fully injective, since there are no reductions, + // (2) admissable 1-d expansion in innermost dimension. TODO: accept + if (nest == op.getRank(lhs)) { + *sparseOut = lhs; + outerParNest = nest; + return true; + } } return false; } @@ -704,9 +716,15 @@ return genVectorInvariantValue(codegen, rewriter, val); return val; } + // Insertion (a sparse tensor output "loads" as zero). + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; + if (t == codegen.sparseOut) { + Type tp = getElementTypeOrSelf(t->get().getType()); + return rewriter.create(op.getLoc(), tp, + rewriter.getZeroAttr(tp)); + } // Actual load. SmallVector args; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; Value ptr = genSubscript(codegen, rewriter, op, t, args); if (codegen.curVecLength > 1) return genVectorLoad(codegen, rewriter, ptr, args); @@ -1515,11 +1533,14 @@ // Rejects an inadmissable tensor expression. OpOperand *sparseOut = nullptr; - if (!isAdmissableTensorExp(merger, op, exp, &sparseOut)) + unsigned outerParNest = 0; + if (!isAdmissableTensorExp(merger, op, topSort, exp, &sparseOut, + outerParNest)) return failure(); // Recursively generates code. - CodeGen codegen(options, numTensors, numLoops, sparseOut); + merger.setHasSparseOut(sparseOut != nullptr); + CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest); genBuffers(merger, codegen, rewriter, op); genStmt(merger, codegen, rewriter, op, topSort, exp, 0); genResult(merger, codegen, rewriter, op); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -415,9 +415,13 @@ case kInvariant: { // Either the index is really used in the tensor expression, or it is // set to the undefined index in that dimension. An invariant expression - // is set to a synthetic tensor with undefined indices only. + // and a truly dynamic sparse output tensor are set to a synthetic tensor + // with undefined indices only to ensure the iteration space is not + // skipped as a result of their contents. unsigned s = addSet(); unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor; + if (hasSparseOut && t == outTensor) + t = syntheticTensor; latSets[s].push_back(addLat(t, i, e)); return s; } @@ -593,8 +597,8 @@ } } // Construct binary operations if subexpressions can be built. - // TODO: see buildLattices() for an explanation of rejecting - // certain division and shift operations + // See buildLattices() for an explanation of rejecting certain + // division and shift operations if (def->getNumOperands() == 2) { auto x = buildTensorExp(op, def->getOperand(0)); auto y = buildTensorExp(op, def->getOperand(1)); diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -11,6 +11,10 @@ dimOrdering = affine_map<(i,j) -> (i,j)> }> +#SparseTensor = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed" ] +}> + #trait_scale_inpl = { indexing_maps = [ affine_map<(i,j) -> (i,j)> // X (out) @@ -182,3 +186,161 @@ } -> tensor<10x20xf32, #DCSR> return %0 : tensor<10x20xf32, #DCSR> } + +#trait_sumred = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)" +} + +// CHECK-LABEL: func @sumred( +// CHECK-SAME: %[[VAL_0:.*]]: tensor>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor>) +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor> +// CHECK: %[[VAL_8:.*]] = sparse_tensor.init{{\[}}%[[VAL_6]], %[[VAL_7]]] : tensor> +// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor> to memref +// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor> to memref +// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref +// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_18:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_19:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_20:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor> to memref +// CHECK: %[[VAL_21:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor> to memref +// CHECK: %[[VAL_22:.*]] = sparse_tensor.values %[[VAL_1]] : tensor> to memref +// CHECK: %[[VAL_23:.*]] = memref.alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_28:.*]]:2 = scf.while (%[[VAL_29:.*]] = %[[VAL_24]], %[[VAL_30:.*]] = %[[VAL_26]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_31:.*]] = arith.cmpi ult, %[[VAL_29]], %[[VAL_25]] : index +// CHECK: %[[VAL_32:.*]] = arith.cmpi ult, %[[VAL_30]], %[[VAL_27]] : index +// CHECK: %[[VAL_33:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1 +// CHECK: scf.condition(%[[VAL_33]]) %[[VAL_29]], %[[VAL_30]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index): +// CHECK: %[[VAL_36:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_36]] : index +// CHECK: %[[VAL_39:.*]] = select %[[VAL_38]], %[[VAL_37]], %[[VAL_36]] : index +// CHECK: memref.store %[[VAL_39]], %[[VAL_23]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index +// CHECK: %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index +// CHECK: %[[VAL_42:.*]] = arith.andi %[[VAL_40]], %[[VAL_41]] : i1 +// CHECK: scf.if %[[VAL_42]] { +// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index +// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_44]]] : memref +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_47:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index +// CHECK: %[[VAL_48:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_49:.*]]:2 = scf.while (%[[VAL_50:.*]] = %[[VAL_43]], %[[VAL_51:.*]] = %[[VAL_46]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_52:.*]] = arith.cmpi ult, %[[VAL_50]], %[[VAL_45]] : index +// CHECK: %[[VAL_53:.*]] = arith.cmpi ult, %[[VAL_51]], %[[VAL_48]] : index +// CHECK: %[[VAL_54:.*]] = arith.andi %[[VAL_52]], %[[VAL_53]] : i1 +// CHECK: scf.condition(%[[VAL_54]]) %[[VAL_50]], %[[VAL_51]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_55:.*]]: index, %[[VAL_56:.*]]: index): +// CHECK: %[[VAL_57:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_56]]] : memref +// CHECK: %[[VAL_59:.*]] = arith.cmpi ult, %[[VAL_58]], %[[VAL_57]] : index +// CHECK: %[[VAL_60:.*]] = select %[[VAL_59]], %[[VAL_58]], %[[VAL_57]] : index +// CHECK: memref.store %[[VAL_60]], %[[VAL_23]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_61:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index +// CHECK: %[[VAL_62:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index +// CHECK: %[[VAL_63:.*]] = arith.andi %[[VAL_61]], %[[VAL_62]] : i1 +// CHECK: scf.if %[[VAL_63]] { +// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index +// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_65]]] : memref +// CHECK: %[[VAL_67:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_56]]] : memref +// CHECK: %[[VAL_68:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index +// CHECK: %[[VAL_69:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_68]]] : memref +// CHECK: %[[VAL_70:.*]]:3 = scf.while (%[[VAL_71:.*]] = %[[VAL_64]], %[[VAL_72:.*]] = %[[VAL_67]], %[[VAL_73:.*]] = %[[VAL_5]]) : (index, index, i32) -> (index, index, i32) { +// CHECK: %[[VAL_74:.*]] = arith.cmpi ult, %[[VAL_71]], %[[VAL_66]] : index +// CHECK: %[[VAL_75:.*]] = arith.cmpi ult, %[[VAL_72]], %[[VAL_69]] : index +// CHECK: %[[VAL_76:.*]] = arith.andi %[[VAL_74]], %[[VAL_75]] : i1 +// CHECK: scf.condition(%[[VAL_76]]) %[[VAL_71]], %[[VAL_72]], %[[VAL_73]] : index, index, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_77:.*]]: index, %[[VAL_78:.*]]: index, %[[VAL_79:.*]]: i32): +// CHECK: %[[VAL_80:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_77]]] : memref +// CHECK: %[[VAL_81:.*]] = memref.load %[[VAL_21]]{{\[}}%[[VAL_78]]] : memref +// CHECK: %[[VAL_82:.*]] = arith.cmpi ult, %[[VAL_81]], %[[VAL_80]] : index +// CHECK: %[[VAL_83:.*]] = select %[[VAL_82]], %[[VAL_81]], %[[VAL_80]] : index +// CHECK: memref.store %[[VAL_83]], %[[VAL_23]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_84:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index +// CHECK: %[[VAL_85:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index +// CHECK: %[[VAL_86:.*]] = arith.andi %[[VAL_84]], %[[VAL_85]] : i1 +// CHECK: %[[VAL_87:.*]] = scf.if %[[VAL_86]] -> (i32) { +// CHECK: %[[VAL_88:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_77]]] : memref +// CHECK: %[[VAL_89:.*]] = memref.load %[[VAL_22]]{{\[}}%[[VAL_78]]] : memref +// CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_88]], %[[VAL_89]] : i32 +// CHECK: %[[VAL_91:.*]] = arith.addi %[[VAL_79]], %[[VAL_90]] : i32 +// CHECK: scf.yield %[[VAL_91]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_79]] : i32 +// CHECK: } +// CHECK: %[[VAL_92:.*]] = arith.cmpi eq, %[[VAL_80]], %[[VAL_83]] : index +// CHECK: %[[VAL_93:.*]] = arith.addi %[[VAL_77]], %[[VAL_3]] : index +// CHECK: %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_77]] : index +// CHECK: %[[VAL_95:.*]] = arith.cmpi eq, %[[VAL_81]], %[[VAL_83]] : index +// CHECK: %[[VAL_96:.*]] = arith.addi %[[VAL_78]], %[[VAL_3]] : index +// CHECK: %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_78]] : index +// CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_98:.*]] : index, index, i32 +// CHECK: } +// CHECK: sparse_tensor.lex_insert %[[VAL_8]], %[[VAL_23]], %[[VAL_99:.*]]#2 : tensor, memref, i32 +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_100:.*]] = arith.cmpi eq, %[[VAL_57]], %[[VAL_60]] : index +// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_55]], %[[VAL_3]] : index +// CHECK: %[[VAL_102:.*]] = select %[[VAL_100]], %[[VAL_101]], %[[VAL_55]] : index +// CHECK: %[[VAL_103:.*]] = arith.cmpi eq, %[[VAL_58]], %[[VAL_60]] : index +// CHECK: %[[VAL_104:.*]] = arith.addi %[[VAL_56]], %[[VAL_3]] : index +// CHECK: %[[VAL_105:.*]] = select %[[VAL_103]], %[[VAL_104]], %[[VAL_56]] : index +// CHECK: scf.yield %[[VAL_102]], %[[VAL_105]] : index, index +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_106:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_39]] : index +// CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_34]], %[[VAL_3]] : index +// CHECK: %[[VAL_108:.*]] = select %[[VAL_106]], %[[VAL_107]], %[[VAL_34]] : index +// CHECK: %[[VAL_109:.*]] = arith.cmpi eq, %[[VAL_37]], %[[VAL_39]] : index +// CHECK: %[[VAL_110:.*]] = arith.addi %[[VAL_35]], %[[VAL_3]] : index +// CHECK: %[[VAL_111:.*]] = select %[[VAL_109]], %[[VAL_110]], %[[VAL_35]] : index +// CHECK: scf.yield %[[VAL_108]], %[[VAL_111]] : index, index +// CHECK: } +// CHECK: %[[VAL_112:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor +// CHECK: return %[[VAL_112]] : tensor +// CHECK: } +func @sumred(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %arga, %c1 : tensor + %xinit = sparse_tensor.init [%d0, %d1] : tensor + %0 = linalg.generic #trait_sumred + ins(%arga, %argb: tensor, + tensor) + outs(%xinit: tensor) { + ^bb(%a: i32, %b: i32, %x: i32): + %0 = arith.muli %a, %b : i32 + %1 = arith.addi %x, %0 : i32 + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-opt %s \ +// RUN: --sparsification --sparse-tensor-conversion \ +// RUN: --linalg-bufferize --convert-linalg-to-loops \ +// RUN: --convert-vector-to-scf --convert-scf-to-std \ +// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ +// RUN: --std-bufferize --finalizing-bufferize --lower-affine \ +// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-math-to-llvm \ +// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ] +}> + +#SparseTensor = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed" ] +}> + +#redsum = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)" +} + +module { + func @redsum(%arga: tensor, + %argb: tensor) + -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %arga, %c1 : tensor + %xinit = sparse_tensor.init [%d0, %d1] : tensor + %0 = linalg.generic #redsum + ins(%arga, %argb: tensor, + tensor) + outs(%xinit: tensor) { + ^bb(%a: i32, %b: i32, %x: i32): + %0 = arith.muli %a, %b : i32 + %1 = arith.addi %x, %0 : i32 + linalg.yield %1 : i32 + } -> tensor + return %0 : tensor + } + + // Driver method to call and verify tensor kernel. + func @entry() { + %c0 = arith.constant 0 : index + %i0 = arith.constant -1 : i32 + + // Setup very sparse 3-d tensors. + %t1 = arith.constant sparse< + [ [1,1,3], [2,0,0], [2,2,1], [2,2,2], [2,2,3] ], [ 1, 2, 3, 4, 5 ] + > : tensor<3x3x4xi32> + %t2 = arith.constant sparse< + [ [1,0,0], [1,1,3], [2,2,1], [2,2,3] ], [ 6, 7, 8, 9 ] + > : tensor<3x3x4xi32> + %st1 = sparse_tensor.convert %t1 + : tensor<3x3x4xi32> to tensor + %st2 = sparse_tensor.convert %t2 + : tensor<3x3x4xi32> to tensor + + + // Call kernel. + %0 = call @redsum(%st1, %st2) + : (tensor, + tensor) -> tensor + + // + // Verify results. Only two entries stored in result. Correct structure. + // + // CHECK: ( 7, 69, -1, -1 ) + // CHECK-NEXT: ( ( 0, 0, 0 ), ( 0, 7, 0 ), ( 0, 0, 69 ) ) + // + %val = sparse_tensor.values %0 + : tensor to memref + %vv = vector.transfer_read %val[%c0], %i0: memref, vector<4xi32> + vector.print %vv : vector<4xi32> + %dm = sparse_tensor.convert %0 + : tensor to tensor + %db = bufferization.to_memref %dm : memref + %vm = vector.transfer_read %db[%c0, %c0], %i0: memref, vector<3x3xi32> + vector.print %vm : vector<3x3xi32> + + // Release the resources. + sparse_tensor.release %st1 : tensor + sparse_tensor.release %st2 : tensor + sparse_tensor.release %0 : tensor + memref.dealloc %db : memref + return + } +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_vector_ops.mlir @@ -144,7 +144,7 @@ return %0 : tensor } - // Dumps just the values array of the sparse vector. + // Dumps a sparse vector. func @dump(%arg0: tensor) { // Dump the values array to verify only sparse contents are stored. %c0 = arith.constant 0 : index