diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -28,9 +28,128 @@ namespace sparse_tensor { //===----------------------------------------------------------------------===// -// ExecutionEngine/SparseTensorUtils helper functions. +// SparseTensorLoopEmiter class, manages sparse tensors and helps to generate +// loop structure to (co-iterate) sparse tensors. +// +// An example usage: +// To generate following loops over T1 and T2 +// +// for i in T1[0] { +// for j : T2[0] { +// for k : T1[1] {} +// for k : T2[1] {} +// } +// } +// +// One can use +// +// SparseTensorLoopEmiter loopEmiter({T1, T1}); +// loopEmiter.initializeLoopEmit(); +// loopEmiter.enterLoopOverTensorAtDim(T1, 0); +// loopEmiter.enterLoopOverTensorAtDim(T2, 0); +// loopEmiter.enterLoopOverTensorAtDim(T1, 1); +// loopEmiter.exitCurrentLoop(); +// loopEmiter.enterLoopOverTensorAtDim(T2, 1); +// for 0 -> 3: +// loopEmiter.exitCurrentLoop(); //===----------------------------------------------------------------------===// +// TODO: Sparsification should also rely on this class to generate loops. +class SparseTensorLoopEmitter { +public: + /// Constructor: take an array of tensors inputs, on which the generated loops + /// will iterate on. The index of the tensor in the array is also the + /// tensor id (tid) used in related functions. + explicit SparseTensorLoopEmitter(ValueRange tensors); + + /// + /// Core functions. + /// + + /// Starts a loop emitting session: + /// 1. Generates all the buffers needed to iterate tensors. + /// 2. Generates the lo/hi bounds to iterate tensors[0]. + void initializeLoopEmit(OpBuilder &builder, Location loc); + + // TODO: Gets rid of `dim` in the argument list? Track the dimension we + // are currently at internally. Then it would be enterNextDimForTensor. + + /// Emits loop over tensor[dim], it assumes that loops between + /// tensor[0...dim - 1] have already been generated. + /// It also prepares to enter tensor[dim + 1]. + Operation *enterLoopOverTensorAtDim(OpBuilder &builder, Location loc, + size_t tid, size_t dim, + ArrayRef reduc = {}); + + /// Emits a coiteration loop over a set of tensors. + // TODO: not yet implemented + void enterCoiterationOverTensorsAtDims(OpBuilder &builder, Location loc, + ArrayRef ts, + ArrayRef ds); + + /// Emits extra locals, since the locals might not be in simplified lattices + /// point used to generate the loops, but are still required to generates + /// expressions. + Value emitExtraLocalsForTensorsAtDims(OpBuilder &builder, Location loc, + size_t tid, size_t dim); + + void exitCurrentLoop(); + + /// Return the array of coordinate for all the loop generated till now. + void getCoordinateArray(SmallVectorImpl &coords) { + for (auto &l : loopStack) + coords.push_back(l.idx); + } + + /// + /// Getters. + /// + + Value getTensorValueBuffer(size_t tid) { return valBuffer[tid]; } + Value getLastLevelTensorPointerIndex(size_t tid) { + return pidxs[tid].back(); + }; + +private: + struct LoopLevelInfo { + LoopLevelInfo(ArrayRef ts, ArrayRef ds, Value idx) + : tensors(ts), dims(ds), idx(idx) {} + llvm::SmallVector tensors; + llvm::SmallVector dims; + Value idx; + }; + + /// Return false if tid[dim] is a dense dimension that does not need to be + /// prepared (to be used by sparsification for needUniv). + bool prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid, + size_t dim); + + /// Input (TODO: and output) tensors. + std::vector tensors; + /// The dim type array for each tensor. + std::vector> dims; + /// Sparse iteration information (by tensor and dim). These arrays + /// are updated to remain current within the current loop. + std::vector> pidxs; + std::vector> coord; + std::vector> highs; + /// Universal dense indices and upper bounds (by index). The sizes array is + /// set once with the inferred dimension sizes. + std::vector> sizes; + std::vector> ptrBuffer; // to_pointers + std::vector> idxBuffer; // to_indices + std::vector valBuffer; // to_value + + std::vector loopStack; + // TODO: not yet used, it should track the current level for each tensor + // to help eliminate `dim` paramters from above APIs. + std::vector curLv; +}; + +//===----------------------------------------------------------------------===// +// ExecutionEngine/SparseTensorUtils helper functions. +//===----------------------------------------------------------------------===// +// /// Converts an overhead storage bitwidth to its internal type-encoding. OverheadType overheadTypeEncoding(unsigned width); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -8,12 +8,237 @@ #include "CodegenUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::sparse_tensor; +/// Generates a pointer/index load from the sparse storage scheme. Narrower +/// data types need to be zero extended before casting the value into the +/// index type used for looping and indexing. +static Value genIndexLoad(OpBuilder &builder, Location loc, Value ptr, + Value s) { + // For the scalar case, we simply zero extend narrower indices into 64-bit + // values before casting to index without a performance penalty. Here too, + // however, indices that already are 64-bit, in theory, cannot express the + // full range as explained above. + Value load = builder.create(loc, ptr, s); + if (!load.getType().isa()) { + if (load.getType().getIntOrFloatBitWidth() < 64) + load = builder.create(loc, builder.getI64Type(), load); + load = + builder.create(loc, builder.getIndexType(), load); + } + return load; +} + +//===----------------------------------------------------------------------===// +// Sparse tensor loop emitter class implementations +//===----------------------------------------------------------------------===// + +SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors) + : tensors(tensors.begin(), tensors.end()), dims(tensors.size()), + pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()), + sizes(tensors.size()), ptrBuffer(tensors.size()), + idxBuffer(tensors.size()), valBuffer(tensors.size()), loopStack(), + curLv(tensors.size(), 0) { + for (size_t i = 0, e = tensors.size(); i < e; i++) { + auto t = tensors[i]; + auto rtp = t.getType().cast(); + auto rank = static_cast(rtp.getRank()); + auto enc = getSparseTensorEncoding(rtp); + if (enc) + for (auto dimTp : enc.getDimLevelType()) + dims[i].push_back(dimTp); + else + dims[i].assign(rank, SparseTensorEncodingAttr::DimLevelType::Dense); + + // Initialize using empty value. + pidxs[i].assign(rank, Value()); + coord[i].assign(rank, Value()); + highs[i].assign(rank, Value()); + sizes[i].assign(rank, Value()); + ptrBuffer[i].assign(rank, Value()); + idxBuffer[i].assign(rank, Value()); + } +} + +void SparseTensorLoopEmitter::initializeLoopEmit(OpBuilder &builder, + Location loc) { + // For every tensor, find lower and upper bound on dimensions, set the + // same bounds on loop indices, and obtain dense or sparse buffer(s). + // TODO: Provides ability to generate loop on output buffer (with undef + // dim level in Merger in GenericOp Sparsification). + for (size_t t = 0, e = tensors.size(); t < e; t++) { + auto tensor = tensors[t]; + auto rtp = tensor.getType().cast(); + auto rank = rtp.getRank(); + auto shape = rtp.getShape(); + auto enc = getSparseTensorEncoding(rtp); + auto dynShape = {ShapedType::kDynamicSize}; + // Scan all dimensions of current tensor. + for (int64_t d = 0; d < rank; d++) { + // This should be called only once at beginning. + assert(!ptrBuffer[t][d] && !idxBuffer[t][d] && !sizes[t][d] && + !highs[t][d]); + // Handle sparse storage schemes. + if (isCompressedDim(dims[t][d])) { + auto ptrTp = + MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); + auto indTp = + MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); + auto dim = builder.getIndexAttr(d); + // Generate sparse primitives to obtains pointer and indices. + ptrBuffer[t][d] = builder.create(loc, ptrTp, tensor, dim); + idxBuffer[t][d] = builder.create(loc, indTp, tensor, dim); + } else if (isSingletonDim(dims[t][d])) { + llvm_unreachable("TODO: not implemented yet"); + } + + // Find upper bound in current dimension. + unsigned p = toOrigDim(enc, d); + Value up = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, p); + sizes[t][d] = highs[t][d] = up; + } + // Perform the required bufferization. Dense inputs materialize + // from the input tensors. Dense outputs need special handling. + // Sparse inputs use sparse primitives to obtain the values. + Type elementType = rtp.getElementType(); + + if (!enc) { + // Non-annotated dense tensors. + auto denseTp = MemRefType::get(shape, elementType); + // This is not the output tensor + valBuffer[t] = + builder.create(loc, denseTp, tensor); + } else { + // Annotated sparse tensors. + auto dynShape = {ShapedType::kDynamicSize}; + auto sparseTp = MemRefType::get(dynShape, elementType); + valBuffer[t] = builder.create(loc, sparseTp, tensor); + } + // Prepare to enter the first dim for all (input) tensors + prepareLoopOverTensorAtDim(builder, loc, t, 0); + } +} + +Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim( + OpBuilder &builder, Location loc, size_t tid, size_t dim, + ArrayRef reduc) { + assert(dims[tid].size() > dim); + // We can not re-enter the same level. + assert(!coord[tid][dim]); + Value step = constantIndex(builder, loc, 1); + bool isCompressed = isCompressedDim(dims[tid][dim]); + assert(isDenseDim(dims[tid][dim]) || isCompressedDim(dims[tid][dim])); + + Value lo = isCompressed ? pidxs[tid][dim] : constantIndex(builder, loc, 0); + Value hi = highs[tid][dim]; + + // TODO: support reduction. + if (!reduc.empty()) + llvm_unreachable("TODO: not implemented yet"); + + scf::ForOp forOp = builder.create(loc, lo, hi, step, reduc); + builder.setInsertionPointToStart(forOp.getBody()); + Value iv = forOp.getInductionVar(); + Operation *loop = forOp; + + assert(iv); + if (isCompressed) { + pidxs[tid][dim] = iv; + // Generating a load on the indices array yields the coordinate. + Value ptr = idxBuffer[tid][dim]; + // TODO: generates load for vector value. + coord[tid][dim] = genIndexLoad(builder, loc, ptr, iv); + } else { + // Dense tensor, the coordinates is the inducation variable. + coord[tid][dim] = iv; + // generate pidx for dense dim (pidx = i * sz + j) + // TODO: handle vector loop. + Value p = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1]; + Value mul = builder.create(loc, sizes[tid][dim], p); + Value add = builder.create(loc, mul, iv); + pidxs[tid][dim] = add; + } + + // Prepares for next dim if this is not currently the innermost dimension. + if (dim != dims[tid].size() - 1) + prepareLoopOverTensorAtDim(builder, loc, tid, dim + 1); + + loopStack.push_back(LoopLevelInfo({tid}, {dim}, coord[tid][dim])); + return loop; +} + +void SparseTensorLoopEmitter::enterCoiterationOverTensorsAtDims( + OpBuilder &builder, Location loc, ArrayRef ts, + ArrayRef ds) { + llvm_unreachable("TODO: unimplemented"); +} + +bool SparseTensorLoopEmitter::prepareLoopOverTensorAtDim(OpBuilder &builder, + Location loc, + size_t tid, + size_t dim) { + // TODO: generate loop iteration on output tensor based on the shape + // instead of pointer/indices arrays. + assert(dims[tid].size() > dim); + + if (isDenseDim(dims[tid][dim])) + return false; + + // Either the first dimension, or the previous dimension has been set. + assert(dim == 0 || pidxs[tid][dim - 1]); + if (isCompressedDim(dims[tid][dim])) { + Value ptr = ptrBuffer[tid][dim]; + Value c1 = constantIndex(builder, loc, 1); + Value pLo = dim == 0 ? constantIndex(builder, loc, 0) : pidxs[tid][dim - 1]; + Value pHi = builder.create(loc, pLo, c1); + + pidxs[tid][dim] = genIndexLoad(builder, loc, ptr, pLo); + highs[tid][dim] = genIndexLoad(builder, loc, ptr, pHi); + + return true; + } + + if (isSingletonDim(dims[tid][dim])) + llvm_unreachable("TODO: not implemented yet"); + + llvm_unreachable("Unrecognizable dimesion type!"); +} + +Value SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDims( + OpBuilder &builder, Location loc, size_t tid, size_t dim) { + llvm_unreachable("TODO: not implemented yet"); +} + +void SparseTensorLoopEmitter::exitCurrentLoop() { + // Clean up the values, it would help use to discover potential bug at a + // earlier stage (instead of silently using a wrong value). + LoopLevelInfo &loopInfo = loopStack.back(); + assert(loopInfo.tensors.size() == loopInfo.dims.size()); + for (auto info : llvm::zip(loopInfo.tensors, loopInfo.dims)) { + auto tid = std::get<0>(info); + auto dim = std::get<1>(info); + assert(pidxs[tid][dim] && coord[tid][dim] && highs[tid][dim]); + // Reset to null. + pidxs[tid][dim] = Value(); + coord[tid][dim] = Value(); + if (!isDenseDim(dims[tid][dim])) + // Dense dimension, high is fixed. + highs[tid][dim] = Value(); + } + loopStack.pop_back(); +} + //===----------------------------------------------------------------------===// // ExecutionEngine/SparseTensorUtils helper functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -279,7 +280,8 @@ auto convert = rewriter.create(loc, denseTp, op.getSrc()); op->setOperand(0, convert); return success(); - } else if (encDst) { + } + if (encDst) { RankedTensorType rtp = op.getResult().getType().template cast(); auto denseTp = @@ -294,6 +296,60 @@ } }; +/// Sparse rewriting rule for the foreach operator. +struct ForeachRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForeachOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + Value input = op.getTensor(); + auto rtp = input.getType().cast(); + int64_t rank = rtp.getRank(); + auto enc = getSparseTensorEncoding(rtp); + + // 1. Generates loop for the sparse input. + SparseTensorLoopEmitter loopEmitter(ValueRange{input}); + loopEmitter.initializeLoopEmit(rewriter, loc); + for (int64_t i = 0; i < rank; i++) + loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i); + + Value vals = loopEmitter.getTensorValueBuffer(0); + Value idx = loopEmitter.getLastLevelTensorPointerIndex(0); + Value val = rewriter.create(op.getLoc(), vals, idx); + + SmallVector coords; + coords.reserve(rank); + loopEmitter.getCoordinateArray(coords); + + for (int64_t i = 0; i < rank; i++) + loopEmitter.exitCurrentLoop(); + + // 2. Inline the block in the foreach operator. + Block::iterator inlinePos = rewriter.getInsertionPoint(); + Block *srcBlock = op.getBody(); + // Remove sparse_tensor.yield. + rewriter.eraseOp(srcBlock->getTerminator()); + + SmallVector args; + // Remap coordinates. + for (int64_t i = 0; i < rank; i++) { + Value actual = coords[toOrigDim(enc, i)]; + args.push_back(actual); + } + // Remap value. + args.push_back(val); + + // Inline body. + rewriter.mergeBlockBefore(srcBlock, &*inlinePos, args); + // delete the foreach operator. + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace //===---------------------------------------------------------------------===// @@ -301,9 +357,10 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns, - bool /*enableRT*/) { + bool enableRT) { patterns.add, - ReshapeRewriter>(patterns.getContext()); + ReshapeRewriter, ForeachRewriter>( + patterns.getContext()); // TODO: If RT not enabled, rewrite concatenate ops, etc here. } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#Row = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "dense" ] +}> + +#CSR = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ] +}> + +#DCSC = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +module { + + /// uses foreach operator to print coords and values. + func.func @foreach_print_1(%arg0: tensor<2x2xf64, #Row>) { + sparse_tensor.foreach in %arg0 : tensor<2x2xf64, #Row> do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + + func.func @foreach_print_2(%arg0: tensor<2x2xf64, #CSR>) { + sparse_tensor.foreach in %arg0 : tensor<2x2xf64, #CSR> do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + + func.func @foreach_print_3(%arg0: tensor<2x2xf64, #DCSC>) { + sparse_tensor.foreach in %arg0 : tensor<2x2xf64, #DCSC> do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + + // + // Main driver. + // + func.func @entry() { + // + // Initialize a 3-dim dense tensor. + // + %src = arith.constant dense< + [[ 1.0, 2.0], + [ 5.0, 6.0]] + > : tensor<2x2xf64> + + // + // Convert dense tensor directly to various sparse tensors. + // + %s1 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #Row> + %s2 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #CSR> + %s3 = sparse_tensor.convert %src : tensor<2x2xf64> to tensor<2x2xf64, #DCSC> + // CHECK: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1 + // CHECK-NEXT: 6 + call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> () + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1 + // CHECK-NEXT: 6 + call @foreach_print_2(%s2) : (tensor<2x2xf64, #CSR>) -> () + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1 + // CHECK-NEXT: 6 + call @foreach_print_3(%s3) : (tensor<2x2xf64, #DCSC>) -> () + + bufferization.dealloc_tensor %s1 : tensor<2x2xf64, #Row> + bufferization.dealloc_tensor %s2 : tensor<2x2xf64, #CSR> + bufferization.dealloc_tensor %s3 : tensor<2x2xf64, #DCSC> + + return + } +}