diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -31,14 +31,18 @@ kIndex, // Unary operations. kAbsF, + kAbsC, kCeilF, kFloorF, kSqrtF, kExpm1F, kLog1pF, + kLog1pC, kSinF, + kSinC, kTanhF, kNegF, + kNegC, kNegI, kTruncF, kExtF, @@ -60,12 +64,14 @@ kMulC, kMulI, kDivF, + kDivC, // complex kDivS, // signed kDivU, // unsigned kAddF, kAddC, kAddI, kSubF, + kSubC, kSubI, kAndI, kOrI, diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -37,14 +37,18 @@ index = x; break; case kAbsF: + case kAbsC: case kCeilF: case kFloorF: case kSqrtF: case kExpm1F: case kLog1pF: + case kLog1pC: case kSinF: + case kSinC: case kTanhF: case kNegF: + case kNegC: case kNegI: case kCIm: case kCRe: @@ -151,6 +155,8 @@ // TODO: move this if-else logic into buildLattices if (kind == kSubF) s1 = mapSet(kNegF, s1); + else if (kind == kSubC) + s1 = mapSet(kNegC, s1); else if (kind == kSubI) s1 = mapSet(kNegI, s1); // Followed by all in s1. @@ -274,14 +280,18 @@ case kTensor: return tensorExps[e].tensor == t; case kAbsF: + case kAbsC: case kCeilF: case kFloorF: case kSqrtF: case kExpm1F: case kLog1pF: + case kLog1pC: case kSinF: + case kSinC: case kTanhF: case kNegF: + case kNegC: case kNegI: case kTruncF: case kExtF: @@ -298,6 +308,7 @@ case kBitCast: return isSingleCondition(t, tensorExps[e].children.e0); case kDivF: // note: x / c only + case kDivC: case kDivS: case kDivU: assert(!maybeZero(tensorExps[e].children.e1)); @@ -342,6 +353,7 @@ case kIndex: return "index"; case kAbsF: + case kAbsC: return "abs"; case kCeilF: return "ceil"; @@ -352,13 +364,15 @@ case kExpm1F: return "expm1"; case kLog1pF: + case kLog1pC: return "log1p"; case kSinF: + case kSinC: return "sin"; case kTanhF: return "tanh"; case kNegF: - return "-"; + case kNegC: case kNegI: return "-"; case kTruncF: @@ -386,6 +400,7 @@ case kMulI: return "*"; case kDivF: + case kDivC: case kDivS: case kDivU: return "/"; @@ -394,6 +409,7 @@ case kAddI: return "+"; case kSubF: + case kSubC: case kSubI: return "-"; case kAndI: @@ -533,6 +549,7 @@ return s; } case kAbsF: + case kAbsC: case kCeilF: case kCIm: case kCRe: @@ -540,9 +557,12 @@ case kSqrtF: case kExpm1F: case kLog1pF: + case kLog1pC: case kSinF: + case kSinC: case kTanhF: case kNegF: + case kNegC: case kNegI: case kTruncF: case kExtF: @@ -607,6 +627,7 @@ buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); case kDivF: + case kDivC: case kDivS: case kDivU: // A division is tricky, since 0/0, 0/c, c/0 all have @@ -630,6 +651,7 @@ case kAddC: case kAddI: case kSubF: + case kSubC: case kSubI: case kOrI: case kXorI: @@ -696,6 +718,11 @@ /// Only returns false if we are certain this is a nonzero. bool Merger::maybeZero(unsigned e) const { if (tensorExps[e].kind == kInvariant) { + if (auto c = tensorExps[e].val.getDefiningOp()) { + ArrayAttr arrayAttr = c.getValue(); + return arrayAttr[0].cast().getValue().isZero() && + arrayAttr[0].cast().getValue().isZero(); + } if (auto c = tensorExps[e].val.getDefiningOp()) return c.value() == 0; if (auto c = tensorExps[e].val.getDefiningOp()) @@ -750,6 +777,8 @@ unsigned e = x.getValue(); if (isa(def)) return addExp(kAbsF, e); + if (isa(def)) + return addExp(kAbsC, e); if (isa(def)) return addExp(kCeilF, e); if (isa(def)) @@ -760,12 +789,18 @@ return addExp(kExpm1F, e); if (isa(def)) return addExp(kLog1pF, e); + if (isa(def)) + return addExp(kLog1pC, e); if (isa(def)) return addExp(kSinF, e); + if (isa(def)) + return addExp(kSinC, e); if (isa(def)) return addExp(kTanhF, e); if (isa(def)) return addExp(kNegF, e); // no negi in std + if (isa(def)) + return addExp(kNegC, e); if (isa(def)) return addExp(kTruncF, e, v); if (isa(def)) @@ -813,6 +848,8 @@ return addExp(kMulI, e0, e1); if (isa(def) && !maybeZero(e1)) return addExp(kDivF, e0, e1); + if (isa(def) && !maybeZero(e1)) + return addExp(kDivC, e0, e1); if (isa(def) && !maybeZero(e1)) return addExp(kDivS, e0, e1); if (isa(def) && !maybeZero(e1)) @@ -825,6 +862,8 @@ return addExp(kAddI, e0, e1); if (isa(def)) return addExp(kSubF, e0, e1); + if (isa(def)) + return addExp(kSubC, e0, e1); if (isa(def)) return addExp(kSubI, e0, e1); if (isa(def)) @@ -902,6 +941,11 @@ // Unary ops. case kAbsF: return rewriter.create(loc, v0); + case kAbsC: { + auto type = v0.getType().template cast(); + auto eltType = type.getElementType().template cast(); + return rewriter.create(loc, eltType, v0); + } case kCeilF: return rewriter.create(loc, v0); case kFloorF: @@ -912,12 +956,18 @@ return rewriter.create(loc, v0); case kLog1pF: return rewriter.create(loc, v0); + case kLog1pC: + return rewriter.create(loc, v0); case kSinF: return rewriter.create(loc, v0); + case kSinC: + return rewriter.create(loc, v0); case kTanhF: return rewriter.create(loc, v0); case kNegF: return rewriter.create(loc, v0); + case kNegC: + return rewriter.create(loc, v0); case kNegI: // no negi in std return rewriter.create( loc, @@ -964,6 +1014,8 @@ return rewriter.create(loc, v0, v1); case kDivF: return rewriter.create(loc, v0, v1); + case kDivC: + return rewriter.create(loc, v0, v1); case kDivS: return rewriter.create(loc, v0, v1); case kDivU: @@ -976,6 +1028,8 @@ return rewriter.create(loc, v0, v1); case kSubF: return rewriter.create(loc, v0, v1); + case kSubC: + return rewriter.create(loc, v0, v1); case kSubI: return rewriter.create(loc, v0, v1); case kAndI: diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir @@ -0,0 +1,179 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +#trait_op1 = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = OP a(i)" +} + +#trait_op2 = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> (i)>, // b (in) + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP b(i)" +} + +module { + func.func @cops(%arga: tensor, #SparseVector>, + %argb: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c0 = arith.constant 0 : index + %d = tensor.dim %arga, %c0 : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op2 + ins(%arga, %argb: tensor, #SparseVector>, + tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %b: complex, %x: complex): + %1 = complex.neg %b : complex + %2 = complex.sub %a, %1 : complex + linalg.yield %2 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @csin(%arga: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c0 = arith.constant 0 : index + %d = tensor.dim %arga, %c0 : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %0 = linalg.generic #trait_op1 + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %x: complex): + %1 = complex.sin %a : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @cdiv(%arga: tensor, #SparseVector>) + -> tensor, #SparseVector> { + %c0 = arith.constant 0 : index + %d = tensor.dim %arga, %c0 : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor, #SparseVector> + %c = complex.constant [2.0 : f64, 0.0 : f64] : complex + %0 = linalg.generic #trait_op1 + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor, #SparseVector>) { + ^bb(%a: complex, %x: complex): + %1 = complex.div %a, %c : complex + linalg.yield %1 : complex + } -> tensor, #SparseVector> + return %0 : tensor, #SparseVector> + } + + func.func @cabs(%arga: tensor, #SparseVector>) + -> tensor { + %c0 = arith.constant 0 : index + %d = tensor.dim %arga, %c0 : tensor, #SparseVector> + %xv = sparse_tensor.init [%d] : tensor + %0 = linalg.generic #trait_op1 + ins(%arga: tensor, #SparseVector>) + outs(%xv: tensor) { + ^bb(%a: complex, %x: f64): + %1 = complex.abs %a : complex + linalg.yield %1 : f64 + } -> tensor + return %0 : tensor + } + + func.func @dumpc(%arg0: tensor, #SparseVector>, %d: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %mem = sparse_tensor.values %arg0 : tensor, #SparseVector> to memref> + scf.for %i = %c0 to %d step %c1 { + %v = memref.load %mem[%i] : memref> + %real = complex.re %v : complex + %imag = complex.im %v : complex + vector.print %real : f64 + vector.print %imag : f64 + } + return + } + + func.func @dumpf(%arg0: tensor) { + %c0 = arith.constant 0 : index + %d0 = arith.constant 0.0 : f64 + %values = sparse_tensor.values %arg0 : tensor to memref + %0 = vector.transfer_read %values[%c0], %d0: memref, vector<3xf64> + vector.print %0 : vector<3xf64> + return + } + + // Driver method to call and verify complex kernels. + func.func @entry() { + // Setup sparse vectors. + %v1 = arith.constant sparse< + [ [0], [28], [31] ], + [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex> + %v2 = arith.constant sparse< + [ [1], [28], [31] ], + [ (1.0, 0.0), (-2.0, 0.0), (3.0, 0.0) ] > : tensor<32xcomplex> + %sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex> to tensor, #SparseVector> + %sv2 = sparse_tensor.convert %v2 : tensor<32xcomplex> to tensor, #SparseVector> + + // Call sparse vector kernels. + %0 = call @cops(%sv1, %sv2) + : (tensor, #SparseVector>, + tensor, #SparseVector>) -> tensor, #SparseVector> + %1 = call @csin(%sv1) + : (tensor, #SparseVector>) -> tensor, #SparseVector> + %2 = call @cdiv(%sv1) + : (tensor, #SparseVector>) -> tensor, #SparseVector> + %3 = call @cabs(%sv1) + : (tensor, #SparseVector>) -> tensor + + // + // Verify the results. + // + %d3 = arith.constant 3 : index + %d4 = arith.constant 4 : index + // CHECK: -5.13 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 4 + // CHECK-NEXT: 8 + // CHECK-NEXT: 6 + call @dumpc(%0, %d4) : (tensor, #SparseVector>, index) -> () + // CHECK-NEXT: 3.43887 + // CHECK-NEXT: 1.47097 + // CHECK-NEXT: 3.85374 + // CHECK-NEXT: -27.0168 + // CHECK-NEXT: -193.43 + // CHECK-NEXT: 57.2184 + call @dumpc(%1, %d3) : (tensor, #SparseVector>, index) -> () + // CHECK-NEXT: -2.565 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1.5 + // CHECK-NEXT: 2 + // CHECK-NEXT: 2.5 + // CHECK-NEXT: 3 + call @dumpc(%2, %d3) : (tensor, #SparseVector>, index) -> () + // CHECK-NEXT: ( 5.50608, 5, 7.81025 ) + call @dumpf(%3) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sv1 : tensor, #SparseVector> + sparse_tensor.release %sv2 : tensor, #SparseVector> + sparse_tensor.release %0 : tensor, #SparseVector> + sparse_tensor.release %1 : tensor, #SparseVector> + sparse_tensor.release %2 : tensor, #SparseVector> + sparse_tensor.release %3 : tensor + return + } +}