diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRArithmetic MLIRBufferization + MLIRComplex MLIRFunc MLIRIR MLIRLLVMIR diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/ExecutionEngine/SparseTensorUtils.h" #include "mlir/IR/Builders.h" @@ -102,16 +103,27 @@ //===----------------------------------------------------------------------===// /// Generates a 0-valued constant of the given type. In addition to -/// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also -/// works for `RankedTensorType` and `VectorType` (for which it generates -/// a constant `DenseElementsAttr` of zeros). +/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`), +/// this also works for `RankedTensorType` and `VectorType` (for which it +/// generates a constant `DenseElementsAttr` of zeros). inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { + if (auto ctp = tp.dyn_cast()) { + auto zeroe = builder.getZeroAttr(ctp.getElementType()); + auto zeroa = builder.getArrayAttr({zeroe, zeroe}); + return builder.create(loc, tp, zeroa); + } return builder.create(loc, tp, builder.getZeroAttr(tp)); } /// Generates a 1-valued constant of the given type. This supports all /// the same types as `constantZero`. inline Value constantOne(OpBuilder &builder, Location loc, Type tp) { + if (auto ctp = tp.dyn_cast()) { + auto zeroe = builder.getZeroAttr(ctp.getElementType()); + auto onee = getOneAttr(builder, ctp.getElementType()); + auto zeroa = builder.getArrayAttr({onee, zeroe}); + return builder.create(loc, tp, zeroa); + } return builder.create(loc, tp, getOneAttr(builder, tp)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -189,5 +189,7 @@ if (tp.isIntOrIndex()) return builder.create(loc, arith::CmpIPredicate::ne, v, zero); + if (tp.dyn_cast()) + return builder.create(loc, v, zero); llvm_unreachable("Non-numeric type"); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -112,7 +113,8 @@ // The following operations and dialects may be introduced by the // rewriting rules, and are therefore marked as legal. target.addLegalOp(); target .addLegalDialect) -> !llvm.ptr +// CHECK-SAME: %[[A:.*]]: tensor) -> !llvm.ptr { // CHECK-DAG: %[[EmptyCOO:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[FromCOO:.*]] = arith.constant 2 : i32 +// CHECK-DAG: %[[I0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[U:.*]] = tensor.dim %[[A]], %[[C0]] : tensor @@ -191,8 +192,11 @@ // CHECK: %[[T:.*]] = memref.cast %[[M]] : memref<1xindex> to memref // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[U]] step %[[C1]] { // CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor -// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex> -// CHECK: call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]]) +// CHECK: %[[N:.*]] = arith.cmpi ne, %[[E]], %[[I0]] : i32 +// CHECK: scf.if %[[N]] { +// CHECK: memref.store %[[I]], %[[M]][%[[C0]]] : memref<1xindex> +// CHECK: call @addEltI32(%[[C]], %[[E]], %[[T]], %[[Z]]) +// CHECK: } // CHECK: } // CHECK: %[[T:.*]] = call @newSparseTensor(%[[X]], %[[Y]], %[[Z]], %{{.*}}, %{{.*}}, %{{.*}}, %[[FromCOO]], %[[C]]) // CHECK: call @delSparseTensorCOOI32(%[[C]]) @@ -202,6 +206,28 @@ return %0 : tensor } +// CHECK-LABEL: func @sparse_convert_complex( +// CHECK-SAME: %[[A:.*]]: tensor<100xcomplex>) -> !llvm.ptr { +// CHECK-DAG: %[[CC:.*]] = complex.constant [0.000000e+00, 0.000000e+00] : complex +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C100:.*]] = arith.constant 100 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] { +// CHECK: %[[E:.*]] = tensor.extract %[[A]][%[[I]]] : tensor<100xcomplex> +// CHECK: %[[N:.*]] = complex.neq %[[E]], %[[CC]] : complex +// CHECK: scf.if %[[N]] { +// CHECK: memref.store %[[I]], %{{.*}}[%[[C0]]] : memref<1xindex> +// CHECK: call @addEltC64 +// CHECK: } +// CHECK: } +// CHECK: %[[T:.*]] = call @newSparseTensor +// CHECK: call @delSparseTensorCOOC64 +// CHECK: return %[[T]] : !llvm.ptr +func.func @sparse_convert_complex(%arg0: tensor<100xcomplex>) -> tensor<100xcomplex, #SparseVector> { + %0 = sparse_tensor.convert %arg0 : tensor<100xcomplex> to tensor<100xcomplex, #SparseVector> + return %0 : tensor<100xcomplex, #SparseVector> +} + // CHECK-LABEL: func @sparse_convert_1d_ss( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) // CHECK-DAG: %[[ToCOO:.*]] = arith.constant 5 : i32 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2033,6 +2033,7 @@ ":Affine", ":ArithmeticDialect", ":BufferizationDialect", + ":ComplexDialect", ":FuncDialect", ":FuncTransforms", ":IR",