diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -172,6 +172,16 @@ std::unique_ptr createSparseBufferRewritePass(bool enableBufferInitialization); +void populateSparseVectorizationPatterns(RewritePatternSet &patterns, + unsigned vectorLength, + bool enableVLAVectorization, + bool enableSIMDIndex32); + +std::unique_ptr createSparseVectorizationPass(); +std::unique_ptr createSparseVectorizationPass(unsigned vectorLength, + bool enableVLAVectorization, + bool enableSIMDIndex32); + //===----------------------------------------------------------------------===// // Registration. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td @@ -225,4 +225,64 @@ ]; } +def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> { + let summary = "Vectorizes loops after sparsification"; + let description = [{ + A pass that converts loops after sparsification into vector loops. + The vector dialect is used as target to provide an architectural + neutral way of exploiting any platform that supports SIMD instructions. + + The vector length (viz. `vl`) describes the number of packed data elements + (e.g. both vector<16xf32> and vector<16xf64> have a vector length of 16 even + though the actual bitwidths differ). A small multiple of the actual lengths + supported in hardware typically results in efficient SIMD code, since the + backend will map longer vectors to multiple vector registers, thereby + effectively unrolling an addition level within the generated for-loop. + + Example of the conversion: + + ```mlir + Before: + %3 = memref.load %2[] : memref + %4 = scf.for %arg3 = %c0 to %c1024 step %c1 iter_args(%arg4 = %3) -> (f32) { + %6 = memref.load %0[%arg3] : memref + %7 = memref.load %1[%arg3] : memref<1024xf32> + %8 = arith.mulf %6, %7 : f32 + %9 = arith.addf %arg4, %8 : f32 + scf.yield %9 : f32 + } + memref.store %4, %2[] : memref + + After: + %3 = memref.load %2[] : memref + %4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32> + %5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) { + %8 = vector.load %0[%arg3] : memref, vector<32xf32> + %9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32> + %10 = arith.mulf %8, %9 : vector<32xf32> + %11 = arith.addf %arg4, %10 : vector<32xf32> + scf.yield %11 : vector<32xf32> + } + %6 = vector.reduction , %5 : vector<32xf32> into f32 + memref.store %6, %2[] : memref + ``` + }]; + let constructor = "mlir::createSparseVectorizationPass()"; + let dependentDialects = [ + "arith::ArithDialect", + "memref::MemRefDialect", + "scf::SCFDialect", + "sparse_tensor::SparseTensorDialect", + "vector::VectorDialect", + ]; + let options = [ + Option<"vectorLength", "vl", "int32_t", "0", + "Set the vector length (use 0 to disable vectorization)">, + Option<"enableVLAVectorization", "enable-vla-vectorization", "bool", + "false", "Enable vector length agnostic vectorization">, + Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false", + "Enable i32 indexing into vectors (for efficient gather/scatter)">, + ]; +} + #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ SparseTensorConversion.cpp SparseTensorPasses.cpp SparseTensorRewriting.cpp + SparseVectorization.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -27,6 +27,7 @@ #define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS #define GEN_PASS_DEF_SPARSETENSORCODEGEN #define GEN_PASS_DEF_SPARSEBUFFERREWRITE +#define GEN_PASS_DEF_SPARSEVECTORIZATION #include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc" } // namespace mlir @@ -67,10 +68,9 @@ auto *ctx = &getContext(); // Translate strategy flags to strategy options. SparsificationOptions options(parallelization); - // Apply sparsification and vector cleanup rewriting. + // Apply sparsification and cleanup rewriting. RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); - vector::populateVectorToVectorCanonicalizationPatterns(patterns); scf::ForOp::getCanonicalizationPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -250,6 +250,27 @@ } }; +struct SparseVectorizationPass + : public impl::SparseVectorizationBase { + + SparseVectorizationPass() = default; + SparseVectorizationPass(const SparseVectorizationPass &pass) = default; + SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) { + vectorLength = vl; + enableVLAVectorization = vla; + enableSIMDIndex32 = sidx32; + } + + void runOnOperation() override { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateSparseVectorizationPatterns( + patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -322,3 +343,15 @@ mlir::createSparseBufferRewritePass(bool enableBufferInitialization) { return std::make_unique(enableBufferInitialization); } + +std::unique_ptr mlir::createSparseVectorizationPass() { + return std::make_unique(); +} + +std::unique_ptr +mlir::createSparseVectorizationPass(unsigned vectorLength, + bool enableVLAVectorization, + bool enableSIMDIndex32) { + return std::make_unique( + vectorLength, enableVLAVectorization, enableSIMDIndex32); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -0,0 +1,477 @@ +//===- SparseVectorization.cpp - Vectorization of sparsified loops --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A pass that converts loops generated by the sparse compiler into a form that +// can exploit SIMD instructions of the target architecture. Note that this pass +// ensures the sparse compiler can generate efficient SIMD (including ArmSVE +// support) with proper separation of concerns as far as sparsification and +// vectorization is concerned. However, this pass is not the final abstraction +// level we want, and not the general vectorizer we want either. It forms a good +// stepping stone for incremental future improvements though. +// +//===----------------------------------------------------------------------===// + +#include "CodegenUtils.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Matchers.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +namespace { + +/// Target SIMD properties: +/// vectorLength: # packed data elements (viz. vector<16xf32> has length 16) +/// enableVLAVectorization: enables scalable vectors (viz. ARMSve) +/// enableSIMDIndex32: uses 32-bit indices in gather/scatter for efficiency +struct VL { + unsigned vectorLength; + bool enableVLAVectorization; + bool enableSIMDIndex32; +}; + +/// Helper to test for given index value. +static bool isIntValue(Value val, int64_t idx) { + if (auto ival = getConstantIntValue(val)) + return *ival == idx; + return false; +} + +/// Constructs vector type for element type. +static VectorType vectorType(VL vl, Type etp) { + unsigned numScalableDims = vl.enableVLAVectorization; + return VectorType::get(vl.vectorLength, etp, numScalableDims); +} + +/// Constructs vector type from pointer. +static VectorType vectorType(VL vl, Value ptr) { + return vectorType(vl, ptr.getType().cast().getElementType()); +} + +/// Constructs vector iteration mask. +static Value genVectorMask(PatternRewriter &rewriter, Location loc, VL vl, + Value iv, Value lo, Value hi, Value step) { + VectorType mtp = vectorType(vl, rewriter.getI1Type()); + // Special case if the vector length evenly divides the trip count (for + // example, "for i = 0, 128, 16"). A constant all-true mask is generated + // so that all subsequent masked memory operations are immediately folded + // into unconditional memory operations. + IntegerAttr loInt, hiInt, stepInt; + if (matchPattern(lo, m_Constant(&loInt)) && + matchPattern(hi, m_Constant(&hiInt)) && + matchPattern(step, m_Constant(&stepInt))) { + if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) { + Value trueVal = constantI1(rewriter, loc, true); + return rewriter.create(loc, mtp, trueVal); + } + } + // Otherwise, generate a vector mask that avoids overrunning the upperbound + // during vector execution. Here we rely on subsequent loop optimizations to + // avoid executing the mask in all iterations, for example, by splitting the + // loop into an unconditional vector loop and a scalar cleanup loop. + auto min = AffineMap::get( + /*dimCount=*/2, /*symbolCount=*/1, + {rewriter.getAffineSymbolExpr(0), + rewriter.getAffineDimExpr(0) - rewriter.getAffineDimExpr(1)}, + rewriter.getContext()); + Value end = + rewriter.createOrFold(loc, min, ValueRange{hi, iv, step}); + return rewriter.create(loc, mtp, end); +} + +/// Generates a vectorized invariant. Here we rely on subsequent loop +/// optimizations to hoist the invariant broadcast out of the vector loop. +static Value genVectorInvariantValue(PatternRewriter &rewriter, VL vl, + Value val) { + VectorType vtp = vectorType(vl, val.getType()); + return rewriter.create(val.getLoc(), vtp, val); +} + +/// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], +/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. +static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, + Value ptr, ArrayRef idxs, Value vmask) { + VectorType vtp = vectorType(vl, ptr); + Value pass = constantZero(rewriter, loc, vtp); + if (idxs.back().getType().isa()) { + SmallVector scalarArgs(idxs.begin(), idxs.end()); + Value indexVec = idxs.back(); + scalarArgs.back() = constantIndex(rewriter, loc, 0); + return rewriter.create(loc, vtp, ptr, scalarArgs, + indexVec, vmask, pass); + } + return rewriter.create(loc, vtp, ptr, idxs, vmask, + pass); +} + +/// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs +/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. +static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr, + ArrayRef idxs, Value vmask, Value rhs) { + if (idxs.back().getType().isa()) { + SmallVector scalarArgs(idxs.begin(), idxs.end()); + Value indexVec = idxs.back(); + scalarArgs.back() = constantIndex(rewriter, loc, 0); + rewriter.create(loc, ptr, scalarArgs, indexVec, vmask, + rhs); + return; + } + rewriter.create(loc, ptr, idxs, vmask, rhs); +} + +/// Maps operation to combining kind for reduction. +static vector::CombiningKind getCombiningKind(Operation *def) { + if (isa(def) || isa(def) || + isa(def) || isa(def)) + return vector::CombiningKind::ADD; + if (isa(def) || isa(def)) + return vector::CombiningKind::MUL; + if (isa(def)) + return vector::CombiningKind::AND; + if (isa(def)) + return vector::CombiningKind::OR; + if (isa(def)) + return vector::CombiningKind::XOR; + llvm_unreachable("unknown reduction kind"); +} + +/// Generates an initial value for a vector reduction, following the scheme +/// given in Chapter 5 of "The Software Vectorization Handbook", where the +/// initial scalar value is correctly embedded in the vector reduction value, +/// and a straightforward horizontal reduction will complete the operation. +/// The value 'r' denotes the initial value of the accumulator. Value 'rd' +/// denotes the accumulation operation, which is solely used here to determine +/// the kind of combining reduction (viz. addf -> sum-accumulation). +static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, + VectorType vtp, Value r, Value rd) { + vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp()); + switch (kind) { + case vector::CombiningKind::ADD: + case vector::CombiningKind::XOR: + // Initialize reduction vector to: | 0 | .. | 0 | r | + return rewriter.create( + loc, r, constantZero(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); + case vector::CombiningKind::MUL: + // Initialize reduction vector to: | 1 | .. | 1 | r | + return rewriter.create( + loc, r, constantOne(rewriter, loc, vtp), + constantIndex(rewriter, loc, 0)); + case vector::CombiningKind::AND: + case vector::CombiningKind::OR: + // Initialize reduction vector to: | r | .. | r | r | + return rewriter.create(loc, vtp, r); + default: + break; + } + llvm_unreachable("unknown reduction kind"); +} + +/// Generates final value for a vector reduction. +static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc, + Value vexp, Value rd) { + vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp()); + return rewriter.create(loc, kind, vexp); +} + +/// This method is called twice to analyze and rewrite the given subscripts. +/// The first call (!codegen) does the analysis. Then, on success, the second +/// call (codegen) yields the proper vector form in the output parameter +/// vector 'idxs'. This mechanism ensures that analysis and rewriting code +/// stay in sync. +/// +/// See https://llvm.org/docs/GetElementPtr.html for some background on +/// the complications described below. +/// +/// We need to generate a pointer/index load from the sparse storage scheme. +/// Narrower data types need to be zero extended before casting the value +/// into the index type used for looping and indexing. +/// +/// For the scalar case, subscripts simply zero extend narrower indices +/// into 64-bit values before casting to an index type without a performance +/// penalty. Indices that already are 64-bit, in theory, cannot express the +/// full range since the LLVM backend defines addressing in terms of an +/// unsigned pointer/signed index pair. +static bool vectorizeSubscripts(PatternRewriter &rewriter, scf::ForOp forOp, + VL vl, ValueRange subs, bool codegen, + Value vmask, SmallVectorImpl &idxs) { + for (auto sub : subs) { + // Invariant indices simply pass through. + if (sub.dyn_cast() || + sub.getDefiningOp()->getBlock() != &forOp.getRegion().front()) { + if (codegen) + idxs.push_back(sub); + continue; // success so far + } + // Look under the hood of casting. + auto cast = sub; + while (1) { + if (auto icast = cast.getDefiningOp()) + cast = icast->getOperand(0); + else if (auto ecast = cast.getDefiningOp()) + cast = ecast->getOperand(0); + else + break; + } + // Since the index vector is used in a subsequent gather/scatter + // operations, which effectively defines an unsigned pointer + signed + // index, we must zero extend the vector to an index width. For 8-bit + // and 16-bit values, an 32-bit index width suffices. For 32-bit values, + // zero extending the elements into 64-bit loses some performance since + // the 32-bit indexed gather/scatter is more efficient than the 64-bit + // index variant (if the negative 32-bit index space is unused, the + // enableSIMDIndex32 flag can preserve this performance). For 64-bit + // values, there is no good way to state that the indices are unsigned, + // which creates the potential of incorrect address calculations in the + // unlikely case we need such extremely large offsets. + if (auto load = cast.getDefiningOp()) { + if (codegen) { + SmallVector idxs2(load.getIndices()); // no need to analyze + Location loc = forOp.getLoc(); + Value vload = + genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask); + Type etp = vload.getType().cast().getElementType(); + if (!etp.isa()) { + if (etp.getIntOrFloatBitWidth() < 32) + vload = rewriter.create( + loc, vectorType(vl, rewriter.getI32Type()), vload); + else if (etp.getIntOrFloatBitWidth() < 64 && !vl.enableSIMDIndex32) + vload = rewriter.create( + loc, vectorType(vl, rewriter.getI64Type()), vload); + } + idxs.push_back(vload); + } + continue; // success so far + } + return false; + } + return true; +} + +#define UNAOP(xxx) \ + if (isa(def)) { \ + if (codegen) \ + vexp = rewriter.create(loc, vx); \ + return true; \ + } + +#define BINOP(xxx) \ + if (isa(def)) { \ + if (codegen) \ + vexp = rewriter.create(loc, vx, vy); \ + return true; \ + } + +/// This method is called twice to analyze and rewrite the given expression. +/// The first call (!codegen) does the analysis. Then, on success, the second +/// call (codegen) yields the proper vector form in the output parameter 'vexp'. +/// This mechanism ensures that analysis and rewriting code stay in sync. +static bool vectorizeExpr(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, + Value exp, bool codegen, Value vmask, Value &vexp) { + // A block argument in invariant. + if (auto arg = exp.dyn_cast()) { + if (codegen) + vexp = genVectorInvariantValue(rewriter, vl, exp); + return true; + } + // Something defined outside the loop-body is invariant as well. + Operation *def = exp.getDefiningOp(); + if (def->getBlock() != &forOp.getRegion().front()) { + if (codegen) + vexp = genVectorInvariantValue(rewriter, vl, exp); + return true; + } + // Inside loop-body unary and binary operations. Note that it would be + // nicer if we could somehow test and build the operations in a more + // concise manner than just listing them all (although this way we know + // for certain that they can vectorize). + Location loc = forOp.getLoc(); + if (auto load = dyn_cast(def)) { + auto subs = load.getIndices(); + SmallVector idxs; + if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) { + if (codegen) + vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask); + return true; + } + } else if (def->getNumOperands() == 1) { + Value vx; + if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, + vx)) { + UNAOP(math::AbsFOp) + UNAOP(math::AbsIOp) + UNAOP(math::CeilOp) + UNAOP(math::FloorOp) + UNAOP(math::SqrtOp) + UNAOP(math::ExpM1Op) + UNAOP(math::Log1pOp) + UNAOP(math::SinOp) + UNAOP(math::TanhOp) + UNAOP(arith::NegFOp) + } + } else if (def->getNumOperands() == 2) { + Value vx, vy; + if (vectorizeExpr(rewriter, forOp, vl, def->getOperand(0), codegen, vmask, + vx) && + vectorizeExpr(rewriter, forOp, vl, def->getOperand(1), codegen, vmask, + vy)) { + BINOP(arith::MulFOp) + BINOP(arith::MulIOp) + BINOP(arith::DivFOp) + BINOP(arith::DivSIOp) + BINOP(arith::DivUIOp) + BINOP(arith::AddFOp) + BINOP(arith::AddIOp) + BINOP(arith::SubFOp) + BINOP(arith::SubIOp) + BINOP(arith::AndIOp) + BINOP(arith::OrIOp) + BINOP(arith::XOrIOp) + } + } + return false; +} + +#undef UNAOP +#undef BINOP + +/// This method is called twice to analyze and rewrite the given for-loop. +/// The first call (!codegen) does the analysis. Then, on success, the second +/// call (codegen) rewriters the IR into vector form. This mechanism ensures +/// that analysis and rewriting code stay in sync. +static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, + bool codegen) { + Location loc = forOp.getLoc(); + Block &block = forOp.getRegion().front(); + scf::YieldOp yield = cast(block.getTerminator()); + auto &last = *++block.rbegin(); + scf::ForOp forOpNew; + + // Perform initial set up during codegen (we know that the first analysis + // pass was successful). For reductions, we need to construct a completely + // new for-loop, since the incoming and outgoing reduction type + // changes into SIMD form. For stores, we can simply adjust the stride + // and insert in the existing for-loop. In both cases, we set up a vector + // mask for all operations which takes care of confining vectors to + // the original iteration space (later cleanup loops or other + // optimizations can take care of those). + Value vmask; + if (codegen) { + Value step = constantIndex(rewriter, loc, vl.vectorLength); + if (!yield.getResults().empty()) { + Value init = forOp.getInitArgs()[0]; + VectorType vtp = vectorType(vl, init.getType()); + Value vinit = + genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0)); + forOpNew = rewriter.create( + loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); + rewriter.setInsertionPointToStart(forOpNew.getBody()); + } else { + forOp.setStep(step); + rewriter.setInsertionPoint(yield); + } + vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(), + forOp.getLowerBound(), forOp.getUpperBound(), step); + } + + // Sparse for-loops either are terminated by a non-empty yield operation + // (reduction loop) or otherwise by a store operation (pararallel loop). + if (!yield.getResults().empty()) { + if (yield->getNumOperands() != 1) + return false; + Value redOp = yield->getOperand(0); + // Analyze/vectorize reduction. + // TODO: use linalg utils to verify the actual reduction? + Value vrhs; + if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) { + if (codegen) { + rewriter.create(loc, vrhs); + rewriter.setInsertionPointAfter(forOpNew); + Value vres = + genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp); + // Now do some relinking (last one is not completely type safe + // but all bad ones are removed right away). + forOp.getResult(0).replaceAllUsesWith(vres); + forOp.getInductionVar().replaceAllUsesWith(forOpNew.getInductionVar()); + forOp.getRegionIterArg(0).replaceAllUsesWith( + forOpNew.getRegionIterArg(0)); + rewriter.eraseOp(forOp); + } + return true; + } + } else if (auto store = dyn_cast(last)) { + // Analyze/vectorize store operation. + auto subs = store.getIndices(); + SmallVector idxs; + Value rhs = store.getValue(); + Value vrhs; + if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) && + vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) { + if (codegen) { + genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs); + rewriter.eraseOp(store); + } + return true; + } + } + + assert(!codegen && "cannot call codegen when analysis failed"); + return false; +} + +/// Basic for-loop vectorizer. +struct ForOpRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + ForOpRewriter(MLIRContext *context, unsigned vectorLength, + bool enableVLAVectorization, bool enableSIMDIndex32) + : OpRewritePattern(context), + vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + // Check for single block, unit-stride for-loop that is generated by + // sparse compiler, which means no data dependence analysis is required, + // and its loop-body is very restricted in form. + if (!op.getRegion().hasOneBlock() || !isIntValue(op.getStep(), 1) || + !op->hasAttr(SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) + return failure(); + // Analyze (!codegen) and rewrite (codegen) loop-body. + if (vectorizeStmt(rewriter, op, vl, /*codegen=*/false) && + vectorizeStmt(rewriter, op, vl, /*codegen=*/true)) + return success(); + return failure(); + } + +private: + const VL vl; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Public method for populating vectorization rules. +//===----------------------------------------------------------------------===// + +/// Populates the given patterns list with vectorization rules. +void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns, + unsigned vectorLength, + bool enableVLAVectorization, + bool enableSIMDIndex32) { + patterns.add(patterns.getContext(), vectorLength, + enableVLAVectorization, enableSIMDIndex32); +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -sparsification -cse -split-input-file | \ -// RUN: FileCheck %s +// RUN: FileCheck %s --check-prefix=CHECK-VEC1 +// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16" -cse -split-input-file | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC2 +// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=16 enable-simd-index32=true" -cse -split-input-file | \ +// RUN: FileCheck %s --check-prefix=CHECK-VEC3 #DenseVector = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }> @@ -13,18 +17,41 @@ } // -// CHECK-LABEL: func @scale_d -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] { -// CHECK: %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK: %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32 -// CHECK: store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32> -// CHECK: } -// CHECK: return +// CHECK-VEC1-LABEL: func @scale_d +// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC1-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] { +// CHECK-VEC1: %[[l:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[l]], %{{.*}} : f32 +// CHECK-VEC1: store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2-LABEL: func @scale_d +// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC2-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] { +// CHECK-VEC2: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> +// CHECK-VEC2: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32> +// CHECK-VEC2: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: return +// +// CHECK-VEC3-LABEL: func @scale_d +// CHECK-VEC3-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC3-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] { +// CHECK-VEC3: %[[r:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> +// CHECK-VEC3: %[[b:.*]] = vector.broadcast %{{.*}} : f32 to vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = arith.mulf %[[r]], %[[b]] : vector<16xf32> +// CHECK-VEC3: vector.store %[[m]], %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: return // - func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor<1024xf32>) -> tensor<1024xf32> { %0 = linalg.generic #trait_scale_d ins(%arga: tensor<1024xf32, #DenseVector>) @@ -55,27 +82,74 @@ } // -// CHECK-LABEL: func @mul_s -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref -// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref -// CHECK: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] { -// CHECK: %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK: %[[zi:.*]] = arith.extui %[[li]] : i32 to i64 -// CHECK: %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index -// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32> -// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32> -// CHECK: } -// CHECK: return -// -func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, %argb: tensor<1024xf32>, %argx: tensor<1024xf32>) -> tensor<1024xf32> { +// CHECK-VEC1-LABEL: func @mul_s +// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC1: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC1: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK-VEC1: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK-VEC1: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC1: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK-VEC1: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK-VEC1: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK-VEC1: %[[li:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[zi:.*]] = arith.extui %[[li]] : i32 to i64 +// CHECK-VEC1: %[[ci:.*]] = arith.index_cast %[[zi]] : i64 to index +// CHECK-VEC1: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[lb:.*]] = memref.load %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC1: store %[[m]], %{{.*}}[%[[ci]]] : memref<1024xf32> +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC2-LABEL: func @mul_s +// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC2-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC2: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC2: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK-VEC2: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK-VEC2: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC2: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK-VEC2: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC2: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]] +// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC2: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC2: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64> +// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: return +// +// CHECK-VEC3: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC3-LABEL: func @mul_s +// CHECK-VEC3-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[c0]]] : memref +// CHECK-VEC3: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK-VEC3: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[c1]]] : memref +// CHECK-VEC3: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK-VEC3: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]] +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: return +// +func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, + %argb: tensor<1024xf32>, + %argx: tensor<1024xf32>) -> tensor<1024xf32> { %0 = linalg.generic #trait_mul_s ins(%arga, %argb: tensor<1024xf32, #SparseVector>, tensor<1024xf32>) outs(%argx: tensor<1024xf32>) { @@ -101,20 +175,56 @@ } // -// CHECK-LABEL: func @reduction_d -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c1024:.*]] = arith.constant 1024 : index -// CHECK: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) { -// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32> -// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32 -// CHECK: scf.yield %[[a]] : f32 -// CHECK: } -// CHECK: return -// -func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, %argb: tensor<1024xf32>, %argx: tensor) -> tensor { +// CHECK-VEC1-LABEL: func @reduction_d +// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC1-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK-VEC1: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c1]] iter_args(%[[red_in:.*]] = %{{.*}}) -> (f32) { +// CHECK-VEC1: %[[la:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[lb:.*]] = memref.load %{{.*}}[%[[i]]] : memref<1024xf32> +// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC1: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : f32 +// CHECK-VEC1: scf.yield %[[a]] : f32 +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2-LABEL: func @reduction_d +// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC2-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK-VEC2-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-VEC2: %[[l:.*]] = memref.load %{{.*}}[] : memref +// CHECK-VEC2: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> +// CHECK-VEC2: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { +// CHECK-VEC2: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32> +// CHECK-VEC2: scf.yield %[[a]] : vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: %{{.*}} = vector.reduction , %[[red]] : vector<16xf32> into f32 +// CHECK-VEC2: return +// +// CHECK-VEC3-LABEL: func @reduction_d +// CHECK-VEC3-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC3-DAG: %[[c1024:.*]] = arith.constant 1024 : index +// CHECK-VEC3-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-VEC3: %[[l:.*]] = memref.load %{{.*}}[] : memref +// CHECK-VEC3: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32> +// CHECK-VEC3: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) { +// CHECK-VEC3: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref, vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<16xf32> +// CHECK-VEC3: scf.yield %[[a]] : vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: %{{.*}} = vector.reduction , %[[red]] : vector<16xf32> into f32 +// CHECK-VEC3: return +// +func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, + %argb: tensor<1024xf32>, + %argx: tensor) -> tensor { %0 = linalg.generic #trait_reduction_d ins(%arga, %argb: tensor<1024xf32, #DenseVector>, tensor<1024xf32>) outs(%argx: tensor) { @@ -145,31 +255,86 @@ } // -// CHECK-LABEL: func @mul_ds -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c512:.*]] = arith.constant 512 : index -// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { -// CHECK: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 -// CHECK: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index -// CHECK: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref -// CHECK: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 -// CHECK: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index -// CHECK: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] { -// CHECK: %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref -// CHECK: %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64 -// CHECK: %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index -// CHECK: %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref -// CHECK: %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> -// CHECK: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 -// CHECK: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> -// CHECK: } -// CHECK: } -// CHECK: return -// -func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, %argb: tensor<512x1024xf32>, %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> { +// CHECK-VEC1-LABEL: func @mul_ds +// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC1-DAG: %[[c512:.*]] = arith.constant 512 : index +// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC1: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK-VEC1: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK-VEC1: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK-VEC1: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref +// CHECK-VEC1: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK-VEC1: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK-VEC1: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c1]] { +// CHECK-VEC1: %[[lj:.*]] = memref.load %{{.*}}[%[[j]]] : memref +// CHECK-VEC1: %[[zj:.*]] = arith.extui %[[lj]] : i32 to i64 +// CHECK-VEC1: %[[cj:.*]] = arith.index_cast %[[zj]] : i64 to index +// CHECK-VEC1: %[[la:.*]] = memref.load %{{.*}}[%[[j]]] : memref +// CHECK-VEC1: %[[lb:.*]] = memref.load %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK-VEC1: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : f32 +// CHECK-VEC1: store %[[m]], %{{.*}}[%[[i]], %[[cj]]] : memref<512x1024xf32> +// CHECK-VEC1: } +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC2-LABEL: func @mul_ds +// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC2-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC2-DAG: %[[c512:.*]] = arith.constant 512 : index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC2: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC2: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK-VEC2: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK-VEC2: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK-VEC2: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref +// CHECK-VEC2: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK-VEC2: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK-VEC2: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC2: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]] +// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC2: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC2: %[[zj:.*]] = arith.extui %[[lj]] : vector<16xi32> to vector<16xi64> +// CHECK-VEC2: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC2: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC2: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> +// CHECK-VEC2: } +// CHECK-VEC2: } +// CHECK-VEC2: return +// +// CHECK-VEC3: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC3-LABEL: func @mul_ds +// CHECK-VEC3-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC3-DAG: %[[c512:.*]] = arith.constant 512 : index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c512]] step %[[c1]] { +// CHECK-VEC3: %[[p:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC3: %[[a:.*]] = arith.extui %[[p]] : i32 to i64 +// CHECK-VEC3: %[[q:.*]] = arith.index_cast %[[a]] : i64 to index +// CHECK-VEC3: %[[a:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK-VEC3: %[[r:.*]] = memref.load %{{.*}}[%[[a]]] : memref +// CHECK-VEC3: %[[b:.*]] = arith.extui %[[r]] : i32 to i64 +// CHECK-VEC3: %[[s:.*]] = arith.index_cast %[[b]] : i64 to index +// CHECK-VEC3: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]] +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC3: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC3: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> +// CHECK-VEC3: } +// CHECK-VEC3: } +// CHECK-VEC3: return +// +func.func @mul_ds(%arga: tensor<512x1024xf32, #SparseMatrix>, + %argb: tensor<512x1024xf32>, + %argx: tensor<512x1024xf32>) -> tensor<512x1024xf32> { %0 = linalg.generic #trait_mul_ds ins(%arga, %argb: tensor<512x1024xf32, #SparseMatrix>, tensor<512x1024xf32>) outs(%argx: tensor<512x1024xf32>) { @@ -194,26 +359,70 @@ } // -// CHECK-LABEL: func @add_dense -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index -// CHECK: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { -// CHECK: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref -// CHECK: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index -// CHECK: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref -// CHECK: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] { -// CHECK: %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref -// CHECK: %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> -// CHECK: %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref -// CHECK: %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64 -// CHECK: memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> -// CHECK: } -// CHECK: } -// CHECK: return +// CHECK-VEC1-LABEL: func @add_dense +// CHECK-VEC1-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC1-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC1-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-VEC1: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { +// CHECK-VEC1: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC1: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK-VEC1: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref +// CHECK-VEC1: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c1]] { +// CHECK-VEC1: %[[j:.*]] = memref.load %{{.*}}[%[[jj]]] : memref +// CHECK-VEC1: %[[x:.*]] = memref.load %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> +// CHECK-VEC1: %[[a:.*]] = memref.load %{{.*}}[%[[jj]]] : memref +// CHECK-VEC1: %[[s:.*]] = arith.addf %[[x]], %[[a]] : f64 +// CHECK-VEC1: memref.store %[[s]], %{{.*}}[%[[i1]], %[[j]]] : memref<33x64xf64> +// CHECK-VEC1: } +// CHECK-VEC1: } +// CHECK-VEC1: return +// +// CHECK-VEC2: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC2-LABEL: func @add_dense +// CHECK-VEC2-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC2-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC2-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC2-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-VEC2: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { +// CHECK-VEC2: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC2: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK-VEC2: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref +// CHECK-VEC2: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] { +// CHECK-VEC2: %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]] +// CHECK-VEC2: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC2: %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref +// CHECK-VEC2: %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64> +// CHECK-VEC2: %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref +// CHECK-VEC2: %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64> +// CHECK-VEC2: vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64> +// CHECK-VEC2: } +// CHECK-VEC2: } +// CHECK-VEC2: return +// +// CHECK-VEC3: #[[$map:.*]] = affine_map<(d0, d1)[s0] -> (16, d0 - d1) +// CHECK-VEC3-LABEL: func @add_dense +// CHECK-VEC3-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-VEC3-DAG: %[[c1:.*]] = arith.constant 1 : index +// CHECK-VEC3-DAG: %[[c16:.*]] = arith.constant 16 : index +// CHECK-VEC3-DAG: %[[c32:.*]] = arith.constant 32 : index +// CHECK-VEC3: scf.for %[[i:.*]] = %[[c0]] to %[[c32]] step %[[c1]] { +// CHECK-VEC3: %[[lo:.*]] = memref.load %{{.*}}[%[[i]]] : memref +// CHECK-VEC3: %[[i1:.*]] = arith.addi %[[i]], %[[c1]] : index +// CHECK-VEC3: %[[hi:.*]] = memref.load %{{.*}}[%[[i1]]] : memref +// CHECK-VEC3: scf.for %[[jj:.*]] = %[[lo]] to %[[hi]] step %[[c16]] { +// CHECK-VEC3: %[[sub:.*]] = affine.min #[[$map]](%[[hi]], %[[jj]])[%[[c16]]] +// CHECK-VEC3: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> +// CHECK-VEC3: %[[j:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref +// CHECK-VEC3: %[[x:.*]] = vector.gather %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %{{.*}} : memref<33x64xf64> +// CHECK-VEC3: %[[a:.*]] = vector.maskedload %{{.*}}[%[[jj]]], %[[mask]], %{{.*}} : memref +// CHECK-VEC3: %[[s:.*]] = arith.addf %[[x]], %[[a]] : vector<16xf64> +// CHECK-VEC3: vector.scatter %{{.*}}[%[[i1]], %[[c0]]] [%[[j]]], %[[mask]], %[[s]] : memref<33x64xf64> +// CHECK-VEC3: } +// CHECK-VEC3: } +// CHECK-VEC3: return // func.func @add_dense(%arga: tensor<32x64xf64, #SparseMatrix>, - %argx: tensor<33x64xf64>) -> tensor<33x64xf64> { + %argx: tensor<33x64xf64>) -> tensor<33x64xf64> { %0 = linalg.generic #trait_affine ins(%arga: tensor<32x64xf64, #SparseMatrix>) outs(%argx: tensor<33x64xf64>) { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2175,6 +2175,7 @@ ":LinalgDialect", ":LinalgTransforms", ":LinalgUtils", + ":MathDialect", ":MemRefDialect", ":Pass", ":SCFDialect", @@ -2186,6 +2187,7 @@ ":Support", ":TensorDialect", ":Transforms", + ":VectorDialect", "//llvm:Support", ], )