diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -28,6 +28,7 @@ // Leaf. kTensor = 0, kInvariant, + kIndex, // Unary operations. kAbsF, kCeilF, @@ -42,6 +43,7 @@ kCastUF, // unsigned kCastS, // signed kCastU, // unsigned + kCastIdx, kTruncI, kBitCast, // Binary operations. @@ -79,6 +81,9 @@ /// Expressions representing tensors simply have a tensor number. unsigned tensor; + /// Indices hold the index number. + unsigned index; + /// Tensor operations hold the indices of their children. Children children; }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "CodegenUtils.h" + #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" @@ -870,6 +871,13 @@ return rewriter.create(loc, mul, i); } +/// Generates an index value. +static Value genIndexValue(Merger &merger, CodeGen &codegen, unsigned exp) { + assert(codegen.curVecLength == 1); // TODO: implement vectorization! + unsigned idx = merger.exp(exp).index; + return codegen.loops[idx]; +} + /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned exp) { @@ -880,6 +888,8 @@ return genTensorLoad(merger, codegen, rewriter, op, exp); if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); + if (merger.exp(exp).kind == Kind::kIndex) + return genIndexValue(merger, codegen, exp); Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); return merger.buildExp(rewriter, loc, exp, v0, v1); @@ -947,7 +957,8 @@ merger.exp(exp).val = atStart ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); } - } else if (merger.exp(exp).kind != Kind::kInvariant) { + } else if (merger.exp(exp).kind != Kind::kInvariant && + merger.exp(exp).kind != Kind::kIndex) { // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. @@ -1039,7 +1050,12 @@ /// Returns vectorization strategy. Any implicit inner loop in the Linalg /// operation is a candidate. Whether it is actually converted to SIMD code /// depends on the requested strategy. -static bool isVectorFor(CodeGen &codegen, bool isInner, bool isSparse) { +static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, + bool isSparse) { + // Reject vectorization of sparse output, unless innermost is reduction. + if (codegen.sparseOut && !isReduction) + return false; + // Inspect strategy. switch (codegen.options.vectorizationStrategy) { case SparseVectorizationStrategy::kNone: return false; @@ -1056,6 +1072,10 @@ /// to a parallel operation depends on the requested strategy. static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, bool isSparse, bool isVector) { + // Reject parallelization of sparse output. + if (codegen.sparseOut) + return false; + // Inspect strategy. switch (codegen.options.parallelizationStrategy) { case SparseParallelizationStrategy::kNone: return false; @@ -1107,11 +1127,9 @@ auto iteratorTypes = op.iterator_types().getValue(); bool isReduction = isReductionIterator(iteratorTypes[idx]); bool isSparse = merger.isDim(fb, Dim::kSparse); - bool isVector = !codegen.sparseOut && - isVectorFor(codegen, isInner, isSparse) && + bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && denseUnitStrides(merger, op, idx); bool isParallel = - !codegen.sparseOut && isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); // Prepare vector length. @@ -1626,6 +1644,7 @@ LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { + // Detects sparse annotations and translate the per-dimension sparsity // information for all tensors to loop indices in the kernel. assert(op.getNumOutputs() == 1); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -29,6 +29,10 @@ case kInvariant: assert(x == -1u && y == -1u && v); break; + case kIndex: + assert(x != -1u && y == -1u && !v); + index = x; + break; case kAbsF: case kCeilF: case kFloorF: @@ -46,6 +50,7 @@ case kCastUF: case kCastS: case kCastU: + case kCastIdx: case kTruncI: case kBitCast: assert(x != -1u && y == -1u && v); @@ -230,6 +235,7 @@ case kCastUF: case kCastS: case kCastU: + case kCastIdx: case kTruncI: case kBitCast: return isSingleCondition(t, tensorExps[e].children.e0); @@ -273,6 +279,8 @@ return "tensor"; case kInvariant: return "invariant"; + case kIndex: + return "index"; case kAbsF: return "abs"; case kCeilF: @@ -291,6 +299,7 @@ case kCastUF: case kCastS: case kCastU: + case kCastIdx: case kTruncI: case kBitCast: return "cast"; @@ -340,6 +349,9 @@ case kInvariant: llvm::dbgs() << "invariant"; break; + case kIndex: + llvm::dbgs() << "index_" << tensorExps[e].index; + break; case kAbsF: case kCeilF: case kFloorF: @@ -353,6 +365,7 @@ case kCastUF: case kCastS: case kCastU: + case kCastIdx: case kTruncI: case kBitCast: llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; @@ -420,16 +433,20 @@ Kind kind = tensorExps[e].kind; switch (kind) { case kTensor: - case kInvariant: { + case kInvariant: + case kIndex: { // Either the index is really used in the tensor expression, or it is - // set to the undefined index in that dimension. An invariant expression - // and a truly dynamic sparse output tensor are set to a synthetic tensor - // with undefined indices only to ensure the iteration space is not - // skipped as a result of their contents. + // set to the undefined index in that dimension. An invariant expression, + // a proper index value, and a truly dynamic sparse output tensor are set + // to a synthetic tensor with undefined indices only to ensure the + // iteration space is not skipped as a result of their contents. unsigned s = addSet(); - unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor; - if (hasSparseOut && t == outTensor) - t = syntheticTensor; + unsigned t = syntheticTensor; + if (kind == kTensor) { + t = tensorExps[e].tensor; + if (hasSparseOut && t == outTensor) + t = syntheticTensor; + } latSets[s].push_back(addLat(t, i, e)); return s; } @@ -446,6 +463,7 @@ case kCastUF: case kCastS: case kCastU: + case kCastIdx: case kTruncI: case kBitCast: // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the @@ -569,6 +587,11 @@ Operation *def = v.getDefiningOp(); if (def->getBlock() != &op.region().front()) return addExp(kInvariant, v); + // Construct index operations. + if (def->getNumOperands() == 0) { + if (auto indexOp = dyn_cast(def)) + return addExp(kIndex, indexOp.dim()); + } // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { auto x = buildTensorExp(op, def->getOperand(0)); @@ -598,6 +621,8 @@ return addExp(kCastS, e, v); if (isa(def)) return addExp(kCastU, e, v); + if (isa(def)) + return addExp(kCastIdx, e, v); if (isa(def)) return addExp(kTruncI, e, v); if (isa(def)) @@ -654,6 +679,7 @@ switch (tensorExps[e].kind) { case kTensor: case kInvariant: + case kIndex: llvm_unreachable("unexpected non-op"); // Unary ops. case kAbsF: @@ -686,6 +712,8 @@ return rewriter.create(loc, inferType(e, v0), v0); case kCastU: return rewriter.create(loc, inferType(e, v0), v0); + case kCastIdx: + return rewriter.create(loc, inferType(e, v0), v0); case kTruncI: return rewriter.create(loc, inferType(e, v0), v0); case kBitCast: diff --git a/mlir/test/Dialect/SparseTensor/sparse_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_index.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_index.mlir @@ -0,0 +1,128 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#DenseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "dense"] +}> + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["compressed", "compressed"] +}> + +#trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * i * j" +} + +// CHECK-LABEL: func @dense_index( +// CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_18]] : i64 +// CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_16]], %[[VAL_19]] : i64 +// CHECK: memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_21:.*]] = sparse_tensor.load %[[VAL_5]] : tensor) + -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %0 = tensor.dim %arga, %c0 : tensor + %1 = tensor.dim %arga, %c1 : tensor + %init = sparse_tensor.init [%0, %1] : tensor + %r = linalg.generic #trait + ins(%arga: tensor) + outs(%init: tensor) { + ^bb(%a: i64, %x: i64): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %ii = arith.index_cast %i : index to i64 + %jj = arith.index_cast %j : index to i64 + %m1 = arith.muli %ii, %a : i64 + %m2 = arith.muli %jj, %m1 : i64 + linalg.yield %m2 : i64 + } -> tensor + return %r : tensor +} + +// CHECK-LABEL: func @sparse_index( +// CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_1]]] : memref +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_2]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_2]] { +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref +// CHECK: memref.store %[[VAL_16]], %[[VAL_12]]{{\[}}%[[VAL_1]]] : memref +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_2]] { +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref +// CHECK: memref.store %[[VAL_21]], %[[VAL_12]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_22:.*]] = arith.index_cast %[[VAL_21]] : index to i64 +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_16]] : index to i64 +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_23]], %[[VAL_24]] : i64 +// CHECK: %[[VAL_26:.*]] = arith.muli %[[VAL_22]], %[[VAL_25]] : i64 +// CHECK: sparse_tensor.lex_insert %[[VAL_6]], %[[VAL_12]], %[[VAL_26]] : tensor) + -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %0 = tensor.dim %arga, %c0 : tensor + %1 = tensor.dim %arga, %c1 : tensor + %init = sparse_tensor.init [%0, %1] : tensor + %r = linalg.generic #trait + ins(%arga: tensor) + outs(%init: tensor) { + ^bb(%a: i64, %x: i64): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %ii = arith.index_cast %i : index to i64 + %jj = arith.index_cast %j : index to i64 + %m1 = arith.muli %ii, %a : i64 + %m2 = arith.muli %jj, %m1 : i64 + linalg.yield %m2 : i64 + } -> tensor + return %r : tensor +} + diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseMatrix = #sparse_tensor.encoding<{ + dimLevelType = ["compressed", "compressed"] +}> + +#trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * i * j" +} + +module { + + // + // Kernel that uses indices in the index notation. + // + func @sparse_index(%arga: tensor<3x4xi64, #SparseMatrix>) + -> tensor<3x4xi64, #SparseMatrix> { + %d0 = arith.constant 3 : index + %d1 = arith.constant 4 : index + %init = sparse_tensor.init [%d0, %d1] : tensor<3x4xi64, #SparseMatrix> + %r = linalg.generic #trait + ins(%arga: tensor<3x4xi64, #SparseMatrix>) + outs(%init: tensor<3x4xi64, #SparseMatrix>) { + ^bb(%a: i64, %x: i64): + %i = linalg.index 0 : index + %j = linalg.index 1 : index + %ii = arith.index_cast %i : index to i64 + %jj = arith.index_cast %j : index to i64 + %m1 = arith.muli %ii, %a : i64 + %m2 = arith.muli %jj, %m1 : i64 + linalg.yield %m2 : i64 + } -> tensor<3x4xi64, #SparseMatrix> + return %r : tensor<3x4xi64, #SparseMatrix> + } + + // + // Main driver. + // + func @entry() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %du = arith.constant -1 : i64 + + // Setup input "sparse" matrix. + %d = arith.constant dense <[ + [ 1, 1, 1, 1 ], + [ 1, 1, 1, 1 ], + [ 1, 1, 1, 1 ] + ]> : tensor<3x4xi64> + %a = sparse_tensor.convert %d : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix> + + // Call the kernel. + %0 = call @sparse_index(%a) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64, #SparseMatrix> + + // + // Verify result. + // + // CHECK: ( ( 0, 0, 0, 0 ), ( 0, 1, 2, 3 ), ( 0, 2, 4, 6 ) ) + // + %x = sparse_tensor.convert %0 : tensor<3x4xi64, #SparseMatrix> to tensor<3x4xi64> + %m = bufferization.to_memref %x : memref<3x4xi64> + %v = vector.transfer_read %m[%c0, %c0], %du: memref<3x4xi64>, vector<3x4xi64> + vector.print %v : vector<3x4xi64> + + // Release resources. + sparse_tensor.release %a : tensor<3x4xi64, #SparseMatrix> + sparse_tensor.release %0 : tensor<3x4xi64, #SparseMatrix> + memref.dealloc %m : memref<3x4xi64> + + return + } +}