diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -156,8 +156,12 @@ Merger(unsigned t, unsigned l) : outTensor(t - 1), syntheticTensor(t), numTensors(t + 1), numLoops(l), hasSparseOut(false), - dimTypes(t + 1, std::vector(l, DimLevelType::Undef)), - loopIdxToDim(t + 1, std::vector>(l, llvm::None)) {} + dimTypes(numTensors, + std::vector(numLoops, DimLevelType::Undef)), + loopIdxToDim(numTensors, + std::vector>(numLoops, llvm::None)), + dimToLoopIdx(numTensors, + std::vector>(numLoops, llvm::None)) {} /// Adds a tensor expression. Returns its index. unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value(), @@ -258,6 +262,11 @@ return getDimLevelType(tensor(b), index(b)); } + Optional getLoopIdx(unsigned t, unsigned dim) const { + assert(t < numTensors && dim < numLoops); + return dimToLoopIdx[t][dim]; + } + /// Gets the dimension number of the the `t`th tensor on `i`th loop. Optional getDimNum(unsigned t, unsigned i) const { assert(t < numTensors && i < numLoops); @@ -276,6 +285,8 @@ assert(isValidDLT(dlt)); dimTypes[t][i] = dlt; loopIdxToDim[t][i] = dim; + assert(dim < numLoops); + dimToLoopIdx[t][dim] = i; } // Iterates the bits of a lattice, for each set bit, converts it into the @@ -341,6 +352,8 @@ std::vector> dimTypes; // Map that converts pair to the corresponding dimension. std::vector>> loopIdxToDim; + // Map that converts pair to the corresponding loop id. + std::vector>> dimToLoopIdx; llvm::SmallVector tensorExps; llvm::SmallVector latPoints; llvm::SmallVector> latSets; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -338,7 +338,6 @@ // loopEmiter.exitCurrentLoop(); // exit i //===----------------------------------------------------------------------===// -// TODO: Sparsification should also rely on this class to generate loops. class SparseTensorLoopEmitter { public: /// Optional callback function to setup dense output tensors when @@ -407,6 +406,10 @@ ArrayRef extraTids = {}, ArrayRef extraDims = {}); + void genDenseAffineAddressAtCurLevel(OpBuilder &builder, Location loc, + size_t tid, size_t dim, + AffineExpr affine); + /// Emits a co-iteration loop over a set of tensors. Operation *enterCoIterationOverTensorsAtDims( OpBuilder &builder, Location loc, ArrayRef tids, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -329,6 +329,13 @@ return loop; } +void SparseTensorLoopEmitter::genDenseAffineAddressAtCurLevel( + OpBuilder &builder, Location loc, size_t tid, size_t dim, + AffineExpr affine) { + Value affineV = genAffine(builder, affine, loc); + pidxs[tid][dim] = genAddress(builder, loc, tid, dim, affineV); +} + Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims( OpBuilder &builder, Location loc, ArrayRef tids, ArrayRef dims, bool needsUniv, MutableArrayRef reduc, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TensorEncoding.h" #include "llvm/ADT/SmallBitVector.h" @@ -98,6 +99,27 @@ } }; +class AffineDimFinder : public AffineExprVisitor { +public: + explicit AffineDimFinder(linalg::GenericOp op) + : iterTypes(op.getIteratorTypesArray()) {} + void visitDimExpr(AffineDimExpr expr) { + if (pickedDim == nullptr || pickIterType == iterTypes[expr.getPosition()]) { + pickedDim = expr; + } + } + + void setPickedIterType(utils::IteratorType iterType) { + pickIterType = iterType; + } + + AffineDimExpr getDimExpr() const { return pickedDim.cast(); } + +private: + AffineExpr pickedDim; + utils::IteratorType pickIterType; + SmallVector iterTypes; +}; } // namespace //===----------------------------------------------------------------------===// @@ -219,25 +241,44 @@ /// Helper method to add all constraints from the indices in one affine /// expression before all indices in the other affine expression. For /// example i0+i1 < i2+i3+1 yields i0 i0 < fidx, i1 < fidx. +/// The affine expression `b` is empty iff `tidx` have a value, leading to +/// tidx < a = (i0 + i1) => tidx < i0, tidx < i1. static void addAffineOrderings(std::vector> &adjM, std::vector &inDegree, AffineExpr a, - AffineExpr b, unsigned fidx) { - switch (a.getKind()) { - case AffineExprKind::DimId: { - unsigned idx = a.cast().getPosition(); - if (b) - addAffineOrderings(adjM, inDegree, b, AffineExpr(), idx); - else if (!adjM[fidx][idx]) { - adjM[fidx][idx] = true; - inDegree[idx]++; + AffineExpr b, Optional fidx, + Optional tidx) { + if (!a && !b) { + // Recursion leaf. + assert(fidx && tidx); + unsigned f = *fidx, t = *tidx; + if (!adjM[f][t]) { + adjM[f][t] = true; + inDegree[t]++; } + return; + } + auto toExpand = a ? a : b; + switch (toExpand.getKind()) { + case AffineExprKind::DimId: { + auto idx = toExpand.cast().getPosition(); + if (toExpand == a) + addAffineOrderings(adjM, inDegree, AffineExpr(), b, idx, tidx); + else // toExpand == b + addAffineOrderings(adjM, inDegree, a, AffineExpr(), fidx, idx); break; } case AffineExprKind::Add: case AffineExprKind::Mul: { - auto binOp = a.cast(); - addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx); - addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx); + auto binOp = toExpand.cast(); + if (toExpand == a) { + addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx, tidx); + addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx, tidx); + } else { + addAffineOrderings(adjM, inDegree, a, binOp.getLHS(), fidx, tidx); + addAffineOrderings(adjM, inDegree, a, binOp.getRHS(), fidx, tidx); + } break; } default: @@ -275,10 +316,53 @@ // by default) puts an ordering constraint on the loop indices. For // example, the tensor expresion A_ijk forces the ordering i < j < k // on the loop indices if no explicit dimension ordering is given. - for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { - AffineExpr f = map.getResult(toOrigDim(enc, d - 1)); - AffineExpr t = map.getResult(toOrigDim(enc, d)); - addAffineOrderings(adjM, inDegree, f, t, 0); + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + AffineExpr ta = map.getResult(toOrigDim(enc, d)); + Optional tldx = merger.getLoopIdx(t.getOperandNumber(), d); + + if (d > 0) { + AffineExpr fa = map.getResult(toOrigDim(enc, d - 1)); + Optional fldx = + merger.getLoopIdx(t.getOperandNumber(), d - 1); + + if (!(mask & SortMask::kIncludeDense) && !tldx) { + AffineDimFinder finder(op); + // e.g, for [dense, dense] -> (d0 + d1, d2 + d3) + // It is totally fine to have loop sequence d0->d2->d1->d3 instead of + // requiring d0 < d2, d1 < d2, d0 < d3, d1 < d3. + // We use a heuristic here to only pick one dim expression from each + // compound affine expression to establish the order between two dense + // dimensions. + // NOTE: The ordering can only be loosen when the destination level is + // dense, for [dense, sparse] -> (d0 + d1, d2), we still require both + // d0 < d2 and d1 < d2 to ensure correct ordering (i.e., no ordering + // like d0->d2->d1). + // TODO: this is obviously a sub optimal solution. + if (!fldx && fa.isa()) { + assert(isDenseDLT(getDimLevelType(enc, d - 1)) && + !fa.isa()); + // Heuristic: we prefer parallel loop for lhs to reduce the chance + // we add reduce < parallel ordering. + finder.setPickedIterType(utils::IteratorType::parallel); + finder.walkPostOrder(fa); + fa = finder.getDimExpr(); + fldx = finder.getDimExpr().getPosition(); + } + if (!ta.isa()) { + // Dense compound affine + assert(isDenseDLT(getDimLevelType(enc, d)) && + !ta.isa()); + // Heuristic: we prefer reduction loop for rhs to reduce the chance + // addint reduce < parallel ordering. + finder.setPickedIterType(utils::IteratorType::reduction); + finder.walkPostOrder(ta); + ta = finder.getDimExpr(); + tldx = finder.getDimExpr().getPosition(); + } + } + + addAffineOrderings(adjM, inDegree, fa, ta, fldx, tldx); + } } // Push unrelated loops into sparse iteration space, so these // will be skipped more often. @@ -505,9 +589,6 @@ unsigned rank = map.getNumResults(); if (enc) { // Note that currently, all sparse subscripts are simple. - // TODO: accept affine too? - assert(map.getResult(toOrigDim(enc, rank - 1)).getKind() == - AffineExprKind::DimId); Value pidx = codegen.loopEmitter.getPidxs()[tensor].back(); assert(pidx); args.push_back(pidx); // position index @@ -750,8 +831,11 @@ switch (a.getKind()) { case AffineExprKind::DimId: { unsigned idx = a.cast().getPosition(); - if (idx == ldx) + if (idx == ldx) { atLevel = true; + // Must be invariant if we are at the level. + return true; + } return codegen.getLoopIdxValue(idx) != nullptr; // no longer in play? } case AffineExprKind::Add: @@ -1090,12 +1174,13 @@ return false; } -static void translateBitsToTidDimPairs(Merger &merger, CodeGen &codegen, - unsigned li, unsigned idx, - SmallVectorImpl &condTids, - SmallVectorImpl &condDims, - SmallVectorImpl &extraTids, - SmallVectorImpl &extraDims) { +static void translateBitsToTidDimPairs( + Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li, + unsigned idx, SmallVectorImpl &condTids, + SmallVectorImpl &condDims, SmallVectorImpl &extraTids, + SmallVectorImpl &extraDims, SmallVectorImpl &affineTids, + SmallVectorImpl &affineDims, SmallVectorImpl &exps) { + const BitVector &all = merger.lat(li).bits; const BitVector &simple = merger.lat(li).simple; @@ -1123,6 +1208,53 @@ // TODO: get rid of extraTids and extraDims. extraTids.push_back(tid); extraDims.push_back(dim.value()); + } else { + assert(isUndefDLT(dlt)); + if (tid >= op.getNumDpsInputs()) + // We only handle affine expression on input tensors (for now). + return; + OpOperand *operand = &op->getOpOperand(tid); + auto enc = getSparseTensorEncoding(operand->get().getType()); + // Non-annotated dense tensors requires no special handling. + if (!enc) + return; + + ArrayRef affines = + op.getMatchingIndexingMap(operand).getResults(); + assert(affines.size() == enc.getDimLevelType().size()); + for (unsigned i = 0, e = affines.size(); i < e; i++) { + AffineExpr exp = affines[toOrigDim(enc, i)]; + // Skip simple affine expression and non dense dimensions (which has + // it own filter loop). + if (exp.isa() || !isDenseDLT(getDimLevelType(enc, i))) + continue; + + // Constant affine expressions on dense level required to be generated + // when + // 1. The previous level is an (at-level) invariant compound dense + // affine (with no corresponding loop idx); or + // 2. The previous level is being generated right now. + if (exp.isa()) { + // TODO: Should we come up with a more adhersive way to handle + // constant expression? We now requires two (somehow ad-hoc) code for + // it. + assert(false && "do not support constant affine"); + } + + bool atLevel = false; + if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) { + // If the compound affine is invariant and we are right at the + // level. We need to generate the address according to the affine + // expression. This is also the best place we can do it to avoid + // putting it inside inner loops. + // NOTE: It assumes that the levels of the input tensor are + // initialized in order, another more admissible approach might be + // accepting out-of-order access between consecutive dense levels. + affineTids.push_back(tid); + affineDims.push_back(i); + exps.push_back(exp); + } + } } }); @@ -1146,12 +1278,23 @@ // The set of (dense) tensors that is optimized from condition, yet still // need extra locals to iterate on them. SmallVector extraTids, extraDims; - - translateBitsToTidDimPairs(merger, codegen, li, codegen.topSort[at], condTids, - condDims, extraTids, extraDims); + // The set of dense tensors with non-trivial affine expression that just + // becomes invariant and the address shall now be generated at the current + // level. + SmallVector affineTids, affineDims; + SmallVector affines; + + translateBitsToTidDimPairs(merger, codegen, op, li, codegen.topSort[at], + condTids, condDims, extraTids, extraDims, + affineTids, affineDims, affines); // Emit the for/while-loop control. Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv, condTids, condDims, extraTids, extraDims); + + for (auto [tid, dim, exp] : llvm::zip(affineTids, affineDims, affines)) { + codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(), + tid, dim, exp); + } return loop; } diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir @@ -3,6 +3,7 @@ #SpVec = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> #CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }> +#Row = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "dense" ] }> #trait1 = { indexing_maps = [ @@ -160,3 +161,67 @@ } -> tensor<32x16xf64> return %0 : tensor<32x16xf64> } + +#trait4 = { + indexing_maps = [ + affine_map<(i,j) -> (i+2,j)>, // a + affine_map<(i,j) -> (i,j+3)>, // b + affine_map<(i,j) -> (i,j)> // x (out) + ], + iterator_types = ["parallel","parallel"], + doc = "x(i,j) += a(i+2,j) * b(i,j+3)" +} + +// CHECK-LABEL: func.func @mul_affine_dense_dim_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<34x16xf64, #sparse_tensor.encoding +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 19 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64> +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_5]] { +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_18]], %[[VAL_6]] : index +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref +// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] { +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_24]], %[[VAL_7]] : index +// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_17]], %[[VAL_3]] : index +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_25]] : index +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_24]]] : memref<32x16xf64> +// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_31:.*]] = arith.mulf %[[VAL_29]], %[[VAL_30]] : f64 +// CHECK: %[[VAL_32:.*]] = arith.addf %[[VAL_28]], %[[VAL_31]] : f64 +// CHECK: memref.store %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_24]]] : memref<32x16xf64> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16xf64> +// CHECK: return %[[VAL_33]] : tensor<32x16xf64> +// CHECK: } +func.func @mul_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>, + %argb: tensor<32x19xf64, #Row>, + %argx: tensor<32x16xf64>) -> tensor<32x16xf64> { + %0 = linalg.generic #trait4 + ins(%arga, %argb: tensor<34x16xf64, #CSR>, tensor<32x19xf64, #Row>) + outs(%argx: tensor<32x16xf64>) { + ^bb(%a: f64, %b: f64, %x: f64): + %0 = arith.mulf %a, %b : f64 + %1 = arith.addf %x, %0 : f64 + linalg.yield %1 : f64 + } -> tensor<32x16xf64> + return %0 : tensor<32x16xf64> +}