diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -50,6 +50,8 @@ kCastU, // unsigned kCastIdx, kTruncI, + kCIm, // complex.im + kCRe, // complex.re kBitCast, kBinaryBranch, // semiring unary branch created from a binary op kUnary, // semiring unary op diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -46,6 +46,8 @@ case kTanhF: case kNegF: case kNegI: + case kCIm: + case kCRe: assert(x != -1u && y == -1u && !v && !o); children.e0 = x; children.e1 = y; @@ -291,6 +293,8 @@ case kCastU: case kCastIdx: case kTruncI: + case kCIm: + case kCRe: case kBitCast: return isSingleCondition(t, tensorExps[e].children.e0); case kDivF: // note: x / c only @@ -367,6 +371,10 @@ case kCastU: case kCastIdx: case kTruncI: + case kCIm: + return "complex.im"; + case kCRe: + return "complex.re"; case kBitCast: return "cast"; case kBinaryBranch: @@ -526,6 +534,8 @@ } case kAbsF: case kCeilF: + case kCIm: + case kCRe: case kFloorF: case kSqrtF: case kExpm1F: @@ -776,6 +786,10 @@ return addExp(kCastIdx, e, v); if (isa(def)) return addExp(kTruncI, e, v); + if (isa(def)) + return addExp(kCIm, e); + if (isa(def)) + return addExp(kCRe, e); if (isa(def)) return addExp(kBitCast, e, v); if (isa(def)) @@ -930,6 +944,15 @@ return rewriter.create(loc, inferType(e, v0), v0); case kTruncI: return rewriter.create(loc, inferType(e, v0), v0); + case kCIm: + case kCRe: { + auto type = v0.getType().template cast(); + auto eltType = type.getElementType().template cast(); + if (tensorExps[e].kind == kCIm) + return rewriter.create(loc, eltType, v0); + + return rewriter.create(loc, eltType, v0); + } case kBitCast: return rewriter.create(loc, inferType(e, v0), v0); // Binary ops. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +#trait_op = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = OP a(i)" +} + +module { + func.func @cre(%arga: tensor, #SparseVector>) + -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_op + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor) { + ^bb(%a: complex, %x: f32): + %1 = complex.re %a : complex + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor + } + + func.func @cim(%arga: tensor, #SparseVector>) + -> tensor { + %c = arith.constant 0 : index + %d = tensor.dim %arga, %c : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_op + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor) { + ^bb(%a: complex, %x: f32): + %1 = complex.im %a : complex + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor + } + + func.func @dump(%arg0: tensor) { + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f32 + %values = sparse_tensor.values %arg0 : tensor to memref + %0 = vector.transfer_read %values[%c0], %d0: memref, vector<4xf32> + vector.print %0 : vector<4xf32> + %indices = sparse_tensor.indices %arg0, %c0 : tensor to memref + %1 = vector.transfer_read %indices[%c0], %c0: memref, vector<4xindex> + vector.print %1 : vector<4xindex> + return + } + + // Driver method to call and verify functions cim and cre. + func.func @entry() { + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [20], [31] ], + [ (5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex> + %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex> to tensor, #SparseVector> + + // Call sparse vector kernels. + %0 = call @cre(%sv1) + : (tensor, #SparseVector>) -> tensor + + %1 = call @cim(%sv1) + : (tensor, #SparseVector>) -> tensor + + // + // Verify the results. + // + // CHECK: ( 5.13, 3, 5, -1 ) + // CHECK-NEXT: ( 0, 20, 31, 0 ) + // CHECK-NEXT: ( 2, 4, 6, -1 ) + // CHECK-NEXT: ( 0, 20, 31, 0 ) + // + call @dump(%0) : (tensor) -> () + call @dump(%1) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor, #SparseVector> + sparse_tensor.release %0 : tensor + sparse_tensor.release %1 : tensor + return + } +}