diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -351,8 +351,8 @@ //===----------------------------------------------------------------------===// // Small runtime support library for sparse tensors. //===----------------------------------------------------------------------===// -extern "C" MLIR_CRUNNERUTILS_EXPORT void *openTensorC(char *filename, - uint64_t *idata); +extern "C" MLIR_CRUNNERUTILS_EXPORT void * +openTensorC(char *filename, uint64_t *idata, uint64_t *perm); extern "C" MLIR_CRUNNERUTILS_EXPORT void readTensorItemC(void *tensor, uint64_t *idata, double *ddata); extern "C" MLIR_CRUNNERUTILS_EXPORT void closeTensor(void *tensor); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -54,6 +54,20 @@ } } +/// Returns integers of given width and values as a constant tensor. +/// We cast the static shape into a dynamic shape to ensure that the +/// method signature remains uniform accross different tensor dimensions. +static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width, + Location loc, ArrayRef values) { + Type etp = rewriter.getIntegerType(width); + unsigned sz = values.size(); + RankedTensorType tt1 = RankedTensorType::get({sz}, etp); + RankedTensorType tt2 = RankedTensorType::get({ShapedType::kDynamicSize}, etp); + auto elts = + rewriter.create(loc, DenseElementsAttr::get(tt1, values)); + return rewriter.create(loc, tt2, elts); +} + /// Returns function reference (first hit also inserts into module). static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, ValueRange operands) { @@ -117,22 +131,29 @@ return failure(); // User pointer. params.push_back(operands[0]); - // Sparsity annotations in tensor constant form. Note that we cast - // the static shape into a dynamic shape to ensure that the method - // signature remains uniform accross different tensor dimensions. + // Sparsity annotations in tensor constant form. SmallVector attrs; unsigned sz = enc.getDimLevelType().size(); for (unsigned i = 0; i < sz; i++) attrs.push_back( APInt(8, getDimLevelTypeEncoding(enc.getDimLevelType()[i]))); - Type etp = rewriter.getIntegerType(8); - RankedTensorType tt1 = RankedTensorType::get({sz}, etp); - RankedTensorType tt2 = - RankedTensorType::get({ShapedType::kDynamicSize}, etp); - auto elts = - rewriter.create(loc, DenseElementsAttr::get(tt1, attrs)); - params.push_back(rewriter.create(loc, tt2, elts)); - // Seconary and primary types encoding. + params.push_back(getTensor(rewriter, 8, loc, attrs)); + // Dimension order permutation array. This is the "identity" + // permutation by default, or otherwise the "reverse" permutation + // of a given ordering, so that indices can be mapped quickly + // to the right position. + SmallVector perm(sz); + AffineMap p = enc.getDimOrdering(); + if (p) { + assert(p.isPermutation() && p.getNumResults() == sz); + for (unsigned i = 0; i < sz; i++) + perm[p.getDimPosition(i)] = APInt(64, i); + } else { + for (unsigned i = 0; i < sz; i++) + perm[i] = APInt(64, i); + } + params.push_back(getTensor(rewriter, 64, loc, perm)); + // Secondary and primary types encoding. unsigned secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth()); unsigned secInd = getOverheadTypeEncoding(enc.getIndexBitWidth()); unsigned primary; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -333,6 +333,18 @@ } // namespace +// Helper method to apply dimension ordering permutation. +static unsigned perm(SparseTensorEncodingAttr &enc, unsigned d) { + if (enc) { + auto order = enc.getDimOrdering(); + if (order) { + assert(order.isPermutation()); + return order.getDimPosition(d); + } + } + return d; +} + // Helper method to translate dim level type to internal representation. static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) { if (enc) { @@ -353,17 +365,17 @@ unsigned lhs = numTensors - 1; for (unsigned t = 0; t < numTensors; t++) { auto map = op.getIndexingMap(t); - unsigned rank = op.getShapedType(t).getRank(); + if (!map.isProjectedPermutation()) + return false; auto enc = getSparseTensorEncoding(op.getShapedType(t)); if (enc) { annotated = true; - if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity()) - return false; // TODO: handle permutations if (t == lhs) return false; // TODO: handle sparse outputs } - for (unsigned d = 0; d < rank; d++) { - unsigned idx = map.getDimPosition(d); + assert(map.getNumResults() == op.getShapedType(t).getRank()); + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + unsigned idx = map.getDimPosition(perm(enc, d)); merger.setDim(t, idx, toDim(enc, d)); } } @@ -405,18 +417,18 @@ unsigned numTensors = op.getNumShapedOperands(); for (unsigned t = 0; t < numTensors; t++) { auto map = op.getIndexingMap(t); + auto enc = getSparseTensorEncoding(op.getShapedType(t)); assert(map.getNumDims() == n); // Skip dense tensor constraints when sparse only is requested. - if (sparseOnly && !getSparseTensorEncoding(op.getShapedType(t))) + if (sparseOnly && !enc) continue; - // At the moment, we take the index variables in the tensor access - // expression in the order in which they appear (conceptually a - // "row-major" layout of every tensor). So, a tensor access A_ijk - // forces the ordering i < j < k on the loop indices. - // TODO: support affine map to define alternative dimension orders. - for (unsigned d = 1, e = map.getNumResults(); d < e; d++) { - unsigned f = map.getDimPosition(d - 1); - unsigned t = map.getDimPosition(d); + // Each tensor expression and optional dimension ordering (row-major + // by default) puts an ordering constraint on the loop indices. For + // example, the tensor expresion A_ijk forces the ordering i < j < k + // on the loop indices if no explicit dimension ordering is given. + for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { + unsigned f = map.getDimPosition(perm(enc, d - 1)); + unsigned t = map.getDimPosition(perm(enc, d)); adjM[f][t] = true; } } @@ -441,15 +453,10 @@ Value val) { if (auto arg = val.dyn_cast()) { unsigned argN = arg.getArgNumber(); - if (arg.getOwner()->getParentOp() == op) { - // Any parameter of the generic op is considered a tensor, - // indexed by the implicit loop bounds. - auto map = op.getIndexingMap(argN); - if (map.isProjectedPermutation()) - return merger.addExp(Kind::kTensor, argN); - // Cannot handle (yet). - return None; - } + // Any parameter of the generic op is considered a tensor, + // indexed by the implicit loop bounds. + if (arg.getOwner()->getParentOp() == op) + return merger.addExp(Kind::kTensor, argN); // Any parameter of a higher op is invariant. return merger.addExp(Kind::kInvariant, val); } @@ -568,10 +575,10 @@ auto enc = getSparseTensorEncoding(tensorType); // Scan all dimensions of current tensor. args.clear(); - for (unsigned d = 0, rank = shape.size(); d < rank; d++) { - unsigned i = map.getDimPosition(d); + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + unsigned idx = map.getDimPosition(perm(enc, d)); // Handle sparse storage schemes. - if (merger.isDim(t, i, Dim::kSparse)) { + if (merger.isDim(t, idx, Dim::kSparse)) { auto dynShape = {ShapedType::kDynamicSize}; auto ptrTp = MemRefType::get( dynShape, genIntType(rewriter, enc.getPointerBitWidth())); @@ -579,9 +586,9 @@ dynShape, genIntType(rewriter, enc.getIndexBitWidth())); Value dim = rewriter.create(loc, d); // Generate sparse primitives to obtains pointer and indices. - codegen.pointers[t][i] = + codegen.pointers[t][idx] = rewriter.create(loc, ptrTp, tensor, dim); - codegen.indices[t][i] = + codegen.indices[t][idx] = rewriter.create(loc, indTp, tensor, dim); } // Find lower and upper bound in current dimension. @@ -592,7 +599,7 @@ } else { up = rewriter.create(loc, shape[d]); } - codegen.sizes[i] = codegen.highs[t][i] = up; + codegen.sizes[idx] = codegen.highs[t][idx] = up; } // Perform the required bufferization. All dense inputs materialize // from the input tensor. The dense output tensor needs special @@ -705,8 +712,8 @@ unsigned tensor = merger.exp(exp).e0; auto map = op.getIndexingMap(tensor); auto enc = getSparseTensorEncoding(op.getShapedType(tensor)); - for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { - unsigned idx = map.getDimPosition(i); + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + unsigned idx = map.getDimPosition(perm(enc, d)); args.push_back(codegen.loops[idx]); // universal dense index if (enc) { args.clear(); @@ -737,8 +744,9 @@ // Actual store. SmallVector args; auto map = op.getIndexingMap(tensor); - for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { - unsigned idx = map.getDimPosition(i); + assert(!getSparseTensorEncoding(op.getShapedType(tensor))); + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + unsigned idx = map.getDimPosition(d); args.push_back(codegen.loops[idx]); // universal dense index } Value ptr = codegen.buffers[tensor]; @@ -888,8 +896,9 @@ bool atLevel = ldx == -1u; unsigned tensor = merger.exp(exp).e0; auto map = op.getIndexingMap(tensor); - for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { - unsigned idx = map.getDimPosition(i); + auto enc = getSparseTensorEncoding(op.getShapedType(tensor)); + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + unsigned idx = map.getDimPosition(perm(enc, d)); if (!codegen.loops[idx]) return; // still in play else if (idx == ldx) @@ -1001,9 +1010,8 @@ for (unsigned t = 0; t < numTensors; t++) { if (!getSparseTensorEncoding(op.getShapedType(t))) { auto map = op.getIndexingMap(t); - unsigned r = map.getNumResults(); - for (unsigned i = 0; i < r; i++) { - if (map.getDimPosition(i) == idx && i != r - 1) + for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { + if (map.getDimPosition(d) == idx && d != rank - 1) return false; } } diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -243,9 +243,11 @@ /// Templated reader. template -void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t size) { +void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t *perm, + uint64_t size) { uint64_t idata[64]; - SparseTensor *t = static_cast(openTensorC(filename, idata)); + SparseTensor *t = + static_cast(openTensorC(filename, idata, perm)); assert(size == t->getRank()); // sparsity array must match rank SparseTensorStorageBase *tensor = new SparseTensorStorage(t, sparsity); @@ -371,7 +373,7 @@ /// understood by other methods in the sparse runtime support library. An /// array parameter is used to pass the rank, the number of nonzero elements, /// and the dimension sizes (one per rank). -void *openTensorC(char *filename, uint64_t *idata) { +void *openTensorC(char *filename, uint64_t *idata, uint64_t *perm) { // Open the file. FILE *file = fopen(filename, "r"); if (!file) { @@ -393,16 +395,24 @@ uint64_t nnz = idata[1]; std::vector indices(rank); for (uint64_t r = 0; r < rank; r++) - indices[r] = idata[2 + r]; + if (perm) + indices[perm[r]] = idata[2 + r]; + else + indices[r] = idata[2 + r]; SparseTensor *tensor = new SparseTensor(indices, nnz); // Read all nonzero elements. for (uint64_t k = 0; k < nnz; k++) { + uint64_t idx = -1; for (uint64_t r = 0; r < rank; r++) { - if (fscanf(file, "%" PRIu64, &indices[r]) != 1) { + if (fscanf(file, "%" PRIu64, &idx) != 1) { fprintf(stderr, "Cannot find next index in %s\n", filename); exit(1); } - indices[r]--; // 0-based index + // Add 0-based index. + if (perm) + indices[perm[r]] = idx - 1; + else + indices[r] = idx - 1; } double value; if (fscanf(file, "%lg\n", &value) != 1) { @@ -421,7 +431,7 @@ void *openTensor(char *filename, uint64_t *ibase, uint64_t *idata, uint64_t ioff, uint64_t isize, uint64_t istride) { assert(istride == 1); - return openTensorC(filename, idata + ioff); + return openTensorC(filename, idata + ioff, nullptr); } /// Yields the next element from the given opaque sparse tensor object. @@ -477,7 +487,7 @@ #define CASE(p, i, v, P, I, V) \ if (ptrTp == (p) && indTp == (i) && valTp == (v)) \ - return newSparseTensor(filename, sparsity, asize) + return newSparseTensor(filename, sparsity, perm, asize) #define IMPL1(RET, NAME, TYPE, LIB) \ RET NAME(void *tensor) { \ @@ -515,9 +525,12 @@ void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata, uint64_t aoff, uint64_t asize, uint64_t astride, - uint64_t ptrTp, uint64_t indTp, uint64_t valTp) { - assert(astride == 1); + uint64_t *pbase, uint64_t *pdata, uint64_t poff, + uint64_t psize, uint64_t pstride, uint64_t ptrTp, + uint64_t indTp, uint64_t valTp) { + assert(astride == 1 && pstride == 1); uint8_t *sparsity = adata + aoff; + uint64_t *perm = pdata + poff; // The most common cases: 64-bit or 32-bit overhead, double/float values. CASE(kU64, kU64, kF64, uint64_t, uint64_t, double); diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -20,6 +20,11 @@ dimLevelType = ["dense", "compressed"] }> +#SparseTensor = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed", "compressed"], + dimOrdering = affine_map<(i,j,k) -> (k,i,j)> +}> + // CHECK-LABEL: func @sparse_dim( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 0 : index @@ -35,7 +40,9 @@ // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK: %[[D:.*]] = constant dense<1> : tensor<1xi8> // CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<1xi8> to tensor -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, i64, i64, i64) -> !llvm.ptr +// CHECK: %[[P:.*]] = constant dense<0> : tensor<1xi64> +// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<1xi64> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, tensor, i64, i64, i64) -> !llvm.ptr // CHECK: return %[[T]] : !llvm.ptr func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor<128xf64, #SparseVector> @@ -46,13 +53,28 @@ // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK: %[[D:.*]] = constant dense<[0, 1]> : tensor<2xi8> // CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<2xi8> to tensor -// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, i64, i64, i64) -> !llvm.ptr +// CHECK: %[[P:.*]] = constant dense<[0, 1]> : tensor<2xi64> +// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<2xi64> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, tensor, i64, i64, i64) -> !llvm.ptr // CHECK: return %[[T]] : !llvm.ptr func @sparse_new2d(%arg0: !llvm.ptr) -> tensor { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor return %0 : tensor } +// CHECK-LABEL: func @sparse_new3d( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[D:.*]] = constant dense<[0, 1, 1]> : tensor<3xi8> +// CHECK: %[[C:.*]] = tensor.cast %[[D]] : tensor<3xi8> to tensor +// CHECK: %[[P:.*]] = constant dense<[1, 2, 0]> : tensor<3xi64> +// CHECK: %[[Q:.*]] = tensor.cast %[[P]] : tensor<3xi64> to tensor +// CHECK: %[[T:.*]] = call @newSparseTensor(%[[A]], %[[C]], %[[Q]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, tensor, tensor, i64, i64, i64) -> !llvm.ptr +// CHECK: return %[[T]] : !llvm.ptr +func @sparse_new3d(%arg0: !llvm.ptr) -> tensor { + %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor + return %0 : tensor +} + // CHECK-LABEL: func @sparse_pointers( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK: %[[C:.*]] = constant 0 : index diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_lower.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_lower.mlir @@ -21,60 +21,60 @@ } // CHECK-HIR-LABEL: func @matvec( -// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>, // CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, -// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<64xf64>) -> tensor<64xf64> { -// CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> { +// CHECK-HIR: %[[VAL_3:.*]] = constant 32 : index // CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index -// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-HIR: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> -// CHECK-HIR: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64> -// CHECK-HIR: %[[VAL_12:.*]] = memref.alloc() : memref<64xf64> -// CHECK-HIR: linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<64xf64>, memref<64xf64> -// CHECK-HIR: scf.for %[[VAL_13:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK-HIR: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref -// CHECK-HIR: %[[VAL_15:.*]] = addi %[[VAL_13]], %[[VAL_5]] : index -// CHECK-HIR: %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref -// CHECK-HIR: %[[VAL_17:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<64xf64> -// CHECK-HIR: %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_17]]) -> (f64) { -// CHECK-HIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref -// CHECK-HIR: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref -// CHECK-HIR: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<64xf64> -// CHECK-HIR: %[[VAL_24:.*]] = mulf %[[VAL_22]], %[[VAL_23]] : f64 -// CHECK-HIR: %[[VAL_25:.*]] = addf %[[VAL_20]], %[[VAL_24]] : f64 -// CHECK-HIR: scf.yield %[[VAL_25]] : f64 +// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> +// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> +// CHECK-HIR: %[[VAL_11:.*]] = memref.alloc() : memref<32xf64> +// CHECK-HIR: linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32xf64>, memref<32xf64> +// CHECK-HIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK-HIR: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// CHECK-HIR: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index +// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref +// CHECK-HIR: %[[VAL_16:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64> +// CHECK-HIR: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) { +// CHECK-HIR: %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref +// CHECK-HIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK-HIR: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref<64xf64> +// CHECK-HIR: %[[VAL_23:.*]] = mulf %[[VAL_21]], %[[VAL_22]] : f64 +// CHECK-HIR: %[[VAL_24:.*]] = addf %[[VAL_19]], %[[VAL_23]] : f64 +// CHECK-HIR: scf.yield %[[VAL_24]] : f64 // CHECK-HIR: } -// CHECK-HIR: store %[[VAL_26:.*]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<64xf64> +// CHECK-HIR: memref.store %[[VAL_25:.*]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64> // CHECK-HIR: } -// CHECK-HIR: %[[VAL_27:.*]] = memref.tensor_load %[[VAL_12]] : memref<64xf64> -// CHECK-HIR: return %[[VAL_27]] : tensor<64xf64> +// CHECK-HIR: %[[VAL_26:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64> +// CHECK-HIR: return %[[VAL_26]] : tensor<32xf64> // CHECK-HIR: } // CHECK-MIR-LABEL: func @matvec( // CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, -// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<64xf64>) -> tensor<64xf64> { -// CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> { +// CHECK-MIR: %[[VAL_3:.*]] = constant 32 : index // CHECK-MIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-MIR: %[[VAL_5:.*]] = constant 1 : index // CHECK-MIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref // CHECK-MIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref // CHECK-MIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref // CHECK-MIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> -// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64> -// CHECK-MIR: %[[VAL_11:.*]] = memref.alloc() : memref<64xf64> +// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> +// CHECK-MIR: %[[VAL_11:.*]] = memref.alloc() : memref<32xf64> // CHECK-MIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK-MIR: %[[VAL_13:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<64xf64> -// CHECK-MIR: store %[[VAL_13]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<64xf64> +// CHECK-MIR: %[[VAL_13:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_12]]] : memref<32xf64> +// CHECK-MIR: memref.store %[[VAL_13]], %[[VAL_11]]{{\[}}%[[VAL_12]]] : memref<32xf64> // CHECK-MIR: } // CHECK-MIR: scf.for %[[VAL_14:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { // CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref // CHECK-MIR: %[[VAL_16:.*]] = addi %[[VAL_14]], %[[VAL_5]] : index // CHECK-MIR: %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_16]]] : memref -// CHECK-MIR: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<64xf64> +// CHECK-MIR: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64> // CHECK-MIR: %[[VAL_19:.*]] = scf.for %[[VAL_20:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_5]] iter_args(%[[VAL_21:.*]] = %[[VAL_18]]) -> (f64) { // CHECK-MIR: %[[VAL_22:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref // CHECK-MIR: %[[VAL_23:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref @@ -83,32 +83,32 @@ // CHECK-MIR: %[[VAL_26:.*]] = addf %[[VAL_21]], %[[VAL_25]] : f64 // CHECK-MIR: scf.yield %[[VAL_26]] : f64 // CHECK-MIR: } -// CHECK-MIR: store %[[VAL_27:.*]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<64xf64> +// CHECK-MIR: memref.store %[[VAL_27:.*]], %[[VAL_11]]{{\[}}%[[VAL_14]]] : memref<32xf64> // CHECK-MIR: } -// CHECK-MIR: %[[VAL_28:.*]] = memref.tensor_load %[[VAL_11]] : memref<64xf64> -// CHECK-MIR: return %[[VAL_28]] : tensor<64xf64> +// CHECK-MIR: %[[VAL_28:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64> +// CHECK-MIR: return %[[VAL_28]] : tensor<32xf64> // CHECK-MIR: } // CHECK-LIR-LABEL: func @matvec( // CHECK-LIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-LIR-SAME: %[[VAL_1:.*]]: memref<64xf64>, -// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<64xf64>) -> memref<64xf64> { -// CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> { +// CHECK-LIR: %[[VAL_3:.*]] = constant 32 : index // CHECK-LIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-LIR: %[[VAL_5:.*]] = constant 1 : index // CHECK-LIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref // CHECK-LIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref // CHECK-LIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref -// CHECK-LIR: %[[VAL_9:.*]] = memref.alloc() : memref<64xf64> +// CHECK-LIR: %[[VAL_9:.*]] = memref.alloc() : memref<32xf64> // CHECK-LIR: scf.for %[[VAL_10:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { -// CHECK-LIR: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<64xf64> -// CHECK-LIR: store %[[VAL_11]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<64xf64> +// CHECK-LIR: %[[VAL_11:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_10]]] : memref<32xf64> +// CHECK-LIR: memref.store %[[VAL_11]], %[[VAL_9]]{{\[}}%[[VAL_10]]] : memref<32xf64> // CHECK-LIR: } // CHECK-LIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { // CHECK-LIR: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref // CHECK-LIR: %[[VAL_14:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index // CHECK-LIR: %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_14]]] : memref -// CHECK-LIR: %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64> +// CHECK-LIR: %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64> // CHECK-LIR: %[[VAL_17:.*]] = scf.for %[[VAL_18:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] iter_args(%[[VAL_19:.*]] = %[[VAL_16]]) -> (f64) { // CHECK-LIR: %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref // CHECK-LIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref @@ -117,21 +117,21 @@ // CHECK-LIR: %[[VAL_24:.*]] = addf %[[VAL_19]], %[[VAL_23]] : f64 // CHECK-LIR: scf.yield %[[VAL_24]] : f64 // CHECK-LIR: } -// CHECK-LIR: store %[[VAL_25:.*]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64> +// CHECK-LIR: memref.store %[[VAL_25:.*]], %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<32xf64> // CHECK-LIR: } -// CHECK-LIR: return %[[VAL_9]] : memref<64xf64> +// CHECK-LIR: return %[[VAL_9]] : memref<32xf64> // CHECK-LIR: } -func @matvec(%arga: tensor<64x64xf64, #CSR>, +func @matvec(%arga: tensor<32x64xf64, #CSR>, %argb: tensor<64xf64>, - %argx: tensor<64xf64>) -> tensor<64xf64> { + %argx: tensor<32xf64>) -> tensor<32xf64> { %0 = linalg.generic #trait_matvec - ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>) - outs(%argx: tensor<64xf64>) { + ins(%arga, %argb : tensor<32x64xf64, #CSR>, tensor<64xf64>) + outs(%argx: tensor<32xf64>) { ^bb(%A: f64, %b: f64, %x: f64): %0 = mulf %A, %b : f64 %1 = addf %x, %0 : f64 linalg.yield %1 : f64 - } -> tensor<64xf64> - return %0 : tensor<64xf64> + } -> tensor<32xf64> + return %0 : tensor<32xf64> } diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_lower_col.mlir @@ -0,0 +1,139 @@ +// RUN: mlir-opt %s -sparsification | FileCheck %s --check-prefix=CHECK-HIR +// +// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion \ +// RUN: --convert-linalg-to-loops | FileCheck %s --check-prefix=CHECK-MIR +// +// RUN: mlir-opt %s -sparsification --sparse-tensor-conversion \ +// RUN: --convert-linalg-to-loops --func-bufferize --tensor-constant-bufferize \ +// RUN: --tensor-bufferize --finalizing-bufferize | \ +// RUN: FileCheck %s --check-prefix=CHECK-LIR + +#CSC = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +#trait_matvec = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (j)>, // b + affine_map<(i,j) -> (i)> // x (out) + ], + iterator_types = ["parallel","reduction"], + doc = "x(i) += A(i,j) * b(j)" +} + +// CHECK-HIR-LABEL: func @matvec( +// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>>, +// CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, +// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> { +// CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index +// CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index +// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)>, pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> +// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> +// CHECK-HIR: %[[VAL_11:.*]] = memref.alloc() : memref<32xf64> +// CHECK-HIR: linalg.copy(%[[VAL_10]], %[[VAL_11]]) : memref<32xf64>, memref<32xf64> +// CHECK-HIR: scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { +// CHECK-HIR: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref<64xf64> +// CHECK-HIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// CHECK-HIR: %[[VAL_15:.*]] = addi %[[VAL_12]], %[[VAL_5]] : index +// CHECK-HIR: %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref +// CHECK-HIR: scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_5]] { +// CHECK-HIR: %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref +// CHECK-HIR: %[[VAL_19:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<32xf64> +// CHECK-HIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK-HIR: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_13]] : f64 +// CHECK-HIR: %[[VAL_22:.*]] = addf %[[VAL_19]], %[[VAL_21]] : f64 +// CHECK-HIR: memref.store %[[VAL_22]], %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<32xf64> +// CHECK-HIR: } +// CHECK-HIR: } +// CHECK-HIR: %[[VAL_23:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64> +// CHECK-HIR: return %[[VAL_23]] : tensor<32xf64> +// CHECK-HIR: } + +// CHECK-MIR-LABEL: func @matvec( +// CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, +// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<32xf64>) -> tensor<32xf64> { +// CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-MIR: %[[VAL_4:.*]] = constant 32 : index +// CHECK-MIR: %[[VAL_5:.*]] = constant 0 : index +// CHECK-MIR: %[[VAL_6:.*]] = constant 1 : index +// CHECK-MIR: %[[VAL_7:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK-MIR: %[[VAL_8:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK-MIR: %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> +// CHECK-MIR: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> +// CHECK-MIR: %[[VAL_12:.*]] = memref.alloc() : memref<32xf64> +// CHECK-MIR: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK-MIR: %[[VAL_14:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_13]]] : memref<32xf64> +// CHECK-MIR: memref.store %[[VAL_14]], %[[VAL_12]]{{\[}}%[[VAL_13]]] : memref<32xf64> +// CHECK-MIR: } +// CHECK-MIR: scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK-MIR: %[[VAL_16:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<64xf64> +// CHECK-MIR: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK-MIR: %[[VAL_18:.*]] = addi %[[VAL_15]], %[[VAL_6]] : index +// CHECK-MIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref +// CHECK-MIR: scf.for %[[VAL_20:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_6]] { +// CHECK-MIR: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref +// CHECK-MIR: %[[VAL_22:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<32xf64> +// CHECK-MIR: %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]]] : memref +// CHECK-MIR: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_16]] : f64 +// CHECK-MIR: %[[VAL_25:.*]] = addf %[[VAL_22]], %[[VAL_24]] : f64 +// CHECK-MIR: memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_21]]] : memref<32xf64> +// CHECK-MIR: } +// CHECK-MIR: } +// CHECK-MIR: %[[VAL_26:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64> +// CHECK-MIR: return %[[VAL_26]] : tensor<32xf64> +// CHECK-MIR: } + +// CHECK-LIR-LABEL: func @matvec( +// CHECK-LIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-LIR-SAME: %[[VAL_1:.*]]: memref<64xf64>, +// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<32xf64>) -> memref<32xf64> { +// CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-LIR: %[[VAL_4:.*]] = constant 32 : index +// CHECK-LIR: %[[VAL_5:.*]] = constant 0 : index +// CHECK-LIR: %[[VAL_6:.*]] = constant 1 : index +// CHECK-LIR: %[[VAL_7:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK-LIR: %[[VAL_8:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_6]]) : (!llvm.ptr, index) -> memref +// CHECK-LIR: %[[VAL_9:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref +// CHECK-LIR: %[[VAL_10:.*]] = memref.alloc() : memref<32xf64> +// CHECK-LIR: scf.for %[[VAL_11:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] { +// CHECK-LIR: %[[VAL_12:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_11]]] : memref<32xf64> +// CHECK-LIR: memref.store %[[VAL_12]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64> +// CHECK-LIR: } +// CHECK-LIR: scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] { +// CHECK-LIR: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_13]]] : memref<64xf64> +// CHECK-LIR: %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref +// CHECK-LIR: %[[VAL_16:.*]] = addi %[[VAL_13]], %[[VAL_6]] : index +// CHECK-LIR: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref +// CHECK-LIR: scf.for %[[VAL_18:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_6]] { +// CHECK-LIR: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK-LIR: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64> +// CHECK-LIR: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK-LIR: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_14]] : f64 +// CHECK-LIR: %[[VAL_23:.*]] = addf %[[VAL_20]], %[[VAL_22]] : f64 +// CHECK-LIR: memref.store %[[VAL_23]], %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64> +// CHECK-LIR: } +// CHECK-LIR: } +// CHECK-LIR: return %[[VAL_10]] : memref<32xf64> +// CHECK-LIR: } + +func @matvec(%arga: tensor<32x64xf64, #CSC>, + %argb: tensor<64xf64>, + %argx: tensor<32xf64>) -> tensor<32xf64> { + %0 = linalg.generic #trait_matvec + ins(%arga, %argb : tensor<32x64xf64, #CSC>, tensor<64xf64>) + outs(%argx: tensor<32xf64>) { + ^bb(%A: f64, %b: f64, %x: f64): + %0 = mulf %A, %b : f64 + %1 = addf %x, %0 : f64 + linalg.yield %1 : f64 + } -> tensor<32xf64> + return %0 : tensor<32xf64> +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_lower_inplace.mlir @@ -21,22 +21,22 @@ } // CHECK-HIR-LABEL: func @matvec( -// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-HIR-SAME: %[[VAL_0:.*]]: tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>>, // CHECK-HIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, -// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> { -// CHECK-HIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-HIR-SAME: %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { +// CHECK-HIR: %[[VAL_3:.*]] = constant 32 : index // CHECK-HIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-HIR: %[[VAL_5:.*]] = constant 1 : index -// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref -// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<64x64xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-HIR: %[[VAL_6:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_7:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK-HIR: %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x64xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref // CHECK-HIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> -// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64> +// CHECK-HIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> // CHECK-HIR: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { // CHECK-HIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref // CHECK-HIR: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index // CHECK-HIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref -// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-HIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64> // CHECK-HIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) { // CHECK-HIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref // CHECK-HIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref @@ -45,29 +45,29 @@ // CHECK-HIR: %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64 // CHECK-HIR: scf.yield %[[VAL_23]] : f64 // CHECK-HIR: } -// CHECK-HIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-HIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64> // CHECK-HIR: } -// CHECK-HIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64> -// CHECK-HIR: return %[[VAL_25]] : tensor<64xf64> +// CHECK-HIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64> +// CHECK-HIR: return %[[VAL_25]] : tensor<32xf64> // CHECK-HIR: } // CHECK-MIR-LABEL: func @matvec( // CHECK-MIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-MIR-SAME: %[[VAL_1:.*]]: tensor<64xf64>, -// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> { -// CHECK-MIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-MIR-SAME: %[[VAL_2:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { +// CHECK-MIR: %[[VAL_3:.*]] = constant 32 : index // CHECK-MIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-MIR: %[[VAL_5:.*]] = constant 1 : index // CHECK-MIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref // CHECK-MIR: %[[VAL_7:.*]] = call @sparseIndices(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref // CHECK-MIR: %[[VAL_8:.*]] = call @sparseValuesF64(%[[VAL_0]]) : (!llvm.ptr) -> memref // CHECK-MIR: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_1]] : memref<64xf64> -// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<64xf64> +// CHECK-MIR: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> // CHECK-MIR: scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] { // CHECK-MIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref // CHECK-MIR: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_5]] : index // CHECK-MIR: %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref -// CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-MIR: %[[VAL_15:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64> // CHECK-MIR: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_5]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) { // CHECK-MIR: %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref // CHECK-MIR: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref @@ -76,17 +76,17 @@ // CHECK-MIR: %[[VAL_23:.*]] = addf %[[VAL_18]], %[[VAL_22]] : f64 // CHECK-MIR: scf.yield %[[VAL_23]] : f64 // CHECK-MIR: } -// CHECK-MIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<64xf64> +// CHECK-MIR: memref.store %[[VAL_24:.*]], %[[VAL_10]]{{\[}}%[[VAL_11]]] : memref<32xf64> // CHECK-MIR: } -// CHECK-MIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<64xf64> -// CHECK-MIR: return %[[VAL_25]] : tensor<64xf64> +// CHECK-MIR: %[[VAL_25:.*]] = memref.tensor_load %[[VAL_10]] : memref<32xf64> +// CHECK-MIR: return %[[VAL_25]] : tensor<32xf64> // CHECK-MIR: } // CHECK-LIR-LABEL: func @matvec( // CHECK-LIR-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-LIR-SAME: %[[VAL_1:.*]]: memref<64xf64>, -// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<64xf64> {linalg.inplaceable = true}) -> memref<64xf64> { -// CHECK-LIR: %[[VAL_3:.*]] = constant 64 : index +// CHECK-LIR-SAME: %[[VAL_2:.*]]: memref<32xf64> {linalg.inplaceable = true}) -> memref<32xf64> { +// CHECK-LIR: %[[VAL_3:.*]] = constant 32 : index // CHECK-LIR: %[[VAL_4:.*]] = constant 0 : index // CHECK-LIR: %[[VAL_5:.*]] = constant 1 : index // CHECK-LIR: %[[VAL_6:.*]] = call @sparsePointers(%[[VAL_0]], %[[VAL_5]]) : (!llvm.ptr, index) -> memref @@ -96,7 +96,7 @@ // CHECK-LIR: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_9]]] : memref // CHECK-LIR: %[[VAL_11:.*]] = addi %[[VAL_9]], %[[VAL_5]] : index // CHECK-LIR: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref -// CHECK-LIR: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64> +// CHECK-LIR: %[[VAL_13:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64> // CHECK-LIR: %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_10]] to %[[VAL_12]] step %[[VAL_5]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f64) { // CHECK-LIR: %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref // CHECK-LIR: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref @@ -105,21 +105,21 @@ // CHECK-LIR: %[[VAL_21:.*]] = addf %[[VAL_16]], %[[VAL_20]] : f64 // CHECK-LIR: scf.yield %[[VAL_21]] : f64 // CHECK-LIR: } -// CHECK-LIR: memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<64xf64> +// CHECK-LIR: memref.store %[[VAL_22:.*]], %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<32xf64> // CHECK-LIR: } -// CHECK-LIR: return %[[VAL_2]] : memref<64xf64> +// CHECK-LIR: return %[[VAL_2]] : memref<32xf64> // CHECK-LIR: } -func @matvec(%arga: tensor<64x64xf64, #CSR>, +func @matvec(%arga: tensor<32x64xf64, #CSR>, %argb: tensor<64xf64>, - %argx: tensor<64xf64> {linalg.inplaceable = true}) -> tensor<64xf64> { + %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { %0 = linalg.generic #trait_matvec - ins(%arga, %argb : tensor<64x64xf64, #CSR>, tensor<64xf64>) - outs(%argx: tensor<64xf64>) { + ins(%arga, %argb : tensor<32x64xf64, #CSR>, tensor<64xf64>) + outs(%argx: tensor<32xf64>) { ^bb(%A: f64, %b: f64, %x: f64): %0 = mulf %A, %b : f64 %1 = addf %x, %0 : f64 linalg.yield %1 : f64 - } -> tensor<64xf64> - return %0 : tensor<64xf64> + } -> tensor<32xf64> + return %0 : tensor<32xf64> } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_flatten.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s \ +// RUN: --sparsification --sparse-tensor-conversion \ +// RUN: --convert-linalg-to-loops --convert-vector-to-scf --convert-scf-to-std \ +// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \ +// RUN: --std-bufferize --finalizing-bufferize \ +// RUN: --convert-vector-to-llvm --convert-std-to-llvm | \ +// RUN: TENSOR0="%mlir_integration_test_dir/data/test.tns" \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +!Filename = type !llvm.ptr + +#SparseTensor = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed", "compressed", "compressed", + "compressed", "compressed", "compressed", "compressed" ], + // Note that any dimOrdering permutation should give the same results + // since, even though it impacts the sparse storage scheme layout, + // it should not change the semantics. + dimOrdering = affine_map<(i,j,k,l,m,n,o,p) -> (p,o,j,k,i,l,m,n)> +}> + +#trait_flatten = { + indexing_maps = [ + affine_map<(i,j,k,l,m,n,o,p) -> (i,j,k,l,m,n,o,p)>, // A + affine_map<(i,j,k,l,m,n,o,p) -> (i,j)> // X (out) + ], + iterator_types = [ "parallel", "parallel", "reduction", "reduction", + "reduction", "reduction", "reduction", "reduction" ], + doc = "X(i,j) += A(i,j,k,l,m,n,o,p)" +} + +// +// Integration test that lowers a kernel annotated as sparse to +// actual sparse code, initializes a matching sparse storage scheme +// from file, and runs the resulting code with the JIT compiler. +// +module { + // + // A kernel that flattens a rank 8 tensor into a dense matrix. + // + func @kernel_flatten(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>, + %argx: tensor<7x3xf64>) -> tensor<7x3xf64> { + %0 = linalg.generic #trait_flatten + ins(%arga: tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>) + outs(%argx: tensor<7x3xf64>) { + ^bb(%a: f64, %x: f64): + %0 = addf %x, %a : f64 + linalg.yield %0 : f64 + } -> tensor<7x3xf64> + return %0 : tensor<7x3xf64> + } + + func private @getTensorFilename(index) -> (!Filename) + + // + // Main driver that reads tensor from file and calls the sparse kernel. + // + func @entry() { + %d0 = constant 0.0 : f64 + %c0 = constant 0 : index + %c1 = constant 1 : index + %c3 = constant 3 : index + %c7 = constant 7 : index + + // Setup matrix memory that is initialized to zero. + %xdata = memref.alloc() : memref<7x3xf64> + scf.for %i = %c0 to %c7 step %c1 { + scf.for %j = %c0 to %c3 step %c1 { + memref.store %d0, %xdata[%i, %j] : memref<7x3xf64> + } + } + %x = memref.tensor_load %xdata : memref<7x3xf64> + + // Read the sparse tensor from file, construct sparse storage. + %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) + %a = sparse_tensor.new %fileName : !llvm.ptr to tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor> + + // Call the kernel. + %0 = call @kernel_flatten(%a, %x) + : (tensor<7x3x3x3x3x3x5x3xf64, #SparseTensor>, tensor<7x3xf64>) -> tensor<7x3xf64> + + // Print the result for verification. + // + // CHECK: ( 6.25, 0, 0 ) + // CHECK: ( 4.224, 6.21, 0 ) + // CHECK: ( 0, 0, 15.455 ) + // CHECK: ( 0, 0, 0 ) + // CHECK: ( 0, 0, 0 ) + // CHECK: ( 0, 0, 0 ) + // CHECK: ( 7, 0, 0 ) + // + %r = memref.buffer_cast %0 : memref<7x3xf64> + scf.for %i = %c0 to %c7 step %c1 { + %v = vector.transfer_read %r[%i, %c0], %d0: memref<7x3xf64>, vector<3xf64> + vector.print %v : vector<3xf64> + } + + // Release the resources. + memref.dealloc %xdata : memref<7x3xf64> + + return + } +}