diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms BufferizableOpInterfaceImpl.cpp + CodegenEnv.cpp CodegenUtils.cpp SparseBufferRewriting.cpp SparseTensorCodegen.cpp diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h @@ -0,0 +1,136 @@ +//===- CodegenEnv.h - Code generation environment class ---0-----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines the code generation environment class. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENENV_H_ +#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENENV_H_ + +#include "CodegenUtils.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" + +namespace mlir { +namespace sparse_tensor { + +/// The code generation environment class aggregates a number of data +/// structures that are needed during the code generation phase of +/// sparsification. This environment simplifies passing around such +/// data during sparsification (rather than passing around all the +/// individual compoments where needed). Furthermore, it provides +/// a number of delegate and convience methods that keep some of the +/// implementation details transparent to sparsification. +class CodegenEnv { +public: + CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, + unsigned numTensors, unsigned numLoops, unsigned numFilterLoops); + + // Start emitting. + void startEmit(SparseTensorLoopEmitter *le); + + // Delegate methods to merger. + TensorExp &exp(unsigned e) { return merger.exp(e); } + LatPoint &lat(unsigned l) { return merger.lat(l); } + SmallVector &set(unsigned s) { return merger.set(s); } + DimLevelType dimLevelType(unsigned t, unsigned i) const { + return merger.getDimLevelType(t, i); + } + DimLevelType dimLevelType(unsigned b) const { + return merger.getDimLevelType(b); + } + bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); } + + // Delegate methods to loop emitter. + Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); } + const std::vector &getValBuffer() const { + return loopEmitter->getValBuffer(); + } + + // Convenience method to slice topsort. + ArrayRef getTopSortSlice(size_t n, size_t m) const { + return ArrayRef(topSort).slice(n, m); + } + + // Convenience method to get current loop stack. + ArrayRef getLoopCurStack() const { + return getTopSortSlice(0, loopEmitter->getCurrentDepth()); + } + + // Convenience method to get the IV of the given loop index. + Value getLoopIdxValue(size_t loopIdx) const { + for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++) + if (topSort[lv] == loopIdx) + return getLoopIV(lv); + llvm_unreachable("invalid loop index"); + } + + // + // Reductions. + // + + void startReduc(unsigned exp, Value val); + void updateReduc(Value val); + bool isReduc() const { return redExp != -1u; } + Value getReduc() const { return redVal; } + Value endReduc(); + + void startCustomReduc(unsigned exp); + bool isCustomReduc() const { return redCustom != -1u; } + Value getCustomRedId(); + void endCustomReduc(); + +public: + // + // TODO make this section private too, using similar refactoring as for reduc + // + + // Linalg operation. + linalg::GenericOp linalgOp; + + // Sparsification options. + SparsificationOptions options; + + // Topological sort. + std::vector topSort; + + // Merger helper class. + Merger merger; + + // Loop emitter helper class (keep reference in scope!). + // TODO: move emitter constructor up in time? + SparseTensorLoopEmitter *loopEmitter; + + // Sparse tensor as output. Implemented either through direct injective + // insertion in lexicographic index order or through access pattern expansion + // in the innermost loop nest (`expValues` through `expCount`). + OpOperand *sparseOut; + unsigned outerParNest; + Value insChain; // bookkeeping for insertion chain + Value expValues; + Value expFilled; + Value expAdded; + Value expCount; + +private: + // Bookkeeping for reductions (up-to-date value of the reduction, and indices + // into the merger's expression tree. When the indices of a tensor reduction + // expression are exhausted, all inner loops can use a scalarized reduction. + Value redVal; + unsigned redExp; + unsigned redCustom; +}; + +} // namespace sparse_tensor +} // namespace mlir + +#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENENV_H_ diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp @@ -0,0 +1,69 @@ +//===- CodegenEnv.cpp - Code generation environment class ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CodegenEnv.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +//===----------------------------------------------------------------------===// +// Code generation environment constructor and setup +//===----------------------------------------------------------------------===// + +CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts, + unsigned numTensors, unsigned numLoops, + unsigned numFilterLoops) + : linalgOp(linop), options(opts), topSort(), + merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr), + sparseOut(nullptr), redVal(nullptr), redExp(-1u), redCustom(-1u) {} + +void CodegenEnv::startEmit(SparseTensorLoopEmitter *le) { + assert(!loopEmitter && "must only start emitting once"); + loopEmitter = le; + if (sparseOut) { + insChain = sparseOut->get(); + merger.setHasSparseOut(true); + } +} + +//===----------------------------------------------------------------------===// +// Code generation environment methods +//===----------------------------------------------------------------------===// + +void CodegenEnv::startReduc(unsigned exp, Value val) { + assert(redExp == -1u && exp != -1u); + redExp = exp; + updateReduc(val); +} + +void CodegenEnv::updateReduc(Value val) { + assert(redExp != -1u); + redVal = exp(redExp).val = val; +} + +Value CodegenEnv::endReduc() { + Value val = redVal; + updateReduc(Value()); + redExp = -1u; + return val; +} + +void CodegenEnv::startCustomReduc(unsigned exp) { + assert(redCustom == -1u && exp != -1u); + redCustom = exp; +} + +Value CodegenEnv::getCustomRedId() { + assert(redCustom != -1u); + return dyn_cast(exp(redCustom).op).getIdentity(); +} + +void CodegenEnv::endCustomReduc() { + assert(redCustom != -1u); + redCustom = -1u; +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "CodegenEnv.h" #include "CodegenUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -36,7 +37,7 @@ using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// -// Declarations of data structures. +// Declarations //===----------------------------------------------------------------------===// namespace { @@ -49,100 +50,6 @@ kIncludeAll = 0x3 }; -/// Reduction kinds. -enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom }; - -/// Code generation environment. This structure aggregates a number -/// of data structures needed during code generation. Such an environment -/// simplifies passing around data during sparsification (rather than -/// passing around all the individual compoments where needed). -// -// TODO: refactor further, move into own file -// -struct CodeGenEnv { - CodeGenEnv(linalg::GenericOp linop, SparsificationOptions opts, - unsigned numTensors, unsigned numLoops, unsigned numFilterLoops) - : linalgOp(linop), options(opts), topSort(), - merger(numTensors, numLoops, numFilterLoops), loopEmitter(nullptr), - redExp(-1u), redKind(kNoReduc), redCustom(-1u), sparseOut(nullptr) {} - - // Start emitting. - void startEmit(SparseTensorLoopEmitter *le) { - assert(!loopEmitter && "must only start emitting once"); - loopEmitter = le; - if (sparseOut) { - insChain = sparseOut->get(); - merger.setHasSparseOut(true); - } - } - - // Delegate methods to merger. - TensorExp &exp(unsigned e) { return merger.exp(e); } - LatPoint &lat(unsigned l) { return merger.lat(l); } - SmallVector &set(unsigned s) { return merger.set(s); } - DimLevelType dimLevelType(unsigned t, unsigned i) const { - return merger.getDimLevelType(t, i); - } - DimLevelType dimLevelType(unsigned b) const { - return merger.getDimLevelType(b); - } - bool isFilterLoop(unsigned i) const { return merger.isFilterLoop(i); } - - // Delegate methods to loop emitter. - Value getLoopIV(unsigned i) const { return loopEmitter->getLoopIV(i); } - const std::vector &getValBuffer() const { - return loopEmitter->getValBuffer(); - } - - // Convenience method to slice topsort. - ArrayRef getTopSortSlice(size_t n, size_t m) const { - return ArrayRef(topSort).slice(n, m); - } - - // Convenience method to get current loop stack. - ArrayRef getLoopCurStack() const { - return getTopSortSlice(0, loopEmitter->getCurrentDepth()); - } - - // Convenience method to get the IV of the given loop index. - Value getLoopIdxValue(size_t loopIdx) const { - for (unsigned lv = 0, lve = topSort.size(); lv < lve; lv++) - if (topSort[lv] == loopIdx) - return getLoopIV(lv); - llvm_unreachable("invalid loop index"); - } - - // TODO: make private - - /// Linalg operation. - linalg::GenericOp linalgOp; - /// Sparsification options. - SparsificationOptions options; - // Topological sort. - std::vector topSort; - /// Merger helper class. - Merger merger; - /// Loop emitter helper class (keep reference in scope!). - /// TODO: move emitter constructor up in time? - SparseTensorLoopEmitter *loopEmitter; - /// Current reduction, updated during code generation. When indices of a - /// reduction are exhausted, all inner loops can use a scalarized reduction. - unsigned redExp; - Value redVal; - Reduction redKind; - unsigned redCustom; - /// Sparse tensor as output. Implemented either through direct injective - /// insertion in lexicographic index order or through access pattern expansion - /// in the innermost loop nest (`expValues` through `expCount`). - OpOperand *sparseOut; - unsigned outerParNest; - Value insChain; // bookkeeping for insertion chain - Value expValues; - Value expFilled; - Value expAdded; - Value expCount; -}; - /// A helper class that visits an affine expression and tries to find an /// AffineDimExpr to which the corresponding iterator from a GenericOp matches /// the desired iterator type. @@ -212,14 +119,14 @@ } /// Determines if affine expression is invariant. -static bool isInvariantAffine(CodeGenEnv &env, AffineExpr a, unsigned ldx, +static bool isInvariantAffine(CodegenEnv &env, AffineExpr a, unsigned ldx, bool &atLevel) { return isInvariantAffine(a, env.getLoopCurStack(), ldx, atLevel); } /// Helper method to construct a permuted dimension ordering /// that adheres to the given topological sort. -static AffineMap permute(CodeGenEnv &env, AffineMap m) { +static AffineMap permute(CodegenEnv &env, AffineMap m) { assert(m.getNumDims() + env.merger.getNumFilterLoops() == env.topSort.size() && "size mismatch"); @@ -346,7 +253,7 @@ /// Returns true if the sparse annotations and affine subscript /// expressions of all tensors are admissible. Returns false if /// no annotations are found or inadmissible constructs occur. -static bool findSparseAnnotations(CodeGenEnv &env) { +static bool findSparseAnnotations(CodegenEnv &env) { bool annotated = false; unsigned filterLdx = env.merger.getFilterLoopStartingIdx(); for (OpOperand &t : env.linalgOp->getOpOperands()) { @@ -371,7 +278,7 @@ /// as we use adj matrix for the graph. /// The sorted result will put the first Reduction iterator to the /// latest possible index. -static bool topSortOptimal(CodeGenEnv &env, unsigned n, +static bool topSortOptimal(CodegenEnv &env, unsigned n, ArrayRef iteratorTypes, std::vector &inDegree, std::vector> &adjM) { @@ -517,7 +424,7 @@ /// along fixed dimensions. Even for dense storage formats, however, the /// natural index order yields innermost unit-stride access with better /// spatial locality. -static bool computeIterationGraph(CodeGenEnv &env, unsigned mask, +static bool computeIterationGraph(CodegenEnv &env, unsigned mask, OpOperand *skip = nullptr) { // Set up an n x n from/to adjacency matrix of the iteration graph // for the implicit loop indices i_0 .. i_n-1. @@ -614,12 +521,12 @@ val.getDefiningOp(); } -/// Returns true when the tensor expression is admissible for env. +/// Returns true when the tensor expression is admissible for codegen. /// Since all sparse input tensors are admissible, we just need to check /// whether the out tensor in the tensor expression codegen is admissible. /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective /// nesting depth when a "truly dynamic" sparse tensor output occurs. -static bool isAdmissibleTensorExp(CodeGenEnv &env, unsigned exp) { +static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp) { // We reject any expression that makes a reduction from `-outTensor`, as those // expressions create a dependency between the current iteration (i) and the // previous iteration (i-1). It would require iterating over the whole @@ -693,48 +600,6 @@ return false; } -//===----------------------------------------------------------------------===// -// Sparse compiler synthesis methods (reductions). -//===----------------------------------------------------------------------===// - -/// Maps operation to reduction. -static Reduction getReduction(Kind kind) { - switch (kind) { - case Kind::kAddF: - case Kind::kAddC: - case Kind::kAddI: - case Kind::kSubF: - case Kind::kSubC: - case Kind::kSubI: - return kSum; - case Kind::kMulF: - case Kind::kMulC: - case Kind::kMulI: - return kProduct; - case Kind::kAndI: - return kAnd; - case Kind::kOrI: - return kOr; - case Kind::kXorI: - return kXor; - case Kind::kReduce: - return kCustom; - default: - llvm_unreachable("unexpected reduction operator"); - } -} - -/// Updates scalarized reduction value. -static void updateReduc(CodeGenEnv &env, Value reduc) { - assert(env.redKind != kNoReduc); - env.redVal = env.exp(env.redExp).val = reduc; -} - -/// Extracts identity from custom reduce. -static Value getCustomRedId(Operation *op) { - return dyn_cast(op).getIdentity(); -} - //===----------------------------------------------------------------------===// // Sparse compiler synthesis methods (statements and expressions). //===----------------------------------------------------------------------===// @@ -742,12 +607,12 @@ /// Generates loop boundary statements (entering/exiting loops). The function /// passes and updates the reduction value. static Optional genLoopBoundary( - CodeGenEnv &env, + CodegenEnv &env, function_ref(MutableArrayRef reduc)> callback) { SmallVector reduc; - if (env.redVal) - reduc.push_back(env.redVal); + if (env.isReduc()) + reduc.push_back(env.getReduc()); if (env.expValues) reduc.push_back(env.expCount); if (env.insChain) @@ -757,8 +622,8 @@ // Callback should do in-place update on reduction value vector. unsigned i = 0; - if (env.redVal) - updateReduc(env, reduc[i++]); + if (env.isReduc()) + env.updateReduc(reduc[i++]); if (env.expValues) env.expCount = reduc[i++]; if (env.insChain) @@ -768,7 +633,7 @@ } /// Local bufferization of all dense and sparse data structures. -static void genBuffers(CodeGenEnv &env, OpBuilder &builder) { +static void genBuffers(CodegenEnv &env, OpBuilder &builder) { linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); assert(op.getNumOperands() == op.getNumDpsInputs() + 1); @@ -810,7 +675,7 @@ } /// Generates index for load/store on sparse tensor. -static Value genIndex(CodeGenEnv &env, OpOperand *t) { +static Value genIndex(CodegenEnv &env, OpOperand *t) { auto map = env.linalgOp.getMatchingIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); @@ -820,7 +685,7 @@ } /// Generates subscript for load/store on a dense or sparse tensor. -static Value genSubscript(CodeGenEnv &env, OpBuilder &builder, OpOperand *t, +static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t, SmallVectorImpl &args) { linalg::GenericOp op = env.linalgOp; unsigned tensor = t->getOperandNumber(); @@ -841,7 +706,7 @@ } /// Generates insertion code to implement dynamic tensor load. -static Value genInsertionLoad(CodeGenEnv &env, OpBuilder &builder, +static Value genInsertionLoad(CodegenEnv &env, OpBuilder &builder, OpOperand *t) { linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); @@ -856,15 +721,14 @@ } /// Generates insertion code to implement dynamic tensor load for reduction. -static Value genInsertionLoadReduce(CodeGenEnv &env, OpBuilder &builder, +static Value genInsertionLoadReduce(CodegenEnv &env, OpBuilder &builder, OpOperand *t) { linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); - Value identity = getCustomRedId(env.exp(env.redCustom).op); + Value identity = env.getCustomRedId(); // Direct lexicographic index order, tensor loads as identity. - if (!env.expValues) { + if (!env.expValues) return identity; - } // Load from expanded access pattern if filled, identity otherwise. Value index = genIndex(env, t); Value isFilled = builder.create(loc, env.expFilled, index); @@ -873,7 +737,7 @@ } /// Generates insertion code to implement dynamic tensor store. -static void genInsertionStore(CodeGenEnv &env, OpBuilder &builder, OpOperand *t, +static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t, Value rhs) { linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); @@ -920,7 +784,7 @@ } /// Generates a load on a dense or sparse tensor. -static Value genTensorLoad(CodeGenEnv &env, OpBuilder &builder, unsigned exp) { +static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, unsigned exp) { // Test if the load was hoisted to a higher loop nest. Value val = env.exp(exp).val; if (val) @@ -930,7 +794,7 @@ linalg::GenericOp op = env.linalgOp; OpOperand &t = op->getOpOperand(env.exp(exp).tensor); if (&t == env.sparseOut) { - if (env.redCustom != -1u) + if (env.isCustomReduc()) return genInsertionLoadReduce(env, builder, &t); return genInsertionLoad(env, builder, &t); } @@ -941,13 +805,13 @@ } /// Generates a store on a dense or sparse tensor. -static void genTensorStore(CodeGenEnv &env, OpBuilder &builder, unsigned exp, +static void genTensorStore(CodegenEnv &env, OpBuilder &builder, unsigned exp, Value rhs) { linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); // Test if this is a scalarized reduction. - if (env.redVal) { - updateReduc(env, rhs); + if (env.isReduc()) { + env.updateReduc(rhs); return; } // Store during insertion. @@ -989,12 +853,12 @@ } /// Generates an invariant value. -inline static Value genInvariantValue(CodeGenEnv &env, unsigned exp) { +inline static Value genInvariantValue(CodegenEnv &env, unsigned exp) { return env.exp(exp).val; } /// Generates an index value. -inline static Value genIndexValue(CodeGenEnv &env, unsigned idx) { +inline static Value genIndexValue(CodegenEnv &env, unsigned idx) { return env.getLoopIdxValue(idx); } @@ -1003,7 +867,7 @@ /// branch or otherwise invariantly defined outside the loop nest, with the /// exception of index computations, which need to be relinked to actual /// inlined cloned code. -static Value relinkBranch(CodeGenEnv &env, RewriterBase &rewriter, Block *block, +static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block, Value e, unsigned ldx) { if (Operation *def = e.getDefiningOp()) { if (auto indexOp = dyn_cast(def)) @@ -1018,7 +882,7 @@ } /// Recursively generates tensor expression. -static Value genExp(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp, +static Value genExp(CodegenEnv &env, RewriterBase &rewriter, unsigned exp, unsigned ldx) { linalg::GenericOp op = env.linalgOp; Location loc = op.getLoc(); @@ -1032,11 +896,8 @@ if (env.exp(exp).kind == Kind::kIndex) return genIndexValue(env, env.exp(exp).index); - // Make custom reduction identity accessible for expanded access pattern. - if (env.exp(exp).kind == Kind::kReduce) { - assert(env.redCustom == -1u); - env.redCustom = exp; - } + if (env.exp(exp).kind == Kind::kReduce) + env.startCustomReduc(exp); // enter custom Value v0 = genExp(env, rewriter, env.exp(exp).children.e0, ldx); Value v1 = genExp(env, rewriter, env.exp(exp).children.e1, ldx); @@ -1048,20 +909,20 @@ env.exp(exp).kind == Kind::kSelect)) ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx); + if (env.exp(exp).kind == Kind::kReduce) + env.endCustomReduc(); // exit custom + if (env.exp(exp).kind == kSelect) { assert(!env.exp(exp).val); env.exp(exp).val = v0; // Preserve value for later use. - } else if (env.exp(exp).kind == Kind::kReduce) { - assert(env.redCustom != -1u); - env.redCustom = -1u; } return ee; } /// Hoists loop invariant tensor loads for which indices have been exhausted. -static void genInvariants(CodeGenEnv &env, OpBuilder &builder, unsigned exp, - unsigned ldx, bool atStart, unsigned last = -1u) { +static void genInvariants(CodegenEnv &env, OpBuilder &builder, unsigned exp, + unsigned ldx, bool atStart) { if (exp == -1u) return; if (env.exp(exp).kind == Kind::kTensor) { @@ -1090,18 +951,11 @@ if (lhs == &t) { // Start or end a scalarized reduction if (atStart) { - Kind kind = env.exp(last).kind; - Value load = kind == Kind::kReduce ? getCustomRedId(env.exp(last).op) - : genTensorLoad(env, builder, exp); - env.redKind = getReduction(kind); - env.redExp = exp; - updateReduc(env, load); + Value load = env.isCustomReduc() ? env.getCustomRedId() + : genTensorLoad(env, builder, exp); + env.startReduc(exp, load); } else { - Value redVal = env.redVal; - updateReduc(env, Value()); - env.redExp = -1u; - env.redKind = kNoReduc; - genTensorStore(env, builder, exp, redVal); + genTensorStore(env, builder, exp, env.endReduc()); } } else { // Start or end loop invariant hoisting of a tensor load. @@ -1112,21 +966,25 @@ // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. + if (env.exp(exp).kind == Kind::kReduce) + env.startCustomReduc(exp); // enter custom unsigned e0 = env.exp(exp).children.e0; unsigned e1 = env.exp(exp).children.e1; - genInvariants(env, builder, e0, ldx, atStart, exp); - genInvariants(env, builder, e1, ldx, atStart, exp); + genInvariants(env, builder, e0, ldx, atStart); + genInvariants(env, builder, e1, ldx, atStart); + if (env.exp(exp).kind == Kind::kReduce) + env.endCustomReduc(); // exit custom } } /// Generates an expanded access pattern in innermost dimension. -static void genExpansion(CodeGenEnv &env, OpBuilder &builder, unsigned at, +static void genExpansion(CodegenEnv &env, OpBuilder &builder, unsigned at, bool atStart) { linalg::GenericOp op = env.linalgOp; OpOperand *lhs = env.sparseOut; if (!lhs || env.outerParNest != op.getRank(lhs) - 1 || at != env.outerParNest) return; // not needed at this level - assert(env.redVal == nullptr); + assert(!env.isReduc()); // Generate start or end of an expanded access pattern. Note that because // an expension does not rely on the ongoing contents of the sparse storage // scheme, we can use the original tensor as incoming SSA value (which @@ -1166,7 +1024,7 @@ /// Returns parallelization strategy. Any implicit loop in the Linalg /// operation that is marked "parallel" is a candidate. Whether it is actually /// converted to a parallel operation depends on the requested strategy. -static bool isParallelFor(CodeGenEnv &env, bool isOuter, bool isSparse) { +static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) { // Reject parallelization of sparse output. if (env.sparseOut) return false; @@ -1190,7 +1048,7 @@ } /// Generates a for-loop on a single index. -static Operation *genFor(CodeGenEnv &env, OpBuilder &builder, bool isOuter, +static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter, bool isInner, unsigned idx, size_t tid, size_t dim, ArrayRef extraTids, ArrayRef extraDims) { @@ -1222,13 +1080,14 @@ } /// Emit a while-loop for co-iteration over multiple indices. -static Operation *genWhile(CodeGenEnv &env, OpBuilder &builder, unsigned idx, +static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx, bool needsUniv, ArrayRef condTids, ArrayRef condDims, ArrayRef extraTids, ArrayRef extraDims) { Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef reduc) { - // Construct the while-loop with a parameter for each index. + // Construct the while-loop with a parameter for each + // index. return env.loopEmitter->enterCoIterationOverTensorsAtDims( builder, env.linalgOp.getLoc(), condTids, condDims, needsUniv, reduc, extraTids, extraDims); @@ -1239,7 +1098,7 @@ /// Generates a for-loop or a while-loop, depending on whether it implements /// singleton iteration or co-iteration over the given conjunction. -static Operation *genLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at, +static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, unsigned at, bool needsUniv, ArrayRef condTids, ArrayRef condDims, ArrayRef extraTids, ArrayRef extraDims) { @@ -1257,19 +1116,19 @@ } /// Generates the induction structure for a while-loop. -static void finalizeWhileOp(CodeGenEnv &env, OpBuilder &builder, unsigned idx, +static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, unsigned idx, bool needsUniv, BitVector &induction, scf::WhileOp whileOp) { Location loc = env.linalgOp.getLoc(); // Finalize each else branch of all if statements. - if (env.redVal || env.expValues || env.insChain) { + if (env.isReduc() || env.expValues || env.insChain) { while (auto ifOp = dyn_cast_or_null( builder.getInsertionBlock()->getParentOp())) { unsigned y = 0; SmallVector yields; - if (env.redVal) { - yields.push_back(env.redVal); - updateReduc(env, ifOp.getResult(y++)); + if (env.isReduc()) { + yields.push_back(env.getReduc()); + env.updateReduc(ifOp.getResult(y++)); } if (env.expValues) { yields.push_back(env.expCount); @@ -1288,7 +1147,7 @@ } /// Generates a single if-statement within a while-loop. -static scf::IfOp genIf(CodeGenEnv &env, OpBuilder &builder, unsigned idx, +static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, unsigned idx, BitVector &conditions) { Location loc = env.linalgOp.getLoc(); SmallVector types; @@ -1313,8 +1172,8 @@ } cond = cond ? builder.create(loc, cond, clause) : clause; } - if (env.redVal) - types.push_back(env.redVal.getType()); + if (env.isReduc()) + types.push_back(env.getReduc().getType()); if (env.expValues) types.push_back(builder.getIndexType()); if (env.insChain) @@ -1325,13 +1184,13 @@ } /// Generates end of true branch of if-statement within a while-loop. -static void endIf(CodeGenEnv &env, OpBuilder &builder, scf::IfOp ifOp, +static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp, Operation *loop, Value redInput, Value cntInput, Value insInput) { SmallVector operands; - if (env.redVal) { - operands.push_back(env.redVal); - updateReduc(env, redInput); + if (env.isReduc()) { + operands.push_back(env.getReduc()); + env.updateReduc(redInput); } if (env.expValues) { operands.push_back(env.expCount); @@ -1352,7 +1211,7 @@ /// Starts a loop sequence at given level. Returns true if /// the universal loop index must be maintained at this level. -static bool startLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp, +static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at, unsigned idx, unsigned ldx, unsigned lts) { assert(!env.getLoopIdxValue(idx)); @@ -1394,7 +1253,7 @@ return false; } -static void genConstantDenseAddressFromLevel(CodeGenEnv &env, +static void genConstantDenseAddressFromLevel(CodegenEnv &env, OpBuilder &builder, unsigned tid, unsigned lvl) { // TODO: Handle affine expression on output tensor. @@ -1416,7 +1275,7 @@ } } -static void genInitConstantDenseAddress(CodeGenEnv &env, +static void genInitConstantDenseAddress(CodegenEnv &env, RewriterBase &rewriter) { // We can generate address for constant affine expression before any loops // starting from the first level as they do not depend on any thing. @@ -1427,7 +1286,7 @@ } static void translateBitsToTidDimPairs( - CodeGenEnv &env, unsigned li, unsigned idx, + CodegenEnv &env, unsigned li, unsigned idx, SmallVectorImpl &condTids, SmallVectorImpl &condDims, SmallVectorImpl &extraTids, SmallVectorImpl &extraDims, SmallVectorImpl &affineTids, SmallVectorImpl &affineDims, @@ -1513,7 +1372,7 @@ } /// Starts a single loop in current sequence. -static Operation *startLoop(CodeGenEnv &env, OpBuilder &builder, unsigned at, +static Operation *startLoop(CodegenEnv &env, OpBuilder &builder, unsigned at, unsigned li, bool needsUniv) { // The set of tensors + dims to generate loops on SmallVector condTids, condDims; @@ -1551,7 +1410,7 @@ } /// Ends a single loop in current sequence. Returns new values for needsUniv. -static bool endLoop(CodeGenEnv &env, RewriterBase &rewriter, Operation *loop, +static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop, unsigned idx, unsigned li, bool needsUniv) { // End a while-loop. if (auto whileOp = dyn_cast(loop)) { @@ -1569,7 +1428,7 @@ } /// Ends a loop sequence at given level. -static void endLoopSeq(CodeGenEnv &env, OpBuilder &builder, unsigned exp, +static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp, unsigned at, unsigned idx, unsigned ldx) { assert(env.getLoopIdxValue(idx) == nullptr); env.loopEmitter->exitCurrentLoopSeq(); @@ -1582,7 +1441,7 @@ /// Recursively generates code while computing iteration lattices in order /// to manage the complexity of implementing co-iteration over unions /// and intersections of sparse iterations spaces. -static void genStmt(CodeGenEnv &env, RewriterBase &rewriter, unsigned exp, +static void genStmt(CodegenEnv &env, RewriterBase &rewriter, unsigned exp, unsigned at) { // At each leaf, assign remaining tensor (sub)expression to output tensor. if (at == env.topSort.size()) { @@ -1612,7 +1471,7 @@ // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. - Value redInput = env.redVal; + Value redInput = env.getReduc(); Value cntInput = env.expCount; Value insInput = env.insChain; bool isWhile = dyn_cast(loop) != nullptr; @@ -1640,7 +1499,7 @@ } /// Converts the result computed by the sparse kernel into the required form. -static void genResult(CodeGenEnv &env, RewriterBase &rewriter) { +static void genResult(CodegenEnv &env, RewriterBase &rewriter) { linalg::GenericOp op = env.linalgOp; OpOperand *lhs = op.getDpsInitOperand(0); Value tensor = lhs->get(); @@ -1683,7 +1542,7 @@ unsigned numTensors = op->getNumOperands(); unsigned numLoops = op.getNumLoops(); unsigned numFilterLoops = getNumCompoundAffineOnSparseDims(op); - CodeGenEnv env(op, options, numTensors, numLoops, numFilterLoops); + CodegenEnv env(op, options, numTensors, numLoops, numFilterLoops); // Detects sparse annotations and translates the per-dimension sparsity // information for all tensors to loop indices in the kernel. @@ -1744,7 +1603,7 @@ private: // Last resort cycle resolution. - LogicalResult resolveCycle(CodeGenEnv &env, PatternRewriter &rewriter) const { + LogicalResult resolveCycle(CodegenEnv &env, PatternRewriter &rewriter) const { // Compute topological sort while leaving out every // sparse input tensor in succession until an acylic // iteration graph results.