diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -799,11 +799,11 @@ Value red = codegen.redVal; if (!red) return; + assert(codegen.curVecLength == 1); codegen.redVal = merger.exp(codegen.redExp).val = Value(); // end chain unsigned lhs = op.getNumShapedOperands() - 1; if (red.getType().isa()) { // TODO: assumes + reductions for now - codegen.curVecLength = 1; Value ld = genTensorLoad(merger, codegen, rewriter, op, codegen.redExp); red = rewriter.create( op.getLoc(), ld.getType(), rewriter.getStringAttr("add"), red, ld); @@ -947,6 +947,25 @@ llvm_unreachable("unexpected parallelization strategy"); } +/// Checks unit strides for dense tensors. The iteration graph may have ignored +/// dense access patterns in order to avoid cycles (sparse access patterns are +/// always placed innermost), but that means dense access has become strided. +/// For now, we reject vectorization of such cases. +/// TODO: implement strided load/stores on dense arrays +static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, + unsigned idx) { + unsigned numTensors = op.getNumShapedOperands(); + for (unsigned t = 0; t < numTensors; t++) { + if (!merger.isSparseTensor(t) && !linkedSparse(op, t)) { + auto map = op.getIndexingMap(t); + unsigned r = map.getNumResults(); + if (r && map.getDimPosition(r - 1) != idx) + return false; + } + } + return true; +} + /// Generates a for-loop on a single index. static Operation *genFor(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, @@ -958,7 +977,8 @@ auto iteratorTypes = op.iterator_types().getValue(); bool isReduction = linalg::isReductionIteratorType(iteratorTypes[idx]); bool isSparse = merger.isDim(fb, Dim::kSparse); - bool isVector = isVectorFor(codegen, isInner, isSparse); + bool isVector = isVectorFor(codegen, isInner, isSparse) && + denseUnitStrides(merger, op, idx); bool isParallel = isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); @@ -1279,10 +1299,10 @@ } // Wrap-up loop sequence. + codegen.curVecLength = 1; genReductionEnd(merger, codegen, rewriter, op); genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false); codegen.loops[idx] = Value(); - codegen.curVecLength = 1; } namespace {