diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -219,10 +219,13 @@ assert(tp1.getRank() == tp2.getRank()); auto shape1 = tp1.getShape(); auto shape2 = tp2.getShape(); - for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) + for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { if (shape1[d] != shape2[d]) return op.emitError() << "unexpected conversion mismatch in dimension " << d; + if (shape1[d] == MemRefType::kDynamicSize) + return op.emitError("unexpected dynamic size"); + } return success(); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -27,6 +27,10 @@ namespace { +//===----------------------------------------------------------------------===// +// Helper methods. +//===----------------------------------------------------------------------===// + /// Returns internal type encoding for primary storage. Keep these /// values consistent with the sparse runtime support library. static unsigned getPrimaryTypeEncoding(Type tp) { @@ -105,6 +109,109 @@ return SymbolRefAttr::get(context, name); } +/// Generates a call into the "swiss army knife" method of the sparse runtime +/// support library for materializing sparse tensors into the computation. +static void genNewCall(ConversionPatternRewriter &rewriter, Operation *op, + SparseTensorEncodingAttr &enc, uint32_t action, + Value ptr) { + Location loc = op->getLoc(); + ShapedType resType = op->getResult(0).getType().cast(); + SmallVector params; + // Sparsity annotations in tensor constant form. + SmallVector attrs; + unsigned sz = enc.getDimLevelType().size(); + for (unsigned i = 0; i < sz; i++) + attrs.push_back( + APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); + params.push_back(getTensor(rewriter, 8, loc, attrs)); + // Dimension sizes array of the enveloping *dense* tensor. Useful for either + // verification of external data, or for construction of internal data. + auto shape = resType.getShape(); + SmallVector sizes; + for (unsigned i = 0; i < sz; i++) { + uint64_t s = shape[i] == ShapedType::kDynamicSize ? 0 : shape[i]; + sizes.push_back(APInt(64, s)); + } + params.push_back(getTensor(rewriter, 64, loc, sizes)); + // Dimension order permutation array. This is the "identity" permutation by + // default, or otherwise the "reverse" permutation of a given ordering, so + // that indices can be mapped quickly to the right position. + SmallVector perm(sz); + AffineMap p = enc.getDimOrdering(); + if (p) { + assert(p.isPermutation() && p.getNumResults() == sz); + for (unsigned i = 0; i < sz; i++) + perm[p.getDimPosition(i)] = APInt(64, i); + } else { + for (unsigned i = 0; i < sz; i++) + perm[i] = APInt(64, i); + } + params.push_back(getTensor(rewriter, 64, loc, perm)); + // Secondary and primary types encoding. + unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); + unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); + unsigned primary = getPrimaryTypeEncoding(resType.getElementType()); + assert(primary); + params.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(secPtr))); + params.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(secInd))); + params.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(primary))); + // User action and pointer. + params.push_back( + rewriter.create(loc, rewriter.getI32IntegerAttr(action))); + params.push_back(ptr); + // Generate the call to create new tensor. + Type ptrType = + LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); + StringRef name = "newSparseTensor"; + rewriter.replaceOpWithNewOp( + op, ptrType, getFunc(op, name, ptrType, params), params); +} + +/// Generates a call that exposes the data pointer as a void pointer. +// TODO: probing the data pointer directly is a bit raw; we should replace +// this with proper memref util calls once they become available. +static bool genPtrCall(ConversionPatternRewriter &rewriter, Operation *op, + Value val, Value &ptr) { + Location loc = op->getLoc(); + ShapedType sType = op->getResult(0).getType().cast(); + Type eltType = sType.getElementType(); + // Specialize name for the data type. Even though the final buffferized + // version only operates on pointers, different names are required to + // ensure type correctness for all intermediate states. + StringRef name; + if (eltType.isF64()) + name = "getPtrF64"; + else if (eltType.isF32()) + name = "getPtrF32"; + else if (eltType.isInteger(64)) + name = "getPtrI64"; + else if (eltType.isInteger(32)) + name = "getPtrI32"; + else if (eltType.isInteger(16)) + name = "getPtrI16"; + else if (eltType.isInteger(8)) + name = "getPtrI8"; + else + return false; + auto memRefTp = MemRefType::get(sType.getShape(), eltType); + auto unrankedTp = UnrankedMemRefType::get(eltType, 0); + Value c = rewriter.create(loc, memRefTp, val); + Value d = rewriter.create(loc, unrankedTp, c); + Type ptrType = + LLVM::LLVMPointerType::get(IntegerType::get(op->getContext(), 8)); + auto call = + rewriter.create(loc, ptrType, getFunc(op, name, ptrType, d), d); + ptr = call.getResult(0); + return true; +} + +//===----------------------------------------------------------------------===// +// Conversion rules. +//===----------------------------------------------------------------------===// + /// Sparse conversion rule for returns. class SparseReturnConverter : public OpConversionPattern { public: @@ -141,56 +248,11 @@ LogicalResult matchAndRewrite(NewOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); Type resType = op.getType(); - Type eltType = resType.cast().getElementType(); - MLIRContext *context = op->getContext(); - SmallVector params; - // Sparse encoding. auto enc = getSparseTensorEncoding(resType); if (!enc) return failure(); - // User pointer. - params.push_back(operands[0]); - // Sparsity annotations in tensor constant form. - SmallVector attrs; - unsigned sz = enc.getDimLevelType().size(); - for (unsigned i = 0; i < sz; i++) - attrs.push_back( - APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); - params.push_back(getTensor(rewriter, 8, loc, attrs)); - // Dimension order permutation array. This is the "identity" - // permutation by default, or otherwise the "reverse" permutation - // of a given ordering, so that indices can be mapped quickly - // to the right position. - SmallVector perm(sz); - AffineMap p = enc.getDimOrdering(); - if (p) { - assert(p.isPermutation() && p.getNumResults() == sz); - for (unsigned i = 0; i < sz; i++) - perm[p.getDimPosition(i)] = APInt(64, i); - } else { - for (unsigned i = 0; i < sz; i++) - perm[i] = APInt(64, i); - } - params.push_back(getTensor(rewriter, 64, loc, perm)); - // Secondary and primary types encoding. - unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); - unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); - unsigned primary = getPrimaryTypeEncoding(eltType); - if (!primary) - return failure(); - params.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(secPtr))); - params.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(secInd))); - params.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(primary))); - // Generate the call to create new tensor. - Type ptrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - StringRef name = "newSparseTensor"; - rewriter.replaceOpWithNewOp( - op, ptrType, getFunc(op, name, ptrType, params), params); + genNewCall(rewriter, op, enc, 0, operands[0]); return success(); } }; @@ -201,8 +263,19 @@ LogicalResult matchAndRewrite(ConvertOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - // TODO: implement conversions lowering - return failure(); + Type resType = op.getType(); + auto encDst = getSparseTensorEncoding(resType); + auto encSrc = getSparseTensorEncoding(op.source().getType()); + // TODO: implement sparse => sparse + // and sparse => dense + if (!encDst || encSrc) + return failure(); + // This is a dense => sparse conversion. + Value ptr; + if (!genPtrCall(rewriter, op, operands[0], ptr)) + return failure(); + genNewCall(rewriter, op, encDst, 1, ptr); + return success(); } }; @@ -325,6 +398,10 @@ } // namespace +//===----------------------------------------------------------------------===// +// Public method for populating conversion rules. +//===----------------------------------------------------------------------===// + /// Populates the given patterns list with conversion rules required for /// the sparsification of linear algebra operations. void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -97,7 +97,8 @@ RewritePatternSet patterns(ctx); SparseTensorTypeConverter converter; ConversionTarget target(*ctx); - target.addIllegalOp(); + target.addIllegalOp(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); target.addDynamicallyLegalOp([&](CallOp op) { @@ -105,8 +106,8 @@ }); target.addDynamicallyLegalOp( [&](ReturnOp op) { return converter.isLegal(op.getOperandTypes()); }); - target.addLegalOp(); - target.addLegalOp(); + target.addLegalOp(); populateFuncOpTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); populateSparseTensorConversionPatterns(converter, patterns); diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -71,14 +71,15 @@ template struct SparseTensor { public: - SparseTensor(const std::vector &szs, uint64_t capacity) + SparseTensor(const std::vector &szs, uint64_t capacity = 0) : sizes(szs), pos(0) { - elements.reserve(capacity); + if (capacity) + elements.reserve(capacity); } /// Adds element as indices and value. void add(const std::vector &ind, V val) { assert(getRank() == ind.size()); - for (int64_t r = 0, rank = getRank(); r < rank; r++) + for (uint64_t r = 0, rank = getRank(); r < rank; r++) assert(ind[r] < sizes[r]); // within bounds elements.emplace_back(Element(ind, val)); } @@ -97,7 +98,7 @@ /// Returns true if indices of e1 < indices of e2. static bool lexOrder(const Element &e1, const Element &e2) { assert(e1.indices.size() == e2.indices.size()); - for (int64_t r = 0, rank = e1.indices.size(); r < rank; r++) { + for (uint64_t r = 0, rank = e1.indices.size(); r < rank; r++) { if (e1.indices[r] == e2.indices[r]) continue; return e1.indices[r] < e2.indices[r]; @@ -332,7 +333,8 @@ /// Reads a sparse tensor with the given filename into a memory-resident /// sparse tensor in coordinate scheme. template -static SparseTensor *openTensor(char *filename, uint64_t *perm) { +static SparseTensor *openTensor(char *filename, uint64_t size, + uint64_t *sizes, uint64_t *perm) { // Open the file. FILE *file = fopen(filename, "r"); if (!file) { @@ -351,16 +353,19 @@ } // Prepare sparse tensor object with per-rank dimension sizes // and the number of nonzeros as initial capacity. - uint64_t rank = idata[0]; + assert(size == idata[0] && "rank mismatch"); uint64_t nnz = idata[1]; - std::vector indices(rank); - for (uint64_t r = 0; r < rank; r++) - indices[perm[r]] = idata[2 + r]; + std::vector indices(size); + for (uint64_t r = 0; r < size; r++) { + uint64_t sz = idata[2 + r]; + assert((sizes[r] == 0 || sizes[r] == sz) && "dimension size mismatch"); + indices[perm[r]] = sz; + } SparseTensor *tensor = new SparseTensor(indices, nnz); // Read all nonzero elements. for (uint64_t k = 0; k < nnz; k++) { uint64_t idx = -1; - for (uint64_t r = 0; r < rank; r++) { + for (uint64_t r = 0; r < size; r++) { if (fscanf(file, "%" PRIu64, &idx) != 1) { fprintf(stderr, "Cannot find next index in %s\n", filename); exit(1); @@ -382,6 +387,39 @@ return tensor; } +/// Helper to copy a linearized dense tensor. +template +static V *copyTensorTraverse(SparseTensor *tensor, + std::vector &indices, uint64_t r, + uint64_t rank, uint64_t *sizes, uint64_t *perm, + V *data) { + for (uint64_t i = 0, sz = sizes[r]; i < sz; i++) { + indices[perm[r]] = i; + if (r + 1 == rank) { + V d = *data++; + if (d) + tensor->add(indices, d); + } else { + data = + copyTensorTraverse(tensor, indices, r + 1, rank, sizes, perm, data); + } + } + return data; +} + +/// Copies the nonzeros of a linearized dense tensor into a memory-resident +/// sparse tensor in coordinate scheme. +template +static SparseTensor *copyTensor(uint64_t size, uint64_t *sizes, + uint64_t *perm, V *data) { + std::vector indices(size); + for (uint64_t r = 0; r < size; r++) + indices[perm[r]] = sizes[r]; + SparseTensor *tensor = new SparseTensor(indices); + copyTensorTraverse(tensor, indices, 0, size, sizes, perm, data); + return tensor; +} + } // anonymous namespace extern "C" { @@ -407,6 +445,11 @@ // //===----------------------------------------------------------------------===// +struct UnrankedMemRef { + uint64_t rank; + void *descriptor; +}; + #define TEMPLATE(NAME, TYPE) \ struct NAME { \ const TYPE *base; \ @@ -418,8 +461,11 @@ #define CASE(p, i, v, P, I, V) \ if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \ - SparseTensor *tensor = openTensor(filename, perm); \ - assert(asize == tensor->getRank()); \ + SparseTensor *tensor; \ + if (action == 0) \ + tensor = openTensor(static_cast(ptr), asize, sizes, perm); \ + else \ + tensor = copyTensor(asize, sizes, perm, static_cast(ptr)); \ return SparseTensorStorage::newSparseTensor(tensor, sparsity); \ } @@ -437,6 +483,9 @@ return {v->data(), v->data(), 0, {v->size()}, {1}}; \ } +#define PTR(NAME) \ + const void *NAME(int64_t sz, UnrankedMemRef *m) { return m->descriptor; } + TEMPLATE(MemRef1DU64, uint64_t); TEMPLATE(MemRef1DU32, uint32_t); TEMPLATE(MemRef1DU16, uint16_t); @@ -459,13 +508,18 @@ kI8 = 6 }; -void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata, - uint64_t aoff, uint64_t asize, uint64_t astride, - uint64_t *pbase, uint64_t *pdata, uint64_t poff, - uint64_t psize, uint64_t pstride, uint64_t ptrTp, - uint64_t indTp, uint64_t valTp) { - assert(astride == 1 && pstride == 1); +/// Constructs a new sparse tensor. This is the "swiss army knife" +/// method for materializing sparse tensors into the computation. +void *newSparseTensor(uint8_t *abase, uint8_t *adata, uint64_t aoff, + uint64_t asize, uint64_t astride, uint64_t *sbase, + uint64_t *sdata, uint64_t soff, uint64_t ssize, + uint64_t sstride, uint64_t *pbase, uint64_t *pdata, + uint64_t poff, uint64_t psize, uint64_t pstride, + uint64_t ptrTp, uint64_t indTp, uint64_t valTp, + uint32_t action, void *ptr) { + assert(astride == 1 && sstride == 1 && pstride == 1); uint8_t *sparsity = adata + aoff; + uint64_t *sizes = sdata + soff; uint64_t *perm = pdata + poff; // Double matrices with all combinations of overhead storage. @@ -524,10 +578,12 @@ exit(1); } +/// Returns size of sparse tensor in given dimension. uint64_t sparseDimSize(void *tensor, uint64_t d) { return static_cast(tensor)->getDimSize(d); } +/// Methods that provide direct access to pointers, indices, and values. IMPL2(MemRef1DU64, sparsePointers, uint64_t, getPointers) IMPL2(MemRef1DU64, sparsePointers64, uint64_t, getPointers) IMPL2(MemRef1DU32, sparsePointers32, uint32_t, getPointers) @@ -545,10 +601,19 @@ IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues) IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues) +/// Releases sparse tensor storage. void delSparseTensor(void *tensor) { delete static_cast(tensor); } +/// Helper to get pointer, one per value type. +PTR(getPtrF64) +PTR(getPtrF32) +PTR(getPtrI64) +PTR(getPtrI32) +PTR(getPtrI16) +PTR(getPtrI8) + #undef TEMPLATE #undef CASE #undef IMPL1 diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --sparse-tensor-conversion | FileCheck %s +// RUN: mlir-opt %s --sparse-tensor-conversion --canonicalize | FileCheck %s #DenseVector = #sparse_tensor.encoding<{ dimLevelType = ["dense"] @@ -42,11 +42,13 @@ // CHECK-LABEL: func @sparse_new1d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[D:.*]] = constant dense<1> : tensor<1xi8> -// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi8> to tensor -// CHECK: %[[P:.*]] = constant dense<0> : tensor<1xi64> -// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<1xi64> to tensor -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, tensor, i64, i64, i64) -> !llvm.ptr +// CHECK-DAG: %[[U:.*]] = constant dense<1> : tensor<1xi8> +// CHECK-DAG: %[[V:.*]] = constant dense<128> : tensor<1xi64> +// CHECK-DAG: %[[W:.*]] = constant dense<0> : tensor<1xi64> +// CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<1xi8> to tensor +// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<1xi64> to tensor +// CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<1xi64> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]]) // CHECK: return %[[T]] : !llvm.ptr func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<128xf64, #SparseVector> @@ -55,11 +57,13 @@ // CHECK-LABEL: func @sparse_new2d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[D:.*]] = constant dense<[0, 1]> : tensor<2xi8> -// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi8> to tensor -// CHECK: %[[P:.*]] = constant dense<[0, 1]> : tensor<2xi64> -// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<2xi64> to tensor -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, tensor, i64, i64, i64) -> !llvm.ptr +// CHECK-DAG: %[[U:.*]] = constant dense<[0, 1]> : tensor<2xi8> +// CHECK-DAG: %[[V:.*]] = constant dense<0> : tensor<2xi64> +// CHECK-DAG: %[[W:.*]] = constant dense<[0, 1]> : tensor<2xi64> +// CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor +// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor +// CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]]) // CHECK: return %[[T]] : !llvm.ptr func @sparse_new2d(%arg0: !llvm.ptr) -> tensor { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor @@ -68,17 +72,37 @@ // CHECK-LABEL: func @sparse_new3d( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[D:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8> -// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<3xi8> to tensor -// CHECK: %[[P:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64> -// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<3xi64> to tensor -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, tensor, i64, i64, i64) -> !llvm.ptr +// CHECK-DAG: %[[U:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8> +// CHECK-DAG: %[[V:.*]] = constant dense<0> : tensor<3xi64> +// CHECK-DAG: %[[W:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64> +// CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<3xi8> to tensor +// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<3xi64> to tensor +// CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<3xi64> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[A]]) // CHECK: return %[[T]] : !llvm.ptr func @sparse_new3d(%arg0: !llvm.ptr) -> tensor { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor return %0 : tensor } +// CHECK-LABEL: func @sparse_convert( +// CHECK-SAME: %[[A:.*]]: tensor<2x4xf64>) -> !llvm.ptr +// CHECK-DAG: %[[U:.*]] = constant dense<[0, 1]> : tensor<2xi8> +// CHECK-DAG: %[[V:.*]] = constant dense<[2, 4]> : tensor<2xi64> +// CHECK-DAG: %[[W:.*]] = constant dense<[0, 1]> : tensor<2xi64> +// CHECK: %[[C:.*]] = memref.buffer_cast %arg0 : memref<2x4xf64> +// CHECK: %[[M:.*]] = memref.cast %[[C]] : memref<2x4xf64> to memref<*xf64> +// CHECK: %[[C:.*]] = call @getPtrF64(%[[M]]) : (memref<*xf64>) -> !llvm.ptr +// CHECK-DAG: %[[X:.*]] = tensor.cast %[[U]] : tensor<2xi8> to tensor +// CHECK-DAG: %[[Y:.*]] = tensor.cast %[[V]] : tensor<2xi64> to tensor +// CHECK-DAG: %[[Z:.*]] = tensor.cast %[[W]] : tensor<2xi64> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[C]]) +// CHECK: return %[[T]] : !llvm.ptr +func @sparse_convert(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #SparseMatrix> { + %0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #SparseMatrix> + return %0 : tensor<2x4xf64, #SparseMatrix> +} + // CHECK-LABEL: func @sparse_pointers( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 0 : index diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_scale.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt %s \ +// RUN: --sparsification --sparse-tensor-conversion \ +// RUN: --convert-vector-to-scf --convert-scf-to-std \ +// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ +// RUN: --std-bufferize --finalizing-bufferize \ +// RUN: --convert-vector-to-llvm --convert-memref-to-llvm --convert-std-to-llvm | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#CSR = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }> + +#trait_scale = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = X(i,j) * 2" +} + +// +// Integration test that lowers a kernel annotated as sparse to actual sparse +// code, initializes a matching sparse storage scheme from a dense tensor, +// and runs the resulting code with the JIT compiler. +// +module { + // + // A kernel that scales a sparse matrix A by a factor of 2.0. + // + func @sparse_scale(%argx: tensor<8x8xf32, #CSR> + {linalg.inplaceable = true}) -> tensor<8x8xf32, #CSR> { + %c = constant 2.0 : f32 + %0 = linalg.generic #trait_scale + outs(%argx: tensor<8x8xf32, #CSR>) { + ^bb(%x: f32): + %1 = mulf %x, %c : f32 + linalg.yield %1 : f32 + } -> tensor<8x8xf32, #CSR> + return %0 : tensor<8x8xf32, #CSR> + } + + // + // Main driver that converts a dense tensor into a sparse tensor + // and then calls the sparse scaling kernel with the sparse tensor + // as input argument. + // + func @entry() { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + + // Initialize a dense tensor. + %0 = constant dense<[ + [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0, 6.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 7.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 8.0] + ]> : tensor<8x8xf32> + + // Convert dense tensor to sparse tensor and call sparse kernel. + %1 = sparse_tensor.convert %0 : tensor<8x8xf32> to tensor<8x8xf32, #CSR> + %2 = call @sparse_scale(%1) + : (tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> + + // Print the resulting compacted values for verification. + // + // CHECK: ( 2, 2, 2, 4, 6, 8, 2, 10, 2, 2, 12, 2, 14, 2, 2, 16 ) + // + %m = sparse_tensor.values %2 : tensor<8x8xf32, #CSR> to memref + %v = vector.transfer_read %m[%c0], %f0: memref, vector<16xf32> + vector.print %v : vector<16xf32> + + return + } +}