diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -867,7 +867,13 @@ SparseVectorizationStrategy v, unsigned vl, SparseIntType pt, SparseIntType it) : parallelizationStrategy(p), vectorizationStrategy(v), vectorLength(vl), - ptrType(pt), indType(it) {} + ptrType(pt), indType(it) { + // TODO: remove restriction when vectors with index elements are supported + assert((v != SparseVectorizationStrategy::kAnyStorageInnerLoop || + (ptrType != SparseIntType::kNative && + indType != SparseIntType::kNative)) && + "This combination requires support for vectors with index elements"); + } SparsificationOptions() : SparsificationOptions(SparseParallelizationStrategy::kNone, SparseVectorizationStrategy::kNone, 1u, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -46,6 +46,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Matchers.h" using namespace mlir; @@ -301,7 +302,8 @@ indices(numTensors, std::vector(numLoops)), highs(numTensors, std::vector(numLoops)), pidxs(numTensors, std::vector(numLoops)), - idxs(numTensors, std::vector(numLoops)), redExp(-1u), redVal() {} + idxs(numTensors, std::vector(numLoops)), redExp(-1u), redVal(), + curVecLength(1), curVecMask() {} /// Sparsification options. linalg::SparsificationOptions options; /// Universal dense indices and upper bounds (by index). The loops array @@ -327,6 +329,9 @@ // is most effective; we could generalize to more outer and while-loops. unsigned redExp; Value redVal; + // Current vector length and mask. + unsigned curVecLength; + Value curVecMask; }; } // namespace @@ -558,6 +563,71 @@ } } +/// Constructs vector type from pointer. +static VectorType vectorType(CodeGen &codegen, Value ptr) { + Type etp = ptr.getType().cast().getElementType(); + return VectorType::get(codegen.curVecLength, etp); +} + +/// Constructs vector iteration mask. +static Value genVectorMask(CodeGen &codegen, PatternRewriter &rewriter, + Value iv, Value lo, Value hi, Value step) { + Location loc = iv.getLoc(); + VectorType mtp = + VectorType::get(codegen.curVecLength, rewriter.getIntegerType(1)); + // Special case if the vector length evenly divides the trip count (for + // example, "for i = 0, 128, 16"). A constant all-true mask is generated + // so that all subsequent masked memory operations are immediately folded + // into unconditional memory operations. + IntegerAttr loInt, hiInt, stepInt; + if (matchPattern(lo, m_Constant(&loInt)) && + matchPattern(hi, m_Constant(&hiInt)) && + matchPattern(step, m_Constant(&stepInt))) { + if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) + return rewriter.create( + loc, mtp, rewriter.getI64ArrayAttr(codegen.curVecLength)); + } + // Otherwise, generate a vector mask that avoids overrunning the upperbound + // during vector execution. Here we rely on subsequent loop optimizations to + // avoid executing the mask in all iterations, for example, by splitting the + // loop into an unconditional vector loop and a scalar cleanup loop. + Value end = rewriter.create(loc, hi, iv); + return rewriter.create(loc, mtp, end); +} + +/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. +static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter, + Value ptr, ArrayRef args) { + Location loc = ptr.getLoc(); + VectorType vtp = vectorType(codegen, ptr); + Value pass = rewriter.create(loc, vtp, rewriter.getZeroAttr(vtp)); + if (args.back().getType().isa()) + return rewriter.create(loc, vtp, ptr, args.back(), + codegen.curVecMask, pass); + return rewriter.create(loc, vtp, ptr, args, + codegen.curVecMask, pass); +} + +/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. +static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter, + Value rhs, Value ptr, ArrayRef args) { + Location loc = ptr.getLoc(); + if (args.back().getType().isa()) + rewriter.create(loc, ptr, args.back(), + codegen.curVecMask, rhs); + else + rewriter.create(loc, ptr, args, codegen.curVecMask, + rhs); +} + +/// Generates a vectorized invariant. Here we rely on subsequent loop +/// optimizations to hoist the invariant broadcast out of the vector loop. +static Value genVectorInvariantValue(CodeGen &codegen, + PatternRewriter &rewriter, Value val) { + VectorType vtp = VectorType::get(codegen.curVecLength, val.getType()); + return rewriter.create(val.getLoc(), vtp, val); +} + /// Generates a load on a dense or sparse tensor. static Value genTensorLoad(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, @@ -582,6 +652,8 @@ } Location loc = op.getLoc(); Value ptr = codegen.buffers[tensor]; + if (codegen.curVecLength > 1) + return genVectorLoad(codegen, rewriter, ptr, args); return rewriter.create(loc, ptr, args); } @@ -595,7 +667,7 @@ codegen.redVal = rhs; return; } - // Actual load. + // Actual store. SmallVector args; auto map = op.getIndexingMap(tensor); for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { @@ -604,12 +676,17 @@ } Location loc = op.getLoc(); Value ptr = codegen.buffers[tensor]; - rewriter.create(loc, rhs, ptr, args); + if (codegen.curVecLength > 1) + genVectorStore(codegen, rewriter, rhs, ptr, args); + else + rewriter.create(loc, rhs, ptr, args); } /// Generates a pointer/index load from the sparse storage scheme. -static Value genLoad(PatternRewriter &rewriter, Location loc, Value ptr, - Value s) { +static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc, + Value ptr, Value s) { + if (codegen.curVecLength > 1) + return genVectorLoad(codegen, rewriter, ptr, {s}); Value load = rewriter.create(loc, ptr, s); return load.getType().isa() ? load @@ -619,7 +696,10 @@ /// Generates an invariant value. static Value genInvariantValue(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, unsigned exp) { - return merger.exp(exp).val; + Value val = merger.exp(exp).val; + if (codegen.curVecLength > 1) + return genVectorInvariantValue(codegen, rewriter, val); + return val; } /// Recursively generates tensor expression. @@ -707,9 +787,9 @@ Value one = rewriter.create(loc, 1); Value p0 = (pat == 0) ? rewriter.create(loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; - codegen.pidxs[tensor][idx] = genLoad(rewriter, loc, ptr, p0); + codegen.pidxs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p0); Value p1 = rewriter.create(loc, p0, one); - codegen.highs[tensor][idx] = genLoad(rewriter, loc, ptr, p1); + codegen.highs[tensor][idx] = genLoad(codegen, rewriter, loc, ptr, p1); } else { // Dense index still in play. needsUniv = true; @@ -722,6 +802,39 @@ return needsUniv; } +/// Returns vectorization strategy. Any implicit inner loop in the Linalg +/// operation is a candidate. Whether it is actually converted to SIMD code +/// depends on the requested strategy. +static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { + switch (codegen.options.vectorizationStrategy) { + case linalg::SparseVectorizationStrategy::kNone: + return false; + case linalg::SparseVectorizationStrategy::kDenseInnerLoop: + return isInner && !isSparse; + case linalg::SparseVectorizationStrategy::kAnyStorageInnerLoop: + return isInner; + } +} + +/// Returns parallelization strategy. Any implicit loop in the Linalg operation +/// that is marked "parallel" is a candidate. Whether it is actually converted +/// to a parallel operation depends on the requested strategy. +static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, + bool isSparse, bool isVector) { + switch (codegen.options.parallelizationStrategy) { + case linalg::SparseParallelizationStrategy::kNone: + return false; + case linalg::SparseParallelizationStrategy::kDenseOuterLoop: + return isOuter && !isSparse && !isReduction && !isVector; + case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop: + return isOuter && !isReduction && !isVector; + case linalg::SparseParallelizationStrategy::kDenseAnyLoop: + return !isSparse && !isReduction && !isVector; + case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop: + return !isReduction && !isVector; + } +} + /// Generates a for-loop on a single index. static Operation *genFor(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, @@ -730,46 +843,26 @@ unsigned fb = indices.find_first(); unsigned tensor = merger.tensor(fb); assert(idx == merger.index(fb)); - - // Parallelization strategy. Any implicit loop in the Linalg operation that - // is marked "parallel" is a candidate. Whether it is actually converted to - // a parallel operation depends on the requested strategy. auto iteratorTypes = op.iterator_types().getValue(); + bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); bool isSparse = merger.isDim(fb, Dim::kSparse); - bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]); - switch (codegen.options.parallelizationStrategy) { - case linalg::SparseParallelizationStrategy::kNone: - isParallel = false; - break; - case linalg::SparseParallelizationStrategy::kDenseOuterLoop: - isParallel &= isOuter && !isSparse; - break; - case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop: - isParallel &= isOuter; - break; - case linalg::SparseParallelizationStrategy::kDenseAnyLoop: - isParallel &= !isSparse; - break; - case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop: - break; - } + bool isVector = isVectorFor(codegen, isInner, isSparse); + bool isParallel = + isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); + + // Prepare vector length. + if (isVector) + codegen.curVecLength = codegen.options.vectorLength; // Loop bounds and increment. Location loc = op.getLoc(); - Value lo; - Value hi; - Value step = rewriter.create(loc, 1); - Value index; - if (isSparse) { - lo = codegen.pidxs[tensor][idx]; - hi = codegen.highs[tensor][idx]; - } else { - lo = codegen.loops[idx]; - hi = codegen.sizes[idx]; - } + Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; + Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; + Value step = rewriter.create(loc, codegen.curVecLength); // Emit a parallel loop. if (isParallel) { + assert(!isVector); scf::ParallelOp parOp = rewriter.create(loc, lo, hi, step); if (isSparse) codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; @@ -783,10 +876,16 @@ bool scalarRed = isInner && codegen.redExp != -1u; SmallVector operands; if (scalarRed) { - Value load = - codegen.redVal - ? codegen.redVal // chained with previous for-loop - : genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); + Value load; + if (codegen.redVal) { + load = codegen.redVal; // chained with previous for-loop + } else if (isVector) { + // TODO: assumes + reductions for now + VectorType vtp = vectorType(codegen, codegen.buffers[codegen.redExp]); + load = rewriter.create(loc, vtp, rewriter.getZeroAttr(vtp)); + } else { + load = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); + } operands.push_back(load); } scf::ForOp forOp = rewriter.create(loc, lo, hi, step, operands); @@ -795,11 +894,15 @@ forOp.getRegionIterArgs().front(); } // Assign induction variable to sparse or dense index. + Value iv = forOp.getInductionVar(); if (isSparse) - codegen.pidxs[tensor][idx] = forOp.getInductionVar(); + codegen.pidxs[tensor][idx] = iv; else - codegen.loops[idx] = forOp.getInductionVar(); + codegen.loops[idx] = iv; rewriter.setInsertionPointToStart(forOp.getBody()); + // Share vector iteration mask between all subsequent loads/stores. + if (isVector) + codegen.curVecMask = genVectorMask(codegen, rewriter, iv, lo, hi, step); return forOp; } @@ -886,7 +989,7 @@ assert(idx == merger.index(b)); Value ptr = codegen.indices[tensor][idx]; Value s = codegen.pidxs[tensor][idx]; - Value load = genLoad(rewriter, loc, ptr, s); + Value load = genLoad(codegen, rewriter, loc, ptr, s); codegen.idxs[tensor][idx] = load; if (!needsUniv) { if (min) { @@ -998,7 +1101,7 @@ // Then emit initialization code for the loop sequence at this level. // We maintain the universal dense index if dense indices are still // in play for a non-singleton loop sequence. - // Location loc = op.getLoc(); + Location loc = op.getLoc(); unsigned idx = topSort[at]; unsigned lts = merger.optimizeSet(buildLattices(merger, op, exp, idx)); unsigned lsize = merger.set(lts).size(); @@ -1015,6 +1118,7 @@ unsigned li = merger.set(lts)[i]; // Emit loop. + codegen.curVecLength = 1; llvm::BitVector indices = merger.lat(li).simple; Operation *loop = genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); @@ -1055,7 +1159,7 @@ } else { needsUniv = false; if (codegen.redVal) { - rewriter.create(op.getLoc(), codegen.redVal); + rewriter.create(loc, codegen.redVal); codegen.redVal = loop->getResult(0); } } @@ -1067,10 +1171,16 @@ if (red) { codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain unsigned lhs = op.getNumShapedOperands() - 1; + if (codegen.curVecLength > 1) { + codegen.curVecLength = 1; + Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); + red = rewriter.create( + loc, ld.getType(), rewriter.getStringAttr("add"), red, ld); + } genTensorStore(merger, codegen, rewriter, op, lhs, red); } - codegen.loops[idx] = Value(); genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); + codegen.loops[idx] = Value(); } namespace { diff --git a/mlir/test/Dialect/Linalg/sparse_vector.mlir b/mlir/test/Dialect/Linalg/sparse_vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/sparse_vector.mlir @@ -0,0 +1,310 @@ +// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=0 ptr-type=2 ind-type=2 vl=16" | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC0 +// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=1 ptr-type=2 ind-type=2 vl=16" | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC1 +// RUN: mlir-opt %s -test-sparsification="vectorization-strategy=2 ptr-type=2 ind-type=2 vl=16" | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC2 + +#trait_scale_d = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "D" ], // a + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) * b" +} + +// +// CHECK-VEC0-LABEL: func @scale_d +// CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC0-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC0: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] { +// CHECK-VEC0: %[[l:.*]] = load %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK-VEC0: %[[m:.*]] = mulf %[[l]], %{{.*}} : f32 +// CHECK-VEC0: store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK-VEC0: } +// CHECK-VEC0: return +// +// CHECK-VEC1-LABEL: func @scale_d +// CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC1-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC1-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] { +// CHECK-VEC1: %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32> +// CHECK-VEC1: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32> +// CHECK-VEC1: %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32> +// CHECK-VEC1: vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32> +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2-LABEL: func @scale_d +// CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC2-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] { +// CHECK-VEC2: %[[r:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %{{.*}} {masked = [false]} : memref<1024xf32>, vector<16xf32> +// CHECK-VEC2: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = mulf %[[r]], %[[b]] : vector<16xf32> +// CHECK-VEC2: vector.transfer_write %[[m]], %{{.*}}[%[[i]]] {masked = [false]} : vector<16xf32>, memref<1024xf32> +// CHECK-VEC2: } +// CHECK-VEC2: return +// +func @scale_d(%arga: tensor<1024xf32>, %scale: f32) -> tensor<1024xf32> { + %0 = linalg.generic #trait_scale_d + ins(%arga: tensor<1024xf32>) + outs(%arga: tensor<1024xf32>) { + ^bb(%a: f32, %s : f32): + %0 = mulf %a, %scale : f32 + linalg.yield %0 : f32 + } -> tensor<1024xf32> + return %0 : tensor<1024xf32> +} + +#trait_mul_s = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "S" ], // a + [ "D" ], // b + [ "D" ] // x + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) * b(i)" +} + +// +// CHECK-VEC0-LABEL: func @mul_s +// CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC0: %[[p:.*]] = load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC0: %[[q:.*]] = index_cast %[[p]] : i32 to index +// CHECK-VEC0: %[[r:.*]] = load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC0: %[[s:.*]] = index_cast %[[r]] : i32 to index +// CHECK-VEC0: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK-VEC0: %[[li:.*]] = load %{{.*}}[%[[i]]] : memref +// CHECK-VEC0: %[[ci:.*]] = index_cast %[[li]] : i32 to index +// CHECK-VEC0: %[[la:.*]] = load %{{.*}}[%[[i]]] : memref +// CHECK-VEC0: %[[lb:.*]] = load %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK-VEC0: %[[m:.*]] = mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC0: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK-VEC0: } +// CHECK-VEC0: return +// +// CHECK-VEC1-LABEL: func @mul_s +// CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC1-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC1: %[[p:.*]] = load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC1: %[[q:.*]] = index_cast %[[p]] : i32 to index +// CHECK-VEC1: %[[r:.*]] = load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC1: %[[s:.*]] = index_cast %[[r]] : i32 to index +// CHECK-VEC1: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK-VEC1: %[[li:.*]] = load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[ci:.*]] = index_cast %[[li]] : i32 to index +// CHECK-VEC1: %[[la:.*]] = load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[lb:.*]] = load %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK-VEC1: %[[m:.*]] = mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC1: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2-LABEL: func @mul_s +// CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC2-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC2: %[[p:.*]] = load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC2: %[[q:.*]] = index_cast %[[p]] : i32 to index +// CHECK-VEC2: %[[r:.*]] = load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC2: %[[s:.*]] = index_cast %[[r]] : i32 to index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC2: %[[sub:.*]] = subi %[[s]], %[[i]] : index +// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: return +// +func @mul_s(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>) -> tensor<1024xf32> { + %0 = linalg.generic #trait_mul_s + ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>) + outs(%arga: tensor<1024xf32>) { + ^bb(%a: f32, %b: f32, %s : f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<1024xf32> + return %0 : tensor<1024xf32> +} + +#trait_reduction_d = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // b + affine_map<(i) -> ()> // x (out) + ], + sparse = [ + [ "D" ], // a + [ "D" ], // b + [ ] // x + ], + iterator_types = ["reduction"], + doc = "x += a(i) * b(i)" +} + +// +// CHECK-VEC0-LABEL: func @reduction_d +// CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC0-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC0: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) { +// CHECK-VEC0: %[[la:.*]] = load %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK-VEC0: %[[lb:.*]] = load %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK-VEC0: %[[m:.*]] = mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC0: %[[a:.*]] = addf %[[red_in]], %[[m]] : f32 +// CHECK-VEC0: scf.yield %[[a]] : f32 +// CHECK-VEC0: } +// CHECK-VEC0: return +// +// CHECK-VEC1-LABEL: func @reduction_d +// CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC1-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC1-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC1-DAG: %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32> +// CHECK-VEC1: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) { +// CHECK-VEC1: %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32> +// CHECK-VEC1: %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32> +// CHECK-VEC1: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC1: %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32> +// CHECK-VEC1: scf.yield %[[a]] : vector<16xf32> +// CHECK-VEC1: } +// CHECK-VEC1: %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32 +// CHECK-VEC1: return +// +// CHECK-VEC2-LABEL: func @reduction_d +// CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC2-DAG: %[[c1024:.*]] = constant 1024 : index +// CHECK-VEC2-DAG: %[[v0:.*]] = constant dense<0.000000e+00> : vector<16xf32> +// CHECK-VEC2: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[v0]]) -> (vector<16xf32>) { +// CHECK-VEC2: %[[la:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.transfer_read %{{.*}}[%[[i]]], %cst_0 {masked = [false]} : memref<1024xf32>, vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2: %[[a:.*]] = addf %[[red_in]], %[[m]] : vector<16xf32> +// CHECK-VEC2: scf.yield %[[a]] : vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: %{{.*}} = vector.reduction "add", %[[red]], %{{.*}} : vector<16xf32> into f32 +// CHECK-VEC2: return +// +func @reduction_d(%arga: tensor<1024xf32>, %argb: tensor<1024xf32>, %argx: tensor) -> tensor { + %0 = linalg.generic #trait_reduction_d + ins(%arga, %argb: tensor<1024xf32>, tensor<1024xf32>) + outs(%argx: tensor) { + ^bb(%a: f32, %b : f32, %x : f32): + %0 = mulf %a, %b : f32 + %1 = addf %x, %0 : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +#trait_mul_ds = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // a + affine_map<(i,j) -> (i,j)>, // b + affine_map<(i,j) -> (i,j)> // x (out) + ], + sparse = [ + [ "D", "S" ], // a + [ "D", "D" ], // b + [ "D", "D" ] // x + ], + iterator_types = ["parallel", "parallel"], + doc = "x(i,j) = a(i,j) * b(i,j)" +} + +// +// CHECK-VEC0-LABEL: func @mul_ds +// CHECK-VEC0-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC0-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC0-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-VEC0: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC0: %[[p:.*]] = load %{{.*}}[%[[i]]] : memref +// CHECK-VEC0: %[[q:.*]] = index_cast %[[p]] : i32 to index +// CHECK-VEC0: %[[a:.*]] = addi %[[i]], %[[c1]] : index +// CHECK-VEC0: %[[r:.*]] = load %{{.*}}[%[[a]]] : memref +// CHECK-VEC0: %[[s:.*]] = index_cast %[[r]] : i32 to index +// CHECK-VEC0: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK-VEC0: %[[lj:.*]] = load %{{.*}}[%[[j]]] : memref +// CHECK-VEC0: %[[cj:.*]] = index_cast %[[lj]] : i32 to index +// CHECK-VEC0: %[[la:.*]] = load %{{.*}}[%[[j]]] : memref +// CHECK-VEC0: %[[lb:.*]] = load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK-VEC0: %[[m:.*]] = mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC0: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK-VEC0: } +// CHECK-VEC0: } +// CHECK-VEC0: return +// +// CHECK-VEC1-LABEL: func @mul_ds +// CHECK-VEC1-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC1-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC1-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC1: %[[p:.*]] = load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[q:.*]] = index_cast %[[p]] : i32 to index +// CHECK-VEC1: %[[a:.*]] = addi %[[i]], %[[c1]] : index +// CHECK-VEC1: %[[r:.*]] = load %{{.*}}[%[[a]]] : memref +// CHECK-VEC1: %[[s:.*]] = index_cast %[[r]] : i32 to index +// CHECK-VEC1: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK-VEC1: %[[lj:.*]] = load %{{.*}}[%[[j]]] : memref +// CHECK-VEC1: %[[cj:.*]] = index_cast %[[lj]] : i32 to index +// CHECK-VEC1: %[[la:.*]] = load %{{.*}}[%[[j]]] : memref +// CHECK-VEC1: %[[lb:.*]] = load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK-VEC1: %[[m:.*]] = mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC1: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK-VEC1: } +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2-LABEL: func @mul_ds +// CHECK-VEC2-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-VEC2-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = constant 16 : index +// CHECK-VEC2-DAG: %[[c512:.*]] = constant 512 : index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC2: %[[p:.*]] = load %{{.*}}[%[[i]]] : memref +// CHECK-VEC2: %[[q:.*]] = index_cast %[[p]] : i32 to index +// CHECK-VEC2: %[[a:.*]] = addi %[[i]], %[[c1]] : index +// CHECK-VEC2: %[[r:.*]] = load %{{.*}}[%[[a]]] : memref +// CHECK-VEC2: %[[s:.*]] = index_cast %[[r]] : i32 to index +// CHECK-VEC2: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC2: %[[sub:.*]] = subi %[[s]], %[[j]] : index +// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC2: %[[lj:.*]] = vector.maskedload %{{.*}}[%arg3], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%arg3], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: } +// CHECK-VEC2: return +// +func @mul_ds(%arga: tensor<512x1024xf32>, %argb: tensor<512x1024xf32>) -> tensor<512x1024xf32> { + %0 = linalg.generic #trait_mul_ds + ins(%arga, %argb: tensor<512x1024xf32>, tensor<512x1024xf32>) + outs(%arga: tensor<512x1024xf32>) { + ^bb(%a: f32, %b: f32, %s : f32): + %0 = mulf %a, %b : f32 + linalg.yield %0 : f32 + } -> tensor<512x1024xf32> + return %0 : tensor<512x1024xf32> +} + diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp --- a/mlir/test/lib/Transforms/TestSparsification.cpp +++ b/mlir/test/lib/Transforms/TestSparsification.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -94,6 +95,7 @@ typeOption(indType)); // Apply rewriting. linalg::populateSparsificationPatterns(ctx, patterns, options); + vector::populateVectorToVectorCanonicalizationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };