diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -997,7 +997,6 @@ ArrayRef extraTids, ArrayRef extraDims) { Location loc = op.getLoc(); - auto iteratorTypes = op.getIteratorTypesArray(); bool isSparse = isCompressedDLT(merger.getDimLevelType(tid, idx)) || isSingletonDLT(merger.getDimLevelType(tid, idx)); bool isParallel = isParallelFor(codegen, isOuter, isSparse); @@ -1189,6 +1188,42 @@ return false; } +static void genConstantDenseAddressFromLevel(CodeGen &codegen, + OpBuilder &builder, + linalg::GenericOp op, unsigned tid, + unsigned lvl) { + // TODO: Handle affine expression on output tensor. + assert(tid < op.getNumDpsInputs()); + + OpOperand *input = op.getDpsInputOperands()[tid]; + ArrayRef affines = op.getMatchingIndexingMap(input).getResults(); + auto enc = getSparseTensorEncoding(input->get().getType()); + if (enc) { + for (unsigned i = lvl, e = affines.size(); i < e; i++) { + AffineExpr affine = affines[toOrigDim(enc, i)]; + if (isDenseDLT(getDimLevelType(enc, i)) && + affine.isa()) { + codegen.loopEmitter.genDenseAffineAddressAtCurLevel( + builder, op.getLoc(), input->getOperandNumber(), i, affine); + } else { + // Breaks on first non-dense non-constant level. + return; + } + } + } +} + +static void genInitConstantDenseAddress(CodeGen &codegen, + RewriterBase &rewriter, + linalg::GenericOp op) { + // We can generates address for constant affine expression before any loops + // starting from the first level as they do not depend on any thing. + // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two + // levels can be determined before loops. + for (unsigned tid = 0, e = op.getNumDpsInputs(); tid < e; tid++) + genConstantDenseAddressFromLevel(codegen, rewriter, op, tid, 0); +} + static void translateBitsToTidDimPairs( Merger &merger, CodeGen &codegen, linalg::GenericOp op, unsigned li, unsigned idx, SmallVectorImpl &condTids, @@ -1244,30 +1279,21 @@ if (exp.isa() || !isDenseDLT(getDimLevelType(enc, i))) continue; - // Constant affine expressions on dense level required to be generated - // when - // 1. The previous level is an (at-level) invariant compound dense - // affine (with no corresponding loop idx); or - // 2. The previous level is being generated right now. - if (exp.isa()) { - // TODO: Should we come up with a more adhersive way to handle - // constant expression? We now requires two (somehow ad-hoc) code for - // it. - assert(false && "do not support constant affine"); - } - - bool atLevel = false; - if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) { - // If the compound affine is invariant and we are right at the - // level. We need to generate the address according to the affine - // expression. This is also the best place we can do it to avoid - // putting it inside inner loops. - // NOTE: It assumes that the levels of the input tensor are - // initialized in order, another more admissible approach might be - // accepting out-of-order access between consecutive dense levels. - affineTids.push_back(tid); - affineDims.push_back(i); - exps.push_back(exp); + // Constant affine expression are handled in genLoop + if (!exp.isa()) { + bool atLevel = false; + if (isInvariantAffine(codegen, exp, idx, atLevel) && atLevel) { + // If the compound affine is invariant and we are right at the + // level. We need to generate the address according to the affine + // expression. This is also the best place we can do it to avoid + // putting it inside inner loops. + // NOTE: It assumes that the levels of the input tensor are + // initialized in order, another more admissible approach might be + // accepting out-of-order access between consecutive dense levels. + affineTids.push_back(tid); + affineDims.push_back(i); + exps.push_back(exp); + } } } } @@ -1310,6 +1336,17 @@ codegen.loopEmitter.genDenseAffineAddressAtCurLevel(builder, op.getLoc(), tid, dim, exp); } + + // Until now, we have entered every pair in {cond, extra, + // affine}Tids/Dims. The addresses of the upcoming levels which are dependent + // on constant affines expression may now be determined. + auto allTids = llvm::concat(condTids, extraTids, affineTids); + auto allDims = llvm::concat(condDims, extraDims, affineDims); + for (auto [tid, dim] : llvm::zip(allTids, allDims)) { + if (tid != merger.getOutTensorID()) + genConstantDenseAddressFromLevel(codegen, builder, op, tid, dim + 1); + } + return loop; } @@ -1437,7 +1474,6 @@ //===----------------------------------------------------------------------===// namespace { - /// Sparse rewriting rule for generic Lingalg operation. struct GenericOpSparsifier : public OpRewritePattern { public: @@ -1505,6 +1541,7 @@ CodeGen codegen(options, op.getContext(), tensors, numTensors, numLoops, sparseOut, outerParNest, topSort); genBuffers(merger, codegen, rewriter, op); + genInitConstantDenseAddress(codegen, rewriter, op); genStmt(merger, codegen, rewriter, op, exp, 0); genResult(merger, codegen, rewriter, op); return success(); diff --git a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_affine.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_affine.mlir @@ -225,3 +225,64 @@ } -> tensor<32x16xf64> return %0 : tensor<32x16xf64> } + +#trait5 = { + indexing_maps = [ + affine_map<(i,j) -> (2,j)>, // a + affine_map<(i,j) -> (i,3)>, // b + affine_map<(i,j) -> (i,j)> // x (out) + ], + iterator_types = ["parallel","parallel"], + doc = "x(i,j) += a(2,j) * b(i,3)" +} + +// CHECK-LABEL: func.func @mul_const_affine_dense_dim_2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<34x16xf64, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<32x16xf64>) -> tensor<32x16xf64> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 19 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 3 : index +// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]] {dimension = 1 : index} : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]] {dimension = 1 : index} : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<34x16xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 0 : index} : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 0 : index} : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<32x19xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : memref<32x16xf64> +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_6]]] : memref +// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] { +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_3]] : index +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_19]], %[[VAL_7]] : index +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_7]]] : memref +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_22]] to %[[VAL_23]] step %[[VAL_6]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref<32x16xf64> +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_27]], %[[VAL_21]] : f64 +// CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_26]], %[[VAL_28]] : f64 +// CHECK: memref.store %[[VAL_29]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_25]]] : memref<32x16xf64> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16xf64> +// CHECK: return %[[VAL_30]] : tensor<32x16xf64> +// CHECK: } +func.func @mul_const_affine_dense_dim_2d(%arga: tensor<34x16xf64, #CSR>, + %argb: tensor<32x19xf64, #Row>, + %argx: tensor<32x16xf64>) -> tensor<32x16xf64> { + %0 = linalg.generic #trait5 + ins(%arga, %argb: tensor<34x16xf64, #CSR>, tensor<32x19xf64, #Row>) + outs(%argx: tensor<32x16xf64>) { + ^bb(%a: f64, %b: f64, %x: f64): + %0 = arith.mulf %a, %b : f64 + %1 = arith.addf %x, %0 : f64 + linalg.yield %1 : f64 + } -> tensor<32x16xf64> + return %0 : tensor<32x16xf64> +}