diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -778,8 +778,45 @@ //===----------------------------------------------------------------------===// // Support for sparse tensor code generation. //===----------------------------------------------------------------------===// -void populateSparsificationPatterns(MLIRContext *context, - OwningRewritePatternList &patterns); + +/// Defines a parallelization stategy for the sparse compiler. Any implicit loop +/// in the Linalg operation that is marked "parallel" (viz. no "reduction") is +/// a candidate for parallelization. The loop is made parallel if (1) allowed by +/// the strategy (e.g. AnyStorageOuterLoop considers either a dense or sparse +/// outermost loop only, and (2) the emitted code is an actual for-loop (and +/// not a co-iterating while-loop). +enum class SparseParallelizationStrategy { + kNone, + kDenseOuterLoop, + kAnyStorageOuterLoop, + kDenseAnyLoop, + kAnyStorageAnyLoop +}; + +/// Defines a vectorization strategy for the sparse compiler. Any implicit inner +/// in the Linalg operation is a candidate (full SIMD for "parallel" loops and +/// horizontal SIMD for "reduction" loops). A loop is actually vectorized if +/// (1) allowed by the strategy, and (2) the emitted code is an actual +/// for-loop (and not a co-iterating while-loop). +enum class SparseVectorizationStrategy { + kNone, + kDenseInnerLoop, + kAnyStorageInnerLoop +}; + +/// Sparsification options. +struct SparsificationOptions { + SparsificationOptions() + : parallelizationStrategy(SparseParallelizationStrategy::kNone), + vectorizationStrategy(SparseVectorizationStrategy::kNone) {} + SparseParallelizationStrategy parallelizationStrategy; + SparseVectorizationStrategy vectorizationStrategy; +}; + +/// Set up sparsification rewriting rules with the given options. +void populateSparsificationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + const SparsificationOptions &options = SparsificationOptions()); } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -54,13 +54,16 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; /// Tensor expression. Represents a MLIR expression in tensor index notation. -/// For tensors and invariants, e0 denotes the tensor index. For all binary -/// operations, e0 and e1 denote the index of the children tensor expressions. +/// For tensors, e0 denotes the tensor index. For invariants, the IR value is +/// stored directly. For binary operations, e0 and e1 denote the index of the +/// children tensor expressions. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y) : kind(k), e0(x), e1(y) {} + TensorExp(Kind k, unsigned x, unsigned y, Value v) + : kind(k), e0(x), e1(y), val(v) {} Kind kind; unsigned e0; unsigned e1; + Value val; }; /// Lattice point. Each lattice point consist of a conjunction of tensor @@ -85,11 +88,12 @@ : numTensors(t), numLoops(l), isSparse(t, std::vector(l, false)) {} /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u) { + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1)); + tensorExps.push_back(TensorExp(k, e0, e1, v)); return e; } + unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } /// Adds an iteration lattice point. Returns its index. unsigned addLat(unsigned t, unsigned i, unsigned e) { @@ -231,13 +235,16 @@ // Code generation. struct CodeGen { - CodeGen(unsigned numTensors, unsigned numLoops) - : loops(numLoops), sizes(numLoops), buffers(numTensors), + CodeGen(linalg::SparsificationOptions o, unsigned numTensors, + unsigned numLoops) + : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), pointers(numTensors, std::vector(numLoops)), indices(numTensors, std::vector(numLoops)), highs(numTensors, std::vector(numLoops)), pidxs(numTensors, std::vector(numLoops)), idxs(numTensors, std::vector(numLoops)) {} + // Sparsification options. + linalg::SparsificationOptions options; // Universal dense indices and upper bounds (by index). std::vector loops; std::vector sizes; @@ -339,7 +346,6 @@ /// building (compared to using the SSA representation everywhere). static Optional buildTensorExp(Merger &merger, linalg::GenericOp op, Value val) { - Operation *def = val.getDefiningOp(); if (auto arg = val.dyn_cast()) { unsigned argN = arg.getArgNumber(); if (arg.getOwner()->getParentOp() == op) { @@ -348,10 +354,16 @@ auto map = op.getIndexingMap(argN); if (map.isProjectedPermutation()) return merger.addExp(Kind::kTensor, argN); - } else { - // Any parameter of a higher op is invariant in the tensor expression. - return merger.addExp(Kind::kInvariant, argN); + // Cannot handle (yet). + return None; } + // Any parameter of a higher op is invariant. + return merger.addExp(Kind::kInvariant, val); + } + Operation *def = val.getDefiningOp(); + if (def->getBlock() != &op.region().front()) { + // Something defined outside is invariant. + return merger.addExp(Kind::kInvariant, val); } else if (def->getNumOperands() == 2) { // Construct binary operations if subexpressions could be built. auto x = buildTensorExp(merger, op, def->getOperand(0)); @@ -380,9 +392,12 @@ Kind kind = merger.exp(exp).kind; if (kind == Kind::kTensor || kind == Kind::kInvariant) { // Either the index is really used in the tensor expression, or it it - // set to the "non-existing dense index" in that dimension. + // set to the "non-existing dense index" in that dimension. Invariant + // expressions borrow the output tensor indices. unsigned s = merger.addSet(); - merger.set(s).push_back(merger.addLat(merger.exp(exp).e0, idx, exp)); + unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 + : op.getNumInputsAndOutputs() - 1; + merger.set(s).push_back(merger.addLat(t, idx, exp)); return s; } unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); @@ -502,7 +517,7 @@ if (merger.exp(exp).kind == Kind::kTensor) return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0); else if (merger.exp(exp).kind == Kind::kInvariant) - return op.getParentRegion()->front().getArgument(merger.exp(exp).e0); + return merger.exp(exp).val; Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); switch (merger.exp(exp).kind) { @@ -561,38 +576,81 @@ return needsUniv; } -/// Generates a for-loop or a while-loop, depending on whether it implements -/// singleton iteration or co-iteration over the given conjunction. -static void genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, - linalg::GenericOp op, unsigned idx, bool needsUniv, - llvm::BitVector &indices, scf::ForOp &forOp, - scf::WhileOp &whileOp) { +/// Generates a for-loop loop on a single index. +static Operation *genFor(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + bool isOuter, unsigned idx, llvm::BitVector &indices) { + unsigned fb = indices.find_first(); + unsigned tensor = merger.tensor(fb); + assert(idx == merger.index(fb)); + + // Parallelization strategy. Any implicit loop in the Linalg operation that + // is marked "parallel" is a candidate. Whether it is actually converted to + // a parallel operation depends on the requested strategy. + auto iteratorTypes = op.iterator_types().getValue(); + bool isSparse = merger.isSparseBit(fb); + bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]); + switch (codegen.options.parallelizationStrategy) { + case linalg::SparseParallelizationStrategy::kNone: + isParallel = false; + break; + case linalg::SparseParallelizationStrategy::kDenseOuterLoop: + isParallel &= isOuter && !isSparse; + break; + case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop: + isParallel &= isOuter; + break; + case linalg::SparseParallelizationStrategy::kDenseAnyLoop: + isParallel &= !isSparse; + break; + case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop: + break; + } + + // Loop bounds and increment. Location loc = op.getLoc(); + Value lo; + Value hi; + Value step = rewriter.create(loc, 1); + Value index; + if (isSparse) { + lo = codegen.pidxs[tensor][idx]; + hi = codegen.highs[tensor][idx]; + } else { + lo = codegen.loops[idx]; + hi = codegen.sizes[idx]; + } + + // Emit a parallel loop. + if (isParallel) { + scf::ParallelOp parOp = rewriter.create(loc, lo, hi, step); + if (isSparse) + codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; + else + codegen.loops[idx] = parOp.getInductionVars()[0]; + rewriter.setInsertionPointToStart(parOp.getBody()); + return parOp; + } + + // Emit a sequential loop. + scf::ForOp forOp = rewriter.create(loc, lo, hi, step); + if (isSparse) + codegen.pidxs[tensor][idx] = forOp.getInductionVar(); + else + codegen.loops[idx] = forOp.getInductionVar(); + rewriter.setInsertionPointToStart(forOp.getBody()); + return forOp; +} - // Emit a for-loop for a single index. - if (indices.count() == 1) { - unsigned fb = indices.find_first(); - unsigned tensor = merger.tensor(fb); - assert(idx == merger.index(fb)); - // Emit a sparse for-loop or a dense for-loop. - Value one = rewriter.create(loc, 1); - if (merger.isSparseBit(fb)) { - forOp = rewriter.create(loc, codegen.pidxs[tensor][idx], - codegen.highs[tensor][idx], one); - codegen.pidxs[tensor][idx] = forOp.getInductionVar(); - } else { - forOp = rewriter.create(loc, codegen.loops[idx], - codegen.sizes[idx], one); - codegen.loops[idx] = forOp.getInductionVar(); - } - rewriter.setInsertionPointToStart(forOp.getBody()); - return; - } - - // Otherwise, emit a while-loop for co-iteration. +/// Emit a while-loop for co-iteration over multiple indices. +static Operation *genWhile(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + unsigned idx, bool needsUniv, + llvm::BitVector &indices) { Type indexType = rewriter.getIndexType(); SmallVector types; SmallVector operands; + // Construct the while-loop with a parameter for each index. for (unsigned b = 0, be = indices.size(); b < be; b++) { if (indices[b] && merger.isSparseBit(b)) { unsigned tensor = merger.tensor(b); @@ -605,9 +663,11 @@ types.push_back(indexType); operands.push_back(codegen.loops[idx]); } - whileOp = rewriter.create(loc, types, operands); + Location loc = op.getLoc(); + scf::WhileOp whileOp = rewriter.create(loc, types, operands); Block *before = rewriter.createBlock(&whileOp.before(), {}, types); Block *after = rewriter.createBlock(&whileOp.after(), {}, types); + // Build the "before" region, which effectively consists // of a conjunction of "i < upper" tests on all induction. rewriter.setInsertionPointToStart(&whileOp.before().front()); @@ -629,6 +689,18 @@ assert(o == operands.size()); rewriter.create(loc, cond, before->getArguments()); rewriter.setInsertionPointToStart(&whileOp.after().front()); + return whileOp; +} + +/// Generates a for-loop or a while-loop, depending on whether it implements +/// singleton iteration or co-iteration over the given conjunction. +static Operation *genLoop(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + bool isOuter, unsigned idx, bool needsUniv, + llvm::BitVector &indices) { + if (indices.count() == 1) + return genFor(merger, codegen, rewriter, op, isOuter, idx, indices); + return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); } /// Generates the local variables for this loop, consisting of the sparse @@ -792,16 +864,16 @@ LatPoint lati = merger.lat(li); // Emit loop. - scf::ForOp forOp; - scf::WhileOp whileOp; llvm::BitVector indices = lati.bits; optimizeIndices(merger, lsize, indices); - genLoop(merger, codegen, rewriter, op, idx, needsUniv, indices, forOp, - whileOp); + bool isOuter = at == 0; + Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx, + needsUniv, indices); genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits); // Visit all lattices points with Li >= Lj to generate the // loop-body, possibly with if statements for coiteration. + bool isWhile = dyn_cast(loop) != nullptr; scf::IfOp ifOp; for (unsigned lj : merger.set(lts)) { if (li == lj || merger.latGT(li, lj)) { @@ -811,22 +883,22 @@ if (merger.hasAnyOf(tmp, false)) continue; // dense exhausted within if/else // Recurse into body of each branch. - if (whileOp) + if (isWhile) genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp); genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1); } } // Wrap-up induction and restore insertion point. - if (forOp) { - needsUniv = false; - rewriter.setInsertionPointAfter(forOp); - } else { + if (isWhile) { + scf::WhileOp whileOp = cast(loop); rewriter.setInsertionPointToEnd(&whileOp.after().front()); genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv, lati.bits, whileOp.results()); - rewriter.setInsertionPointAfter(whileOp); + } else { + needsUniv = false; } + rewriter.setInsertionPointAfter(loop); } } @@ -834,7 +906,9 @@ /// Sparse rewriting rule for generic Lingalg operation. struct GenericOpSparsifier : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + GenericOpSparsifier(MLIRContext *context, linalg::SparsificationOptions o) + : OpRewritePattern(context), options(o) {} LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { @@ -866,7 +940,7 @@ return failure(); // build failure // Recursively generates code. - CodeGen codegen(numTensors, numLoops); + CodeGen codegen(options, numTensors, numLoops); genBuffers(merger, codegen, rewriter, op); genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0); Value result = @@ -874,13 +948,18 @@ rewriter.replaceOp(op, result); return success(); } + +private: + /// Options to control sparse code generation. + linalg::SparsificationOptions options; }; } // namespace /// Populates the given patterns list with rewriting rules required for /// the sparsification of linear algebra operations. -void mlir::linalg::populateSparsificationPatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert(context); +void linalg::populateSparsificationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns, + const SparsificationOptions &options) { + patterns.insert(context, options); } diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir --- a/mlir/test/Dialect/Linalg/sparse_2d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir @@ -1106,6 +1106,56 @@ return %0 : tensor } +#trait_scale = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "D", "S" ], // A + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * SCALE" +} + +// CHECK-LABEL: func @scale( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = constant 2.000000e+00 : f64 +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] { +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64 +// CHECK: store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_19:.*]] = tensor_load %[[VAL_10]] : memref +// CHECK: return %[[VAL_19]] : tensor +// CHECK: } +func @scale(%arga: tensor) -> tensor { + %0 = constant 2.0 : f64 + %1 = linalg.generic #trait_scale + ins(%arga: tensor) { + ^bb(%a: f64): + %2 = mulf %a, %0 : f64 + linalg.yield %2 : f64 + } -> tensor + return %1 : tensor +} + #trait_sampled_dense_dense = { indexing_maps = [ affine_map<(i,j,k) -> (i,j)>, // S diff --git a/mlir/test/Dialect/Linalg/sparse_parallel.mlir b/mlir/test/Dialect/Linalg/sparse_parallel.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/sparse_parallel.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=0" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR0 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=1" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR1 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=2" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR2 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=3" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR3 +// RUN: mlir-opt %s -test-sparsification="parallelization-strategy=4" | \ +// RUN: FileCheck %s --check-prefix=CHECK-PAR4 + +#trait_dd = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "D", "D" ], // A + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * SCALE" +} + +// +// CHECK-PAR0-LABEL: func @scale_dd +// CHECK-PAR0: scf.for +// CHECK-PAR0: scf.for +// CHECK-PAR0: return +// +// CHECK-PAR1-LABEL: func @scale_dd +// CHECK-PAR1: scf.parallel +// CHECK-PAR1: scf.for +// CHECK-PAR1: return +// +// CHECK-PAR2-LABEL: func @scale_dd +// CHECK-PAR2: scf.parallel +// CHECK-PAR2: scf.for +// CHECK-PAR2: return +// +// CHECK-PAR3-LABEL: func @scale_dd +// CHECK-PAR3: scf.parallel +// CHECK-PAR3: scf.parallel +// CHECK-PAR3: return +// +// CHECK-PAR4-LABEL: func @scale_dd +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: return +// +func @scale_dd(%scale: f32, %arga: tensor) -> tensor { + %0 = linalg.generic #trait_dd + ins(%arga: tensor) { + ^bb(%a: f32): + %0 = mulf %a, %scale : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +#trait_ss = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "S", "S" ], // A + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * SCALE" +} + +// +// CHECK-PAR0-LABEL: func @scale_ss +// CHECK-PAR0: scf.for +// CHECK-PAR0: scf.for +// CHECK-PAR0: return +// +// CHECK-PAR1-LABEL: func @scale_ss +// CHECK-PAR1: scf.for +// CHECK-PAR1: scf.for +// CHECK-PAR1: return +// +// CHECK-PAR2-LABEL: func @scale_ss +// CHECK-PAR2: scf.parallel +// CHECK-PAR2: scf.for +// CHECK-PAR2: return +// +// CHECK-PAR3-LABEL: func @scale_ss +// CHECK-PAR3: scf.for +// CHECK-PAR3: scf.for +// CHECK-PAR3: return +// +// CHECK-PAR4-LABEL: func @scale_ss +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: return +// +func @scale_ss(%scale: f32, %arga: tensor) -> tensor { + %0 = linalg.generic #trait_ss + ins(%arga: tensor) { + ^bb(%a: f32): + %0 = mulf %a, %scale : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +#trait_matvec = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (j)>, // b + affine_map<(i,j) -> (i)> // x (out) + ], + sparse = [ + [ "D", "S" ], // A + [ "D" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel", "reduction"], + doc = "x(i) += A(i,j) * b(j)" +} + +// +// CHECK-PAR0-LABEL: func @matvec +// CHECK-PAR0: scf.for +// CHECK-PAR0: scf.for +// CHECK-PAR0: return +// +// CHECK-PAR1-LABEL: func @matvec +// CHECK-PAR1: scf.parallel +// CHECK-PAR1: scf.for +// CHECK-PAR1: return +// +// CHECK-PAR2-LABEL: func @matvec +// CHECK-PAR2: scf.parallel +// CHECK-PAR2: scf.for +// CHECK-PAR2: return +// +// CHECK-PAR3-LABEL: func @matvec +// CHECK-PAR3: scf.parallel +// CHECK-PAR3: scf.for +// CHECK-PAR3: return +// +// CHECK-PAR4-LABEL: func @matvec +// CHECK-PAR4: scf.parallel +// CHECK-PAR4: scf.for +// CHECK-PAR4: return +// +func @matvec(%argA: tensor<16x32xf32>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> { + %0 = linalg.generic #trait_matvec + ins(%argA, %argb : tensor<16x32xf32>, tensor<32xf32>) + init(%argx : tensor<16xf32>) { + ^bb(%A: f32, %b: f32, %x: f32): + %0 = mulf %A, %b : f32 + %1 = addf %0, %x : f32 + linalg.yield %1 : f32 + } -> tensor<16xf32> + return %0 : tensor<16xf32> +} diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp --- a/mlir/test/lib/Transforms/TestSparsification.cpp +++ b/mlir/test/lib/Transforms/TestSparsification.cpp @@ -16,13 +16,65 @@ struct TestSparsification : public PassWrapper { + + TestSparsification() = default; + TestSparsification(const TestSparsification &pass) {} + + Option parallelization{ + *this, "parallelization-strategy", + llvm::cl::desc("Set the parallelization strategy"), llvm::cl::init(0)}; + + Option vectorization{ + *this, "vectorization-strategy", + llvm::cl::desc("Set the vectorization strategy"), llvm::cl::init(0)}; + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } + void runOnFunction() override { auto *ctx = &getContext(); OwningRewritePatternList patterns; - linalg::populateSparsificationPatterns(ctx, patterns); + // Translate strategy flag to strategy options. + linalg::SparsificationOptions options; + switch (parallelization) { + case 0: + options.parallelizationStrategy = + linalg::SparseParallelizationStrategy::kNone; + break; + case 1: + options.parallelizationStrategy = + linalg::SparseParallelizationStrategy::kDenseOuterLoop; + break; + case 2: + options.parallelizationStrategy = + linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop; + break; + case 3: + options.parallelizationStrategy = + linalg::SparseParallelizationStrategy::kDenseAnyLoop; + break; + case 4: + options.parallelizationStrategy = + linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop; + break; + } + switch (vectorization) { + case 0: + options.vectorizationStrategy = + linalg::SparseVectorizationStrategy::kNone; + break; + case 1: + options.vectorizationStrategy = + linalg::SparseVectorizationStrategy::kDenseInnerLoop; + break; + case 2: + options.vectorizationStrategy = + linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop; + break; + } + // Apply rewriting. + linalg::populateSparsificationPatterns(ctx, patterns, options); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };