diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -35,12 +35,15 @@ kCeilF, kFloorF, kSqrtF, + kSqrtC, kExpm1F, + kExpm1C, kLog1pF, kLog1pC, kSinF, kSinC, kTanhF, + kTanhC, kNegF, kNegC, kNegI, diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -41,12 +41,15 @@ case kCeilF: case kFloorF: case kSqrtF: + case kSqrtC: case kExpm1F: + case kExpm1C: case kLog1pF: case kLog1pC: case kSinF: case kSinC: case kTanhF: + case kTanhC: case kNegF: case kNegC: case kNegI: @@ -284,12 +287,15 @@ case kCeilF: case kFloorF: case kSqrtF: + case kSqrtC: case kExpm1F: + case kExpm1C: case kLog1pF: case kLog1pC: case kSinF: case kSinC: case kTanhF: + case kTanhC: case kNegF: case kNegC: case kNegI: @@ -360,8 +366,10 @@ case kFloorF: return "floor"; case kSqrtF: + case kSqrtC: return "sqrt"; case kExpm1F: + case kExpm1C: return "expm1"; case kLog1pF: case kLog1pC: @@ -370,6 +378,7 @@ case kSinC: return "sin"; case kTanhF: + case kTanhC: return "tanh"; case kNegF: case kNegC: @@ -449,10 +458,13 @@ case kCeilF: case kFloorF: case kSqrtF: + case kSqrtC: case kExpm1F: + case kExpm1C: case kLog1pF: case kSinF: case kTanhF: + case kTanhC: case kNegF: case kNegI: case kTruncF: @@ -555,12 +567,15 @@ case kCRe: case kFloorF: case kSqrtF: + case kSqrtC: case kExpm1F: + case kExpm1C: case kLog1pF: case kLog1pC: case kSinF: case kSinC: case kTanhF: + case kTanhC: case kNegF: case kNegC: case kNegI: @@ -785,8 +800,12 @@ return addExp(kFloorF, e); if (isa(def)) return addExp(kSqrtF, e); + if (isa(def)) + return addExp(kSqrtC, e); if (isa(def)) return addExp(kExpm1F, e); + if (isa(def)) + return addExp(kExpm1C, e); if (isa(def)) return addExp(kLog1pF, e); if (isa(def)) @@ -797,6 +816,8 @@ return addExp(kSinC, e); if (isa(def)) return addExp(kTanhF, e); + if (isa(def)) + return addExp(kTanhC, e); if (isa(def)) return addExp(kNegF, e); // no negi in std if (isa(def)) @@ -952,8 +973,12 @@ return rewriter.create(loc, v0); case kSqrtF: return rewriter.create(loc, v0); + case kSqrtC: + return rewriter.create(loc, v0); case kExpm1F: return rewriter.create(loc, v0); + case kExpm1C: + return rewriter.create(loc, v0); case kLog1pF: return rewriter.create(loc, v0); case kLog1pC: @@ -964,6 +989,8 @@ return rewriter.create(loc, v0); case kTanhF: return rewriter.create(loc, v0); + case kTanhC: + return rewriter.create(loc, v0); case kNegF: return rewriter.create(loc, v0); case kNegC: diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir @@ -59,6 +59,54 @@ return %0 : tensor, #SparseVector> } + func.func @complex_sqrt(%arga: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c0 = arith.constant 0 : index + %d = tensor.dim %arga, %c0 : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op1 + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %x: complex): + %1 = complex.sqrt %a : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @complex_tanh(%arga: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c0 = arith.constant 0 : index + %d = tensor.dim %arga, %c0 : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op1 + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %x: complex): + %1 = complex.tanh %a : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @clog1p_expm1(%arga: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c0 = arith.constant 0 : index + %d = tensor.dim %arga, %c0 : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op1 + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %x: complex): + %1 = complex.log1p %a : complex + // TODO(bixia): Enable this line after adding complex.expm1 to + // complex to standard lowering. + // %2 = complex.expm1 %1 : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + func.func @cdiv(%arga: tensor, #SparseVector>) -> tensor, #SparseVector> { %c0 = arith.constant 0 : index @@ -131,9 +179,15 @@ tensor, #SparseVector>) -> tensor, #SparseVector> %1 = call @csin(%sv1) : (tensor, #SparseVector>) -> tensor, #SparseVector> - %2 = call @cdiv(%sv1) + %2 = call @complex_sqrt(%sv1) + : (tensor, #SparseVector>) -> tensor, #SparseVector> + %3 = call @complex_tanh(%sv2) + : (tensor, #SparseVector>) -> tensor, #SparseVector> + %4 = call @clog1p_expm1(%sv1) : (tensor, #SparseVector>) -> tensor, #SparseVector> - %3 = call @cabs(%sv1) + %5 = call @cdiv(%sv1) + : (tensor, #SparseVector>) -> tensor, #SparseVector> + %6 = call @cabs(%sv1) : (tensor, #SparseVector>) -> tensor // @@ -157,15 +211,36 @@ // CHECK-NEXT: -193.43 // CHECK-NEXT: 57.2184 call @dumpc(%1, %d3) : (tensor, #SparseVector>, index) -> () + // CHECK-NEXT: 0.433635 + // CHECK-NEXT: 2.30609 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2.53083 + // CHECK-NEXT: 1.18538 + call @dumpc(%2, %d3) : (tensor, #SparseVector>, index) -> () + // CHECK-NEXT: 0.761594 + // CHECK-NEXT: 0 + // CHECK-NEXT: -0.964028 + // CHECK-NEXT: 0 + // CHECK-NEXT: 0.995055 + // CHECK-NEXT: 0 + call @dumpc(%3, %d3) : (tensor, #SparseVector>, index) -> () + // CHECK-NEXT: 1.52361 + // CHECK-NEXT: 2.69061 + // CHECK-NEXT: 1.73287 + // CHECK-NEXT: 0.785398 + // CHECK-NEXT: 2.13833 + // CHECK-NEXT: 0.785398 + call @dumpc(%4, %d3) : (tensor, #SparseVector>, index) -> () // CHECK-NEXT: -2.565 // CHECK-NEXT: 1 // CHECK-NEXT: 1.5 // CHECK-NEXT: 2 // CHECK-NEXT: 2.5 // CHECK-NEXT: 3 - call @dumpc(%2, %d3) : (tensor, #SparseVector>, index) -> () + call @dumpc(%5, %d3) : (tensor, #SparseVector>, index) -> () // CHECK-NEXT: ( 5.50608, 5, 7.81025 ) - call @dumpf(%3) : (tensor) -> () + call @dumpf(%6) : (tensor) -> () // Release the resources. sparse_tensor.release %sv1 : tensor, #SparseVector> @@ -173,7 +248,10 @@ sparse_tensor.release %0 : tensor, #SparseVector> sparse_tensor.release %1 : tensor, #SparseVector> sparse_tensor.release %2 : tensor, #SparseVector> - sparse_tensor.release %3 : tensor + sparse_tensor.release %3 : tensor, #SparseVector> + sparse_tensor.release %4 : tensor, #SparseVector> + sparse_tensor.release %5 : tensor, #SparseVector> + sparse_tensor.release %6 : tensor return } }