diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -52,6 +52,7 @@ namespace { enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; +enum class Dim { kSparse, kDense, kUndef }; /// Tensor expression. Represents a MLIR expression in tensor index notation. /// For tensors, e0 denotes the tensor index. For invariants, the IR value is @@ -81,8 +82,13 @@ bits.set(b); } LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} - /// Conjunction of tensor loop indices as bitvector. + /// Conjunction of tensor loop indices as bitvector. This represents + /// all indices involved in the tensor expression llvm::BitVector bits; + /// Simplified conjunction of tensor loop indices as bitvector. This + /// represents a simplified condition under which this tensor expression + /// must execute. Pre-computed during codegen to avoid repeated eval. + llvm::BitVector simple; /// Index of the tensor expresssion. unsigned exp; }; @@ -93,8 +99,14 @@ /// independently from the basic algorithm if bottlenecks are identified. class Merger { public: + /// Constructs a merger for the given number of tensors and loops. The + /// user supplies the number of tensors involved in the kernel, with the + /// last tensor in this set denoting the output tensor. The merger adds an + /// additional synthetic tensor at the end of this set to represent all + /// invariant expressions in the kernel. Merger(unsigned t, unsigned l) - : numTensors(t), numLoops(l), isSparse(t, std::vector(l, false)) {} + : outTensor(t - 1), numTensors(t + 1), numLoops(l), + dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { @@ -132,8 +144,8 @@ return p; } - /// Conjunctive merge of L1 and L2 is conjunction of cartesian product. - /// Returns the index of the new set. + /// Conjunctive merge of two lattice sets L0 and L1 is conjunction of + /// cartesian product. Returns the index of the new set. unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { unsigned s = addSet(); for (unsigned p0 : latSets[s0]) @@ -142,7 +154,7 @@ return s; } - /// Disjunctive merge of L0 and L1 is (L0 /\_op L1, L0, L1). + /// Disjunctive merge of two lattice sets L0 and L1 is (L0 /\_op L1, L0, L1). /// Returns the index of the new set. unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { unsigned s = takeConj(kind, s0, s1); @@ -156,26 +168,27 @@ /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid /// generating redundant loops and conditions. - unsigned optimize(unsigned s0) { + unsigned optimizeSet(unsigned s0) { unsigned s = addSet(); assert(latSets[s0].size() != 0); unsigned p0 = latSets[s0][0]; for (unsigned p1 : latSets[s0]) { bool add = true; + llvm::BitVector simple = simplifyCond(s0, p1); if (p0 != p1) { // Is this a straightforward copy? unsigned e = latPoints[p1].exp; - if (exp(e).kind == Kind::kTensor && exp(e).e0 == numTensors - 1) + if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) continue; - // Is any dense index exhausted? + // Only dense exhausted? llvm::BitVector tmp = latPoints[p1].bits; tmp ^= latPoints[p0].bits; - if (hasAnyOf(tmp, false)) + if (!hasAnyDimOf(tmp, Dim::kSparse)) continue; - // Is this a direct duplication of an earlier conjunction? + // Duplication of an earlier conjunction? for (unsigned p2 : latSets[s]) { - tmp = latPoints[p1].bits; - tmp ^= latPoints[p2].bits; + tmp = simple; + tmp ^= latPoints[p2].simple; if (tmp.count() == 0) { add = false; break; @@ -183,13 +196,49 @@ } assert(!add || latGT(p0, p1)); } - if (add) + if (add) { latSets[s].push_back(p1); + latPoints[latSets[s].back()].simple = simple; + } } return s; } - // Returns true if Li > Lj. + /// Simplifies the conditions in a conjunction of a given lattice point + /// within the given set using just two basic rules: + /// (1) multiple dense conditions are reduced to single dense, and + /// (2) a *singleton* sparse/dense is reduced to sparse/random access. + llvm::BitVector simplifyCond(unsigned s, unsigned p0) { + // First determine if this lattice point is a *singleton*, i.e., + // the last point in a lattice, no other is less than this one. + bool isSingleton = true; + for (unsigned p1 : latSets[s]) { + if (p0 != p1 && latGT(p0, p1)) { + unsigned e = latPoints[p1].exp; + if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor) + continue; + llvm::BitVector tmp = latPoints[p1].bits; + tmp ^= latPoints[p0].bits; + if (hasAnyDimOf(tmp, Dim::kSparse)) { + isSingleton = false; + break; + } + } + } + // Now apply the two basic rules. + llvm::BitVector simple = latPoints[p0].bits; + bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse); + for (unsigned b = 0, be = simple.size(); b < be; b++) { + if (simple[b] && !isDim(b, Dim::kSparse)) { + if (reset) + simple.reset(b); + reset = true; + } + } + return simple; + } + + /// Returns true if Li > Lj. bool latGT(unsigned i, unsigned j) const { const llvm::BitVector &bitsi = latPoints[i].bits; const llvm::BitVector &bitsj = latPoints[j].bits; @@ -203,40 +252,41 @@ return false; } - // Bit translation. + /// Bit translation. unsigned tensor(unsigned b) const { return b % numTensors; } unsigned index(unsigned b) const { return b / numTensors; } - // Returns true if bit corresponds to sparse access. - bool isSparseBit(unsigned b) const { - return isSparseAccess(tensor(b), index(b)); - } + /// Returns true if bit corresponds to queried dim. + bool isDim(unsigned b, Dim d) const { return isDim(tensor(b), index(b), d); } - // Returns true if tensor access at given index is sparse. - bool isSparseAccess(unsigned t, unsigned i) const { + /// Returns true if tensor access at given index has queried dim. + bool isDim(unsigned t, unsigned i, Dim d) const { assert(t < numTensors && i < numLoops); - return isSparse[t][i]; + return dims[t][i] == d; } - // Returns true if any set bit corresponds to sparse/dense access. - bool hasAnyOf(const llvm::BitVector &bits, bool sparse) const { + /// Returns true if any set bit corresponds to queried dim. + bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { for (unsigned b = 0, be = bits.size(); b < be; b++) - if (bits[b] && isSparseBit(b) == sparse) + if (bits[b] && isDim(b, d)) return true; return false; } - // Getters. - std::vector> &sparse() { return isSparse; } + // Setter + void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } + + /// Getters. TensorExp &exp(unsigned e) { return tensorExps[e]; } LatPoint &lat(unsigned l) { return latPoints[l]; } SmallVector &set(unsigned s) { return latSets[s]; } private: + const unsigned outTensor; const unsigned numTensors; const unsigned numLoops; - std::vector> isSparse; + std::vector> dims; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector, 8> latSets; @@ -251,34 +301,39 @@ indices(numTensors, std::vector(numLoops)), highs(numTensors, std::vector(numLoops)), pidxs(numTensors, std::vector(numLoops)), - idxs(numTensors, std::vector(numLoops)) {} - // Sparsification options. + idxs(numTensors, std::vector(numLoops)), redExp(-1u), redVal() {} + /// Sparsification options. linalg::SparsificationOptions options; - // Universal dense indices and upper bounds (by index). The loops array - // is updated with the value of the universal dense index in the current - // loop. The sizes array is set once with the inferred dimension sizes. + /// Universal dense indices and upper bounds (by index). The loops array + /// is updated with the value of the universal dense index in the current + /// loop. The sizes array is set once with the inferred dimension sizes. std::vector loops; std::vector sizes; - // Buffers for storing dense and sparse numerical values (by tensor). - // This array is set once during bufferization of all tensors. + /// Buffers for storing dense and sparse numerical values (by tensor). + /// This array is set once during bufferization of all tensors. std::vector buffers; - // Sparse storage schemes (1-D): pointers and indices (by tensor and index). - // This array is set once during bufferization of all sparse tensors. + /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). + /// This array is set once during bufferization of all sparse tensors. std::vector> pointers; std::vector> indices; - // Sparse iteration information (by tensor and index). These arrays - // are updated to remain current within the current loop. + /// Sparse iteration information (by tensor and index). These arrays + /// are updated to remain current within the current loop. std::vector> highs; std::vector> pidxs; std::vector> idxs; + /// Current reduction, updated during code generation. When indices of a + /// reduction are exhausted, all inner loops can "scalarize" the reduction. + // TODO: currently only done for (a chain of) innermost for-loops, where it + // is most effective; we could generalize to more outer and while-loops. + unsigned redExp; + Value redVal; }; } // namespace /// Helper method to inspect sparse annotations in the linalg operation. /// Fills the per-dimension sparsity information for all tensors. -static void findSparseAnnotations(linalg::GenericOp op, - std::vector> &isSparse) { +static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) { unsigned numTensors = op.getNumInputsAndOutputs(); ArrayAttr sparseAttr = op.sparseAttr(); for (unsigned t = 0; t < numTensors; t++) { @@ -287,13 +342,15 @@ // For each tensor, we accept a per-dimension Sparse or Dense annotation. // This is translated to the loop index that indexes that dimension. unsigned rank = op.getShapedType(t).getRank(); - for (unsigned d = 0; d < rank; d++) + for (unsigned d = 0; d < rank; d++) { + unsigned idx = map.getDimPosition(d); if (isSparseDim(dimAttr[d])) { - unsigned idx = map.getDimPosition(d); - isSparse[t][idx] = true; + merger.setDim(t, idx, Dim::kSparse); } else { assert(isDenseDim(dimAttr[d])); + merger.setDim(t, idx, Dim::kDense); } + } } } @@ -406,11 +463,11 @@ Kind kind = merger.exp(exp).kind; if (kind == Kind::kTensor || kind == Kind::kInvariant) { // Either the index is really used in the tensor expression, or it is - // set to the "non-existing dense index" in that dimension. Invariant - // expressions borrow the output tensor indices. + // set to the undefined index in that dimension. An invariant expression + // is set to a synthetic tensor with undefined indices only. unsigned s = merger.addSet(); unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 - : op.getNumInputsAndOutputs() - 1; + : op.getNumInputsAndOutputs(); merger.set(s).push_back(merger.addLat(t, idx, exp)); return s; } @@ -468,7 +525,7 @@ for (unsigned d = 0, rank = shape.size(); d < rank; d++) { unsigned i = map.getDimPosition(d); // Handle sparse storage schemes. - if (merger.isSparseAccess(t, i)) { + if (merger.isDim(t, i, Dim::kSparse)) { allDense = false; auto dynShape = {ShapedType::kDynamicSize}; auto ptrTp = MemRefType::get( @@ -514,10 +571,8 @@ unsigned exp) { // Test if the load was hoisted to a higher loop nest. Value val = merger.exp(exp).val; - if (val) { - merger.exp(exp).val = Value(); // reset + if (val) return val; - } // Actual load. SmallVector args; unsigned tensor = merger.exp(exp).e0; @@ -526,7 +581,7 @@ for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { unsigned idx = map.getDimPosition(i); args.push_back(codegen.loops[idx]); // universal dense index - if (sparse || merger.isSparseAccess(tensor, idx)) { + if (sparse || merger.isDim(tensor, idx, Dim::kSparse)) { sparse = true; args.clear(); args.push_back(codegen.pidxs[tensor][idx]); // position index @@ -541,6 +596,13 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned tensor, Value rhs) { + // Test if this is a scalarized reduction. + unsigned lhs = op.getNumInputsAndOutputs() - 1; + if (lhs == tensor && codegen.redVal) { + codegen.redVal = rhs; + return; + } + // Actual load. SmallVector args; auto map = op.getIndexingMap(tensor); for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { @@ -594,27 +656,35 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - unsigned exp) { + unsigned exp, unsigned ldx, bool hoist) { if (merger.exp(exp).kind == Kind::kTensor) { - unsigned lhs = op.getNumInputsAndOutputs() - 1; + // Inspect tensor indices. + bool atLevel = ldx == -1u; unsigned tensor = merger.exp(exp).e0; - if (tensor == lhs) - return; // TODO: scalarize reduction as well (using scf.yield) auto map = op.getIndexingMap(tensor); for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { unsigned idx = map.getDimPosition(i); if (!codegen.loops[idx]) return; // still in play + else if (idx == ldx) + atLevel = true; + } + // All exhausted at this level (atLevel denotes exactly at this level). + unsigned lhs = op.getNumInputsAndOutputs() - 1; + if (lhs == tensor) { + codegen.redExp = hoist ? exp : -1u; + } else if (atLevel) { + merger.exp(exp).val = + hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); } - // All exhausted at this level. - merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp); - } else if (merger.exp(exp).kind != Kind::kInvariant) { // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0); - genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1); + unsigned e0 = merger.exp(exp).e0; + unsigned e1 = merger.exp(exp).e1; + genInvariants(merger, codegen, rewriter, op, e0, ldx, hoist); + genInvariants(merger, codegen, rewriter, op, e1, ldx, hoist); } } @@ -633,7 +703,7 @@ if (inits[b]) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); - if (merger.isSparseBit(b)) { + if (merger.isDim(b, Dim::kSparse)) { // Initialize sparse index. unsigned pat = at; for (; pat != 0; pat--) { @@ -672,7 +742,7 @@ // is marked "parallel" is a candidate. Whether it is actually converted to // a parallel operation depends on the requested strategy. auto iteratorTypes = op.iterator_types().getValue(); - bool isSparse = merger.isSparseBit(fb); + bool isSparse = merger.isDim(fb, Dim::kSparse); bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]); switch (codegen.options.parallelizationStrategy) { case linalg::SparseParallelizationStrategy::kNone: @@ -716,8 +786,22 @@ return parOp; } - // Emit a sequential loop. - scf::ForOp forOp = rewriter.create(loc, lo, hi, step); + // Emit a sequential loop, potentially with a scalarized reduction. + bool scalarRed = isInner && codegen.redExp != -1u; + SmallVector operands; + if (scalarRed) { + Value load = + codegen.redVal + ? codegen.redVal // chained with previous for-loop + : genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); + operands.push_back(load); + } + scf::ForOp forOp = rewriter.create(loc, lo, hi, step, operands); + if (scalarRed) { + codegen.redVal = merger.exp(codegen.redExp).val = + forOp.getRegionIterArgs().front(); + } + // Assign induction variable to sparse or dense index. if (isSparse) codegen.pidxs[tensor][idx] = forOp.getInductionVar(); else @@ -736,7 +820,7 @@ // Construct the while-loop with a parameter for each index. Type indexType = rewriter.getIndexType(); for (unsigned b = 0, be = indices.size(); b < be; b++) { - if (indices[b] && merger.isSparseBit(b)) { + if (indices[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); types.push_back(indexType); @@ -758,7 +842,7 @@ Value cond; unsigned o = 0; for (unsigned b = 0, be = indices.size(); b < be; b++) { - if (indices[b] && merger.isSparseBit(b)) { + if (indices[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value op1 = before->getArgument(o); @@ -804,7 +888,7 @@ // Initialize sparse indices. Value min; for (unsigned b = 0, be = locals.size(); b < be; b++) { - if (locals[b] && merger.isSparseBit(b)) { + if (locals[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value ptr = codegen.indices[tensor][idx]; @@ -831,11 +915,9 @@ // Initialize dense positions. for (unsigned b = 0, be = locals.size(); b < be; b++) { - if (locals[b] && !merger.isSparseBit(b)) { + if (locals[b] && merger.isDim(b, Dim::kDense)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); - if (!codegen.highs[tensor][idx]) - continue; // unused dimension unsigned pat = at; for (; pat != 0; pat--) if (codegen.pidxs[tensor][topSort[pat - 1]]) @@ -858,8 +940,8 @@ unsigned o = 0; SmallVector operands; Value one = rewriter.create(loc, 1); - for (unsigned b = 0, be = induction.size(); b < be; b++) - if (induction[b] && merger.isSparseBit(b)) { + for (unsigned b = 0, be = induction.size(); b < be; b++) { + if (induction[b] && merger.isDim(b, Dim::kSparse)) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value op1 = codegen.idxs[tensor][idx]; @@ -870,6 +952,7 @@ operands.push_back(rewriter.create(loc, cmp, add, op3)); codegen.pidxs[tensor][idx] = results[o++]; } + } if (needsUniv) { operands.push_back(rewriter.create(loc, codegen.loops[idx], one)); codegen.loops[idx] = results[o++]; @@ -879,19 +962,17 @@ } /// Generates a single if-statement within a while-loop. -static void genIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, - linalg::GenericOp op, unsigned idx, - llvm::BitVector &conditions, scf::IfOp &ifOp) { +static scf::IfOp genIf(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + unsigned idx, llvm::BitVector &conditions) { Location loc = op.getLoc(); - if (ifOp) - rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); Value cond; for (unsigned b = 0, be = conditions.size(); b < be; b++) { if (conditions[b]) { unsigned tensor = merger.tensor(b); assert(idx == merger.index(b)); Value clause; - if (merger.isSparseBit(b)) { + if (merger.isDim(b, Dim::kSparse)) { Value op1 = codegen.idxs[tensor][idx]; Value op2 = codegen.loops[idx]; clause = rewriter.create(loc, CmpIPredicate::eq, op1, op2); @@ -901,25 +982,9 @@ cond = cond ? rewriter.create(loc, cond, clause) : clause; } } - ifOp = rewriter.create(loc, cond, /*else*/ true); + scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ true); rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); -} - -/// Optimize the loop indices of Li with two rules rules: -/// (1) convert multiple dense to single dense, and -/// (2) convert singleton sparse/dense to sparse/random access. -static void optimizeIndices(Merger merger, unsigned lsize, - llvm::BitVector &indices) { - if (merger.hasAnyOf(indices, false)) { - bool reset = lsize == 1 && merger.hasAnyOf(indices, true); - for (unsigned b = 0, be = indices.size(); b < be; b++) { - if (indices[b] && !merger.isSparseBit(b)) { - if (reset) - indices.reset(b); - reset = true; - } - } - } + return ifOp; } /// Recursively generates code while computing iteration lattices in order @@ -940,43 +1005,51 @@ // Then emit initialization code for the loop sequence at this level. // We maintain the universal dense index if dense indices are still // in play for a non-singleton loop sequence. + // Location loc = op.getLoc(); unsigned idx = topSort[at]; - unsigned lts = merger.optimize(buildLattices(merger, op, exp, idx)); + unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); unsigned lsize = merger.set(lts).size(); assert(lsize != 0); unsigned l0 = merger.set(lts)[0]; - LatPoint lat0 = merger.lat(l0); - genInvariants(merger, codegen, rewriter, op, exp); - bool needsUniv = - genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) && - lsize > 1; + unsigned ldx = at == 0 ? -1u : topSort[at - 1]; + genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/true); + bool needsUniv = genInit(merger, codegen, rewriter, op, topSort, at, + merger.lat(l0).bits) && + lsize > 1; // Emit a loop for every lattice point L0 >= Li. - for (unsigned li : merger.set(lts)) { - LatPoint lati = merger.lat(li); + for (unsigned i = 0; i < lsize; i++) { + unsigned li = merger.set(lts)[i]; // Emit loop. - llvm::BitVector indices = lati.bits; - optimizeIndices(merger, lsize, indices); + llvm::BitVector indices = merger.lat(li).simple; Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); - genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits); + genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, + merger.lat(li).bits); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. bool isWhile = dyn_cast(loop) != nullptr; - scf::IfOp ifOp; - for (unsigned lj : merger.set(lts)) { + for (unsigned j = 0; j < lsize; j++) { + unsigned lj = merger.set(lts)[j]; + unsigned ej = merger.lat(lj).exp; if (li == lj || merger.latGT(li, lj)) { - LatPoint latj = merger.lat(lj); - llvm::BitVector tmp = latj.bits; - tmp ^= lati.bits; - if (merger.hasAnyOf(tmp, false)) - continue; // dense exhausted within if/else + if (li != lj) { + llvm::BitVector tmp = merger.lat(lj).bits; + tmp ^= merger.lat(li).bits; + if (!merger.hasAnyDimOf(tmp, Dim::kSparse)) + continue; // only dense exhausted within if/else + } // Recurse into body of each branch. - if (isWhile) - genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp); - genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1); + if (isWhile) { + scf::IfOp ifOp = + genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); + genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); + rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); + } else { + genStmt(merger, codegen, rewriter, op, topSort, ej, at + 1); + } } } @@ -985,13 +1058,26 @@ scf::WhileOp whileOp = cast(loop); rewriter.setInsertionPointToEnd(&whileOp.after().front()); genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, - lati.bits, whileOp.results()); + merger.lat(li).bits, whileOp.results()); } else { needsUniv = false; + if (codegen.redVal) { + rewriter.create(op.getLoc(), codegen.redVal); + codegen.redVal = loop->getResult(0); + } } rewriter.setInsertionPointAfter(loop); } + + // Wrap-up loop sequence. + Value red = codegen.redVal; + if (red) { + codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain + unsigned lhs = op.getNumInputsAndOutputs() - 1; + genTensorStore(merger, codegen, rewriter, op, lhs, red); + } codegen.loops[idx] = Value(); + genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); } namespace { @@ -1012,7 +1098,7 @@ unsigned numTensors = op.getNumInputsAndOutputs(); unsigned numLoops = op.iterator_types().getValue().size(); Merger merger(numTensors, numLoops); - findSparseAnnotations(op, merger.sparse()); + findSparseAnnotations(merger, op); // Computes a topologically sorted iteration graph to ensure // tensors are visited in natural index order. Fails on cycles. diff --git a/mlir/test/Dialect/Linalg/sparse_1d.mlir b/mlir/test/Dialect/Linalg/sparse_1d.mlir --- a/mlir/test/Dialect/Linalg/sparse_1d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_1d.mlir @@ -636,6 +636,198 @@ return %0 : tensor<32xf32> } +#trait_two_way_inv = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "S" ], // a + [ "S" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) * c + b(i) * c" +} + +// CHECK-LABEL: func @two_way_inv( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_2:.*2]]: f32) -> tensor<16xf32> { +// CHECK: %[[VAL_3:.*]] = constant 999 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_8:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_11:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_12:.*]] = alloca() : memref<16xf32> +// CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_17:.*]]:3 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_4]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_21:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_14]] : index +// CHECK: %[[VAL_22:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_16]] : index +// CHECK: %[[VAL_23:.*]] = and %[[VAL_21]], %[[VAL_22]] : i1 +// CHECK: scf.condition(%[[VAL_23]]) %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index): +// CHECK: %[[VAL_27:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_29:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_30:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: %[[VAL_31:.*]] = and %[[VAL_29]], %[[VAL_30]] : i1 +// CHECK: scf.if %[[VAL_31]] { +// CHECK: %[[VAL_32:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_32]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_34:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_34]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_33]], %[[VAL_35]] : f32 +// CHECK: store %[[VAL_36]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32> +// CHECK: } else { +// CHECK: %[[VAL_37:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: scf.if %[[VAL_37]] { +// CHECK: %[[VAL_38:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_38]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_39]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32> +// CHECK: } else { +// CHECK: %[[VAL_40:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: scf.if %[[VAL_40]] { +// CHECK: %[[VAL_41:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_42:.*]] = mulf %[[VAL_41]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_42]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_43:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_44:.*]] = addi %[[VAL_24]], %[[VAL_5]] : index +// CHECK: %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_24]] : index +// CHECK: %[[VAL_46:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: %[[VAL_47:.*]] = addi %[[VAL_25]], %[[VAL_5]] : index +// CHECK: %[[VAL_48:.*]] = select %[[VAL_46]], %[[VAL_47]], %[[VAL_25]] : index +// CHECK: %[[VAL_49:.*]] = addi %[[VAL_26]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_45]], %[[VAL_48]], %[[VAL_49]] : index, index, index +// CHECK: } +// CHECK: scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#0 to %[[VAL_14]] step %[[VAL_5]] { +// CHECK: %[[VAL_52:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_50]]] : memref +// CHECK: %[[VAL_53:.*]] = mulf %[[VAL_52]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_53]], %[[VAL_12]]{{\[}}%[[VAL_51]]#2] : memref<16xf32> +// CHECK: } +// CHECK: scf.for %[[VAL_54:.*]] = %[[VAL_55:.*]]#1 to %[[VAL_16]] step %[[VAL_5]] { +// CHECK: %[[VAL_56:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_54]]] : memref +// CHECK: %[[VAL_57:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_54]]] : memref +// CHECK: %[[VAL_58:.*]] = mulf %[[VAL_57]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_58]], %[[VAL_12]]{{\[}}%[[VAL_56]]] : memref<16xf32> +// CHECK: } +// CHECK: %[[VAL_59:.*]] = tensor_load %[[VAL_12]] : memref<16xf32> +// CHECK: return %[[VAL_59]] : tensor<16xf32> +// CHECK: } +func @two_way_inv(%arga: tensor<16xf32>, + %argb: tensor<16xf32>, %argc: f32) -> tensor<16xf32> { + %0 = linalg.generic #trait_two_way_inv + ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>) { + ^bb(%a : f32, %b : f32): + %0 = mulf %a, %argc : f32 + %1 = mulf %b, %argc : f32 + %2 = addf %0, %1 : f32 + linalg.yield %2: f32 + } -> tensor<16xf32> + return %0 : tensor<16xf32> +} + +// CHECK-LABEL: func @two_way_inv_alt( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_2:.*2]]: f32) -> tensor<16xf32> { +// CHECK: %[[VAL_3:.*]] = constant 999 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_8:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_11:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_12:.*]] = alloca() : memref<16xf32> +// CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_17:.*]]:3 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_4]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_21:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_14]] : index +// CHECK: %[[VAL_22:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_16]] : index +// CHECK: %[[VAL_23:.*]] = and %[[VAL_21]], %[[VAL_22]] : i1 +// CHECK: scf.condition(%[[VAL_23]]) %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index): +// CHECK: %[[VAL_27:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_29:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_30:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: %[[VAL_31:.*]] = and %[[VAL_29]], %[[VAL_30]] : i1 +// CHECK: scf.if %[[VAL_31]] { +// CHECK: %[[VAL_32:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_33:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_34:.*]] = addf %[[VAL_32]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_34]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32> +// CHECK: } else { +// CHECK: %[[VAL_36:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: scf.if %[[VAL_36]] { +// CHECK: %[[VAL_37:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_37]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_38]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32> +// CHECK: } else { +// CHECK: %[[VAL_39:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: scf.if %[[VAL_39]] { +// CHECK: %[[VAL_40:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_41:.*]] = mulf %[[VAL_40]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_41]], %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_42:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_43:.*]] = addi %[[VAL_24]], %[[VAL_5]] : index +// CHECK: %[[VAL_44:.*]] = select %[[VAL_42]], %[[VAL_43]], %[[VAL_24]] : index +// CHECK: %[[VAL_45:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: %[[VAL_46:.*]] = addi %[[VAL_25]], %[[VAL_5]] : index +// CHECK: %[[VAL_47:.*]] = select %[[VAL_45]], %[[VAL_46]], %[[VAL_25]] : index +// CHECK: %[[VAL_48:.*]] = addi %[[VAL_26]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_44]], %[[VAL_47]], %[[VAL_48]] : index, index, index +// CHECK: } +// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#0 to %[[VAL_14]] step %[[VAL_5]] { +// CHECK: %[[VAL_51:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_49]]] : memref +// CHECK: %[[VAL_52:.*]] = mulf %[[VAL_51]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_52]], %[[VAL_12]]{{\[}}%[[VAL_50]]#2] : memref<16xf32> +// CHECK: } +// CHECK: scf.for %[[VAL_53:.*]] = %[[VAL_54:.*]]#1 to %[[VAL_16]] step %[[VAL_5]] { +// CHECK: %[[VAL_55:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_53]]] : memref +// CHECK: %[[VAL_56:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_53]]] : memref +// CHECK: %[[VAL_57:.*]] = mulf %[[VAL_56]], %[[VAL_2]] : f32 +// CHECK: store %[[VAL_57]], %[[VAL_12]]{{\[}}%[[VAL_55]]] : memref<16xf32> +// CHECK: } +// CHECK: %[[VAL_58:.*]] = tensor_load %[[VAL_12]] : memref<16xf32> +// CHECK: return %[[VAL_58]] : tensor<16xf32> +// CHECK: } +func @two_way_inv_alt(%arga: tensor<16xf32>, + %argb: tensor<16xf32>, %argc: f32) -> tensor<16xf32> { + // Same kernel, but now expressed as "x(i) = (a(i) + b(i)) * c". + %0 = linalg.generic #trait_two_way_inv + ins(%arga, %argb : tensor<16xf32>, tensor<16xf32>) { + ^bb(%a : f32, %b : f32): + %0 = addf %a, %b : f32 + %1 = mulf %0, %argc : f32 + linalg.yield %1: f32 + } -> tensor<16xf32> + return %0 : tensor<16xf32> +} + #trait_sum_reduction = { indexing_maps = [ affine_map<(i) -> (i)>, // a @@ -646,7 +838,7 @@ [ ] // x ], iterator_types = ["reduction"], - doc = "x = SUM_i a(i)" + doc = "x += SUM_i a(i)" } // CHECK-LABEL: func @sum_reduction( @@ -661,14 +853,15 @@ // CHECK: %[[VAL_8:.*]] = alloca() : memref // CHECK: %[[VAL_9:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref // CHECK: %[[VAL_10:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref -// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { -// CHECK: %[[VAL_12:.*]] = load %[[VAL_8]][] : memref -// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref -// CHECK: %[[VAL_14:.*]] = addf %[[VAL_12]], %[[VAL_13]] : f32 -// CHECK: store %[[VAL_14]], %[[VAL_8]][] : memref +// CHECK: %[[VAL_11:.*]] = load %[[VAL_8]][] : memref +// CHECK: %[[VAL_12:.*]] = scf.for %[[VAL_13:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] iter_args(%[[VAL_14:.*]] = %[[VAL_11]]) -> (f32) { +// CHECK: %[[VAL_15:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_16:.*]] = addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: scf.yield %[[VAL_16]] : f32 // CHECK: } -// CHECK: %[[VAL_15:.*]] = tensor_load %[[VAL_8]] : memref -// CHECK: return %[[VAL_15]] : tensor +// CHECK: store %[[VAL_17:.*]], %[[VAL_8]][] : memref +// CHECK: %[[VAL_18:.*]] = tensor_load %[[VAL_8]] : memref +// CHECK: return %[[VAL_18]] : tensor // CHECK: } func @sum_reduction(%arga: tensor, %argx: tensor) -> tensor { %0 = linalg.generic #trait_sum_reduction @@ -680,3 +873,233 @@ } -> tensor return %0 : tensor } + +#trait_sum_reduction_ss = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i)-> ()> // x (scalar out) + ], + sparse = [ + [ "S" ], // a + [ "S" ], // b + [ ] // x + ], + iterator_types = ["reduction"], + doc = "x += SUM_i a(i) + b(i)" +} + +// CHECK-LABEL: func @sum_reduction_ss( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_2:.*2]]: tensor) -> tensor { +// CHECK: %[[VAL_3:.*]] = constant 999 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_8:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_11:.*]] = alloca(%[[VAL_3]]) : memref +// CHECK: %[[VAL_12:.*]] = alloca() : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_17:.*]]:3 = scf.while (%[[VAL_18:.*]] = %[[VAL_13]], %[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_4]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_21:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_14]] : index +// CHECK: %[[VAL_22:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_16]] : index +// CHECK: %[[VAL_23:.*]] = and %[[VAL_21]], %[[VAL_22]] : i1 +// CHECK: scf.condition(%[[VAL_23]]) %[[VAL_18]], %[[VAL_19]], %[[VAL_20]] : index, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index): +// CHECK: %[[VAL_27:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_29:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_30:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: %[[VAL_31:.*]] = and %[[VAL_29]], %[[VAL_30]] : i1 +// CHECK: scf.if %[[VAL_31]] { +// CHECK: %[[VAL_32:.*]] = load %[[VAL_12]][] : memref +// CHECK: %[[VAL_33:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_34:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_35:.*]] = addf %[[VAL_33]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_32]], %[[VAL_35]] : f32 +// CHECK: store %[[VAL_36]], %[[VAL_12]][] : memref +// CHECK: } else { +// CHECK: %[[VAL_37:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: scf.if %[[VAL_37]] { +// CHECK: %[[VAL_38:.*]] = load %[[VAL_12]][] : memref +// CHECK: %[[VAL_39:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_40:.*]] = addf %[[VAL_38]], %[[VAL_39]] : f32 +// CHECK: store %[[VAL_40]], %[[VAL_12]][] : memref +// CHECK: } else { +// CHECK: %[[VAL_41:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: scf.if %[[VAL_41]] { +// CHECK: %[[VAL_42:.*]] = load %[[VAL_12]][] : memref +// CHECK: %[[VAL_43:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_44:.*]] = addf %[[VAL_42]], %[[VAL_43]] : f32 +// CHECK: store %[[VAL_44]], %[[VAL_12]][] : memref +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_45:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_46:.*]] = addi %[[VAL_24]], %[[VAL_5]] : index +// CHECK: %[[VAL_47:.*]] = select %[[VAL_45]], %[[VAL_46]], %[[VAL_24]] : index +// CHECK: %[[VAL_48:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_26]] : index +// CHECK: %[[VAL_49:.*]] = addi %[[VAL_25]], %[[VAL_5]] : index +// CHECK: %[[VAL_50:.*]] = select %[[VAL_48]], %[[VAL_49]], %[[VAL_25]] : index +// CHECK: %[[VAL_51:.*]] = addi %[[VAL_26]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_47]], %[[VAL_50]], %[[VAL_51]] : index, index, index +// CHECK: } +// CHECK: %[[VAL_52:.*]] = load %[[VAL_12]][] : memref +// CHECK: %[[VAL_53:.*]] = scf.for %[[VAL_54:.*]] = %[[VAL_55:.*]]#0 to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_56:.*]] = %[[VAL_52]]) -> (f32) { +// CHECK: %[[VAL_57:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_54]]] : memref +// CHECK: %[[VAL_58:.*]] = addf %[[VAL_56]], %[[VAL_57]] : f32 +// CHECK: scf.yield %[[VAL_58]] : f32 +// CHECK: } +// CHECK: %[[VAL_59:.*]] = scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_16]] step %[[VAL_5]] iter_args(%[[VAL_62:.*]] = %[[VAL_63:.*]]) -> (f32) { +// CHECK: %[[VAL_64:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_60]]] : memref +// CHECK: %[[VAL_65:.*]] = addf %[[VAL_62]], %[[VAL_64]] : f32 +// CHECK: scf.yield %[[VAL_65]] : f32 +// CHECK: } +// CHECK: store %[[VAL_66:.*]], %[[VAL_12]][] : memref +// CHECK: %[[VAL_67:.*]] = tensor_load %[[VAL_12]] : memref +// CHECK: return %[[VAL_67]] : tensor +// CHECK: } +func @sum_reduction_ss(%arga: tensor<16xf32>, + %argb: tensor<16xf32>, + %argx: tensor) -> tensor { + // Just for testing. This case would be better expressed + // as two separate reductions kernels. + %0 = linalg.generic #trait_sum_reduction_ss + ins(%arga, %argb: tensor<16xf32>, tensor<16xf32>) + init(%argx : tensor) { + ^bb(%a : f32, %b : f32, %x : f32): + %0 = addf %a, %b : f32 + %1 = addf %x, %0 : f32 + linalg.yield %1: f32 + } -> tensor + return %0 : tensor +} + +#trait_sum_reduction_inv_ss = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> ()>, // b + affine_map<(i) -> (i)>, // c + affine_map<(i) -> ()> // x (out) + ], + sparse = [ + [ "S" ], // a + [ ], // b + [ "S" ], // c + [ ] // x + ], + iterator_types = ["reduction"], + doc = "x += SUM_i a(i) * b + c(i)" +} + +// CHECK-LABEL: func @sum_reduction_inv( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor, +// CHECK-SAME: %[[VAL_2:.*2]]: tensor<16xf32>, +// CHECK-SAME: %[[VAL_3:.*3]]: tensor) -> tensor { +// CHECK: %[[VAL_4:.*]] = constant 999 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_8:.*]] = alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_10:.*]] = alloca() : memref +// CHECK: %[[VAL_11:.*]] = alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_12:.*]] = alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_13:.*]] = alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_14:.*]] = alloca() : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_10]][] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_20:.*]]:3 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_5]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_24:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_17]] : index +// CHECK: %[[VAL_25:.*]] = cmpi "ult", %[[VAL_22]], %[[VAL_19]] : index +// CHECK: %[[VAL_26:.*]] = and %[[VAL_24]], %[[VAL_25]] : i1 +// CHECK: scf.condition(%[[VAL_26]]) %[[VAL_21]], %[[VAL_22]], %[[VAL_23]] : index, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index, %[[VAL_29:.*]]: index): +// CHECK: %[[VAL_30:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_31:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_32:.*]] = cmpi "eq", %[[VAL_30]], %[[VAL_29]] : index +// CHECK: %[[VAL_33:.*]] = cmpi "eq", %[[VAL_31]], %[[VAL_29]] : index +// CHECK: %[[VAL_34:.*]] = and %[[VAL_32]], %[[VAL_33]] : i1 +// CHECK: scf.if %[[VAL_34]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_14]][] : memref +// CHECK: %[[VAL_36:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_36]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_38:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_39:.*]] = addf %[[VAL_37]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = addf %[[VAL_35]], %[[VAL_39]] : f32 +// CHECK: store %[[VAL_40]], %[[VAL_14]][] : memref +// CHECK: } else { +// CHECK: %[[VAL_41:.*]] = cmpi "eq", %[[VAL_30]], %[[VAL_29]] : index +// CHECK: scf.if %[[VAL_41]] { +// CHECK: %[[VAL_42:.*]] = load %[[VAL_14]][] : memref +// CHECK: %[[VAL_43:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_44:.*]] = mulf %[[VAL_43]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_45:.*]] = addf %[[VAL_42]], %[[VAL_44]] : f32 +// CHECK: store %[[VAL_45]], %[[VAL_14]][] : memref +// CHECK: } else { +// CHECK: %[[VAL_46:.*]] = cmpi "eq", %[[VAL_31]], %[[VAL_29]] : index +// CHECK: scf.if %[[VAL_46]] { +// CHECK: %[[VAL_47:.*]] = load %[[VAL_14]][] : memref +// CHECK: %[[VAL_48:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_49:.*]] = addf %[[VAL_47]], %[[VAL_48]] : f32 +// CHECK: store %[[VAL_49]], %[[VAL_14]][] : memref +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_50:.*]] = cmpi "eq", %[[VAL_30]], %[[VAL_29]] : index +// CHECK: %[[VAL_51:.*]] = addi %[[VAL_27]], %[[VAL_6]] : index +// CHECK: %[[VAL_52:.*]] = select %[[VAL_50]], %[[VAL_51]], %[[VAL_27]] : index +// CHECK: %[[VAL_53:.*]] = cmpi "eq", %[[VAL_31]], %[[VAL_29]] : index +// CHECK: %[[VAL_54:.*]] = addi %[[VAL_28]], %[[VAL_6]] : index +// CHECK: %[[VAL_55:.*]] = select %[[VAL_53]], %[[VAL_54]], %[[VAL_28]] : index +// CHECK: %[[VAL_56:.*]] = addi %[[VAL_29]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_52]], %[[VAL_55]], %[[VAL_56]] : index, index, index +// CHECK: } +// CHECK: %[[VAL_57:.*]] = load %[[VAL_14]][] : memref +// CHECK: %[[VAL_58:.*]] = scf.for %[[VAL_59:.*]] = %[[VAL_60:.*]]#0 to %[[VAL_17]] step %[[VAL_6]] iter_args(%[[VAL_61:.*]] = %[[VAL_57]]) -> (f32) { +// CHECK: %[[VAL_62:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_63:.*]] = mulf %[[VAL_62]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_64:.*]] = addf %[[VAL_61]], %[[VAL_63]] : f32 +// CHECK: scf.yield %[[VAL_64]] : f32 +// CHECK: } +// CHECK: %[[VAL_65:.*]] = scf.for %[[VAL_66:.*]] = %[[VAL_67:.*]]#1 to %[[VAL_19]] step %[[VAL_6]] iter_args(%[[VAL_68:.*]] = %[[VAL_69:.*]]) -> (f32) { +// CHECK: %[[VAL_70:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_66]]] : memref +// CHECK: %[[VAL_71:.*]] = addf %[[VAL_68]], %[[VAL_70]] : f32 +// CHECK: scf.yield %[[VAL_71]] : f32 +// CHECK: } +// CHECK: store %[[VAL_72:.*]], %[[VAL_14]][] : memref +// CHECK: %[[VAL_73:.*]] = tensor_load %[[VAL_14]] : memref +// CHECK: return %[[VAL_73]] : tensor +// CHECK: } +func @sum_reduction_inv(%arga: tensor<16xf32>, + %argb: tensor, + %argc: tensor<16xf32>, + %argx: tensor) -> tensor { + // Just for testing. This case would be better expressed + // as two separate reductions kernels. + %0 = linalg.generic #trait_sum_reduction_inv_ss + ins(%arga, %argb, %argc : tensor<16xf32>, tensor, tensor<16xf32>) + init(%argx : tensor) { + ^bb(%a : f32, %b : f32, %c : f32, %x : f32): + %0 = mulf %a, %b : f32 + %1 = addf %0, %c : f32 + %2 = addf %x, %1 : f32 + linalg.yield %2: f32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir --- a/mlir/test/Dialect/Linalg/sparse_2d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir @@ -1012,7 +1012,7 @@ [ "D" ] // x ], iterator_types = ["parallel", "reduction"], - doc = "x(i) += A(i,j) * b(j)" + doc = "x(i) += SUM_j A(i,j) * b(j)" } // CHECK-LABEL: func @matvec( @@ -1032,18 +1032,19 @@ // CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref // CHECK: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_6]] : index // CHECK: %[[VAL_15:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref -// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_6]] { -// CHECK: %[[VAL_17:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_18:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref -// CHECK: %[[VAL_19:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<32xf32> -// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_18]], %[[VAL_19]] : f32 -// CHECK: %[[VAL_21:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32> -// CHECK: %[[VAL_22:.*]] = addf %[[VAL_20]], %[[VAL_21]] : f32 -// CHECK: store %[[VAL_22]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32> +// CHECK: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_6]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f32) { +// CHECK: %[[VAL_20:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_22:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<32xf32> +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f32 +// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_19]] : f32 +// CHECK: scf.yield %[[VAL_24]] : f32 // CHECK: } +// CHECK: store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32> // CHECK: } -// CHECK: %[[VAL_23:.*]] = tensor_load %[[VAL_11]] : memref<16xf32> -// CHECK: return %[[VAL_23]] : tensor<16xf32> +// CHECK: %[[VAL_26:.*]] = tensor_load %[[VAL_11]] : memref<16xf32> +// CHECK: return %[[VAL_26]] : tensor<16xf32> // CHECK: } func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> { %0 = linalg.generic #trait_matvec @@ -1059,20 +1060,20 @@ #trait_sum_reduction = { indexing_maps = [ - affine_map<(i,j) -> (i,j)>, // a - affine_map<(i,j) -> ()> // x (scalar out) + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> ()> // x (scalar out) ], sparse = [ - [ "D","S" ], // a + [ "D", "S" ], // A [ ] // x ], iterator_types = ["reduction", "reduction"], - doc = "x = SUM_ij a(i,j)" + doc = "x += SUM_ij A(i,j)" } // CHECK-LABEL: func @sum_reduction( -// CHECK-SAME: %[[VAL_0:.*0]]: tensor<10x20xf32>, -// CHECK-SAME: %[[VAL_1:.*1]]: tensor) -> tensor { +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { // CHECK: %[[VAL_2:.*]] = constant 999 : index // CHECK: %[[VAL_3:.*]] = constant 10 : index // CHECK: %[[VAL_4:.*]] = constant 0 : index @@ -1085,15 +1086,16 @@ // CHECK: %[[VAL_11:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref // CHECK: %[[VAL_12:.*]] = addi %[[VAL_10]], %[[VAL_5]] : index // CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref -// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_11]] to %[[VAL_13]] step %[[VAL_5]] { -// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]][] : memref -// CHECK: %[[VAL_16:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref -// CHECK: %[[VAL_17:.*]] = addf %[[VAL_15]], %[[VAL_16]] : f32 -// CHECK: store %[[VAL_17]], %[[VAL_9]][] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_9]][] : memref +// CHECK: %[[VAL_15:.*]] = scf.for %[[VAL_16:.*]] = %[[VAL_11]] to %[[VAL_13]] step %[[VAL_5]] iter_args(%[[VAL_17:.*]] = %[[VAL_14]]) -> (f32) { +// CHECK: %[[VAL_18:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_19:.*]] = addf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: scf.yield %[[VAL_19]] : f32 // CHECK: } +// CHECK: store %[[VAL_20:.*]], %[[VAL_9]][] : memref // CHECK: } -// CHECK: %[[VAL_18:.*]] = tensor_load %[[VAL_9]] : memref -// CHECK: return %[[VAL_18]] : tensor +// CHECK: %[[VAL_21:.*]] = tensor_load %[[VAL_9]] : memref +// CHECK: return %[[VAL_21]] : tensor // CHECK: } func @sum_reduction(%arga: tensor<10x20xf32>, %argx: tensor) -> tensor { %0 = linalg.generic #trait_sum_reduction @@ -1170,7 +1172,7 @@ [ "D", "D" ] // X ], iterator_types = ["parallel", "parallel", "reduction"], - doc = "X(i,j) = S(i,j) SUM_k A(i,k) B(k,j)" + doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)" } // CHECK-LABEL: func @sampled_dense_dense( @@ -1234,3 +1236,235 @@ } -> tensor return %0 : tensor } + +#trait_sum_kernel_with_inv = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)>, // C + affine_map<(i,j) -> (i)>, // d + affine_map<(i,j) -> ()>, // e + affine_map<(i,j) -> (i)> // x (out) + ], + sparse = [ + [ "S", "S" ], // A + [ "D", "S" ], // B + [ "D", "S" ], // C + [ "D" ], // d + [ ], // e + [ "D" ] // x + ], + iterator_types = ["parallel", "reduction"], + doc = "x(i) = SUM_j A(i,j) * B(i,j) * d(i) * e + C(i,j)" +} + +// CHECK-LABEL: func @sum_kernel_with_inv( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor, +// CHECK-SAME: %[[VAL_2:.*2]]: tensor, +// CHECK-SAME: %[[VAL_3:.*3]]: tensor, +// CHECK-SAME: %[[VAL_4:.*4]]: tensor, +// CHECK-SAME: %[[VAL_5:.*5]]: tensor) -> tensor { +// CHECK: %[[VAL_6:.*]] = constant 999 : index +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant true +// CHECK: %[[VAL_9:.*]] = constant 1 : index +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_11:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_12:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_13:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_14:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_15:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_16:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_17:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_18:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_19:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_20:.*]] = alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_21:.*]] = dim %[[VAL_3]], %[[VAL_7]] : tensor +// CHECK: %[[VAL_22:.*]] = alloca(%[[VAL_21]]) : memref +// CHECK: %[[VAL_23:.*]] = alloca() : memref +// CHECK: %[[VAL_24:.*]] = dim %[[VAL_5]], %[[VAL_7]] : tensor +// CHECK: %[[VAL_25:.*]] = alloca(%[[VAL_24]]) : memref +// CHECK: %[[VAL_26:.*]] = load %[[VAL_23]][] : memref +// CHECK: %[[VAL_27:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_7]]] : memref +// CHECK: %[[VAL_28:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref +// CHECK: %[[VAL_29:.*]]:2 = scf.while (%[[VAL_30:.*]] = %[[VAL_27]], %[[VAL_31:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_32:.*]] = cmpi "ult", %[[VAL_30]], %[[VAL_28]] : index +// CHECK: scf.condition(%[[VAL_32]]) %[[VAL_30]], %[[VAL_31]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index): +// CHECK: %[[VAL_35:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_33]]] : memref +// CHECK: %[[VAL_36:.*]] = cmpi "eq", %[[VAL_35]], %[[VAL_34]] : index +// CHECK: scf.if %[[VAL_36]] { +// CHECK: %[[VAL_37:.*]] = load %[[VAL_22]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_38:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_33]]] : memref +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_33]], %[[VAL_9]] : index +// CHECK: %[[VAL_40:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_39]]] : memref +// CHECK: %[[VAL_41:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_42:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index +// CHECK: %[[VAL_43:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_42]]] : memref +// CHECK: %[[VAL_44:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_45:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index +// CHECK: %[[VAL_46:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_45]]] : memref +// CHECK: %[[VAL_47:.*]]:4 = scf.while (%[[VAL_48:.*]] = %[[VAL_38]], %[[VAL_49:.*]] = %[[VAL_41]], %[[VAL_50:.*]] = %[[VAL_44]], %[[VAL_51:.*]] = %[[VAL_7]]) : (index, index, index, index) -> (index, index, index, index) { +// CHECK: %[[VAL_52:.*]] = cmpi "ult", %[[VAL_48]], %[[VAL_40]] : index +// CHECK: %[[VAL_53:.*]] = cmpi "ult", %[[VAL_49]], %[[VAL_43]] : index +// CHECK: %[[VAL_54:.*]] = and %[[VAL_52]], %[[VAL_53]] : i1 +// CHECK: %[[VAL_55:.*]] = cmpi "ult", %[[VAL_50]], %[[VAL_46]] : index +// CHECK: %[[VAL_56:.*]] = and %[[VAL_54]], %[[VAL_55]] : i1 +// CHECK: scf.condition(%[[VAL_56]]) %[[VAL_48]], %[[VAL_49]], %[[VAL_50]], %[[VAL_51]] : index, index, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_57:.*]]: index, %[[VAL_58:.*]]: index, %[[VAL_59:.*]]: index, %[[VAL_60:.*]]: index): +// CHECK: %[[VAL_61:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_57]]] : memref +// CHECK: %[[VAL_62:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_63:.*]] = load %[[VAL_19]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_64:.*]] = cmpi "eq", %[[VAL_61]], %[[VAL_60]] : index +// CHECK: %[[VAL_65:.*]] = cmpi "eq", %[[VAL_62]], %[[VAL_60]] : index +// CHECK: %[[VAL_66:.*]] = and %[[VAL_64]], %[[VAL_65]] : i1 +// CHECK: %[[VAL_67:.*]] = cmpi "eq", %[[VAL_63]], %[[VAL_60]] : index +// CHECK: %[[VAL_68:.*]] = and %[[VAL_66]], %[[VAL_67]] : i1 +// CHECK: scf.if %[[VAL_68]] { +// CHECK: %[[VAL_69:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_70:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_57]]] : memref +// CHECK: %[[VAL_71:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_72:.*]] = mulf %[[VAL_70]], %[[VAL_71]] : f32 +// CHECK: %[[VAL_73:.*]] = mulf %[[VAL_72]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_74:.*]] = mulf %[[VAL_73]], %[[VAL_26]] : f32 +// CHECK: %[[VAL_75:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_76:.*]] = addf %[[VAL_74]], %[[VAL_75]] : f32 +// CHECK: %[[VAL_77:.*]] = addf %[[VAL_69]], %[[VAL_76]] : f32 +// CHECK: store %[[VAL_77]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: } else { +// CHECK: %[[VAL_78:.*]] = cmpi "eq", %[[VAL_61]], %[[VAL_60]] : index +// CHECK: %[[VAL_79:.*]] = cmpi "eq", %[[VAL_62]], %[[VAL_60]] : index +// CHECK: %[[VAL_80:.*]] = and %[[VAL_78]], %[[VAL_79]] : i1 +// CHECK: scf.if %[[VAL_80]] { +// CHECK: %[[VAL_81:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_82:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_57]]] : memref +// CHECK: %[[VAL_83:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_84:.*]] = mulf %[[VAL_82]], %[[VAL_83]] : f32 +// CHECK: %[[VAL_85:.*]] = mulf %[[VAL_84]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_86:.*]] = mulf %[[VAL_85]], %[[VAL_26]] : f32 +// CHECK: %[[VAL_87:.*]] = addf %[[VAL_81]], %[[VAL_86]] : f32 +// CHECK: store %[[VAL_87]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: } else { +// CHECK: %[[VAL_88:.*]] = cmpi "eq", %[[VAL_63]], %[[VAL_60]] : index +// CHECK: scf.if %[[VAL_88]] { +// CHECK: %[[VAL_89:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_90:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_91:.*]] = addf %[[VAL_89]], %[[VAL_90]] : f32 +// CHECK: store %[[VAL_91]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_92:.*]] = cmpi "eq", %[[VAL_61]], %[[VAL_60]] : index +// CHECK: %[[VAL_93:.*]] = addi %[[VAL_57]], %[[VAL_9]] : index +// CHECK: %[[VAL_94:.*]] = select %[[VAL_92]], %[[VAL_93]], %[[VAL_57]] : index +// CHECK: %[[VAL_95:.*]] = cmpi "eq", %[[VAL_62]], %[[VAL_60]] : index +// CHECK: %[[VAL_96:.*]] = addi %[[VAL_58]], %[[VAL_9]] : index +// CHECK: %[[VAL_97:.*]] = select %[[VAL_95]], %[[VAL_96]], %[[VAL_58]] : index +// CHECK: %[[VAL_98:.*]] = cmpi "eq", %[[VAL_63]], %[[VAL_60]] : index +// CHECK: %[[VAL_99:.*]] = addi %[[VAL_59]], %[[VAL_9]] : index +// CHECK: %[[VAL_100:.*]] = select %[[VAL_98]], %[[VAL_99]], %[[VAL_59]] : index +// CHECK: %[[VAL_101:.*]] = addi %[[VAL_60]], %[[VAL_9]] : index +// CHECK: scf.yield %[[VAL_94]], %[[VAL_97]], %[[VAL_100]], %[[VAL_101]] : index, index, index, index +// CHECK: } +// CHECK: %[[VAL_102:.*]]:3 = scf.while (%[[VAL_103:.*]] = %[[VAL_104:.*]]#0, %[[VAL_105:.*]] = %[[VAL_104]]#1, %[[VAL_106:.*]] = %[[VAL_104]]#3) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_107:.*]] = cmpi "ult", %[[VAL_103]], %[[VAL_40]] : index +// CHECK: %[[VAL_108:.*]] = cmpi "ult", %[[VAL_105]], %[[VAL_43]] : index +// CHECK: %[[VAL_109:.*]] = and %[[VAL_107]], %[[VAL_108]] : i1 +// CHECK: scf.condition(%[[VAL_109]]) %[[VAL_103]], %[[VAL_105]], %[[VAL_106]] : index, index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_110:.*]]: index, %[[VAL_111:.*]]: index, %[[VAL_112:.*]]: index): +// CHECK: %[[VAL_113:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_110]]] : memref +// CHECK: %[[VAL_114:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_111]]] : memref +// CHECK: %[[VAL_115:.*]] = cmpi "eq", %[[VAL_113]], %[[VAL_112]] : index +// CHECK: %[[VAL_116:.*]] = cmpi "eq", %[[VAL_114]], %[[VAL_112]] : index +// CHECK: %[[VAL_117:.*]] = and %[[VAL_115]], %[[VAL_116]] : i1 +// CHECK: scf.if %[[VAL_117]] { +// CHECK: %[[VAL_118:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_119:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_110]]] : memref +// CHECK: %[[VAL_120:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_111]]] : memref +// CHECK: %[[VAL_121:.*]] = mulf %[[VAL_119]], %[[VAL_120]] : f32 +// CHECK: %[[VAL_122:.*]] = mulf %[[VAL_121]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_123:.*]] = mulf %[[VAL_122]], %[[VAL_26]] : f32 +// CHECK: %[[VAL_124:.*]] = addf %[[VAL_118]], %[[VAL_123]] : f32 +// CHECK: store %[[VAL_124]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_125:.*]] = cmpi "eq", %[[VAL_113]], %[[VAL_112]] : index +// CHECK: %[[VAL_126:.*]] = addi %[[VAL_110]], %[[VAL_9]] : index +// CHECK: %[[VAL_127:.*]] = select %[[VAL_125]], %[[VAL_126]], %[[VAL_110]] : index +// CHECK: %[[VAL_128:.*]] = cmpi "eq", %[[VAL_114]], %[[VAL_112]] : index +// CHECK: %[[VAL_129:.*]] = addi %[[VAL_111]], %[[VAL_9]] : index +// CHECK: %[[VAL_130:.*]] = select %[[VAL_128]], %[[VAL_129]], %[[VAL_111]] : index +// CHECK: %[[VAL_131:.*]] = addi %[[VAL_112]], %[[VAL_9]] : index +// CHECK: scf.yield %[[VAL_127]], %[[VAL_130]], %[[VAL_131]] : index, index, index +// CHECK: } +// CHECK: %[[VAL_132:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_133:.*]] = scf.for %[[VAL_134:.*]] = %[[VAL_135:.*]]#2 to %[[VAL_46]] step %[[VAL_9]] iter_args(%[[VAL_136:.*]] = %[[VAL_132]]) -> (f32) { +// CHECK: %[[VAL_137:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_134]]] : memref +// CHECK: %[[VAL_138:.*]] = addf %[[VAL_136]], %[[VAL_137]] : f32 +// CHECK: scf.yield %[[VAL_138]] : f32 +// CHECK: } +// CHECK: store %[[VAL_139:.*]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: } else { +// CHECK: scf.if %[[VAL_8]] { +// CHECK: %[[VAL_140:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_141:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index +// CHECK: %[[VAL_142:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_141]]] : memref +// CHECK: %[[VAL_143:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: %[[VAL_144:.*]] = scf.for %[[VAL_145:.*]] = %[[VAL_140]] to %[[VAL_142]] step %[[VAL_9]] iter_args(%[[VAL_146:.*]] = %[[VAL_143]]) -> (f32) { +// CHECK: %[[VAL_147:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_145]]] : memref +// CHECK: %[[VAL_148:.*]] = addf %[[VAL_146]], %[[VAL_147]] : f32 +// CHECK: scf.yield %[[VAL_148]] : f32 +// CHECK: } +// CHECK: store %[[VAL_149:.*]], %[[VAL_25]]{{\[}}%[[VAL_34]]] : memref +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_150:.*]] = cmpi "eq", %[[VAL_35]], %[[VAL_34]] : index +// CHECK: %[[VAL_151:.*]] = addi %[[VAL_33]], %[[VAL_9]] : index +// CHECK: %[[VAL_152:.*]] = select %[[VAL_150]], %[[VAL_151]], %[[VAL_33]] : index +// CHECK: %[[VAL_153:.*]] = addi %[[VAL_34]], %[[VAL_9]] : index +// CHECK: scf.yield %[[VAL_152]], %[[VAL_153]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_154:.*]] = %[[VAL_155:.*]]#1 to %[[VAL_24]] step %[[VAL_9]] { +// CHECK: %[[VAL_156:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_154]]] : memref +// CHECK: %[[VAL_157:.*]] = addi %[[VAL_154]], %[[VAL_9]] : index +// CHECK: %[[VAL_158:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_157]]] : memref +// CHECK: %[[VAL_159:.*]] = load %[[VAL_25]]{{\[}}%[[VAL_154]]] : memref +// CHECK: %[[VAL_160:.*]] = scf.for %[[VAL_161:.*]] = %[[VAL_156]] to %[[VAL_158]] step %[[VAL_9]] iter_args(%[[VAL_162:.*]] = %[[VAL_159]]) -> (f32) { +// CHECK: %[[VAL_163:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_161]]] : memref +// CHECK: %[[VAL_164:.*]] = addf %[[VAL_162]], %[[VAL_163]] : f32 +// CHECK: scf.yield %[[VAL_164]] : f32 +// CHECK: } +// CHECK: store %[[VAL_165:.*]], %[[VAL_25]]{{\[}}%[[VAL_154]]] : memref +// CHECK: } +// CHECK: %[[VAL_166:.*]] = tensor_load %[[VAL_25]] : memref +// CHECK: return %[[VAL_166]] : tensor +// CHECK: } +func @sum_kernel_with_inv(%arga: tensor, + %argb: tensor, + %argc: tensor, + %argd: tensor, + %arge: tensor, + %argx: tensor) -> tensor { + %0 = linalg.generic #trait_sum_kernel_with_inv + ins(%arga, %argb, %argc, %argd, %arge : tensor, + tensor, + tensor, + tensor, + tensor) + init(%argx : tensor) { + ^bb(%a : f32, %b : f32, %c : f32, %d : f32, %e : f32, %x : f32): + %0 = mulf %a, %b : f32 + %1 = mulf %0, %d : f32 + %2 = mulf %1, %e : f32 + %3 = addf %2, %c : f32 + %4 = addf %x, %3 : f32 + linalg.yield %4: f32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Linalg/sparse_3d.mlir b/mlir/test/Dialect/Linalg/sparse_3d.mlir --- a/mlir/test/Dialect/Linalg/sparse_3d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_3d.mlir @@ -1160,7 +1160,7 @@ [ "D", "D" ] // A ], iterator_types = ["parallel", "parallel", "reduction", "reduction"], - doc = "A(i,j) = SUM_k,l B(i,k,l) * C(k,j) * D(l,j)" + doc = "A(i,j) += SUM_k,l B(i,k,l) * C(k,j) * D(l,j)" } // CHECK-LABEL: func @kernel_3d( @@ -1223,17 +1223,18 @@ } -> tensor return %0 : tensor } + #trait_sum_reduction = { indexing_maps = [ - affine_map<(i,j,k) -> (i,j,k)>, // a + affine_map<(i,j,k) -> (i,j,k)>, // A affine_map<(i,j,k) -> ()> // x (scalar out) ], sparse = [ - [ "S", "S", "S" ], // a + [ "S", "S", "S" ], // A [ ] // x ], iterator_types = ["reduction", "reduction", "reduction"], - doc = "x = SUM_ijk a(i,j,k)" + doc = "x += SUM_ijk A(i,j,k)" } // CHECK-LABEL: func @sum_reduction( @@ -1260,16 +1261,17 @@ // CHECK: %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref // CHECK: %[[VAL_21:.*]] = addi %[[VAL_19]], %[[VAL_4]] : index // CHECK: %[[VAL_22:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref -// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_4]] { -// CHECK: %[[VAL_24:.*]] = load %[[VAL_12]][] : memref -// CHECK: %[[VAL_25:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref -// CHECK: %[[VAL_26:.*]] = addf %[[VAL_24]], %[[VAL_25]] : f32 -// CHECK: store %[[VAL_26]], %[[VAL_12]][] : memref +// CHECK: %[[VAL_23:.*]] = load %[[VAL_12]][] : memref +// CHECK: %[[VAL_24:.*]] = scf.for %[[VAL_25:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_4]] iter_args(%[[VAL_26:.*]] = %[[VAL_23]]) -> (f32) { +// CHECK: %[[VAL_27:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_28:.*]] = addf %[[VAL_26]], %[[VAL_27]] : f32 +// CHECK: scf.yield %[[VAL_28]] : f32 // CHECK: } +// CHECK: store %[[VAL_29:.*]], %[[VAL_12]][] : memref // CHECK: } // CHECK: } -// CHECK: %[[VAL_27:.*]] = tensor_load %[[VAL_12]] : memref -// CHECK: return %[[VAL_27]] : tensor +// CHECK: %[[VAL_30:.*]] = tensor_load %[[VAL_12]] : memref +// CHECK: return %[[VAL_30]] : tensor // CHECK: } func @sum_reduction(%arga: tensor<10x20x30xf32>, %argx: tensor) -> tensor { %0 = linalg.generic #trait_sum_reduction @@ -1282,21 +1284,80 @@ return %0 : tensor } +#trait_sum_reduction_inv = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i)>, // b + affine_map<(i,j,k) -> ()> // x (scalar out) + ], + sparse = [ + [ "D", "D", "D" ], // A + [ "D" ], // b + [ ] // x + ], + iterator_types = ["reduction", "reduction", "reduction"], + doc = "x += SUM_i A(i,j,k) * b(i)" +} + +// CHECK-LABEL: func @sum_reduction_inv( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor, +// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_3:.*]] = constant 2 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_7:.*]] = dim %[[VAL_0]], %[[VAL_5]] : tensor +// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) : memref +// CHECK: %[[VAL_10:.*]] = dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_11:.*]] = alloca(%[[VAL_10]]) : memref +// CHECK: %[[VAL_12:.*]] = alloca() : memref +// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_10]] step %[[VAL_5]] { +// CHECK: %[[VAL_14:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_13]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_7]] step %[[VAL_5]] { +// CHECK: %[[VAL_16:.*]] = load %[[VAL_12]][] : memref +// CHECK: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_4]] to %[[VAL_8]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f32) { +// CHECK: %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_13]], %[[VAL_15]], %[[VAL_18]]] : memref +// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_22:.*]] = addf %[[VAL_19]], %[[VAL_21]] : f32 +// CHECK: scf.yield %[[VAL_22]] : f32 +// CHECK: } +// CHECK: store %[[VAL_23:.*]], %[[VAL_12]][] : memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_24:.*]] = tensor_load %[[VAL_12]] : memref +// CHECK: return %[[VAL_24]] : tensor +// CHECK: } +func @sum_reduction_inv(%arga: tensor, + %argb: tensor, + %argx: tensor) -> tensor { + %0 = linalg.generic #trait_sum_reduction_inv + ins(%arga, %argb : tensor, tensor) + init(%argx : tensor) { + ^bb(%a : f32, %b : f32, %x : f32): + %0 = mulf %a, %b : f32 + %1 = addf %x, %0 : f32 + linalg.yield %1: f32 + } -> tensor + return %0 : tensor +} + #trait_invariants = { indexing_maps = [ affine_map<(i,j,k) -> (i)>, // a affine_map<(i,j,k) -> (j)>, // b affine_map<(i,j,k) -> (k)>, // c - affine_map<(i,j,k) -> (i,j,k)> // x + affine_map<(i,j,k) -> (i,j,k)> // X (out) ], sparse = [ [ "D" ], // a [ "D" ], // b [ "D" ], // c - [ "D", "D", "D" ] // x + [ "D", "D", "D" ] // X ], iterator_types = ["parallel", "parallel", "parallel"], - doc = "x(i,j,k) = a(i) * b(j) * c(k)" + doc = "X(i,j,k) = a(i) * b(j) * c(k)" } // CHECK-LABEL: func @invariants(