diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -772,6 +772,12 @@ const FrozenRewritePatternList &stage2Patterns, function_ref stage3Lambda = nullptr); +//===----------------------------------------------------------------------===// +// Support for sparse tensor code generation. +//===----------------------------------------------------------------------===// +void populateSparsificationPatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -9,6 +9,7 @@ Interchange.cpp Loops.cpp Promotion.cpp + Sparsification.cpp Tiling.cpp Transforms.cpp Vectorization.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -0,0 +1,829 @@ +//===- Sparsification.cpp - Implementation of linalg sparsification -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering annotated linalg dialect to sparse code. +// +// The concept of letting a compiler generate sparse code automatically was +// pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and +// formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor +// Algebra Compiler (TACO). The implementation in this file closely follows +// the "sparse iteration theory" that forms the foundation of TACO. A rewriting +// rule is applied to each tensor expression in linalg (MLIR's tensor index +// notation) where the sparsity of tensors is indicated with annotation using +// a per-dimension specification of sparse/dense storage together with a +// specification of the order on the dimensions. Subsequently, a topologically +// sorted iteration graph is constructed to ensure that all tensors are visited +// in natural index order. Next, iteration lattices are constructed for the +// tensor expression for every index in topological order. Lastly, these +// iteration lattices drive actual sparse code generation, which consists of +// a tedious but relatively straightforward 1:1 mapping from iteration lattices +// to combinations of for-loops, while-loops, and if-statements. +// +// [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations. +// PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php). +// [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou, +// David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler. +// Proceedings of the ACM on Programming Languages, October 2017. +// [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation. +// PhD thesis, MIT, February, 2020 (tensor-compiler.org). +// +// Implementation detail: We use llvm::SmallVector for vectors with +// variable lengths and std::vector for vectors with fixed lengths. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +using namespace mlir; + +namespace { + +enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; + +/// Tensor expression. Represents a MLIR expression in tensor index notation. +/// For tensors and invariants, e0 denotes the tensor index. For all binary +/// operations, e0 and e1 denote the index of the children tensor expressions. +struct TensorExp { + TensorExp(Kind k, unsigned x, unsigned y) : kind(k), e0(x), e1(y) {} + Kind kind; + unsigned e0; + unsigned e1; +}; + +/// Lattice point. Each lattice point consist of a conjunction of tensor +/// loop indices (encoded in a bitvector) and the index of the corresponding +/// tensor expression. +struct LatPoint { + LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) { + bits.set(b); + } + LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} + llvm::BitVector bits; + unsigned exp; +}; + +/// A class to handle all iteration lattice operations. This class abstracts +/// away from some implementation details of storing iteration lattices and +/// tensor expressions. This allows for fine-tuning performance characteristics +/// independently from the basic algorithm if bottlenecks are identified. +class Merger { +public: + Merger(unsigned t, unsigned l) + : numTensors(t), numLoops(l), isSparse(t, std::vector(l, false)) {} + + /// Adds a tensor expression. Returns its index. + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u) { + unsigned e = tensorExps.size(); + tensorExps.push_back(TensorExp(k, e0, e1)); + return e; + } + + /// Adds an iteration lattice point. Returns its index. + unsigned addLat(unsigned t, unsigned i, unsigned e) { + assert(t < numTensors && i < numLoops); + unsigned p = latPoints.size(); + latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); + return p; + } + + /// Adds a new, initially empty, set. Returns its index. + unsigned addSet() { + unsigned s = latSets.size(); + latSets.emplace_back(SmallVector()); + return s; + } + + /// Computes a single conjunction of two lattice points by taking the "union" + /// of loop indices (effectively constucting a larger "intersection" of those + /// indices) with a newly constructed tensor (sub)expression of given kind. + /// Returns the index of the new lattice point. + unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) { + unsigned p = latPoints.size(); + llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); + nb |= latPoints[p1].bits; + unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); + latPoints.push_back(LatPoint(nb, e)); + return p; + } + + /// Conjunctive merge of L1 and L2 is conjunction of cartesian product. + /// Returns the index of the new set. + unsigned takeConj(Kind kind, unsigned s0, unsigned s1) { + unsigned s = addSet(); + for (unsigned p0 : latSets[s0]) + for (unsigned p1 : latSets[s1]) + latSets[s].push_back(conjLatPoint(kind, p0, p1)); + return s; + } + + /// Disjunctive merge of L0 and L1 is (L0 /\_op L1, L0, L1). + /// Returns the index of the new set. + unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) { + unsigned s = takeConj(kind, s0, s1); + for (unsigned p : latSets[s0]) + latSets[s].push_back(p); + for (unsigned p : latSets[s1]) + latSets[s].push_back(p); + return s; + } + + /// Optimizes the iteration lattice points in the given set. Removes lattice + /// points Z + unsigned optimize(unsigned s0) { + unsigned s = addSet(); + assert(latSets[s0].size() != 0); + unsigned p0 = latSets[s0][0]; + llvm::BitVector tmp; + llvm::BitVector last = latPoints[p0].bits; + for (unsigned p1 : latSets[s0]) { + if (p0 != p1) { + tmp = latPoints[p1].bits; + tmp ^= latPoints[p0].bits; + if (hasAnyOf(tmp, false)) + continue; // dense exhausted? + tmp = latPoints[p1].bits; + tmp ^= last; + if (tmp.count() == 0) + continue; // direct dup? + assert(latGT(p0, p1)); + last = latPoints[p1].bits; + } + latSets[s].push_back(p1); + } +#if 0 + llvm::dbgs() << "BEFORE:\n"; + dumpSet(s0); + llvm::dbgs() << "AFTER:\n"; + dumpSet(s); +#endif + return s; + } + + // Returns true if Li > Lj. + bool latGT(unsigned i, unsigned j) const { + const llvm::BitVector &bitsi = latPoints[i].bits; + const llvm::BitVector &bitsj = latPoints[j].bits; + assert(bitsi.size() == bitsj.size()); + if (bitsi.count() > bitsj.count()) { + for (unsigned b = 0, be = bitsj.size(); b < be; b++) + if (bitsj[b] && !bitsi[b]) + return false; + return true; + } + return false; + } + + // Bit translation. + unsigned tensor(unsigned b) const { return b % numTensors; } + unsigned index(unsigned b) const { return b / numTensors; } + + // Returns true if bit corresponds to sparse access. + bool isSparseBit(unsigned b) const { + return isSparseAccess(tensor(b), index(b)); + } + + // Returns true if tensor access at given index is sparse. + bool isSparseAccess(unsigned t, unsigned i) const { + assert(t < numTensors && i < numLoops); + return isSparse[t][i]; + } + + // Returns true if any set bit corresponds to sparse/dense access. + bool hasAnyOf(const llvm::BitVector &bits, bool sparse) const { + for (unsigned b = 0, be = bits.size(); b < be; b++) + if (bits[b] && isSparseBit(b) == sparse) + return true; + return false; + } + + // Getters. + std::vector> &sparse() { return isSparse; } + TensorExp &exp(unsigned e) { return tensorExps[e]; } + LatPoint &lat(unsigned l) { return latPoints[l]; } + SmallVector &set(unsigned s) { return latSets[s]; } + +private: + const unsigned numTensors; + const unsigned numLoops; + + std::vector> isSparse; + llvm::SmallVector tensorExps; + llvm::SmallVector latPoints; + llvm::SmallVector, 8> latSets; +}; + +// Code generation. +struct CodeGen { + CodeGen(unsigned numTensors, unsigned numLoops) + : loops(numLoops), buffers(numTensors), + pointers(numTensors, std::vector(numLoops)), + indices(numTensors, std::vector(numLoops)), + sizes(numTensors, std::vector(numLoops)), + highs(numTensors, std::vector(numLoops)), + pidxs(numTensors, std::vector(numLoops)), + idxs(numTensors, std::vector(numLoops)) {} + // Universal dense indices (by index). + std::vector loops; + // Buffers for storing dense and sparse numerical values (by tensor). + std::vector buffers; + // Sparse storage schemes (1-D): pointers and indices (by tensor and index). + std::vector> pointers; + std::vector> indices; + // Dimension information (by tensor and index). + std::vector> sizes; + // Sparse iteration information (by tensor and index). + std::vector> highs; + std::vector> pidxs; + std::vector> idxs; +}; + +} // namespace + +/// Helper method to inspect sparse annotations in the linalg operation. +/// Fills the per-dimension sparsity information for all tensors. +static void findSparseAnnotations(linalg::GenericOp op, + std::vector> &isSparse) { + unsigned numTensors = op.getNumInputsAndOutputs(); + ArrayAttr sparseAttr = op.template getAttrOfType("sparse"); + for (unsigned t = 0; t < numTensors; t++) { + auto map = (op.indexing_maps()[t]).cast().getValue(); + auto dimAttr = sparseAttr[t].cast(); + // For each tensor, we accept a per-dimension Sparse or Dense annotation. + // This is translated to the loop index that indexes that dimension. + unsigned rank = op.getShapedType(t).getRank(); + for (unsigned d = 0; d < rank; d++) { + unsigned i = map.getResult(d).cast().getPosition(); + auto annotation = dimAttr[d].dyn_cast_or_null(); + if (annotation.getValue() == "S") + isSparse[t][i] = true; + else + assert(annotation.getValue() == "D"); + } + } +} + +/// A DFS helper to compute a topological sort. Note that recursion is +/// bounded by the number of implicit loops, which is always small. +/// Returns false when a cycle is detected. +static bool topSortDFS(unsigned i, std::vector &visit, + std::vector &topSort, + std::vector> &adjM) { + if (visit[i] != 0) + return visit[i] != 1; // 1 denotes cycle! + visit[i] = 1; + for (unsigned j = 0, e = visit.size(); j < e; j++) + if (adjM[i][j]) + if (!topSortDFS(j, visit, topSort, adjM)) + return false; + visit[i] = 2; + topSort.push_back(i); + return true; +} + +/// Computes a topologically sorted iteration graph for the linalg operation. +/// Ensures all tensors are visited in natural index order. This is essential +/// for sparse storage formats since these only support access along fixed +/// dimensions. Even for dense storage formats, however, the natural index +/// order yields innermost unit-stride access with better spatial locality. +static bool computeIterationGraph(linalg::GenericOp op, + std::vector &topSort) { + // Set up an n x n from/to adjacency matrix of the iteration graph + // for the implicit loop indices i_0 .. i_n-1. + unsigned n = op.getNumLoops(); + std::vector> adjM(n, std::vector(n, false)); + + // Iterate over the indexing maps of every tensor in the tensor expression. + for (auto imap : llvm::enumerate(op.indexing_maps())) { + auto map = imap.value().template cast().getValue(); + assert(map.getNumDims() == n); + // At the moment, we take the index variables in the tensor access + // expression in the order in which they appear (conceptually a + // "row-major" layout of every tensor). So, a tensor access A_ijk + // forces the ordering i < j < k on the loop indices. + // TODO: support affine map to define alternative dimension orders. + for (unsigned d = 1, e = map.getNumResults(); d < e; d++) { + unsigned f = map.getResult(d - 1).cast().getPosition(); + unsigned t = map.getResult(d).cast().getPosition(); + adjM[f][t] = true; + } + } + + // Topologically sort the iteration graph to determine loop order. + // Report failure for a cyclic iteration graph. + topSort.reserve(n); + std::vector visit(n, 0); + for (unsigned i = 0; i < n; i++) + if (visit[i] == 0) + if (!topSortDFS(i, visit, topSort, adjM)) + return false; // cycle! + std::reverse(std::begin(topSort), std::end(topSort)); + return true; +} + +/// Traverses the SSA tree (possibly a DAG) to build a tensor expression. +/// This simplifies constructing (sub)expressions during iteration lattice +/// building (compared to using the SSA representation everywhere). +static Optional buildTensorExp(Merger &merger, linalg::GenericOp op, + Value val) { + Operation *def = val.getDefiningOp(); + if (auto arg = val.dyn_cast()) { + unsigned argN = arg.getArgNumber(); + if (arg.getOwner()->getParentOp() == op) { + // Any parameter of the generic op is considered a tensor, + // indexed by the implicit loop bounds. + auto map = (op.indexing_maps()[argN]).cast().getValue(); + if (map.isProjectedPermutation()) + return merger.addExp(Kind::kTensor, argN); + } else { + // Any parameter of a higher op is invariant in the tensor expression. + return merger.addExp(Kind::kInvariant, argN); + } + } else if (def->getNumOperands() == 2) { + // Construct binary operations if subexpressions could be built. + auto x = buildTensorExp(merger, op, def->getOperand(0)); + auto y = buildTensorExp(merger, op, def->getOperand(1)); + if (x.hasValue() && y.hasValue()) { + unsigned e0 = x.getValue(); + unsigned e1 = y.getValue(); + if (isa(def)) + return merger.addExp(Kind::kMulF, e0, e1); + if (isa(def)) + return merger.addExp(Kind::kMulI, e0, e1); + if (isa(def)) + return merger.addExp(Kind::kAddF, e0, e1); + if (isa(def)) + return merger.addExp(Kind::kAddI, e0, e1); + } + } + // Cannot build (yet). + return None; +} + +/// Builds the iteration lattices in a bottom-up traversal given the remaining +/// tensor (sub)expression and the next loop index in the iteration graph. +static unsigned buildLattices(Merger &merger, linalg::GenericOp op, + unsigned exp, unsigned idx) { + Kind kind = merger.exp(exp).kind; + if (kind == Kind::kTensor || kind == Kind::kInvariant) { + unsigned s = merger.addSet(); + // Determine if current loop index idx is used in the tensor expression. + if (kind == Kind::kTensor) { + unsigned tensor = merger.exp(exp).e0; + auto map = (op.indexing_maps()[tensor]).cast().getValue(); + for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { + if (map.getResult(i).isFunctionOfDim(idx)) { + merger.set(s).push_back(merger.addLat(tensor, idx, exp)); + return s; + } + } + } + // Otherwise, set to universal dense index (using output tensor). + unsigned lhs = op.getNumInputsAndOutputs() - 1; + merger.set(s).push_back(merger.addLat(lhs, idx, exp)); + return s; + } + unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); + unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx); + switch (kind) { + case Kind::kTensor: + case Kind::kInvariant: + llvm_unreachable("handled above"); + case Kind::kMulF: + case Kind::kMulI: + return merger.takeConj(kind, s0, s1); + case Kind::kAddF: + case Kind::kAddI: + return merger.takeDisj(kind, s0, s1); + } +} + +/// Local bufferization of all dense and sparse data structures. +/// This code enables testing the first prototype sparse compiler. +// TODO: replace this with a proliferated bufferization strategy +void genBuffers(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, + linalg::GenericOp op) { + Location loc = op.getLoc(); + unsigned numTensors = op.getNumInputsAndOutputs(); + unsigned numInputs = op.getNumInputs(); + assert(numTensors == numInputs + 1); + Type indexType = rewriter.getIndexType(); + + // For now, set all unknown dimensions to 999. + // TODO: compute these values (using sparsity or by reading tensor) + Value unknown = rewriter.create(loc, 999); + + // For every tensor, find lower and upper bound on dimensions, set the + // same bounds on loop indices, and allocate dense or sparse buffer(s). + SmallVector args; + for (unsigned t = 0; t < numTensors; t++) { + auto tensorType = op.getShapedType(t); + auto shape = tensorType.getShape(); + auto map = (op.indexing_maps()[t]).cast().getValue(); + // Scan all dimensions of current tensor. + bool allDense = true; + args.clear(); + for (unsigned d = 0, rank = shape.size(); d < rank; d++) { + unsigned i = map.getResult(d).cast().getPosition(); + // Handle sparse storage schemes. + if (merger.isSparseAccess(t, i)) { + allDense = false; + auto dynTp = MemRefType::get({ShapedType::kDynamicSize}, indexType); + codegen.pointers[t][i] = + rewriter.create(op.getLoc(), dynTp, unknown); + codegen.indices[t][i] = + rewriter.create(op.getLoc(), dynTp, unknown); + } + // Find lower and upper bound in current dimension. + Value up; + if (shape[d] == TensorType::kDynamicSize) { + // For the output tensor, we may need to infer the upper bound. + // For all others, we look at the incoming argument. + if (t == numInputs && !op.getNumInitTensors()) { + for (unsigned t2 = 0; t2 < t; t2++) + if (codegen.sizes[t2][i]) { + up = codegen.sizes[t2][i]; + break; + } + } else { + Value arg = t < numInputs ? op.getInput(t) : op.getInitTensor(0); + up = rewriter.create(op.getLoc(), arg, d); + } + args.push_back(up); + } else { + up = rewriter.create(loc, shape[d]); + } + codegen.sizes[t][i] = up; + } + // Allocate dense or sparse buffer for numerical values. + if (allDense) { + auto denseTp = MemRefType::get(shape, tensorType.getElementType()); + codegen.buffers[t] = rewriter.create(op.getLoc(), denseTp, args); + } else { + auto sparseTp = MemRefType::get({ShapedType::kDynamicSize}, + tensorType.getElementType()); + codegen.buffers[t] = + rewriter.create(op.getLoc(), sparseTp, unknown); + } + } +} + +/// Generates a load on a dense or sparse tensor. +static Value genTensorLoad(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + unsigned tensor) { + SmallVector args; + auto map = (op.indexing_maps()[tensor]).cast().getValue(); + bool sparse = false; + for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { + unsigned idx = map.getResult(i).cast().getPosition(); + args.push_back(codegen.loops[idx]); // universal dense index + if (sparse || merger.isSparseAccess(tensor, idx)) { + sparse = true; + args.clear(); + args.push_back(codegen.pidxs[tensor][idx]); // position index + } + } + return rewriter.create(op.getLoc(), codegen.buffers[tensor], args); +} + +/// Generates a store on a dense tensor. +static void genTensorStore(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + unsigned tensor, Value rhs) { + SmallVector args; + auto map = (op.indexing_maps()[tensor]).cast().getValue(); + for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { + unsigned idx = map.getResult(i).cast().getPosition(); + args.push_back(codegen.loops[idx]); // universal dense index + } + rewriter.create(op.getLoc(), rhs, codegen.buffers[tensor], args); +} + +/// Recursively generates tensor expression. +static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, + linalg::GenericOp op, unsigned exp) { + if (merger.exp(exp).kind == Kind::kTensor) { + return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0); + } else if (merger.exp(exp).kind == Kind::kInvariant) { + return op.getParentRegion()->front().getArgument(merger.exp(exp).e0); + } + Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); + Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); + switch (merger.exp(exp).kind) { + case Kind::kTensor: + case Kind::kInvariant: + llvm_unreachable("handled above"); + case Kind::kMulF: + return rewriter.create(op.getLoc(), v0, v1); + case Kind::kMulI: + return rewriter.create(op.getLoc(), v0, v1); + case Kind::kAddF: + return rewriter.create(op.getLoc(), v0, v1); + case Kind::kAddI: + return rewriter.create(op.getLoc(), v0, v1); + } +} + +/// Recursively generates code while computing iteration lattices in order +/// to manage the complexity of implementing co-iteration over unions +/// and intersections of sparse iterations spaces. +static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, + linalg::GenericOp op, std::vector &topSort, + unsigned exp, unsigned at) { + // At each leaf, assign remaining tensor (sub)expression to output tensor. + if (at == topSort.size()) { + unsigned lhs = op.getNumInputsAndOutputs() - 1; + Value rhs = genExp(merger, codegen, rewriter, op, exp); + genTensorStore(merger, codegen, rewriter, op, lhs, rhs); + return; + } + + // Construct iteration lattices for current loop index, with L0 at top. + unsigned idx = topSort[at]; + unsigned lts = merger.optimize(buildLattices(merger, op, exp, idx)); + assert(merger.set(lts).size() != 0); + unsigned l0 = merger.set(lts)[0]; + LatPoint lat0 = merger.lat(l0); + + // Initialize sparse positions. + Location loc = op.getLoc(); + Type indexType = rewriter.getIndexType(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + bool needsUniv = false; + for (unsigned b = 0, be = lat0.bits.size(); b < be; b++) + if (lat0.bits[b]) { + unsigned tensor = merger.tensor(b); + assert(idx == merger.index(b)); + if (merger.isSparseBit(b)) { + unsigned pat = at; + for (; pat != 0; pat--) + if (codegen.pidxs[tensor][topSort[pat - 1]]) + break; + Value ptr = codegen.pointers[tensor][idx]; + Value p = (pat == 0) ? zero : codegen.pidxs[tensor][topSort[pat - 1]]; + codegen.pidxs[tensor][idx] = rewriter.create(loc, ptr, p); + p = rewriter.create(loc, p, one); + codegen.highs[tensor][idx] = rewriter.create(loc, ptr, p); + } else if (merger.set(lts).size() > 1) { + needsUniv = true; + } + } + + // Initialize the universal dense index. + codegen.loops[idx] = zero; + + // Emit a for-loop or a while-loop for every lattice point L0 >= Li. + llvm::BitVector tmp; + for (unsigned li : merger.set(lts)) { + // Optimize the loop indices of Li with some simple rules: + // + convert multiple dense to single dense, + // + convert singleton sparse/dense to sparse/random access. + LatPoint lati = merger.lat(li); + tmp = lati.bits; + if (merger.hasAnyOf(tmp, false)) { + bool reset = merger.set(lts).size() == 1 && merger.hasAnyOf(tmp, true); + bool first = true; + for (unsigned b = 0, be = tmp.size(); b < be; b++) + if (tmp[b] && !merger.isSparseBit(b)) { + if (reset || !first) + tmp.reset(b); + first = false; + } + } + + // Then emit either a for-loop or a while-loop for Li. + SmallVector types; + SmallVector operands; + scf::ForOp forOp; + scf::WhileOp whileOp; + if (tmp.count() == 1) { + // Emit a for-loop. + unsigned fb = tmp.find_first(); + unsigned tensor = merger.tensor(fb); + assert(idx == merger.index(fb)); + if (merger.isSparseBit(fb)) { + forOp = rewriter.create(loc, codegen.pidxs[tensor][idx], + codegen.highs[tensor][idx], one); + codegen.pidxs[tensor][idx] = forOp.getInductionVar(); + } else { + forOp = rewriter.create(loc, codegen.loops[idx], + codegen.sizes[tensor][idx], one); + codegen.loops[idx] = forOp.getInductionVar(); + } + rewriter.setInsertionPointToStart(forOp.getBody()); + } else { + // Incoming arguments for while-loop. + for (unsigned b = 0, be = tmp.size(); b < be; b++) + if (tmp[b] && merger.isSparseBit(b)) { + unsigned tensor = merger.tensor(b); + types.push_back(indexType); + operands.push_back(codegen.pidxs[tensor][idx]); + } + if (needsUniv) { + types.push_back(indexType); + operands.push_back(codegen.loops[idx]); + } + // Emit a while-loop. + whileOp = rewriter.create(loc, types, operands); + Block *before = rewriter.createBlock(&whileOp.before(), {}, types); + Block *after = rewriter.createBlock(&whileOp.after(), {}, types); + + // Build the "before" region. + rewriter.setInsertionPointToStart(&whileOp.before().front()); + Value cond; + unsigned o = 0; + for (unsigned b = 0, be = tmp.size(); b < be; b++) + if (tmp[b] && merger.isSparseBit(b)) { + unsigned tensor = merger.tensor(b); + Value op1 = before->getArgument(o); + Value op2 = codegen.highs[tensor][idx]; + Value opc = + rewriter.create(loc, CmpIPredicate::ult, op1, op2); + cond = cond ? rewriter.create(loc, cond, opc) : opc; + operands[o] = codegen.pidxs[tensor][idx] = after->getArgument(o); + o++; + } + if (needsUniv) { + operands[o] = codegen.loops[idx] = after->getArgument(o); + o++; + } + assert(o == operands.size()); + rewriter.create(loc, cond, before->getArguments()); + // Then continue with the "after" region. + rewriter.setInsertionPointToStart(&whileOp.after().front()); + } + + // Initialize sparse indices. + Value min; + for (unsigned b = 0, be = tmp.size(); b < be; b++) + if (tmp[b] && merger.isSparseBit(b)) { + unsigned tensor = merger.tensor(b); + Value ld = rewriter.create(loc, codegen.indices[tensor][idx], + codegen.pidxs[tensor][idx]); + codegen.idxs[tensor][idx] = ld; + if (!needsUniv) { + if (min) { + Value cmp = + rewriter.create(loc, CmpIPredicate::ult, ld, min); + min = rewriter.create(loc, cmp, ld, min); + } else { + min = ld; + } + } + } + + // Merge dense universal index over minimum. + if (min) + codegen.loops[idx] = min; + + // Initialize dense positions. + for (unsigned b = 0, be = lati.bits.size(); b < be; b++) { + unsigned tensor = merger.tensor(b); + if (lati.bits[b] && !merger.isSparseBit(b) && + codegen.sizes[tensor][idx]) { + unsigned pat = at; + for (; pat != 0; pat--) + if (codegen.pidxs[tensor][topSort[pat - 1]]) + break; + Value p = (pat == 0) ? zero : codegen.pidxs[tensor][topSort[pat - 1]]; + Value m = rewriter.create(loc, codegen.sizes[tensor][idx], p); + codegen.pidxs[tensor][idx] = + rewriter.create(loc, m, codegen.loops[idx]); + } + } + + // Visit all lattices points with Li >= Lj. + scf::IfOp ifOp; + for (unsigned lj : merger.set(lts)) + if (li == lj || merger.latGT(li, lj)) { + LatPoint latj = merger.lat(lj); + tmp = latj.bits; + tmp ^= lati.bits; + if (merger.hasAnyOf(tmp, false)) + continue; // dense exhausted within if/else + // Emit if-statement. + if (whileOp) { + if (ifOp) + rewriter.setInsertionPointToStart(&ifOp.elseRegion().front()); + Value cond; + for (unsigned b = 0, be = latj.bits.size(); b < be; b++) + if (latj.bits[b]) { + unsigned tensor = merger.tensor(b); + Value clause; + if (merger.isSparseBit(b)) { + Value op1 = codegen.idxs[tensor][idx]; + Value op2 = codegen.loops[idx]; + clause = + rewriter.create(loc, CmpIPredicate::eq, op1, op2); + } else { + clause = rewriter.create(loc, 1, 1); // true + } + cond = cond ? rewriter.create(loc, cond, clause) : clause; + } + ifOp = rewriter.create(loc, cond, true); + rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); + } + // Recurse into body of each if-branch. + genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1); + } + + // Wrap-up induction and restore insertion point. + if (forOp) { + rewriter.setInsertionPointAfter(forOp); + needsUniv = false; + } else { + rewriter.setInsertionPointToEnd(&whileOp.after().front()); + unsigned o = 0; + for (unsigned b = 0, be = lati.bits.size(); b < be; b++) + if (lati.bits[b] && merger.isSparseBit(b)) { + unsigned tensor = merger.tensor(b); + Value op1 = codegen.idxs[tensor][idx]; + Value op2 = codegen.loops[idx]; + Value cmp = rewriter.create(loc, CmpIPredicate::eq, op1, op2); + Value add = rewriter.create(loc, operands[o], one); + operands[o] = rewriter.create(loc, cmp, add, operands[o]); + codegen.pidxs[tensor][idx] = whileOp.results()[o++]; + } + if (needsUniv) { + operands[o] = rewriter.create(loc, operands[o], one); + codegen.loops[idx] = whileOp.results()[o++]; + } + assert(o == operands.size()); + rewriter.create(loc, operands); + rewriter.setInsertionPointAfter(whileOp); + } + } +} + +namespace { + +/// Sparse rewriting rule for generic Lingalg operation. +struct GenericOpSparsifier : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + unsigned numTensors = op.getNumInputsAndOutputs(); + unsigned numLoops = op.iterator_types().getValue().size(); + Merger merger(numTensors, numLoops); + + // Detects sparse annotations and translate the per-dimension sparsity + // information for all tensors to loop indices in the kernel. + if (!op.getAttr("sparse")) + return failure(); + findSparseAnnotations(op, merger.sparse()); + + // Accept only single, dense result. + if (op.getNumOutputs() != 1 || + std::any_of(merger.sparse().back().begin(), + merger.sparse().back().end(), [](bool b) { return b; })) + return failure(); + + // Computes a topologically sorted iteration graph to ensure + // tensors are visited in natural index order. Fails on cycles. + // This assumes that higher-level passes have already put the + // tensors in each tensor expression in a feasible order. + // TODO: try again without *dense* constraints on failure or + // even try to insert sparse reorderings to resolve cycles + std::vector topSort; + if (!computeIterationGraph(op, topSort)) + return failure(); + + // Finds the terminating yield statement and builds the tensor + // expression for the Linalg operation in SSA form. + auto ®ion = op.region(); + if (!llvm::hasSingleElement(region)) + return failure(); // single block only + Operation *yield = region.front().getTerminator(); + Optional exp = buildTensorExp(merger, op, yield->getOperand(0)); + if (!exp.hasValue()) + return failure(); // build failure + + // Recursively generates code. + CodeGen codegen(numTensors, numLoops); + genBuffers(merger, codegen, rewriter, op); + genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); + Value result = + rewriter.create(op.getLoc(), codegen.buffers.back()); + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +/// Populates the given patterns list with rewriting rules required for +/// the sparsification of linear algebra operations. +void mlir::linalg::populateSparsificationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert(context); +} diff --git a/mlir/test/Dialect/Linalg/sparse_1d.mlir b/mlir/test/Dialect/Linalg/sparse_1d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/sparse_1d.mlir @@ -0,0 +1,637 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -test-sparsification | FileCheck %s + +#trait_d = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "D" ], // a + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b" +} + +// CHECK-LABEL: func @add_d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: f32) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_6:.*]] = alloc() : memref<32xf32> +// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] { +// CHECK: %[[VAL_8:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_7]]] : memref<32xf32> +// CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[VAL_1]] : f32 +// CHECK: store %[[VAL_9]], %[[VAL_6]]{{\[}}%[[VAL_7]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_10:.*]] = tensor_load %[[VAL_6]] : memref<32xf32> +// CHECK: return %[[VAL_10]] : tensor<32xf32> +// CHECK: } +func @add_d(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> { + %0 = linalg.generic #trait_d + ins(%arga: tensor<32xf32>) { + ^bb(%a: f32): + %0 = addf %a, %argb : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// CHECK-LABEL: func @mul_d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: f32) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_6:.*]] = alloc() : memref<32xf32> +// CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] { +// CHECK: %[[VAL_8:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_7]]] : memref<32xf32> +// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_1]] : f32 +// CHECK: store %[[VAL_9]], %[[VAL_6]]{{\[}}%[[VAL_7]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_10:.*]] = tensor_load %[[VAL_6]] : memref<32xf32> +// CHECK: return %[[VAL_10]] : tensor<32xf32> +// CHECK: } +func @mul_d(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> { + %0 = linalg.generic #trait_d + ins(%arga: tensor<32xf32>) { + ^bb(%a: f32): + %0 = mulf %a, %argb : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +#trait_s = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "S" ], // a + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b" +} + +// CHECK-LABEL: func @add_s( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: f32) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_11:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_12:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_13:.*]]:2 = scf.while (%[[VAL_14:.*]] = %[[VAL_11]], %[[VAL_15:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_16:.*]] = cmpi "ult", %[[VAL_14]], %[[VAL_12]] : index +// CHECK: scf.condition(%[[VAL_16]]) %[[VAL_14]], %[[VAL_15]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index): +// CHECK: %[[VAL_19:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_20:.*]] = cmpi "eq", %[[VAL_19]], %[[VAL_18]] : index +// CHECK: scf.if %[[VAL_20]] { +// CHECK: %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_22:.*]] = addf %[[VAL_21]], %[[VAL_1]] : f32 +// CHECK: store %[[VAL_22]], %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<32xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: store %[[VAL_1]], %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<32xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_23:.*]] = cmpi "eq", %[[VAL_19]], %[[VAL_18]] : index +// CHECK: %[[VAL_24:.*]] = addi %[[VAL_17]], %[[VAL_5]] : index +// CHECK: %[[VAL_25:.*]] = select %[[VAL_23]], %[[VAL_24]], %[[VAL_17]] : index +// CHECK: %[[VAL_26:.*]] = addi %[[VAL_18]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_25]], %[[VAL_26]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_28:.*]]#1 to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: store %[[VAL_1]], %[[VAL_10]]{{\[}}%[[VAL_27]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_29:.*]] = tensor_load %[[VAL_10]] : memref<32xf32> +// CHECK: return %[[VAL_29]] : tensor<32xf32> +// CHECK: } +func @add_s(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> { + %0 = linalg.generic #trait_s + ins(%arga: tensor<32xf32>) { + ^bb(%a: f32): + %0 = addf %a, %argb : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// CHECK-LABEL: func @repeated_add_s( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_1:.*]] = constant 999 : index +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = alloc(%[[VAL_1]]) : memref +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_1]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_1]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_8:.*]] = load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_9:.*]] = load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] { +// CHECK: %[[VAL_11:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_14:.*]] = addf %[[VAL_12]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_15:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_17:.*]] = addf %[[VAL_15]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_18:.*]] = addf %[[VAL_14]], %[[VAL_17]] : f32 +// CHECK: store %[[VAL_18]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_19:.*]] = tensor_load %[[VAL_7]] : memref<32xf32> +// CHECK: return %[[VAL_19]] : tensor<32xf32> +// CHECK: } +func @repeated_add_s(%arga: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_s + ins(%arga: tensor<32xf32>) { + ^bb(%a: f32): + %0 = addf %a, %a : f32 // same tensor + %1 = addf %a, %a : f32 // should yield + %2 = addf %0, %1 : f32 // one guard + linalg.yield %2 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// CHECK-LABEL: func @mul_s( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: f32) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_9:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_1]] : f32 +// CHECK: store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = tensor_load %[[VAL_8]] : memref<32xf32> +// CHECK: return %[[VAL_15]] : tensor<32xf32> +// CHECK: } +func @mul_s(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> { + %0 = linalg.generic #trait_s + ins(%arga: tensor<32xf32>) { + ^bb(%a: f32): + %0 = mulf %a, %argb : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +#trait_dd = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "D" ], // a + [ "D" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b(i)" +} + +// CHECK-LABEL: func @add_dd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_6:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32xf32> +// CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] { +// CHECK: %[[VAL_9:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref<32xf32> +// CHECK: %[[VAL_10:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_8]]] : memref<32xf32> +// CHECK: %[[VAL_11:.*]] = addf %[[VAL_9]], %[[VAL_10]] : f32 +// CHECK: store %[[VAL_11]], %[[VAL_7]]{{\[}}%[[VAL_8]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_12:.*]] = tensor_load %[[VAL_7]] : memref<32xf32> +// CHECK: return %[[VAL_12]] : tensor<32xf32> +// CHECK: } +func @add_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_dd + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// CHECK-LABEL: func @mul_dd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_6:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32xf32> +// CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] { +// CHECK: %[[VAL_9:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_8]]] : memref<32xf32> +// CHECK: %[[VAL_10:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_8]]] : memref<32xf32> +// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_9]], %[[VAL_10]] : f32 +// CHECK: store %[[VAL_11]], %[[VAL_7]]{{\[}}%[[VAL_8]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_12:.*]] = tensor_load %[[VAL_7]] : memref<32xf32> +// CHECK: return %[[VAL_12]] : tensor<32xf32> +// CHECK: } +func @mul_dd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_dd + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +#trait_ds = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "D" ], // a + [ "S" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b(i)" +} + +// CHECK-LABEL: func @add_ds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_17:.*]] = cmpi "ult", %[[VAL_15]], %[[VAL_13]] : index +// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index): +// CHECK: %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_21:.*]] = cmpi "eq", %[[VAL_20]], %[[VAL_19]] : index +// CHECK: scf.if %[[VAL_21]] { +// CHECK: %[[VAL_22:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: %[[VAL_23:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_24:.*]] = addf %[[VAL_22]], %[[VAL_23]] : f32 +// CHECK: store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: %[[VAL_25:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_26:.*]] = cmpi "eq", %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_5]] : index +// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_32:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_30]]] : memref<32xf32> +// CHECK: store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_33:.*]] = tensor_load %[[VAL_11]] : memref<32xf32> +// CHECK: return %[[VAL_33]] : tensor<32xf32> +// CHECK: } +func @add_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_ds + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// CHECK-LABEL: func @mul_ds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_10:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_11:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] { +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_13]]] : memref<32xf32> +// CHECK: %[[VAL_15:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_17:.*]] = tensor_load %[[VAL_9]] : memref<32xf32> +// CHECK: return %[[VAL_17]] : tensor<32xf32> +// CHECK: } +func @mul_ds(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_ds + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +#trait_sd = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "S" ], // a + [ "D" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b(i)" +} + +// CHECK-LABEL: func @add_sd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_17:.*]] = cmpi "ult", %[[VAL_15]], %[[VAL_13]] : index +// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index): +// CHECK: %[[VAL_20:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_21:.*]] = cmpi "eq", %[[VAL_20]], %[[VAL_19]] : index +// CHECK: scf.if %[[VAL_21]] { +// CHECK: %[[VAL_22:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_23:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: %[[VAL_24:.*]] = addf %[[VAL_22]], %[[VAL_23]] : f32 +// CHECK: store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: %[[VAL_25:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_26:.*]] = cmpi "eq", %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_5]] : index +// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_32:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xf32> +// CHECK: store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_33:.*]] = tensor_load %[[VAL_11]] : memref<32xf32> +// CHECK: return %[[VAL_33]] : tensor<32xf32> +// CHECK: } +func @add_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_sd + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// CHECK-LABEL: func @mul_sd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_9:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_10:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_11:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] { +// CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<32xf32> +// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_17:.*]] = tensor_load %[[VAL_9]] : memref<32xf32> +// CHECK: return %[[VAL_17]] : tensor<32xf32> +// CHECK: } +func @mul_sd(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_sd + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +#trait_ss = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "S" ], // a + [ "S" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b(i)" +} + +// CHECK-LABEL: func @add_ss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_16:.*]]:2 = scf.while (%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_14]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_19:.*]] = cmpi "ult", %[[VAL_17]], %[[VAL_13]] : index +// CHECK: %[[VAL_20:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_15]] : index +// CHECK: %[[VAL_21:.*]] = and %[[VAL_19]], %[[VAL_20]] : i1 +// CHECK: scf.condition(%[[VAL_21]]) %[[VAL_17]], %[[VAL_18]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index): +// CHECK: %[[VAL_24:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_26:.*]] = cmpi "ult", %[[VAL_25]], %[[VAL_24]] : index +// CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[VAL_25]], %[[VAL_24]] : index +// CHECK: %[[VAL_28:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_27]] : index +// CHECK: %[[VAL_29:.*]] = cmpi "eq", %[[VAL_25]], %[[VAL_27]] : index +// CHECK: %[[VAL_30:.*]] = and %[[VAL_28]], %[[VAL_29]] : i1 +// CHECK: scf.if %[[VAL_30]] { +// CHECK: %[[VAL_31:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_32:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32 +// CHECK: store %[[VAL_33]], %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<32xf32> +// CHECK: } else { +// CHECK: %[[VAL_34:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_27]] : index +// CHECK: scf.if %[[VAL_34]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_22]]] : memref +// CHECK: store %[[VAL_35]], %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<32xf32> +// CHECK: } else { +// CHECK: %[[VAL_36:.*]] = cmpi "eq", %[[VAL_25]], %[[VAL_27]] : index +// CHECK: scf.if %[[VAL_36]] { +// CHECK: %[[VAL_37:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref +// CHECK: store %[[VAL_37]], %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<32xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_38:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_27]] : index +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_22]], %[[VAL_4]] : index +// CHECK: %[[VAL_40:.*]] = select %[[VAL_38]], %[[VAL_39]], %[[VAL_22]] : index +// CHECK: %[[VAL_41:.*]] = cmpi "eq", %[[VAL_25]], %[[VAL_27]] : index +// CHECK: %[[VAL_42:.*]] = addi %[[VAL_23]], %[[VAL_4]] : index +// CHECK: %[[VAL_43:.*]] = select %[[VAL_41]], %[[VAL_42]], %[[VAL_23]] : index +// CHECK: scf.yield %[[VAL_40]], %[[VAL_43]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#0 to %[[VAL_13]] step %[[VAL_4]] { +// CHECK: %[[VAL_46:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_44]]] : memref +// CHECK: %[[VAL_47:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_44]]] : memref +// CHECK: store %[[VAL_47]], %[[VAL_11]]{{\[}}%[[VAL_46]]] : memref<32xf32> +// CHECK: } +// CHECK: scf.for %[[VAL_48:.*]] = %[[VAL_49:.*]]#1 to %[[VAL_15]] step %[[VAL_4]] { +// CHECK: %[[VAL_50:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_51:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_48]]] : memref +// CHECK: store %[[VAL_51]], %[[VAL_11]]{{\[}}%[[VAL_50]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_52:.*]] = tensor_load %[[VAL_11]] : memref<32xf32> +// CHECK: return %[[VAL_52]] : tensor<32xf32> +// CHECK: } +func @add_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_ss + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// CHECK-LABEL: func @mul_ss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_16:.*]]:2 = scf.while (%[[VAL_17:.*]] = %[[VAL_12]], %[[VAL_18:.*]] = %[[VAL_14]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_19:.*]] = cmpi "ult", %[[VAL_17]], %[[VAL_13]] : index +// CHECK: %[[VAL_20:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_15]] : index +// CHECK: %[[VAL_21:.*]] = and %[[VAL_19]], %[[VAL_20]] : i1 +// CHECK: scf.condition(%[[VAL_21]]) %[[VAL_17]], %[[VAL_18]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index): +// CHECK: %[[VAL_24:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_26:.*]] = cmpi "ult", %[[VAL_25]], %[[VAL_24]] : index +// CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[VAL_25]], %[[VAL_24]] : index +// CHECK: %[[VAL_28:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_27]] : index +// CHECK: %[[VAL_29:.*]] = cmpi "eq", %[[VAL_25]], %[[VAL_27]] : index +// CHECK: %[[VAL_30:.*]] = and %[[VAL_28]], %[[VAL_29]] : i1 +// CHECK: scf.if %[[VAL_30]] { +// CHECK: %[[VAL_31:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_32:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_31]], %[[VAL_32]] : f32 +// CHECK: store %[[VAL_33]], %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<32xf32> +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_34:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_27]] : index +// CHECK: %[[VAL_35:.*]] = addi %[[VAL_22]], %[[VAL_4]] : index +// CHECK: %[[VAL_36:.*]] = select %[[VAL_34]], %[[VAL_35]], %[[VAL_22]] : index +// CHECK: %[[VAL_37:.*]] = cmpi "eq", %[[VAL_25]], %[[VAL_27]] : index +// CHECK: %[[VAL_38:.*]] = addi %[[VAL_23]], %[[VAL_4]] : index +// CHECK: %[[VAL_39:.*]] = select %[[VAL_37]], %[[VAL_38]], %[[VAL_23]] : index +// CHECK: scf.yield %[[VAL_36]], %[[VAL_39]] : index, index +// CHECK: } +// CHECK: %[[VAL_40:.*]] = tensor_load %[[VAL_11]] : memref<32xf32> +// CHECK: return %[[VAL_40]] : tensor<32xf32> +// CHECK: } +func @mul_ss(%arga: tensor<32xf32>, %argb: tensor<32xf32>) -> tensor<32xf32> { + %0 = linalg.generic #trait_ss + ins(%arga, %argb: tensor<32xf32>, tensor<32xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir @@ -0,0 +1,1058 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -test-sparsification | FileCheck %s + +#trait_dd = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "D", "D" ], // A + [ "D", "D" ], // B + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +// CHECK-LABEL: func @add_dd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 16 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_8:.*]] = alloc() : memref<32x16xf32> +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] { +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_11:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: %[[VAL_13:.*]] = addf %[[VAL_11]], %[[VAL_12]] : f32 +// CHECK: store %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_14:.*]] = tensor_load %[[VAL_8]] : memref<32x16xf32> +// CHECK: return %[[VAL_14]] : tensor<32x16xf32> +// CHECK: } +func @add_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_dd + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// CHECK-LABEL: func @mul_dd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 16 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_8:.*]] = alloc() : memref<32x16xf32> +// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] { +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_11:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_11]], %[[VAL_12]] : f32 +// CHECK: store %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_9]], %[[VAL_10]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_14:.*]] = tensor_load %[[VAL_8]] : memref<32x16xf32> +// CHECK: return %[[VAL_14]] : tensor<32x16xf32> +// CHECK: } +func @mul_dd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_dd + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +#trait_ds = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "D", "S" ], // A + [ "D", "D" ], // B + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +// CHECK-LABEL: func @add_ds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = constant true +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16xf32> +// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_14:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_15:.*]] = addi %[[VAL_13]], %[[VAL_6]] : index +// CHECK: %[[VAL_16:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]]:2 = scf.while (%[[VAL_18:.*]] = %[[VAL_14]], %[[VAL_19:.*]] = %[[VAL_5]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_20:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_16]] : index +// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_18]], %[[VAL_19]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index): +// CHECK: %[[VAL_23:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_24:.*]] = cmpi "eq", %[[VAL_23]], %[[VAL_22]] : index +// CHECK: scf.if %[[VAL_24]] { +// CHECK: %[[VAL_25:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_26:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_13]], %[[VAL_22]]] : memref<32x16xf32> +// CHECK: %[[VAL_27:.*]] = addf %[[VAL_25]], %[[VAL_26]] : f32 +// CHECK: store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_22]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_7]] { +// CHECK: %[[VAL_28:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_13]], %[[VAL_22]]] : memref<32x16xf32> +// CHECK: store %[[VAL_28]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_22]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_29:.*]] = cmpi "eq", %[[VAL_23]], %[[VAL_22]] : index +// CHECK: %[[VAL_30:.*]] = addi %[[VAL_21]], %[[VAL_6]] : index +// CHECK: %[[VAL_31:.*]] = select %[[VAL_29]], %[[VAL_30]], %[[VAL_21]] : index +// CHECK: %[[VAL_32:.*]] = addi %[[VAL_22]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_31]], %[[VAL_32]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_33:.*]] = %[[VAL_34:.*]]#1 to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_13]], %[[VAL_33]]] : memref<32x16xf32> +// CHECK: store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_33]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_36:.*]] = tensor_load %[[VAL_12]] : memref<32x16xf32> +// CHECK: return %[[VAL_36]] : tensor<32x16xf32> +// CHECK: } +func @add_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ds + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// CHECK-LABEL: func @mul_ds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32x16xf32> +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] { +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<32x16xf32> +// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: store %[[VAL_19]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_20:.*]] = tensor_load %[[VAL_10]] : memref<32x16xf32> +// CHECK: return %[[VAL_20]] : tensor<32x16xf32> +// CHECK: } +func @mul_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ds + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +#trait_sd = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "S", "D" ], // A + [ "D", "D" ], // B + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +// CHECK-LABEL: func @add_sd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant true +// CHECK: %[[VAL_6:.*]] = constant 0 : index +// CHECK: %[[VAL_7:.*]] = constant 1 : index +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_13:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref +// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_18:.*]] = cmpi "ult", %[[VAL_16]], %[[VAL_14]] : index +// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index): +// CHECK: %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_22:.*]] = cmpi "eq", %[[VAL_21]], %[[VAL_20]] : index +// CHECK: scf.if %[[VAL_22]] { +// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_24:.*]] = muli %[[VAL_19]], %[[VAL_4]] : index +// CHECK: %[[VAL_25:.*]] = addi %[[VAL_24]], %[[VAL_23]] : index +// CHECK: %[[VAL_26:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_27:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_20]], %[[VAL_23]]] : memref<32x16xf32> +// CHECK: %[[VAL_28:.*]] = addf %[[VAL_26]], %[[VAL_27]] : f32 +// CHECK: store %[[VAL_28]], %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_23]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_5]] { +// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_30:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<32x16xf32> +// CHECK: store %[[VAL_30]], %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_31:.*]] = cmpi "eq", %[[VAL_21]], %[[VAL_20]] : index +// CHECK: %[[VAL_32:.*]] = addi %[[VAL_19]], %[[VAL_7]] : index +// CHECK: %[[VAL_33:.*]] = select %[[VAL_31]], %[[VAL_32]], %[[VAL_19]] : index +// CHECK: %[[VAL_34:.*]] = addi %[[VAL_20]], %[[VAL_7]] : index +// CHECK: scf.yield %[[VAL_33]], %[[VAL_34]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_35:.*]] = %[[VAL_36:.*]]#1 to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: scf.for %[[VAL_37:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_38:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_35]], %[[VAL_37]]] : memref<32x16xf32> +// CHECK: store %[[VAL_38]], %[[VAL_12]]{{\[}}%[[VAL_35]], %[[VAL_37]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_39:.*]] = tensor_load %[[VAL_12]] : memref<32x16xf32> +// CHECK: return %[[VAL_39]] : tensor<32x16xf32> +// CHECK: } +func @add_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_sd + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// CHECK-LABEL: func @mul_sd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 16 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_11:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_11]] to %[[VAL_12]] step %[[VAL_5]] { +// CHECK: %[[VAL_14:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_16:.*]] = muli %[[VAL_13]], %[[VAL_3]] : index +// CHECK: %[[VAL_17:.*]] = addi %[[VAL_16]], %[[VAL_15]] : index +// CHECK: %[[VAL_18:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_14]], %[[VAL_15]]] : memref<32x16xf32> +// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_18]], %[[VAL_19]] : f32 +// CHECK: store %[[VAL_20]], %[[VAL_10]]{{\[}}%[[VAL_14]], %[[VAL_15]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_21:.*]] = tensor_load %[[VAL_10]] : memref<32x16xf32> +// CHECK: return %[[VAL_21]] : tensor<32x16xf32> +// CHECK: } +func @mul_sd(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_sd + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +#trait_ss = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "S", "S" ], // A + [ "D", "D" ], // B + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +// CHECK-LABEL: func @add_ss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant true +// CHECK: %[[VAL_6:.*]] = constant 0 : index +// CHECK: %[[VAL_7:.*]] = constant 1 : index +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_14:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_15:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref +// CHECK: %[[VAL_17:.*]]:2 = scf.while (%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_20:.*]] = cmpi "ult", %[[VAL_18]], %[[VAL_16]] : index +// CHECK: scf.condition(%[[VAL_20]]) %[[VAL_18]], %[[VAL_19]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index): +// CHECK: %[[VAL_23:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_24:.*]] = cmpi "eq", %[[VAL_23]], %[[VAL_22]] : index +// CHECK: scf.if %[[VAL_24]] { +// CHECK: %[[VAL_25:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_26:.*]] = addi %[[VAL_21]], %[[VAL_7]] : index +// CHECK: %[[VAL_27:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_28:.*]]:2 = scf.while (%[[VAL_29:.*]] = %[[VAL_25]], %[[VAL_30:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_31:.*]] = cmpi "ult", %[[VAL_29]], %[[VAL_27]] : index +// CHECK: scf.condition(%[[VAL_31]]) %[[VAL_29]], %[[VAL_30]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_32:.*]]: index, %[[VAL_33:.*]]: index): +// CHECK: %[[VAL_34:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref +// CHECK: %[[VAL_35:.*]] = cmpi "eq", %[[VAL_34]], %[[VAL_33]] : index +// CHECK: scf.if %[[VAL_35]] { +// CHECK: %[[VAL_36:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref +// CHECK: %[[VAL_37:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_33]]] : memref<32x16xf32> +// CHECK: %[[VAL_38:.*]] = addf %[[VAL_36]], %[[VAL_37]] : f32 +// CHECK: store %[[VAL_38]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_33]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_5]] { +// CHECK: %[[VAL_39:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_33]]] : memref<32x16xf32> +// CHECK: store %[[VAL_39]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_33]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_40:.*]] = cmpi "eq", %[[VAL_34]], %[[VAL_33]] : index +// CHECK: %[[VAL_41:.*]] = addi %[[VAL_32]], %[[VAL_7]] : index +// CHECK: %[[VAL_42:.*]] = select %[[VAL_40]], %[[VAL_41]], %[[VAL_32]] : index +// CHECK: %[[VAL_43:.*]] = addi %[[VAL_33]], %[[VAL_7]] : index +// CHECK: scf.yield %[[VAL_42]], %[[VAL_43]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_44:.*]] = %[[VAL_45:.*]]#1 to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_46:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_44]]] : memref<32x16xf32> +// CHECK: store %[[VAL_46]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_44]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_5]] { +// CHECK: scf.for %[[VAL_47:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_48:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_22]], %[[VAL_47]]] : memref<32x16xf32> +// CHECK: store %[[VAL_48]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_47]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_49:.*]] = cmpi "eq", %[[VAL_23]], %[[VAL_22]] : index +// CHECK: %[[VAL_50:.*]] = addi %[[VAL_21]], %[[VAL_7]] : index +// CHECK: %[[VAL_51:.*]] = select %[[VAL_49]], %[[VAL_50]], %[[VAL_21]] : index +// CHECK: %[[VAL_52:.*]] = addi %[[VAL_22]], %[[VAL_7]] : index +// CHECK: scf.yield %[[VAL_51]], %[[VAL_52]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_53:.*]] = %[[VAL_54:.*]]#1 to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: scf.for %[[VAL_55:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_56:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_53]], %[[VAL_55]]] : memref<32x16xf32> +// CHECK: store %[[VAL_56]], %[[VAL_14]]{{\[}}%[[VAL_53]], %[[VAL_55]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_57:.*]] = tensor_load %[[VAL_14]] : memref<32x16xf32> +// CHECK: return %[[VAL_57]] : tensor<32x16xf32> +// CHECK: } +func @add_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ss + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// CHECK-LABEL: func @mul_ss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_4]] { +// CHECK: %[[VAL_15:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_17:.*]] = addi %[[VAL_14]], %[[VAL_4]] : index +// CHECK: %[[VAL_18:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_4]] { +// CHECK: %[[VAL_20:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_22:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_20]]] : memref<32x16xf32> +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f32 +// CHECK: store %[[VAL_23]], %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_20]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_24:.*]] = tensor_load %[[VAL_11]] : memref<32x16xf32> +// CHECK: return %[[VAL_24]] : tensor<32x16xf32> +// CHECK: } +func @mul_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ss + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +#trait_ss_ss = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "S", "S" ], // A + [ "S", "S" ], // B + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +// CHECK-LABEL: func @add_ss_ss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_15:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_23:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_17]] : index +// CHECK: %[[VAL_24:.*]] = cmpi "ult", %[[VAL_22]], %[[VAL_19]] : index +// CHECK: %[[VAL_25:.*]] = and %[[VAL_23]], %[[VAL_24]] : i1 +// CHECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index): +// CHECK: %[[VAL_28:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_30:.*]] = cmpi "ult", %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_32:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_33:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_34:.*]] = and %[[VAL_32]], %[[VAL_33]] : i1 +// CHECK: scf.if %[[VAL_34]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_36:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_37:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_36]]] : memref +// CHECK: %[[VAL_38:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_40:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_39]]] : memref +// CHECK: %[[VAL_41:.*]]:2 = scf.while (%[[VAL_42:.*]] = %[[VAL_35]], %[[VAL_43:.*]] = %[[VAL_38]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_44:.*]] = cmpi "ult", %[[VAL_42]], %[[VAL_37]] : index +// CHECK: %[[VAL_45:.*]] = cmpi "ult", %[[VAL_43]], %[[VAL_40]] : index +// CHECK: %[[VAL_46:.*]] = and %[[VAL_44]], %[[VAL_45]] : i1 +// CHECK: scf.condition(%[[VAL_46]]) %[[VAL_42]], %[[VAL_43]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: index): +// CHECK: %[[VAL_49:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_50:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_51:.*]] = cmpi "ult", %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_52:.*]] = select %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_53:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_54:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_55:.*]] = and %[[VAL_53]], %[[VAL_54]] : i1 +// CHECK: scf.if %[[VAL_55]] { +// CHECK: %[[VAL_56:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_57:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_58:.*]] = addf %[[VAL_56]], %[[VAL_57]] : f32 +// CHECK: store %[[VAL_58]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: %[[VAL_59:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: scf.if %[[VAL_59]] { +// CHECK: %[[VAL_60:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_47]]] : memref +// CHECK: store %[[VAL_60]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: %[[VAL_61:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: scf.if %[[VAL_61]] { +// CHECK: %[[VAL_62:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref +// CHECK: store %[[VAL_62]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_63:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_64:.*]] = addi %[[VAL_47]], %[[VAL_4]] : index +// CHECK: %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_47]] : index +// CHECK: %[[VAL_66:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_67:.*]] = addi %[[VAL_48]], %[[VAL_4]] : index +// CHECK: %[[VAL_68:.*]] = select %[[VAL_66]], %[[VAL_67]], %[[VAL_48]] : index +// CHECK: scf.yield %[[VAL_65]], %[[VAL_68]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_69:.*]] = %[[VAL_70:.*]]#0 to %[[VAL_37]] step %[[VAL_4]] { +// CHECK: %[[VAL_71:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_69]]] : memref +// CHECK: %[[VAL_72:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_69]]] : memref +// CHECK: store %[[VAL_72]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_71]]] : memref<32x16xf32> +// CHECK: } +// CHECK: scf.for %[[VAL_73:.*]] = %[[VAL_74:.*]]#1 to %[[VAL_40]] step %[[VAL_4]] { +// CHECK: %[[VAL_75:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_73]]] : memref +// CHECK: %[[VAL_76:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_73]]] : memref +// CHECK: store %[[VAL_76]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_75]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: %[[VAL_77:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: scf.if %[[VAL_77]] { +// CHECK: %[[VAL_78:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_79:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_80:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_79]]] : memref +// CHECK: scf.for %[[VAL_81:.*]] = %[[VAL_78]] to %[[VAL_80]] step %[[VAL_4]] { +// CHECK: %[[VAL_82:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_81]]] : memref +// CHECK: %[[VAL_83:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_81]]] : memref +// CHECK: store %[[VAL_83]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_82]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: %[[VAL_84:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: scf.if %[[VAL_84]] { +// CHECK: %[[VAL_85:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_86:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_87:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_86]]] : memref +// CHECK: scf.for %[[VAL_88:.*]] = %[[VAL_85]] to %[[VAL_87]] step %[[VAL_4]] { +// CHECK: %[[VAL_89:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_88]]] : memref +// CHECK: %[[VAL_90:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_88]]] : memref +// CHECK: store %[[VAL_90]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_89]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_91:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_92:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_93:.*]] = select %[[VAL_91]], %[[VAL_92]], %[[VAL_26]] : index +// CHECK: %[[VAL_94:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_95:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_96:.*]] = select %[[VAL_94]], %[[VAL_95]], %[[VAL_27]] : index +// CHECK: scf.yield %[[VAL_93]], %[[VAL_96]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_97:.*]] = %[[VAL_98:.*]]#0 to %[[VAL_17]] step %[[VAL_4]] { +// CHECK: %[[VAL_99:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_97]]] : memref +// CHECK: %[[VAL_100:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_97]]] : memref +// CHECK: %[[VAL_101:.*]] = addi %[[VAL_97]], %[[VAL_4]] : index +// CHECK: %[[VAL_102:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_101]]] : memref +// CHECK: scf.for %[[VAL_103:.*]] = %[[VAL_100]] to %[[VAL_102]] step %[[VAL_4]] { +// CHECK: %[[VAL_104:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_103]]] : memref +// CHECK: %[[VAL_105:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_103]]] : memref +// CHECK: store %[[VAL_105]], %[[VAL_15]]{{\[}}%[[VAL_99]], %[[VAL_104]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: scf.for %[[VAL_106:.*]] = %[[VAL_107:.*]]#1 to %[[VAL_19]] step %[[VAL_4]] { +// CHECK: %[[VAL_108:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_106]]] : memref +// CHECK: %[[VAL_109:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_106]]] : memref +// CHECK: %[[VAL_110:.*]] = addi %[[VAL_106]], %[[VAL_4]] : index +// CHECK: %[[VAL_111:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_110]]] : memref +// CHECK: scf.for %[[VAL_112:.*]] = %[[VAL_109]] to %[[VAL_111]] step %[[VAL_4]] { +// CHECK: %[[VAL_113:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_112]]] : memref +// CHECK: %[[VAL_114:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_112]]] : memref +// CHECK: store %[[VAL_114]], %[[VAL_15]]{{\[}}%[[VAL_108]], %[[VAL_113]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_115:.*]] = tensor_load %[[VAL_15]] : memref<32x16xf32> +// CHECK: return %[[VAL_115]] : tensor<32x16xf32> +// CHECK: } +func @add_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ss_ss + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// CHECK-LABEL: func @mul_ss_ss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_15:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_23:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_17]] : index +// CHECK: %[[VAL_24:.*]] = cmpi "ult", %[[VAL_22]], %[[VAL_19]] : index +// CHECK: %[[VAL_25:.*]] = and %[[VAL_23]], %[[VAL_24]] : i1 +// CHECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index): +// CHECK: %[[VAL_28:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_30:.*]] = cmpi "ult", %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_32:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_33:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_34:.*]] = and %[[VAL_32]], %[[VAL_33]] : i1 +// CHECK: scf.if %[[VAL_34]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_36:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_37:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_36]]] : memref +// CHECK: %[[VAL_38:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_40:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_39]]] : memref +// CHECK: %[[VAL_41:.*]]:2 = scf.while (%[[VAL_42:.*]] = %[[VAL_35]], %[[VAL_43:.*]] = %[[VAL_38]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_44:.*]] = cmpi "ult", %[[VAL_42]], %[[VAL_37]] : index +// CHECK: %[[VAL_45:.*]] = cmpi "ult", %[[VAL_43]], %[[VAL_40]] : index +// CHECK: %[[VAL_46:.*]] = and %[[VAL_44]], %[[VAL_45]] : i1 +// CHECK: scf.condition(%[[VAL_46]]) %[[VAL_42]], %[[VAL_43]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: index): +// CHECK: %[[VAL_49:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_50:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_51:.*]] = cmpi "ult", %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_52:.*]] = select %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_53:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_54:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_55:.*]] = and %[[VAL_53]], %[[VAL_54]] : i1 +// CHECK: scf.if %[[VAL_55]] { +// CHECK: %[[VAL_56:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_57:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_58:.*]] = mulf %[[VAL_56]], %[[VAL_57]] : f32 +// CHECK: store %[[VAL_58]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_59:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_60:.*]] = addi %[[VAL_47]], %[[VAL_4]] : index +// CHECK: %[[VAL_61:.*]] = select %[[VAL_59]], %[[VAL_60]], %[[VAL_47]] : index +// CHECK: %[[VAL_62:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_63:.*]] = addi %[[VAL_48]], %[[VAL_4]] : index +// CHECK: %[[VAL_64:.*]] = select %[[VAL_62]], %[[VAL_63]], %[[VAL_48]] : index +// CHECK: scf.yield %[[VAL_61]], %[[VAL_64]] : index, index +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_65:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_66:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_67:.*]] = select %[[VAL_65]], %[[VAL_66]], %[[VAL_26]] : index +// CHECK: %[[VAL_68:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_69:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_70:.*]] = select %[[VAL_68]], %[[VAL_69]], %[[VAL_27]] : index +// CHECK: scf.yield %[[VAL_67]], %[[VAL_70]] : index, index +// CHECK: } +// CHECK: %[[VAL_71:.*]] = tensor_load %[[VAL_15]] : memref<32x16xf32> +// CHECK: return %[[VAL_71]] : tensor<32x16xf32> +// CHECK: } +func @mul_ss_ss(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ss_ss + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +#trait_sd_ds = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)>, // B + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "S", "D" ], // A + [ "D", "S" ], // B + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) OP B(i,j)" +} + +// CHECK-LABEL: func @add_sd_ds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_15:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_23:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_17]] : index +// CHECK: %[[VAL_24:.*]] = cmpi "ult", %[[VAL_22]], %[[VAL_19]] : index +// CHECK: %[[VAL_25:.*]] = and %[[VAL_23]], %[[VAL_24]] : i1 +// CHECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index): +// CHECK: %[[VAL_28:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_30:.*]] = cmpi "ult", %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_32:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_33:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_34:.*]] = and %[[VAL_32]], %[[VAL_33]] : i1 +// CHECK: scf.if %[[VAL_34]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_36:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_37:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_36]]] : memref +// CHECK: %[[VAL_38:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_40:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_39]]] : memref +// CHECK: %[[VAL_41:.*]]:2 = scf.while (%[[VAL_42:.*]] = %[[VAL_35]], %[[VAL_43:.*]] = %[[VAL_38]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_44:.*]] = cmpi "ult", %[[VAL_42]], %[[VAL_37]] : index +// CHECK: %[[VAL_45:.*]] = cmpi "ult", %[[VAL_43]], %[[VAL_40]] : index +// CHECK: %[[VAL_46:.*]] = and %[[VAL_44]], %[[VAL_45]] : i1 +// CHECK: scf.condition(%[[VAL_46]]) %[[VAL_42]], %[[VAL_43]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: index): +// CHECK: %[[VAL_49:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_50:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_51:.*]] = cmpi "ult", %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_52:.*]] = select %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_53:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_54:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_55:.*]] = and %[[VAL_53]], %[[VAL_54]] : i1 +// CHECK: scf.if %[[VAL_55]] { +// CHECK: %[[VAL_56:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_57:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_58:.*]] = addf %[[VAL_56]], %[[VAL_57]] : f32 +// CHECK: store %[[VAL_58]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: %[[VAL_59:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: scf.if %[[VAL_59]] { +// CHECK: %[[VAL_60:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_47]]] : memref +// CHECK: store %[[VAL_60]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: %[[VAL_61:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: scf.if %[[VAL_61]] { +// CHECK: %[[VAL_62:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref +// CHECK: store %[[VAL_62]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_63:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_64:.*]] = addi %[[VAL_47]], %[[VAL_4]] : index +// CHECK: %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_47]] : index +// CHECK: %[[VAL_66:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_67:.*]] = addi %[[VAL_48]], %[[VAL_4]] : index +// CHECK: %[[VAL_68:.*]] = select %[[VAL_66]], %[[VAL_67]], %[[VAL_48]] : index +// CHECK: scf.yield %[[VAL_65]], %[[VAL_68]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_69:.*]] = %[[VAL_70:.*]]#0 to %[[VAL_37]] step %[[VAL_4]] { +// CHECK: %[[VAL_71:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_69]]] : memref +// CHECK: %[[VAL_72:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_69]]] : memref +// CHECK: store %[[VAL_72]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_71]]] : memref<32x16xf32> +// CHECK: } +// CHECK: scf.for %[[VAL_73:.*]] = %[[VAL_74:.*]]#1 to %[[VAL_40]] step %[[VAL_4]] { +// CHECK: %[[VAL_75:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_73]]] : memref +// CHECK: %[[VAL_76:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_73]]] : memref +// CHECK: store %[[VAL_76]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_75]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: %[[VAL_77:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: scf.if %[[VAL_77]] { +// CHECK: %[[VAL_78:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_79:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_80:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_79]]] : memref +// CHECK: scf.for %[[VAL_81:.*]] = %[[VAL_78]] to %[[VAL_80]] step %[[VAL_4]] { +// CHECK: %[[VAL_82:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_81]]] : memref +// CHECK: %[[VAL_83:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_81]]] : memref +// CHECK: store %[[VAL_83]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_82]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: %[[VAL_84:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: scf.if %[[VAL_84]] { +// CHECK: %[[VAL_85:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_86:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_87:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_86]]] : memref +// CHECK: scf.for %[[VAL_88:.*]] = %[[VAL_85]] to %[[VAL_87]] step %[[VAL_4]] { +// CHECK: %[[VAL_89:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_88]]] : memref +// CHECK: %[[VAL_90:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_88]]] : memref +// CHECK: store %[[VAL_90]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_89]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_91:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_92:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_93:.*]] = select %[[VAL_91]], %[[VAL_92]], %[[VAL_26]] : index +// CHECK: %[[VAL_94:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_95:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_96:.*]] = select %[[VAL_94]], %[[VAL_95]], %[[VAL_27]] : index +// CHECK: scf.yield %[[VAL_93]], %[[VAL_96]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_97:.*]] = %[[VAL_98:.*]]#0 to %[[VAL_17]] step %[[VAL_4]] { +// CHECK: %[[VAL_99:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_97]]] : memref +// CHECK: %[[VAL_100:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_97]]] : memref +// CHECK: %[[VAL_101:.*]] = addi %[[VAL_97]], %[[VAL_4]] : index +// CHECK: %[[VAL_102:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_101]]] : memref +// CHECK: scf.for %[[VAL_103:.*]] = %[[VAL_100]] to %[[VAL_102]] step %[[VAL_4]] { +// CHECK: %[[VAL_104:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_103]]] : memref +// CHECK: %[[VAL_105:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_103]]] : memref +// CHECK: store %[[VAL_105]], %[[VAL_15]]{{\[}}%[[VAL_99]], %[[VAL_104]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: scf.for %[[VAL_106:.*]] = %[[VAL_107:.*]]#1 to %[[VAL_19]] step %[[VAL_4]] { +// CHECK: %[[VAL_108:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_106]]] : memref +// CHECK: %[[VAL_109:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_106]]] : memref +// CHECK: %[[VAL_110:.*]] = addi %[[VAL_106]], %[[VAL_4]] : index +// CHECK: %[[VAL_111:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_110]]] : memref +// CHECK: scf.for %[[VAL_112:.*]] = %[[VAL_109]] to %[[VAL_111]] step %[[VAL_4]] { +// CHECK: %[[VAL_113:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_112]]] : memref +// CHECK: %[[VAL_114:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_112]]] : memref +// CHECK: store %[[VAL_114]], %[[VAL_15]]{{\[}}%[[VAL_108]], %[[VAL_113]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_115:.*]] = tensor_load %[[VAL_15]] : memref<32x16xf32> +// CHECK: return %[[VAL_115]] : tensor<32x16xf32> +// CHECK: } +func @add_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ss_ss + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +// CHECK-LABEL: func @mul_sd_ds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16xf32>) -> tensor<32x16xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_15:.*]] = alloc() : memref<32x16xf32> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_23:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_17]] : index +// CHECK: %[[VAL_24:.*]] = cmpi "ult", %[[VAL_22]], %[[VAL_19]] : index +// CHECK: %[[VAL_25:.*]] = and %[[VAL_23]], %[[VAL_24]] : i1 +// CHECK: scf.condition(%[[VAL_25]]) %[[VAL_21]], %[[VAL_22]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index): +// CHECK: %[[VAL_28:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_30:.*]] = cmpi "ult", %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : index +// CHECK: %[[VAL_32:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_33:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_34:.*]] = and %[[VAL_32]], %[[VAL_33]] : i1 +// CHECK: scf.if %[[VAL_34]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_36:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_37:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_36]]] : memref +// CHECK: %[[VAL_38:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_40:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_39]]] : memref +// CHECK: %[[VAL_41:.*]]:2 = scf.while (%[[VAL_42:.*]] = %[[VAL_35]], %[[VAL_43:.*]] = %[[VAL_38]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_44:.*]] = cmpi "ult", %[[VAL_42]], %[[VAL_37]] : index +// CHECK: %[[VAL_45:.*]] = cmpi "ult", %[[VAL_43]], %[[VAL_40]] : index +// CHECK: %[[VAL_46:.*]] = and %[[VAL_44]], %[[VAL_45]] : i1 +// CHECK: scf.condition(%[[VAL_46]]) %[[VAL_42]], %[[VAL_43]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_47:.*]]: index, %[[VAL_48:.*]]: index): +// CHECK: %[[VAL_49:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_50:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_51:.*]] = cmpi "ult", %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_52:.*]] = select %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : index +// CHECK: %[[VAL_53:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_54:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_55:.*]] = and %[[VAL_53]], %[[VAL_54]] : i1 +// CHECK: scf.if %[[VAL_55]] { +// CHECK: %[[VAL_56:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_47]]] : memref +// CHECK: %[[VAL_57:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_48]]] : memref +// CHECK: %[[VAL_58:.*]] = mulf %[[VAL_56]], %[[VAL_57]] : f32 +// CHECK: store %[[VAL_58]], %[[VAL_15]]{{\[}}%[[VAL_31]], %[[VAL_52]]] : memref<32x16xf32> +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_59:.*]] = cmpi "eq", %[[VAL_49]], %[[VAL_52]] : index +// CHECK: %[[VAL_60:.*]] = addi %[[VAL_47]], %[[VAL_4]] : index +// CHECK: %[[VAL_61:.*]] = select %[[VAL_59]], %[[VAL_60]], %[[VAL_47]] : index +// CHECK: %[[VAL_62:.*]] = cmpi "eq", %[[VAL_50]], %[[VAL_52]] : index +// CHECK: %[[VAL_63:.*]] = addi %[[VAL_48]], %[[VAL_4]] : index +// CHECK: %[[VAL_64:.*]] = select %[[VAL_62]], %[[VAL_63]], %[[VAL_48]] : index +// CHECK: scf.yield %[[VAL_61]], %[[VAL_64]] : index, index +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: %[[VAL_65:.*]] = cmpi "eq", %[[VAL_28]], %[[VAL_31]] : index +// CHECK: %[[VAL_66:.*]] = addi %[[VAL_26]], %[[VAL_4]] : index +// CHECK: %[[VAL_67:.*]] = select %[[VAL_65]], %[[VAL_66]], %[[VAL_26]] : index +// CHECK: %[[VAL_68:.*]] = cmpi "eq", %[[VAL_29]], %[[VAL_31]] : index +// CHECK: %[[VAL_69:.*]] = addi %[[VAL_27]], %[[VAL_4]] : index +// CHECK: %[[VAL_70:.*]] = select %[[VAL_68]], %[[VAL_69]], %[[VAL_27]] : index +// CHECK: scf.yield %[[VAL_67]], %[[VAL_70]] : index, index +// CHECK: } +// CHECK: %[[VAL_71:.*]] = tensor_load %[[VAL_15]] : memref<32x16xf32> +// CHECK: return %[[VAL_71]] : tensor<32x16xf32> +// CHECK: } +func @mul_sd_ds(%arga: tensor<32x16xf32>, %argb: tensor<32x16xf32>) -> tensor<32x16xf32> { + %0 = linalg.generic #trait_ss_ss + ins(%arga, %argb: tensor<32x16xf32>, tensor<32x16xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16xf32> + return %0 : tensor<32x16xf32> +} + +#trait_matvec = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (j)>, // b + affine_map<(i,j) -> (i)> // x (out) + ], + sparse = [ + [ "D", "S" ], // A + [ "D" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel", "reduction"], + doc = "x(i) += A(i,j) * b(j)" +} + +// CHECK-LABEL: func @matvec( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<16xf32>) -> tensor<16xf32> { +// CHECK: %[[VAL_3:.*]] = constant 999 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_3]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_3]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_3]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32xf32> +// CHECK: %[[VAL_11:.*]] = alloc() : memref<16xf32> +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_6]] : index +// CHECK: %[[VAL_15:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_6]] { +// CHECK: %[[VAL_17:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_17]]] : memref<32xf32> +// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_18]], %[[VAL_19]] : f32 +// CHECK: %[[VAL_21:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32> +// CHECK: %[[VAL_22:.*]] = addf %[[VAL_20]], %[[VAL_21]] : f32 +// CHECK: store %[[VAL_22]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_23:.*]] = tensor_load %[[VAL_11]] : memref<16xf32> +// CHECK: return %[[VAL_23]] : tensor<16xf32> +// CHECK: } +func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> { + %0 = linalg.generic #trait_matvec + ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>) + init(%argx : tensor<16xf32>) { + ^bb(%A: f32, %b: f32, %x: f32): + %0 = mulf %A, %b : f32 + %1 = addf %0, %x : f32 + linalg.yield %1 : f32 + } -> tensor<16xf32> + return %0 : tensor<16xf32> +} diff --git a/mlir/test/Dialect/Linalg/sparse_3d.mlir b/mlir/test/Dialect/Linalg/sparse_3d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/sparse_3d.mlir @@ -0,0 +1,1227 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -test-sparsification | FileCheck %s + +#trait_ddd = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "D", "D", "D" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_ddd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 16 : index +// CHECK: %[[VAL_4:.*]] = constant 8 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_8:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_9:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_10]], %[[VAL_11]], %[[VAL_12]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_14:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_11]], %[[VAL_12]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_15:.*]] = addf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_11]], %[[VAL_12]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_16:.*]] = tensor_load %[[VAL_9]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_16]] : tensor<32x16x8xf32> +// CHECK: } +func @add_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_ddd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_ddd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 16 : index +// CHECK: %[[VAL_4:.*]] = constant 8 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_8:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_9:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_10]], %[[VAL_11]], %[[VAL_12]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_14:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_11]], %[[VAL_12]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_11]], %[[VAL_12]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_16:.*]] = tensor_load %[[VAL_9]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_16]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_ddd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_ddd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_dds = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "D", "D", "S" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_dds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 8 : index +// CHECK: %[[VAL_6:.*]] = constant 0 : index +// CHECK: %[[VAL_7:.*]] = constant 1 : index +// CHECK: %[[VAL_8:.*]] = constant true +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_13:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_16:.*]] = muli %[[VAL_14]], %[[VAL_4]] : index +// CHECK: %[[VAL_17:.*]] = addi %[[VAL_16]], %[[VAL_15]] : index +// CHECK: %[[VAL_18:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = addi %[[VAL_17]], %[[VAL_7]] : index +// CHECK: %[[VAL_20:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]]:2 = scf.while (%[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_6]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_24:.*]] = cmpi "ult", %[[VAL_22]], %[[VAL_20]] : index +// CHECK: scf.condition(%[[VAL_24]]) %[[VAL_22]], %[[VAL_23]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_25:.*]]: index, %[[VAL_26:.*]]: index): +// CHECK: %[[VAL_27:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_28:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: scf.if %[[VAL_28]] { +// CHECK: %[[VAL_29:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref +// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_26]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_31:.*]] = addf %[[VAL_29]], %[[VAL_30]] : f32 +// CHECK: store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_26]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_8]] { +// CHECK: %[[VAL_32:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_26]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_32]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_26]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_33:.*]] = cmpi "eq", %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_34:.*]] = addi %[[VAL_25]], %[[VAL_7]] : index +// CHECK: %[[VAL_35:.*]] = select %[[VAL_33]], %[[VAL_34]], %[[VAL_25]] : index +// CHECK: %[[VAL_36:.*]] = addi %[[VAL_26]], %[[VAL_7]] : index +// CHECK: scf.yield %[[VAL_35]], %[[VAL_36]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_37:.*]] = %[[VAL_38:.*]]#1 to %[[VAL_5]] step %[[VAL_7]] { +// CHECK: %[[VAL_39:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_39]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_40:.*]] = tensor_load %[[VAL_13]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_40]] : tensor<32x16x8xf32> +// CHECK: } +func @add_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_dds + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_dds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_14:.*]] = muli %[[VAL_12]], %[[VAL_4]] : index +// CHECK: %[[VAL_15:.*]] = addi %[[VAL_14]], %[[VAL_13]] : index +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = addi %[[VAL_15]], %[[VAL_6]] : index +// CHECK: %[[VAL_18:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_6]] { +// CHECK: %[[VAL_20:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_22:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_20]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f32 +// CHECK: store %[[VAL_23]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_20]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_24:.*]] = tensor_load %[[VAL_11]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_24]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_dds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_dds + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_dsd = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "D", "S", "D" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_dsd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 8 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_13:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_8]] { +// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref +// CHECK: %[[VAL_16:.*]] = addi %[[VAL_14]], %[[VAL_8]] : index +// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]]:2 = scf.while (%[[VAL_19:.*]] = %[[VAL_15]], %[[VAL_20:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_21:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_17]] : index +// CHECK: scf.condition(%[[VAL_21]]) %[[VAL_19]], %[[VAL_20]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index): +// CHECK: %[[VAL_24:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_23]] : index +// CHECK: scf.if %[[VAL_25]] { +// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_27:.*]] = muli %[[VAL_22]], %[[VAL_5]] : index +// CHECK: %[[VAL_28:.*]] = addi %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_29:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_23]], %[[VAL_26]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_31:.*]] = addf %[[VAL_29]], %[[VAL_30]] : f32 +// CHECK: store %[[VAL_31]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_23]], %[[VAL_26]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_32:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_33:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_23]], %[[VAL_32]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_33]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_23]], %[[VAL_32]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_34:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_23]] : index +// CHECK: %[[VAL_35:.*]] = addi %[[VAL_22]], %[[VAL_8]] : index +// CHECK: %[[VAL_36:.*]] = select %[[VAL_34]], %[[VAL_35]], %[[VAL_22]] : index +// CHECK: %[[VAL_37:.*]] = addi %[[VAL_23]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_36]], %[[VAL_37]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_38:.*]] = %[[VAL_39:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_40:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_41:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_14]], %[[VAL_38]], %[[VAL_40]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_41]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_38]], %[[VAL_40]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_42:.*]] = tensor_load %[[VAL_13]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_42]] : tensor<32x16x8xf32> +// CHECK: } +func @add_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_dsd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_dsd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 8 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_6]] : index +// CHECK: %[[VAL_15:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_14]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_6]] { +// CHECK: %[[VAL_17:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref +// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_19:.*]] = muli %[[VAL_16]], %[[VAL_4]] : index +// CHECK: %[[VAL_20:.*]] = addi %[[VAL_19]], %[[VAL_18]] : index +// CHECK: %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_22:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_12]], %[[VAL_17]], %[[VAL_18]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f32 +// CHECK: store %[[VAL_23]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_17]], %[[VAL_18]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_24:.*]] = tensor_load %[[VAL_11]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_24]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_dsd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_dsd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_dss = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "D", "S", "S" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_dss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 8 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_15:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_8]] { +// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]] = addi %[[VAL_16]], %[[VAL_8]] : index +// CHECK: %[[VAL_19:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_17]], %[[VAL_22:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_23:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_19]] : index +// CHECK: scf.condition(%[[VAL_23]]) %[[VAL_21]], %[[VAL_22]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index): +// CHECK: %[[VAL_26:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_27:.*]] = cmpi "eq", %[[VAL_26]], %[[VAL_25]] : index +// CHECK: scf.if %[[VAL_27]] { +// CHECK: %[[VAL_28:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_29:.*]] = addi %[[VAL_24]], %[[VAL_8]] : index +// CHECK: %[[VAL_30:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref +// CHECK: %[[VAL_31:.*]]:2 = scf.while (%[[VAL_32:.*]] = %[[VAL_28]], %[[VAL_33:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_34:.*]] = cmpi "ult", %[[VAL_32]], %[[VAL_30]] : index +// CHECK: scf.condition(%[[VAL_34]]) %[[VAL_32]], %[[VAL_33]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index): +// CHECK: %[[VAL_37:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_38:.*]] = cmpi "eq", %[[VAL_37]], %[[VAL_36]] : index +// CHECK: scf.if %[[VAL_38]] { +// CHECK: %[[VAL_39:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_40:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_36]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_41:.*]] = addf %[[VAL_39]], %[[VAL_40]] : f32 +// CHECK: store %[[VAL_41]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_36]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: %[[VAL_42:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_36]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_42]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_36]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_43:.*]] = cmpi "eq", %[[VAL_37]], %[[VAL_36]] : index +// CHECK: %[[VAL_44:.*]] = addi %[[VAL_35]], %[[VAL_8]] : index +// CHECK: %[[VAL_45:.*]] = select %[[VAL_43]], %[[VAL_44]], %[[VAL_35]] : index +// CHECK: %[[VAL_46:.*]] = addi %[[VAL_36]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_45]], %[[VAL_46]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_47:.*]] = %[[VAL_48:.*]]#1 to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_49:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_47]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_49]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_47]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_50:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_51:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_50]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_51]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_25]], %[[VAL_50]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_52:.*]] = cmpi "eq", %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_53:.*]] = addi %[[VAL_24]], %[[VAL_8]] : index +// CHECK: %[[VAL_54:.*]] = select %[[VAL_52]], %[[VAL_53]], %[[VAL_24]] : index +// CHECK: %[[VAL_55:.*]] = addi %[[VAL_25]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_54]], %[[VAL_55]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_56:.*]] = %[[VAL_57:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_58:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_59:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_56]], %[[VAL_58]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_59]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_56]], %[[VAL_58]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_60:.*]] = tensor_load %[[VAL_15]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_60]] : tensor<32x16x8xf32> +// CHECK: } +func @add_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_dss + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_dss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: %[[VAL_15:.*]] = addi %[[VAL_13]], %[[VAL_5]] : index +// CHECK: %[[VAL_16:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_5]] { +// CHECK: %[[VAL_18:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_20:.*]] = addi %[[VAL_17]], %[[VAL_5]] : index +// CHECK: %[[VAL_21:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref +// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_19]] to %[[VAL_21]] step %[[VAL_5]] { +// CHECK: %[[VAL_23:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_24:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_23]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_24]], %[[VAL_25]] : f32 +// CHECK: store %[[VAL_26]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_23]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_27:.*]] = tensor_load %[[VAL_12]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_27]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_dss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_dss + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_sdd = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "S", "D", "D" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_sdd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 8 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_13:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_14:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref +// CHECK: %[[VAL_16:.*]]:2 = scf.while (%[[VAL_17:.*]] = %[[VAL_14]], %[[VAL_18:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_19:.*]] = cmpi "ult", %[[VAL_17]], %[[VAL_15]] : index +// CHECK: scf.condition(%[[VAL_19]]) %[[VAL_17]], %[[VAL_18]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_20:.*]]: index, %[[VAL_21:.*]]: index): +// CHECK: %[[VAL_22:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_23:.*]] = cmpi "eq", %[[VAL_22]], %[[VAL_21]] : index +// CHECK: scf.if %[[VAL_23]] { +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: %[[VAL_25:.*]] = muli %[[VAL_20]], %[[VAL_4]] : index +// CHECK: %[[VAL_26:.*]] = addi %[[VAL_25]], %[[VAL_24]] : index +// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_28:.*]] = muli %[[VAL_26]], %[[VAL_5]] : index +// CHECK: %[[VAL_29:.*]] = addi %[[VAL_28]], %[[VAL_27]] : index +// CHECK: %[[VAL_30:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref +// CHECK: %[[VAL_31:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_21]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_32:.*]] = addf %[[VAL_30]], %[[VAL_31]] : f32 +// CHECK: store %[[VAL_32]], %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_33:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_34:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_35:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_21]], %[[VAL_33]], %[[VAL_34]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_35]], %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_33]], %[[VAL_34]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_36:.*]] = cmpi "eq", %[[VAL_22]], %[[VAL_21]] : index +// CHECK: %[[VAL_37:.*]] = addi %[[VAL_20]], %[[VAL_8]] : index +// CHECK: %[[VAL_38:.*]] = select %[[VAL_36]], %[[VAL_37]], %[[VAL_20]] : index +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_21]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_38]], %[[VAL_39]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_40:.*]] = %[[VAL_41:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_42:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_43:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_44:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_40]], %[[VAL_42]], %[[VAL_43]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_44]], %[[VAL_13]]{{\[}}%[[VAL_40]], %[[VAL_42]], %[[VAL_43]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_45:.*]] = tensor_load %[[VAL_13]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_45]] : tensor<32x16x8xf32> +// CHECK: } +func @add_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_sdd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_sdd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 16 : index +// CHECK: %[[VAL_4:.*]] = constant 8 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_12:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_13:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_6]] { +// CHECK: %[[VAL_15:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_17:.*]] = muli %[[VAL_14]], %[[VAL_3]] : index +// CHECK: %[[VAL_18:.*]] = addi %[[VAL_17]], %[[VAL_16]] : index +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK: %[[VAL_20:.*]] = muli %[[VAL_18]], %[[VAL_4]] : index +// CHECK: %[[VAL_21:.*]] = addi %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_22:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_23:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_19]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_22]], %[[VAL_23]] : f32 +// CHECK: store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_15]], %[[VAL_16]], %[[VAL_19]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_25:.*]] = tensor_load %[[VAL_11]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_25]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_sdd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_sdd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_sds = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "S", "D", "S" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_sds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 8 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_15:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref +// CHECK: %[[VAL_18:.*]]:2 = scf.while (%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_21:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_17]] : index +// CHECK: scf.condition(%[[VAL_21]]) %[[VAL_19]], %[[VAL_20]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index): +// CHECK: %[[VAL_24:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_23]] : index +// CHECK: scf.if %[[VAL_25]] { +// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: %[[VAL_27:.*]] = muli %[[VAL_22]], %[[VAL_4]] : index +// CHECK: %[[VAL_28:.*]] = addi %[[VAL_27]], %[[VAL_26]] : index +// CHECK: %[[VAL_29:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_30:.*]] = addi %[[VAL_28]], %[[VAL_8]] : index +// CHECK: %[[VAL_31:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref +// CHECK: %[[VAL_32:.*]]:2 = scf.while (%[[VAL_33:.*]] = %[[VAL_29]], %[[VAL_34:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_35:.*]] = cmpi "ult", %[[VAL_33]], %[[VAL_31]] : index +// CHECK: scf.condition(%[[VAL_35]]) %[[VAL_33]], %[[VAL_34]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_36:.*]]: index, %[[VAL_37:.*]]: index): +// CHECK: %[[VAL_38:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_36]]] : memref +// CHECK: %[[VAL_39:.*]] = cmpi "eq", %[[VAL_38]], %[[VAL_37]] : index +// CHECK: scf.if %[[VAL_39]] { +// CHECK: %[[VAL_40:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_36]]] : memref +// CHECK: %[[VAL_41:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_26]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_42:.*]] = addf %[[VAL_40]], %[[VAL_41]] : f32 +// CHECK: store %[[VAL_42]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_26]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: %[[VAL_43:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_26]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_43]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_26]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_44:.*]] = cmpi "eq", %[[VAL_38]], %[[VAL_37]] : index +// CHECK: %[[VAL_45:.*]] = addi %[[VAL_36]], %[[VAL_8]] : index +// CHECK: %[[VAL_46:.*]] = select %[[VAL_44]], %[[VAL_45]], %[[VAL_36]] : index +// CHECK: %[[VAL_47:.*]] = addi %[[VAL_37]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_46]], %[[VAL_47]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_48:.*]] = %[[VAL_49:.*]]#1 to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_50:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_26]], %[[VAL_48]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_50]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_26]], %[[VAL_48]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_51:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_52:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_53:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_51]], %[[VAL_52]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_53]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_51]], %[[VAL_52]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_54:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_23]] : index +// CHECK: %[[VAL_55:.*]] = addi %[[VAL_22]], %[[VAL_8]] : index +// CHECK: %[[VAL_56:.*]] = select %[[VAL_54]], %[[VAL_55]], %[[VAL_22]] : index +// CHECK: %[[VAL_57:.*]] = addi %[[VAL_23]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_56]], %[[VAL_57]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_60:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_61:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_62:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_58]], %[[VAL_60]], %[[VAL_61]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_62]], %[[VAL_15]]{{\[}}%[[VAL_58]], %[[VAL_60]], %[[VAL_61]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_63:.*]] = tensor_load %[[VAL_15]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_63]] : tensor<32x16x8xf32> +// CHECK: } +func @add_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_sds + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_sds( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 16 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_5]] { +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_18:.*]] = muli %[[VAL_15]], %[[VAL_3]] : index +// CHECK: %[[VAL_19:.*]] = addi %[[VAL_18]], %[[VAL_17]] : index +// CHECK: %[[VAL_20:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = addi %[[VAL_19]], %[[VAL_5]] : index +// CHECK: %[[VAL_22:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref +// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] { +// CHECK: %[[VAL_24:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_25:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_26:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_24]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_25]], %[[VAL_26]] : f32 +// CHECK: store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_24]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_28:.*]] = tensor_load %[[VAL_12]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_28]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_sds(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_sds + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_ssd = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "S", "S", "D" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_ssd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 8 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_15:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_16:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref +// CHECK: %[[VAL_18:.*]]:2 = scf.while (%[[VAL_19:.*]] = %[[VAL_16]], %[[VAL_20:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_21:.*]] = cmpi "ult", %[[VAL_19]], %[[VAL_17]] : index +// CHECK: scf.condition(%[[VAL_21]]) %[[VAL_19]], %[[VAL_20]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index): +// CHECK: %[[VAL_24:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_23]] : index +// CHECK: scf.if %[[VAL_25]] { +// CHECK: %[[VAL_26:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_27:.*]] = addi %[[VAL_22]], %[[VAL_8]] : index +// CHECK: %[[VAL_28:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_29:.*]]:2 = scf.while (%[[VAL_30:.*]] = %[[VAL_26]], %[[VAL_31:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_32:.*]] = cmpi "ult", %[[VAL_30]], %[[VAL_28]] : index +// CHECK: scf.condition(%[[VAL_32]]) %[[VAL_30]], %[[VAL_31]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_33:.*]]: index, %[[VAL_34:.*]]: index): +// CHECK: %[[VAL_35:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_33]]] : memref +// CHECK: %[[VAL_36:.*]] = cmpi "eq", %[[VAL_35]], %[[VAL_34]] : index +// CHECK: scf.if %[[VAL_36]] { +// CHECK: scf.for %[[VAL_37:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_38:.*]] = muli %[[VAL_33]], %[[VAL_5]] : index +// CHECK: %[[VAL_39:.*]] = addi %[[VAL_38]], %[[VAL_37]] : index +// CHECK: %[[VAL_40:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_39]]] : memref +// CHECK: %[[VAL_41:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_34]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_42:.*]] = addf %[[VAL_40]], %[[VAL_41]] : f32 +// CHECK: store %[[VAL_42]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]], %[[VAL_37]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_43:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_44:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_34]], %[[VAL_43]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_44]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_34]], %[[VAL_43]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_45:.*]] = cmpi "eq", %[[VAL_35]], %[[VAL_34]] : index +// CHECK: %[[VAL_46:.*]] = addi %[[VAL_33]], %[[VAL_8]] : index +// CHECK: %[[VAL_47:.*]] = select %[[VAL_45]], %[[VAL_46]], %[[VAL_33]] : index +// CHECK: %[[VAL_48:.*]] = addi %[[VAL_34]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_47]], %[[VAL_48]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_51:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_52:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_49]], %[[VAL_51]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_52]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_49]], %[[VAL_51]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_53:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_54:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_55:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_23]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_55]], %[[VAL_15]]{{\[}}%[[VAL_23]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_56:.*]] = cmpi "eq", %[[VAL_24]], %[[VAL_23]] : index +// CHECK: %[[VAL_57:.*]] = addi %[[VAL_22]], %[[VAL_8]] : index +// CHECK: %[[VAL_58:.*]] = select %[[VAL_56]], %[[VAL_57]], %[[VAL_22]] : index +// CHECK: %[[VAL_59:.*]] = addi %[[VAL_23]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_58]], %[[VAL_59]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_62:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_63:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_64:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_64]], %[[VAL_15]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_65:.*]] = tensor_load %[[VAL_15]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_65]] : tensor<32x16x8xf32> +// CHECK: } +func @add_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_ssd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_ssd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 8 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_13:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_5]] { +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_18:.*]] = addi %[[VAL_15]], %[[VAL_5]] : index +// CHECK: %[[VAL_19:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_5]] { +// CHECK: %[[VAL_21:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref +// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK: %[[VAL_23:.*]] = muli %[[VAL_20]], %[[VAL_3]] : index +// CHECK: %[[VAL_24:.*]] = addi %[[VAL_23]], %[[VAL_22]] : index +// CHECK: %[[VAL_25:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_26:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_16]], %[[VAL_21]], %[[VAL_22]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_25]], %[[VAL_26]] : f32 +// CHECK: store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_21]], %[[VAL_22]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_28:.*]] = tensor_load %[[VAL_12]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_28]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_ssd(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_ssd + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_sss = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j,k)>, // A + affine_map<(i,j,k) -> (i,j,k)>, // B + affine_map<(i,j,k) -> (i,j,k)> // X (out) + ], + sparse = [ + [ "S", "S", "S" ], // A + [ "D", "D", "D" ], // B + [ "D", "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)" +} + +// CHECK-LABEL: func @add_sss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 16 : index +// CHECK: %[[VAL_5:.*]] = constant 8 : index +// CHECK: %[[VAL_6:.*]] = constant true +// CHECK: %[[VAL_7:.*]] = constant 0 : index +// CHECK: %[[VAL_8:.*]] = constant 1 : index +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_13:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_14:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_15:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_16:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_17:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_18:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref +// CHECK: %[[VAL_19:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref +// CHECK: %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_18]], %[[VAL_22:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_23:.*]] = cmpi "ult", %[[VAL_21]], %[[VAL_19]] : index +// CHECK: scf.condition(%[[VAL_23]]) %[[VAL_21]], %[[VAL_22]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index): +// CHECK: %[[VAL_26:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_27:.*]] = cmpi "eq", %[[VAL_26]], %[[VAL_25]] : index +// CHECK: scf.if %[[VAL_27]] { +// CHECK: %[[VAL_28:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_29:.*]] = addi %[[VAL_24]], %[[VAL_8]] : index +// CHECK: %[[VAL_30:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref +// CHECK: %[[VAL_31:.*]]:2 = scf.while (%[[VAL_32:.*]] = %[[VAL_28]], %[[VAL_33:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_34:.*]] = cmpi "ult", %[[VAL_32]], %[[VAL_30]] : index +// CHECK: scf.condition(%[[VAL_34]]) %[[VAL_32]], %[[VAL_33]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_35:.*]]: index, %[[VAL_36:.*]]: index): +// CHECK: %[[VAL_37:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_38:.*]] = cmpi "eq", %[[VAL_37]], %[[VAL_36]] : index +// CHECK: scf.if %[[VAL_38]] { +// CHECK: %[[VAL_39:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_35]]] : memref +// CHECK: %[[VAL_40:.*]] = addi %[[VAL_35]], %[[VAL_8]] : index +// CHECK: %[[VAL_41:.*]] = load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref +// CHECK: %[[VAL_42:.*]]:2 = scf.while (%[[VAL_43:.*]] = %[[VAL_39]], %[[VAL_44:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_45:.*]] = cmpi "ult", %[[VAL_43]], %[[VAL_41]] : index +// CHECK: scf.condition(%[[VAL_45]]) %[[VAL_43]], %[[VAL_44]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_46:.*]]: index, %[[VAL_47:.*]]: index): +// CHECK: %[[VAL_48:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_46]]] : memref +// CHECK: %[[VAL_49:.*]] = cmpi "eq", %[[VAL_48]], %[[VAL_47]] : index +// CHECK: scf.if %[[VAL_49]] { +// CHECK: %[[VAL_50:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_46]]] : memref +// CHECK: %[[VAL_51:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_47]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_52:.*]] = addf %[[VAL_50]], %[[VAL_51]] : f32 +// CHECK: store %[[VAL_52]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_47]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: %[[VAL_53:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_47]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_47]]] : memref<32x16x8xf32> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_54:.*]] = cmpi "eq", %[[VAL_48]], %[[VAL_47]] : index +// CHECK: %[[VAL_55:.*]] = addi %[[VAL_46]], %[[VAL_8]] : index +// CHECK: %[[VAL_56:.*]] = select %[[VAL_54]], %[[VAL_55]], %[[VAL_46]] : index +// CHECK: %[[VAL_57:.*]] = addi %[[VAL_47]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_56]], %[[VAL_57]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#1 to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_60:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_58]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_60]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_58]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_61:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_62:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_61]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_62]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_36]], %[[VAL_61]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_63:.*]] = cmpi "eq", %[[VAL_37]], %[[VAL_36]] : index +// CHECK: %[[VAL_64:.*]] = addi %[[VAL_35]], %[[VAL_8]] : index +// CHECK: %[[VAL_65:.*]] = select %[[VAL_63]], %[[VAL_64]], %[[VAL_35]] : index +// CHECK: %[[VAL_66:.*]] = addi %[[VAL_36]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_65]], %[[VAL_66]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_67:.*]] = %[[VAL_68:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_69:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_70:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_25]], %[[VAL_67]], %[[VAL_69]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_70]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_67]], %[[VAL_69]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: scf.if %[[VAL_6]] { +// CHECK: scf.for %[[VAL_71:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_72:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_73:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_25]], %[[VAL_71]], %[[VAL_72]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_73]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_71]], %[[VAL_72]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_74:.*]] = cmpi "eq", %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_75:.*]] = addi %[[VAL_24]], %[[VAL_8]] : index +// CHECK: %[[VAL_76:.*]] = select %[[VAL_74]], %[[VAL_75]], %[[VAL_24]] : index +// CHECK: %[[VAL_77:.*]] = addi %[[VAL_25]], %[[VAL_8]] : index +// CHECK: scf.yield %[[VAL_76]], %[[VAL_77]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_78:.*]] = %[[VAL_79:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_80:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] { +// CHECK: scf.for %[[VAL_81:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] { +// CHECK: %[[VAL_82:.*]] = load %[[VAL_16]]{{\[}}%[[VAL_78]], %[[VAL_80]], %[[VAL_81]]] : memref<32x16x8xf32> +// CHECK: store %[[VAL_82]], %[[VAL_17]]{{\[}}%[[VAL_78]], %[[VAL_80]], %[[VAL_81]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_83:.*]] = tensor_load %[[VAL_17]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_83]] : tensor<32x16x8xf32> +// CHECK: } +func @add_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_sss + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +// CHECK-LABEL: func @mul_sss( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_6:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_2]]) : memref +// CHECK: %[[VAL_12:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_13:.*]] = alloc() : memref<32x16x8xf32> +// CHECK: %[[VAL_14:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_15:.*]] = load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_4]] { +// CHECK: %[[VAL_17:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_18:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref +// CHECK: %[[VAL_19:.*]] = addi %[[VAL_16]], %[[VAL_4]] : index +// CHECK: %[[VAL_20:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_19]]] : memref +// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_4]] { +// CHECK: %[[VAL_22:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_23:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref +// CHECK: %[[VAL_24:.*]] = addi %[[VAL_21]], %[[VAL_4]] : index +// CHECK: %[[VAL_25:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref +// CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_4]] { +// CHECK: %[[VAL_27:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_28:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_27]]] : memref<32x16x8xf32> +// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_28]], %[[VAL_29]] : f32 +// CHECK: store %[[VAL_30]], %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_27]]] : memref<32x16x8xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_31:.*]] = tensor_load %[[VAL_13]] : memref<32x16x8xf32> +// CHECK: return %[[VAL_31]] : tensor<32x16x8xf32> +// CHECK: } +func @mul_sss(%arga: tensor<32x16x8xf32>, %argb: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> { + %0 = linalg.generic #trait_sss + ins(%arga, %argb: tensor<32x16x8xf32>, tensor<32x16x8xf32>) { + ^bb(%a: f32, %b: f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<32x16x8xf32> + return %0 : tensor<32x16x8xf32> +} + +#trait_kernel_3d = { + indexing_maps = [ + affine_map<(i,j,k,l) -> (i,k,l)>, // B + affine_map<(i,j,k,l) -> (k,j)>, // C + affine_map<(i,j,k,l) -> (l,j)>, // D + affine_map<(i,j,k,l) -> (i,j)> // A (out) + ], + sparse = [ + [ "D", "D", "S" ], // B + [ "D", "D" ], // C + [ "D", "D" ], // D + [ "D", "D" ] // A + ], + iterator_types = ["parallel", "parallel", "reduction", "reduction"], + doc = "A(i,j) = SUM_k,l B(i,k,l) * C(k,j) * D(l,j)" +} + +// CHECK-LABEL: func @kernel_3d( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor, +// CHECK-SAME: %[[VAL_2:.*2]]: tensor, +// CHECK-SAME: %[[VAL_3:.*3]]: tensor) -> tensor { +// CHECK: %[[VAL_4:.*]] = constant 999 : index +// CHECK: %[[VAL_5:.*]] = constant 0 : index +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = dim %[[VAL_1]], %[[VAL_5]] : tensor +// CHECK: %[[VAL_8:.*]] = dim %[[VAL_1]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_9:.*]] = alloc(%[[VAL_4]]) : memref +// CHECK: %[[VAL_10:.*]] = alloc(%[[VAL_4]]) : memref +// CHECK: %[[VAL_11:.*]] = alloc(%[[VAL_4]]) : memref +// CHECK: %[[VAL_12:.*]] = dim %[[VAL_2]], %[[VAL_5]] : tensor +// CHECK: %[[VAL_13:.*]] = dim %[[VAL_2]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_14:.*]] = alloc(%[[VAL_12]], %[[VAL_13]]) : memref +// CHECK: %[[VAL_15:.*]] = dim %[[VAL_3]], %[[VAL_5]] : tensor +// CHECK: %[[VAL_16:.*]] = dim %[[VAL_3]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_17:.*]] = alloc(%[[VAL_15]], %[[VAL_16]]) : memref +// CHECK: %[[VAL_18:.*]] = dim %[[VAL_0]], %[[VAL_5]] : tensor +// CHECK: %[[VAL_19:.*]] = dim %[[VAL_0]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_20:.*]] = alloc(%[[VAL_18]], %[[VAL_19]]) : memref +// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_5]] to %[[VAL_7]] step %[[VAL_6]] { +// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_5]] to %[[VAL_8]] step %[[VAL_6]] { +// CHECK: %[[VAL_23:.*]] = muli %[[VAL_8]], %[[VAL_21]] : index +// CHECK: %[[VAL_24:.*]] = addi %[[VAL_23]], %[[VAL_22]] : index +// CHECK: %[[VAL_25:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_26:.*]] = addi %[[VAL_24]], %[[VAL_6]] : index +// CHECK: %[[VAL_27:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref +// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_25]] to %[[VAL_27]] step %[[VAL_6]] { +// CHECK: %[[VAL_29:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_28]]] : memref +// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] { +// CHECK: %[[VAL_31:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_32:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_30]]] : memref +// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_31]], %[[VAL_32]] : f32 +// CHECK: %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_29]], %[[VAL_30]]] : memref +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_33]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref +// CHECK: %[[VAL_37:.*]] = addf %[[VAL_35]], %[[VAL_36]] : f32 +// CHECK: store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_21]], %[[VAL_30]]] : memref +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_38:.*]] = tensor_load %[[VAL_20]] : memref +// CHECK: return %[[VAL_38]] : tensor +// CHECK: } +func @kernel_3d(%arga: tensor, + %argb: tensor, + %argc: tensor, + %argd: tensor) -> tensor { + %0 = linalg.generic #trait_kernel_3d + ins(%argb, %argc, %argd : tensor, tensor, tensor) + init(%arga : tensor) { + ^bb(%b: f32, %c: f32, %d : f32, %a : f32): + %0 = mulf %b, %c : f32 + %1 = mulf %0, %d : f32 + %2 = addf %1, %a : f32 + linalg.yield %2 : f32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -29,6 +29,7 @@ TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp TestSCFUtils.cpp + TestSparsification.cpp TestVectorTransforms.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestSparsification.cpp @@ -0,0 +1,42 @@ +//===- TestSparsification.cpp - Test sparsification of tensors ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TestSparsification + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + auto *ctx = &getContext(); + OwningRewritePatternList patterns; + linalg::populateSparsificationPatterns(ctx, patterns); + applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { + +void registerTestSparsification() { + PassRegistration sparsificationPass( + "test-sparsification", + "Test automatic geneneration of sparse tensor code"); +} + +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -85,6 +85,7 @@ void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); +void registerTestSparsification(); void registerTestVectorConversions(); } // namespace test } // namespace mlir @@ -148,6 +149,7 @@ test::registerTestOpaqueLoc(); test::registerTestRecursiveTypesPass(); test::registerTestSCFUtilsPass(); + test::registerTestSparsification(); test::registerTestVectorConversions(); } #endif