diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -55,11 +55,13 @@ kUnary, // semiring unary op // Binary operations. kMulF, + kMulC, kMulI, kDivF, kDivS, // signed kDivU, // unsigned kAddF, + kAddC, kAddI, kSubF, kSubI, diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h --- a/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h +++ b/mlir/include/mlir/ExecutionEngine/SparseTensorUtils.h @@ -42,7 +42,9 @@ kI64 = 3, kI32 = 4, kI16 = 5, - kI8 = 6 + kI8 = 6, + kC64 = 7, + kC32 = 8 }; /// The actions performed by @newSparseTensor. diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt @@ -8,6 +8,8 @@ MLIRArithmeticTransforms MLIRAffineToStandard MLIRBufferizationTransforms + MLIRComplexToLLVM + MLIRComplexToStandard MLIRFuncTransforms MLIRLinalgTransforms MLIRMathToLibm diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -48,7 +48,9 @@ pm.addPass(createLowerAffinePass()); pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions())); pm.addPass(createMemRefToLLVMPass()); + pm.addNestedPass(createConvertComplexToStandardPass()); pm.addNestedPass(createConvertMathToLLVMPass()); + pm.addPass(createConvertComplexToLLVMPass()); pm.addPass(createConvertMathToLibmPass()); pm.addPass(createConvertFuncToLLVMPass()); pm.addPass(createReconcileUnrealizedCastsPass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -111,6 +111,13 @@ return PrimaryType::kI16; if (elemTp.isInteger(8)) return PrimaryType::kI8; + if (auto complexTp = elemTp.dyn_cast()) { + auto complexEltTp = complexTp.getElementType(); + if (complexEltTp.isF64()) + return PrimaryType::kC64; + if (complexEltTp.isF32()) + return PrimaryType::kC32; + } llvm_unreachable("Unknown primary type"); } @@ -128,6 +135,10 @@ return "I16"; case PrimaryType::kI8: return "I8"; + case PrimaryType::kC64: + return "C64"; + case PrimaryType::kC32: + return "C32"; } llvm_unreachable("Unknown PrimaryType"); } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRComplex MLIRIR MLIRLinalg ) diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" @@ -303,6 +304,7 @@ assert(isInvariant(tensorExps[e].children.e1)); return isSingleCondition(t, tensorExps[e].children.e0); case kMulF: + case kMulC: case kMulI: case kAndI: if (isSingleCondition(t, tensorExps[e].children.e0)) @@ -312,6 +314,7 @@ return isInvariant(tensorExps[e].children.e0); return false; case kAddF: + case kAddC: case kAddI: return isSingleCondition(t, tensorExps[e].children.e0) && isSingleCondition(t, tensorExps[e].children.e1); @@ -371,21 +374,18 @@ case kUnary: return "unary"; case kMulF: - return "*"; + case kMulC: case kMulI: return "*"; case kDivF: - return "/"; case kDivS: - return "/"; case kDivU: return "/"; case kAddF: - return "+"; + case kAddC: case kAddI: return "+"; case kSubF: - return "-"; case kSubI: return "-"; case kAndI: @@ -581,6 +581,7 @@ return takeDisj(kind, child0, buildLattices(rhs, i), unop); } case kMulF: + case kMulC: case kMulI: case kAndI: // A multiplicative operation only needs to be performed @@ -590,6 +591,8 @@ // ---+---+---+ // !x | 0 | 0 | // x | 0 |x*y| + // + // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); @@ -614,6 +617,7 @@ buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); case kAddF: + case kAddC: case kAddI: case kSubF: case kSubI: @@ -789,6 +793,8 @@ unsigned e1 = y.getValue(); if (isa(def)) return addExp(kMulF, e0, e1); + if (isa(def)) + return addExp(kMulC, e0, e1); if (isa(def)) return addExp(kMulI, e0, e1); if (isa(def) && !maybeZero(e1)) @@ -799,6 +805,8 @@ return addExp(kDivU, e0, e1); if (isa(def)) return addExp(kAddF, e0, e1); + if (isa(def)) + return addExp(kAddC, e0, e1); if (isa(def)) return addExp(kAddI, e0, e1); if (isa(def)) @@ -927,6 +935,8 @@ // Binary ops. case kMulF: return rewriter.create(loc, v0, v1); + case kMulC: + return rewriter.create(loc, v0, v1); case kMulI: return rewriter.create(loc, v0, v1); case kDivF: @@ -937,6 +947,8 @@ return rewriter.create(loc, v0, v1); case kAddF: return rewriter.create(loc, v0, v1); + case kAddC: + return rewriter.create(loc, v0, v1); case kAddI: return rewriter.create(loc, v0, v1); case kSubF: diff --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -33,6 +34,9 @@ #include #include +using complex64 = std::complex; +using complex32 = std::complex; + //===----------------------------------------------------------------------===// // // Internal support for storing and reading sparse tensors. @@ -287,6 +291,8 @@ virtual void getValues(std::vector **) { fatal("vali32"); } virtual void getValues(std::vector **) { fatal("vali16"); } virtual void getValues(std::vector **) { fatal("vali8"); } + virtual void getValues(std::vector **) { fatal("valc64"); } + virtual void getValues(std::vector **) { fatal("valc32"); } /// Element-wise insertion in lexicographic index order. virtual void lexInsert(const uint64_t *, double) { fatal("insf64"); } @@ -295,6 +301,8 @@ virtual void lexInsert(const uint64_t *, int32_t) { fatal("insi32"); } virtual void lexInsert(const uint64_t *, int16_t) { fatal("ins16"); } virtual void lexInsert(const uint64_t *, int8_t) { fatal("insi8"); } + virtual void lexInsert(const uint64_t *, complex64) { fatal("insc64"); } + virtual void lexInsert(const uint64_t *, complex32) { fatal("insc32"); } /// Expanded insertion. virtual void expInsert(uint64_t *, double *, bool *, uint64_t *, uint64_t) { @@ -315,6 +323,14 @@ virtual void expInsert(uint64_t *, int8_t *, bool *, uint64_t *, uint64_t) { fatal("expi8"); } + virtual void expInsert(uint64_t *, complex64 *, bool *, uint64_t *, + uint64_t) { + fatal("expc64"); + } + virtual void expInsert(uint64_t *, complex32 *, bool *, uint64_t *, + uint64_t) { + fatal("expc32"); + } /// Finishes insertion. virtual void endInsert() = 0; @@ -898,7 +914,7 @@ "dimension size mismatch"); SparseTensorCOO *tensor = SparseTensorCOO::newSparseTensorCOO(rank, idata + 2, perm, nnz); - // Read all nonzero elements. + // Read all nonzero elements. std::vector indices(rank); for (uint64_t k = 0; k < nnz; k++) { if (!fgets(line, kColWidth, file)) { @@ -1006,6 +1022,7 @@ static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse, uint64_t **pShape, V **pValues, uint64_t **pIndices) { + assert(tensor); auto sparseTensor = static_cast *>(tensor); uint64_t rank = sparseTensor->getRank(); @@ -1293,6 +1310,10 @@ CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t); CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t); + // Complex matrices with wide overhead. + CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64); + CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32); + // Unsupported case (add above if needed). fputs("unsupported combination of types\n", stderr); exit(1); @@ -1319,6 +1340,8 @@ IMPL_SPARSEVALUES(sparseValuesI32, int32_t, getValues) IMPL_SPARSEVALUES(sparseValuesI16, int16_t, getValues) IMPL_SPARSEVALUES(sparseValuesI8, int8_t, getValues) +IMPL_SPARSEVALUES(sparseValuesC64, complex64, getValues) +IMPL_SPARSEVALUES(sparseValuesC32, complex32, getValues) /// Helper to add value to coordinate scheme, one per value type. IMPL_ADDELT(addEltF64, double) @@ -1327,6 +1350,17 @@ IMPL_ADDELT(addEltI32, int32_t) IMPL_ADDELT(addEltI16, int16_t) IMPL_ADDELT(addEltI8, int8_t) +IMPL_ADDELT(addEltC64, complex64) +IMPL_ADDELT(addEltC32ABI, complex32) +// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without +// any padding (which seem to happen for complex32 when passed as scalar; +// all other cases, e.g. pointer to array, work as expected). +// TODO: cleaner way to avoid ABI padding problem? +void *_mlir_ciface_addEltC32(void *tensor, float r, float i, + StridedMemRefType *iref, + StridedMemRefType *pref) { + return _mlir_ciface_addEltC32ABI(tensor, complex32(r, i), iref, pref); +} /// Helper to enumerate elements of coordinate scheme, one per value type. IMPL_GETNEXT(getNextF64, double) @@ -1335,6 +1369,8 @@ IMPL_GETNEXT(getNextI32, int32_t) IMPL_GETNEXT(getNextI16, int16_t) IMPL_GETNEXT(getNextI8, int8_t) +IMPL_GETNEXT(getNextC64, complex64) +IMPL_GETNEXT(getNextC32, complex32) /// Insert elements in lexicographical index order, one per value type. IMPL_LEXINSERT(lexInsertF64, double) @@ -1343,6 +1379,17 @@ IMPL_LEXINSERT(lexInsertI32, int32_t) IMPL_LEXINSERT(lexInsertI16, int16_t) IMPL_LEXINSERT(lexInsertI8, int8_t) +IMPL_LEXINSERT(lexInsertC64, complex64) +IMPL_LEXINSERT(lexInsertC32ABI, complex32) +// Make prototype explicit to accept the !llvm.struct<(f32, f32)> without +// any padding (which seem to happen for complex32 when passed as scalar; +// all other cases, e.g. pointer to array, work as expected). +// TODO: cleaner way to avoid ABI padding problem? +void _mlir_ciface_lexInsertC32(void *tensor, + StridedMemRefType *cref, float r, + float i) { + _mlir_ciface_lexInsertC32ABI(tensor, cref, complex32(r, i)); +} /// Insert using expansion, one per value type. IMPL_EXPINSERT(expInsertF64, double) @@ -1351,6 +1398,8 @@ IMPL_EXPINSERT(expInsertI32, int32_t) IMPL_EXPINSERT(expInsertI16, int16_t) IMPL_EXPINSERT(expInsertI8, int8_t) +IMPL_EXPINSERT(expInsertC64, complex64) +IMPL_EXPINSERT(expInsertC32, complex32) #undef CASE #undef IMPL_SPARSEVALUES @@ -1379,6 +1428,12 @@ void outSparseTensorI8(void *tensor, void *dest, bool sort) { return outSparseTensor(tensor, dest, sort); } +void outSparseTensorC64(void *tensor, void *dest, bool sort) { + return outSparseTensor(tensor, dest, sort); +} +void outSparseTensorC32(void *tensor, void *dest, bool sort) { + return outSparseTensor(tensor, dest, sort); +} //===----------------------------------------------------------------------===// // @@ -1428,6 +1483,8 @@ IMPL_DELCOO(I32, int32_t) IMPL_DELCOO(I16, int16_t) IMPL_DELCOO(I8, int8_t) +IMPL_DELCOO(C64, complex64) +IMPL_DELCOO(C32, complex32) #undef IMPL_DELCOO /// Initializes sparse tensor from a COO-flavored format expressed using C-style @@ -1489,6 +1546,18 @@ return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, sparse); } +void *convertToMLIRSparseTensorC64(uint64_t rank, uint64_t nse, uint64_t *shape, + complex64 *values, uint64_t *indices, + uint64_t *perm, uint8_t *sparse) { + return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, + sparse); +} +void *convertToMLIRSparseTensorC32(uint64_t rank, uint64_t nse, uint64_t *shape, + complex32 *values, uint64_t *indices, + uint64_t *perm, uint8_t *sparse) { + return toMLIRSparseTensor(rank, nse, shape, values, indices, perm, + sparse); +} /// Converts a sparse tensor to COO-flavored format expressed using C-style /// data structures. The expected output parameters are pointers for these @@ -1540,6 +1609,18 @@ int8_t **pValues, uint64_t **pIndices) { fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, pIndices); } +void convertFromMLIRSparseTensorC64(void *tensor, uint64_t *pRank, + uint64_t *pNse, uint64_t **pShape, + complex64 **pValues, uint64_t **pIndices) { + fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, + pIndices); +} +void convertFromMLIRSparseTensorC32(void *tensor, uint64_t *pRank, + uint64_t *pNse, uint64_t **pShape, + complex32 **pValues, uint64_t **pIndices) { + fromMLIRSparseTensor(tensor, pRank, pNse, pShape, pValues, + pIndices); +} } // extern "C" diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex32.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +#trait_op = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)>, // b (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b(i)" +} + +module { + func.func @cadd(%arga: tensor, #SparseVector>, + %argb: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op + ins(%arga, %argb: tensor, #SparseVector>, + tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %b: complex, %x: complex): + %1 = complex.add %a, %b : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @cmul(%arga: tensor, #SparseVector>, + %argb: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op + ins(%arga, %argb: tensor, #SparseVector>, + tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %b: complex, %x: complex): + %1 = complex.mul %a, %b : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @dump(%arg0: tensor, #SparseVector>, %d: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %mem = sparse_tensor.values %arg0 : tensor, #SparseVector> to memref> + scf.for %i = %c0 to %d step %c1 { + %v = memref.load %mem[%i] : memref> + %real = complex.re %v : complex + %imag = complex.im %v : complex + vector.print %real : f32 + vector.print %imag : f32 + } + return + } + + // Driver method to call and verify complex kernels. + func.func @entry() { + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [28], [31] ], + [ (511.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex> + %v2 = arith.constant sparse< + [ [1], [28], [31] ], + [ (1.0, 0.0), (2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex> + %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex> to tensor, #SparseVector> + %sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex> to tensor, #SparseVector> + + // Call sparse vector kernels. + %0 = call @cadd(%sv1, %sv2) + : (tensor, #SparseVector>, + tensor, #SparseVector>) -> tensor, #SparseVector> + %1 = call @cmul(%sv1, %sv2) + : (tensor, #SparseVector>, + tensor, #SparseVector>) -> tensor, #SparseVector> + + // + // Verify the results. + // + // CHECK: 511.13 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 4 + // CHECK-NEXT: 8 + // CHECK-NEXT: 6 + // CHECK-NEXT: 6 + // CHECK-NEXT: 8 + // CHECK-NEXT: 15 + // CHECK-NEXT: 18 + // + %d1 = arith.constant 4 : index + %d2 = arith.constant 2 : index + call @dump(%0, %d1) : (tensor, #SparseVector>, index) -> () + call @dump(%1, %d2) : (tensor, #SparseVector>, index) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor, #SparseVector> + sparse_tensor.release %sv2 : tensor, #SparseVector> + sparse_tensor.release %0 : tensor, #SparseVector> + sparse_tensor.release %1 : tensor, #SparseVector> + return + } +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex64.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +#trait_op = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)>, // b (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b(i)" +} + +module { + func.func @cadd(%arga: tensor, #SparseVector>, + %argb: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op + ins(%arga, %argb: tensor, #SparseVector>, + tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %b: complex, %x: complex): + %1 = complex.add %a, %b : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @cmul(%arga: tensor, #SparseVector>, + %argb: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op + ins(%arga, %argb: tensor, #SparseVector>, + tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %b: complex, %x: complex): + %1 = complex.mul %a, %b : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @dump(%arg0: tensor, #SparseVector>, %d: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %mem = sparse_tensor.values %arg0 : tensor, #SparseVector> to memref> + scf.for %i = %c0 to %d step %c1 { + %v = memref.load %mem[%i] : memref> + %real = complex.re %v : complex + %imag = complex.im %v : complex + vector.print %real : f64 + vector.print %imag : f64 + } + return + } + + // Driver method to call and verify complex kernels. + func.func @entry() { + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [28], [31] ], + [ (511.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex> + %v2 = arith.constant sparse< + [ [1], [28], [31] ], + [ (1.0, 0.0), (2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex> + %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex> to tensor, #SparseVector> + %sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex> to tensor, #SparseVector> + + // Call sparse vector kernels. + %0 = call @cadd(%sv1, %sv2) + : (tensor, #SparseVector>, + tensor, #SparseVector>) -> tensor, #SparseVector> + %1 = call @cmul(%sv1, %sv2) + : (tensor, #SparseVector>, + tensor, #SparseVector>) -> tensor, #SparseVector> + + // + // Verify the results. + // + // CHECK: 511.13 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 4 + // CHECK-NEXT: 8 + // CHECK-NEXT: 6 + // CHECK-NEXT: 6 + // CHECK-NEXT: 8 + // CHECK-NEXT: 15 + // CHECK-NEXT: 18 + // + %d1 = arith.constant 4 : index + %d2 = arith.constant 2 : index + call @dump(%0, %d1) : (tensor, #SparseVector>, index) -> () + call @dump(%1, %d2) : (tensor, #SparseVector>, index) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor, #SparseVector> + sparse_tensor.release %sv2 : tensor, #SparseVector> + sparse_tensor.release %0 : tensor, #SparseVector> + sparse_tensor.release %1 : tensor, #SparseVector> + return + } +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2009,6 +2009,7 @@ includes = ["include"], deps = [ ":ArithmeticDialect", + ":ComplexDialect", ":IR", ":LinalgOps", ":MathDialect",