diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -783,9 +783,62 @@ //===----------------------------------------------------------------------===// // Support for sparse tensor code generation. +// +// The sparse compiler part of MLIR lowers a tensor expression formulated as a +// Linalg operation into a sequence of loops depending on what dimensions of the +// tensors are marked dense or sparse. The generated code distinguishes between: +// (1) for-loops that iterate over a single dense dimension, +// (2) for-loops that iterate over a single sparse dimension, +// (3) while-loops that co-iterate over several sparse dimensions. +// The for-loops may be subsequently optimized for parallel or vector execution. +// +// For more details, the Dialect/Linalg/Transforms/Sparsification.cpp file. //===----------------------------------------------------------------------===// -void populateSparsificationPatterns(MLIRContext *context, - OwningRewritePatternList &patterns); + +/// Defines a parallelization strategy. Any implicit loop in the Linalg +/// operation that is marked "parallel" (thus not "reduction") is a candidate +/// for parallelization. The loop is made parallel if (1) allowed by the +/// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse +/// outermost loop only), and (2) the generated code is an actual for-loop +/// (and not a co-iterating while-loop). +enum class SparseParallelizationStrategy { + kNone, + kDenseOuterLoop, + kAnyStorageOuterLoop, + kDenseAnyLoop, + kAnyStorageAnyLoop + // TODO: support reduction parallelization too? +}; + +/// Defines a vectorization strategy. Any implicit inner loop in the Linalg +/// operation is a candidate (full SIMD for "parallel" loops and horizontal +/// SIMD for "reduction" loops). A loop is actually vectorized if (1) allowed +/// by the strategy, and (2) the emitted code is an actual for-loop (and not +/// a co-iterating while-loop). +enum class SparseVectorizationStrategy { + kNone, + kDenseInnerLoop, + kAnyStorageInnerLoop +}; + +/// Sparsification options. +struct SparsificationOptions { + SparsificationOptions(SparseParallelizationStrategy p, + SparseVectorizationStrategy v, unsigned vl) + : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl) { + } + SparsificationOptions() + : SparsificationOptions(SparseParallelizationStrategy::kNone, + SparseVectorizationStrategy::kNone, 1u) {} + SparseParallelizationStrategy parallelizationStrategy; + SparseVectorizationStrategy vectorizationStrategy; + unsigned vectorLength; +}; + +/// Set up sparsification rewriting rules with the given options. +void populateSparsificationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + const SparsificationOptions &options = SparsificationOptions()); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -235,22 +235,30 @@ // Code generation. struct CodeGen { - CodeGen(unsigned numTensors, unsigned numLoops) - : loops(numLoops), sizes(numLoops), buffers(numTensors), + CodeGen(linalg::SparsificationOptions o, unsigned numTensors, + unsigned numLoops) + : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), pointers(numTensors, std::vector(numLoops)), indices(numTensors, std::vector(numLoops)), highs(numTensors, std::vector(numLoops)), pidxs(numTensors, std::vector(numLoops)), idxs(numTensors, std::vector(numLoops)) {} - // Universal dense indices and upper bounds (by index). + // Sparsification options. + linalg::SparsificationOptions options; + // Universal dense indices and upper bounds (by index). The loops array + // is updated with the value of the universal dense index in the current + // loop. The sizes array is set once with the inferred dimension sizes. std::vector loops; std::vector sizes; // Buffers for storing dense and sparse numerical values (by tensor). + // This array is set once during bufferization of all tensors. std::vector buffers; // Sparse storage schemes (1-D): pointers and indices (by tensor and index). + // This array is set once during bufferization of all sparse tensors. std::vector> pointers; std::vector> indices; - // Sparse iteration information (by tensor and index). + // Sparse iteration information (by tensor and index). These arrays + // are updated to remain current within the current loop. std::vector> highs; std::vector> pidxs; std::vector> idxs; @@ -388,7 +396,7 @@ unsigned exp, unsigned idx) { Kind kind = merger.exp(exp).kind; if (kind == Kind::kTensor || kind == Kind::kInvariant) { - // Either the index is really used in the tensor expression, or it it + // Either the index is really used in the tensor expression, or it is // set to the "non-existing dense index" in that dimension. Invariant // expressions borrow the output tensor indices. unsigned s = merger.addSet(); @@ -573,38 +581,81 @@ return needsUniv; } -/// Generates a for-loop or a while-loop, depending on whether it implements -/// singleton iteration or co-iteration over the given conjunction. -static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, - linalg::GenericOp op, unsigned idx, bool needsUniv, - llvm::BitVector &indices, scf::ForOp &forOp, - scf::WhileOp &whileOp) { +/// Generates a for-loop on a single index. +static Operation *genFor(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + bool isOuter, unsigned idx, llvm::BitVector &indices) { + unsigned fb = indices.find_first(); + unsigned tensor = merger.tensor(fb); + assert(idx == merger.index(fb)); + + // Parallelization strategy. Any implicit loop in the Linalg operation that + // is marked "parallel" is a candidate. Whether it is actually converted to + // a parallel operation depends on the requested strategy. + auto iteratorTypes = op.iterator_types().getValue(); + bool isSparse = merger.isSparseBit(fb); + bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]); + switch (codegen.options.parallelizationStrategy) { + case linalg::SparseParallelizationStrategy::kNone: + isParallel = false; + break; + case linalg::SparseParallelizationStrategy::kDenseOuterLoop: + isParallel &= isOuter && !isSparse; + break; + case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop: + isParallel &= isOuter; + break; + case linalg::SparseParallelizationStrategy::kDenseAnyLoop: + isParallel &= !isSparse; + break; + case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop: + break; + } + + // Loop bounds and increment. Location loc = op.getLoc(); + Value lo; + Value hi; + Value step = rewriter.create(loc, 1); + Value index; + if (isSparse) { + lo = codegen.pidxs[tensor][idx]; + hi = codegen.highs[tensor][idx]; + } else { + lo = codegen.loops[idx]; + hi = codegen.sizes[idx]; + } + + // Emit a parallel loop. + if (isParallel) { + scf::ParallelOp parOp = rewriter.create(loc, lo, hi, step); + if (isSparse) + codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; + else + codegen.loops[idx] = parOp.getInductionVars()[0]; + rewriter.setInsertionPointToStart(parOp.getBody()); + return parOp; + } + + // Emit a sequential loop. + scf::ForOp forOp = rewriter.create(loc, lo, hi, step); + if (isSparse) + codegen.pidxs[tensor][idx] = forOp.getInductionVar(); + else + codegen.loops[idx] = forOp.getInductionVar(); + rewriter.setInsertionPointToStart(forOp.getBody()); + return forOp; +} - // Emit a for-loop for a single index. - if (indices.count() == 1) { - unsigned fb = indices.find_first(); - unsigned tensor = merger.tensor(fb); - assert(idx == merger.index(fb)); - // Emit a sparse for-loop or a dense for-loop. - Value one = rewriter.create(loc, 1); - if (merger.isSparseBit(fb)) { - forOp = rewriter.create(loc, codegen.pidxs[tensor][idx], - codegen.highs[tensor][idx], one); - codegen.pidxs[tensor][idx] = forOp.getInductionVar(); - } else { - forOp = rewriter.create(loc, codegen.loops[idx], - codegen.sizes[idx], one); - codegen.loops[idx] = forOp.getInductionVar(); - } - rewriter.setInsertionPointToStart(forOp.getBody()); - return; - } - - // Otherwise, emit a while-loop for co-iteration. - Type indexType = rewriter.getIndexType(); +/// Emit a while-loop for co-iteration over multiple indices. +static Operation *genWhile(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + unsigned idx, bool needsUniv, + llvm::BitVector &indices) { SmallVector types; SmallVector operands; + // Construct the while-loop with a parameter for each index. + Type indexType = rewriter.getIndexType(); for (unsigned b = 0, be = indices.size(); b < be; b++) { if (indices[b] && merger.isSparseBit(b)) { unsigned tensor = merger.tensor(b); @@ -617,9 +668,11 @@ types.push_back(indexType); operands.push_back(codegen.loops[idx]); } - whileOp = rewriter.create(loc, types, operands); + Location loc = op.getLoc(); + scf::WhileOp whileOp = rewriter.create(loc, types, operands); Block *before = rewriter.createBlock(&whileOp.before(), {}, types); Block *after = rewriter.createBlock(&whileOp.after(), {}, types); + // Build the "before" region, which effectively consists // of a conjunction of "i < upper" tests on all induction. rewriter.setInsertionPointToStart(&whileOp.before().front()); @@ -641,6 +694,18 @@ assert(o == operands.size()); rewriter.create(loc, cond, before->getArguments()); rewriter.setInsertionPointToStart(&whileOp.after().front()); + return whileOp; +} + +/// Generates a for-loop or a while-loop, depending on whether it implements +/// singleton iteration or co-iteration over the given conjunction. +static Operation *genLoop(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + bool isOuter, unsigned idx, bool needsUniv, + llvm::BitVector &indices) { + if (indices.count() == 1) + return genFor(merger, codegen, rewriter, op, isOuter, idx, indices); + return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); } /// Generates the local variables for this loop, consisting of the sparse @@ -804,16 +869,16 @@ LatPoint lati = merger.lat(li); // Emit loop. - scf::ForOp forOp; - scf::WhileOp whileOp; llvm::BitVector indices = lati.bits; optimizeIndices(merger, lsize, indices); - genLoop(merger, codegen, rewriter, op, idx, needsUniv, indices, forOp, - whileOp); + bool isOuter = at == 0; + Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx, + needsUniv, indices); genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. + bool isWhile = dyn_cast(loop) != nullptr; scf::IfOp ifOp; for (unsigned lj : merger.set(lts)) { if (li == lj || merger.latGT(li, lj)) { @@ -823,22 +888,22 @@ if (merger.hasAnyOf(tmp, false)) continue; // dense exhausted within if/else // Recurse into body of each branch. - if (whileOp) + if (isWhile) genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp); genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1); } } // Wrap-up induction and restore insertion point. - if (forOp) { - needsUniv = false; - rewriter.setInsertionPointAfter(forOp); - } else { + if (isWhile) { + scf::WhileOp whileOp = cast(loop); rewriter.setInsertionPointToEnd(&whileOp.after().front()); genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, lati.bits, whileOp.results()); - rewriter.setInsertionPointAfter(whileOp); + } else { + needsUniv = false; } + rewriter.setInsertionPointAfter(loop); } } @@ -846,7 +911,9 @@ /// Sparse rewriting rule for generic Lingalg operation. struct GenericOpSparsifier : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + GenericOpSparsifier(MLIRContext *context, linalg::SparsificationOptions o) + : OpRewritePattern(context), options(o) {} LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { @@ -878,7 +945,7 @@ return failure(); // build failure // Recursively generates code. - CodeGen codegen(numTensors, numLoops); + CodeGen codegen(options, numTensors, numLoops); genBuffers(merger, codegen, rewriter, op); genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); Value result = @@ -886,13 +953,18 @@ rewriter.replaceOp(op, result); return success(); } + +private: + /// Options to control sparse code generation. + linalg::SparsificationOptions options; }; } // namespace /// Populates the given patterns list with rewriting rules required for /// the sparsification of linear algebra operations. -void mlir::linalg::populateSparsificationPatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert(context); +void linalg::populateSparsificationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + const SparsificationOptions &options) { + patterns.insert(context, options); } diff --git a/mlir/test/Dialect/Linalg/sparse_parallel.mlir b/mlir/test/Dialect/Linalg/sparse_parallel.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/sparse_parallel.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=0" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR0 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=1" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR1 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=2" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR2 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=3" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR3 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=4" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR4 + +#trait_dd = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "D", "D" ], // A + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * SCALE" +} + +// +// CHECK-PAR0-LABEL: func @scale_dd +// CHECK-PAR0: scf.for +// CHECK-PAR0: scf.for +// CHECK-PAR0: return +// +// CHECK-PAR1-LABEL: func @scale_dd +// CHECK-PAR1: scf.parallel +// CHECK-PAR1: scf.for +// CHECK-PAR1: return +// +// CHECK-PAR2-LABEL: func @scale_dd +// CHECK-PAR2: scf.parallel +// CHECK-PAR2: scf.for +// CHECK-PAR2: return +// +// CHECK-PAR3-LABEL: func @scale_dd +// CHECK-PAR3: scf.parallel +// CHECK-PAR3: scf.parallel +// CHECK-PAR3: return +// +// CHECK-PAR4-LABEL: func @scale_dd +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: return +// +func @scale_dd(%scale: f32, %arga: tensor) -> tensor { + %0 = linalg.generic #trait_dd + ins(%arga: tensor) { + ^bb(%a: f32): + %0 = mulf %a, %scale : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +#trait_ss = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "S", "S" ], // A + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * SCALE" +} + +// +// CHECK-PAR0-LABEL: func @scale_ss +// CHECK-PAR0: scf.for +// CHECK-PAR0: scf.for +// CHECK-PAR0: return +// +// CHECK-PAR1-LABEL: func @scale_ss +// CHECK-PAR1: scf.for +// CHECK-PAR1: scf.for +// CHECK-PAR1: return +// +// CHECK-PAR2-LABEL: func @scale_ss +// CHECK-PAR2: scf.parallel +// CHECK-PAR2: scf.for +// CHECK-PAR2: return +// +// CHECK-PAR3-LABEL: func @scale_ss +// CHECK-PAR3: scf.for +// CHECK-PAR3: scf.for +// CHECK-PAR3: return +// +// CHECK-PAR4-LABEL: func @scale_ss +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: return +// +func @scale_ss(%scale: f32, %arga: tensor) -> tensor { + %0 = linalg.generic #trait_ss + ins(%arga: tensor) { + ^bb(%a: f32): + %0 = mulf %a, %scale : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +#trait_matvec = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (j)>, // b + affine_map<(i,j) -> (i)> // x (out) + ], + sparse = [ + [ "D", "S" ], // A + [ "D" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel", "reduction"], + doc = "x(i) += A(i,j) * b(j)" +} + +// +// CHECK-PAR0-LABEL: func @matvec +// CHECK-PAR0: scf.for +// CHECK-PAR0: scf.for +// CHECK-PAR0: return +// +// CHECK-PAR1-LABEL: func @matvec +// CHECK-PAR1: scf.parallel +// CHECK-PAR1: scf.for +// CHECK-PAR1: return +// +// CHECK-PAR2-LABEL: func @matvec +// CHECK-PAR2: scf.parallel +// CHECK-PAR2: scf.for +// CHECK-PAR2: return +// +// CHECK-PAR3-LABEL: func @matvec +// CHECK-PAR3: scf.parallel +// CHECK-PAR3: scf.for +// CHECK-PAR3: return +// +// CHECK-PAR4-LABEL: func @matvec +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: scf.for +// CHECK-PAR4: return +// +func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> { + %0 = linalg.generic #trait_matvec + ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>) + init(%argx : tensor<16xf32>) { + ^bb(%A: f32, %b: f32, %x: f32): + %0 = mulf %A, %b : f32 + %1 = addf %0, %x : f32 + linalg.yield %1 : f32 + } -> tensor<16xf32> + return %0 : tensor<16xf32> +} diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp --- a/mlir/test/lib/Transforms/TestSparsification.cpp +++ b/mlir/test/lib/Transforms/TestSparsification.cpp @@ -16,13 +16,63 @@ struct TestSparsification : public PassWrapper { + + TestSparsification() = default; + TestSparsification(const TestSparsification &pass) {} + + Option parallelization{ + *this, "parallelization-strategy", + llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)}; + + Option vectorization{ + *this, "vectorization-strategy", + llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)}; + + Option vectorLength{ + *this, "vl", llvm::cl::desc("Set the vector length"), llvm::cl::init(1)}; + + /// Registers all dialects required by testing. void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } + + /// Returns parallelization strategy given on command line. + linalg::SparseParallelizationStrategy parallelOption() { + switch (parallelization) { + default: + return linalg::SparseParallelizationStrategy::kNone; + case 1: + return linalg::SparseParallelizationStrategy::kDenseOuterLoop; + case 2: + return linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop; + case 3: + return linalg::SparseParallelizationStrategy::kDenseAnyLoop; + case 4: + return linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop; + } + } + + /// Returns vectorization strategy given on command line. + linalg::SparseVectorizationStrategy vectorOption() { + switch (vectorization) { + default: + return linalg::SparseVectorizationStrategy::kNone; + case 1: + return linalg::SparseVectorizationStrategy::kDenseInnerLoop; + case 2: + return linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop; + } + } + + /// Runs the test on a function. void runOnFunction() override { auto *ctx = &getContext(); OwningRewritePatternList patterns; - linalg::populateSparsificationPatterns(ctx, patterns); + // Translate strategy flags to strategy options. + linalg::SparsificationOptions options(parallelOption(), vectorOption(), + vectorLength); + // Apply rewriting. + linalg::populateSparsificationPatterns(ctx, patterns, options); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };