diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -21,15 +21,20 @@ namespace sparse_tensor { /// Dimension level type for a tensor (undef means index does not appear). -enum class Dim { kSparse, kDense, kSingle, kUndef }; +enum Dim { kSparse, kDense, kSingle, kUndef }; /// Tensor expression kind. -enum class Kind { +enum Kind { // Leaf. - kTensor, + kTensor = 0, kInvariant, - kZero, - // Operation. + // Unary operations. + kAbsF, + kCeilF, + kFloorF, + kNegF, + kNegI, + // Binary operations. kMulF, kMulI, kDivF, @@ -41,6 +46,7 @@ kSubI, kAndI, kOrI, + kXorI, }; /// Children subexpressions of tensor operations. @@ -105,8 +111,7 @@ dims(t + 1, std::vector(l, Dim::kUndef)) {} /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0 = -1u, unsigned e1 = -1u, - Value v = Value()); + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()); unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } /// Adds an iteration lattice point. Returns its index. @@ -129,11 +134,10 @@ /// Returns the index of the new set. unsigned takeDisj(Kind kind, unsigned s0, unsigned s1); - /// Maps a zero operand over a lattice set, i.e. each lattice point on an - /// expression E is simply copied over, but with 0 OP E as new expression. - /// This is useful to deal with disjunctive, but non-commutative operators. - /// Returns the index of the new set. - unsigned mapZero(Kind kind, unsigned s0); + /// Maps the unary operator over the lattice set of the operand, i.e. each + /// lattice point on an expression E is simply copied over, but with OP E + /// as new expression. Returns the index of the new set. + unsigned mapSet(Kind kind, unsigned s0); /// Optimizes the iteration lattice points in the given set. This /// method should be called right before code generation to avoid diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -609,18 +609,22 @@ static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned exp) { Location loc = op.getLoc(); + if (exp == -1u) + return Value(); if (merger.exp(exp).kind == Kind::kTensor) return genTensorLoad(merger, codegen, rewriter, op, exp); if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); - if (merger.exp(exp).kind == Kind::kZero) { - Type tp = op.getOutputTensorTypes()[0].getElementType(); - merger.exp(exp).val = - rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); - return genInvariantValue(merger, codegen, rewriter, exp); - } Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); + if (merger.exp(exp).kind == Kind::kNegI) { + // TODO: no negi in std, need to make zero explicit. + Type tp = op.getOutputTensorTypes()[0].getElementType(); + v1 = v0; + v0 = rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + if (codegen.curVecLength > 1) + v0 = genVectorInvariantValue(codegen, rewriter, v0); + } return merger.buildExp(rewriter, loc, exp, v0, v1); } @@ -628,6 +632,8 @@ static void genInvariants(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned exp, unsigned ldx, bool hoist) { + if (exp == -1u) + return; if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; @@ -649,8 +655,7 @@ merger.exp(exp).val = hoist ? genTensorLoad(merger, codegen, rewriter, op, exp) : Value(); } - } else if (merger.exp(exp).kind != Kind::kInvariant && - merger.exp(exp).kind != Kind::kZero) { + } else if (merger.exp(exp).kind != Kind::kInvariant) { // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -28,8 +28,14 @@ case Kind::kInvariant: assert(x == -1u && y == -1u && v); break; - case Kind::kZero: - assert(x == -1u && y == -1u && !v); + case kAbsF: + case kCeilF: + case kFloorF: + case kNegF: + case kNegI: + assert(x != -1u && y == -1u && !v); + children.e0 = x; + children.e1 = y; break; default: assert(x != -1u && y != -1u && !v); @@ -89,22 +95,25 @@ unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { unsigned s = takeConj(kind, s0, s1); - // Followed by all in s0 and s1. + // Followed by all in s0. for (unsigned p : latSets[s0]) latSets[s].push_back(p); - if (Kind::kSubF <= kind && kind <= Kind::kSubI) - s1 = mapZero(kind, s1); + // Map binary 0-y to unary -y. + if (kind == Kind::kSubF) + s1 = mapSet(Kind::kNegF, s1); + else if (kind == Kind::kSubI) + s1 = mapSet(Kind::kNegI, s1); + // Followed by all in s1. for (unsigned p : latSets[s1]) latSets[s].push_back(p); return s; } -unsigned Merger::mapZero(Kind kind, unsigned s0) { - assert(Kind::kSubF <= kind && kind <= Kind::kSubI); +unsigned Merger::mapSet(Kind kind, unsigned s0) { + assert(Kind::kAbsF <= kind && kind <= Kind::kNegI); unsigned s = addSet(); - unsigned z = addExp(Kind::kZero); for (unsigned p : latSets[s0]) { - unsigned e = addExp(kind, z, latPoints[p].exp); + unsigned e = addExp(kind, latPoints[p].exp); latPoints.push_back(LatPoint(latPoints[p].bits, e)); latSets[s].push_back(latPoints.size() - 1); } @@ -194,6 +203,12 @@ switch (tensorExps[e].kind) { case Kind::kTensor: return tensorExps[e].tensor == t; + case kAbsF: + case kCeilF: + case kFloorF: + case kNegF: + case kNegI: + return isConjunction(t, tensorExps[e].children.e0); case Kind::kMulF: case Kind::kMulI: case Kind::kAndI: @@ -213,30 +228,9 @@ // Print methods (for debugging). // -static char kindToOpSymbol(Kind kind) { - switch (kind) { - case Kind::kMulF: - case Kind::kMulI: - return '*'; - case Kind::kDivF: - case Kind::kDivS: - case Kind::kDivU: - return '/'; - case Kind::kAddF: - case Kind::kAddI: - return '+'; - case Kind::kSubF: - case Kind::kSubI: - return '-'; - case Kind::kAndI: - return '&'; - case Kind::kOrI: - return '|'; - default: - break; - } - llvm_unreachable("unexpected kind"); -} +static const char *kOpSymbols[] = {"", "", "abs", "ceil", "floor", "-", + "-", "*", "*", "/", "/", "+", + "+", "-", "-", "&", "|", "^"}; void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { @@ -250,13 +244,18 @@ case Kind::kInvariant: llvm::dbgs() << "invariant"; break; - case Kind::kZero: - llvm::dbgs() << "zero"; + case kAbsF: + case kCeilF: + case kFloorF: + case kNegF: + case kNegI: + llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " "; + dumpExp(tensorExps[e].children.e0); break; default: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); - llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; + llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " "; dumpExp(tensorExps[e].children.e1); llvm::dbgs() << ")"; } @@ -315,8 +314,7 @@ Kind kind = tensorExps[e].kind; switch (kind) { case Kind::kTensor: - case Kind::kInvariant: - case Kind::kZero: { + case Kind::kInvariant: { // Either the index is really used in the tensor expression, or it is // set to the undefined index in that dimension. An invariant expression // is set to a synthetic tensor with undefined indices only. @@ -325,6 +323,18 @@ latSets[s].push_back(addLat(t, i, e)); return s; } + case kAbsF: + case kCeilF: + case kFloorF: + case kNegF: + case kNegI: + // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the + // lattice set of the operand through the operator into a new set. + // + // -y|!y | y | + // --+---+---+ + // | 0 |-y | + return mapSet(kind, buildLattices(tensorExps[e].children.e0, i)); case Kind::kMulF: case Kind::kMulI: case Kind::kAndI: @@ -357,16 +367,12 @@ return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); - case Kind::kSubF: - case Kind::kSubI: - // Special case: 0-y is -y. - if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero) - return mapZero(kind, // maps to 0-y with just y's lattices - buildLattices(tensorExps[e].children.e1, i)); - LLVM_FALLTHROUGH; case Kind::kAddF: case Kind::kAddI: + case Kind::kSubF: + case Kind::kSubI: case Kind::kOrI: + case Kind::kXorI: // An additive operation needs to be performed // for the disjunction of sparse iteration spaces. // @@ -420,10 +426,15 @@ if (def->getNumOperands() == 1) { auto x = buildTensorExp(op, def->getOperand(0)); if (x.hasValue()) { - unsigned e0 = addExp(Kind::kZero); - unsigned e1 = x.getValue(); + unsigned e = x.getValue(); + if (isa(def)) + return addExp(Kind::kAbsF, e); + if (isa(def)) + return addExp(Kind::kCeilF, e); + if (isa(def)) + return addExp(Kind::kFloorF, e); if (isa(def)) - return addExp(Kind::kSubF, e0, e1); + return addExp(Kind::kNegF, e); // TODO: no negi in std? } } @@ -457,6 +468,8 @@ return addExp(Kind::kAndI, e0, e1); if (isa(def)) return addExp(Kind::kOrI, e0, e1); + if (isa(def)) + return addExp(Kind::kXorI, e0, e1); } } // Cannot build. @@ -468,8 +481,18 @@ switch (tensorExps[e].kind) { case Kind::kTensor: case Kind::kInvariant: - case Kind::kZero: llvm_unreachable("unexpected non-op"); + case kAbsF: + return rewriter.create(loc, v0); + case kCeilF: + return rewriter.create(loc, v0); + case kFloorF: + return rewriter.create(loc, v0); + case kNegF: + return rewriter.create(loc, v0); + case kNegI: + assert(v1); // no negi in std + return rewriter.create(loc, v0, v1); case Kind::kMulF: return rewriter.create(loc, v0, v1); case Kind::kMulI: @@ -492,6 +515,8 @@ return rewriter.create(loc, v0, v1); case Kind::kOrI: return rewriter.create(loc, v0, v1); + case Kind::kXorI: + return rewriter.create(loc, v0, v1); } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir @@ -31,26 +31,120 @@ doc = "x(i) = a(i) OP c" } +// CHECK-LABEL: func @abs( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] { +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_13:.*]] = absf %[[VAL_12]] : f64 +// CHECK: memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64> +// CHECK: } +// CHECK: %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64> +// CHECK: return %[[VAL_14]] : tensor<32xf64> +func @abs(%arga: tensor<32xf64, #SV>, + %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { + %0 = linalg.generic #trait1 + ins(%arga: tensor<32xf64, #SV>) + outs(%argx: tensor<32xf64>) { + ^bb(%a: f64, %x: f64): + %0 = absf %a : f64 + linalg.yield %0 : f64 + } -> tensor<32xf64> + return %0 : tensor<32xf64> +} + +// CHECK-LABEL: func @ceil( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] { +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_13:.*]] = ceilf %[[VAL_12]] : f64 +// CHECK: memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64> +// CHECK: } +// CHECK: %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64> +// CHECK: return %[[VAL_14]] : tensor<32xf64> +// CHECK: } +func @ceil(%arga: tensor<32xf64, #SV>, + %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { + %0 = linalg.generic #trait1 + ins(%arga: tensor<32xf64, #SV>) + outs(%argx: tensor<32xf64>) { + ^bb(%a: f64, %x: f64): + %0 = ceilf %a : f64 + linalg.yield %0 : f64 + } -> tensor<32xf64> + return %0 : tensor<32xf64> +} + +// CHECK-LABEL: func @floor( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { +// CHECK: %[[VAL_2:.*]] = constant 0 : index +// CHECK: %[[VAL_3:.*]] = constant 1 : index +// CHECK: %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] { +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_13:.*]] = floorf %[[VAL_12]] : f64 +// CHECK: memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64> +// CHECK: } +// CHECK: %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64> +// CHECK: return %[[VAL_14]] : tensor<32xf64> +// CHECK: } +func @floor(%arga: tensor<32xf64, #SV>, + %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { + %0 = linalg.generic #trait1 + ins(%arga: tensor<32xf64, #SV>) + outs(%argx: tensor<32xf64>) { + ^bb(%a: f64, %x: f64): + %0 = floorf %a : f64 + linalg.yield %0 : f64 + } -> tensor<32xf64> + return %0 : tensor<32xf64> +} + // CHECK-LABEL: func @neg( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { // CHECK: %[[VAL_2:.*]] = constant 0 : index // CHECK: %[[VAL_3:.*]] = constant 1 : index -// CHECK: %[[VAL_4:.*]] = constant 0.000000e+00 : f64 -// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> -// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref -// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_3]] { -// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref -// CHECK: %[[VAL_14:.*]] = subf %[[VAL_4]], %[[VAL_13]] : f64 -// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf64> +// CHECK: %[[VAL_4:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_5:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> +// CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_8]] to %[[VAL_9]] step %[[VAL_3]] { +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_13:.*]] = negf %[[VAL_12]] : f64 +// CHECK: memref.store %[[VAL_13]], %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref<32xf64> // CHECK: } -// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf64> -// CHECK: return %[[VAL_15]] : tensor<32xf64> +// CHECK: %[[VAL_14:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf64> +// CHECK: return %[[VAL_14]] : tensor<32xf64> // CHECK: } func @neg(%arga: tensor<32xf64, #SV>, %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { @@ -132,47 +226,46 @@ // CHECK: %[[VAL_4:.*]] = constant 0 : index // CHECK: %[[VAL_5:.*]] = constant true // CHECK: %[[VAL_6:.*]] = constant 1 : index -// CHECK: %[[VAL_7:.*]] = constant 0.000000e+00 : f64 -// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> -// CHECK: %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> -// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_4]]] : memref -// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_6]]] : memref -// CHECK: %[[VAL_15:.*]]:2 = scf.while (%[[VAL_16:.*]] = %[[VAL_13]], %[[VAL_17:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { -// CHECK: %[[VAL_18:.*]] = cmpi ult, %[[VAL_16]], %[[VAL_14]] : index -// CHECK: scf.condition(%[[VAL_18]]) %[[VAL_16]], %[[VAL_17]] : index, index +// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> +// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xf64> +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index +// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_19:.*]]: index, %[[VAL_20:.*]]: index): -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref -// CHECK: %[[VAL_22:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index -// CHECK: scf.if %[[VAL_22]] { -// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64> -// CHECK: %[[VAL_25:.*]] = subf %[[VAL_23]], %[[VAL_24]] : f64 -// CHECK: memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64> +// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index): +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index +// CHECK: scf.if %[[VAL_21]] { +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64> +// CHECK: %[[VAL_24:.*]] = subf %[[VAL_22]], %[[VAL_23]] : f64 +// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64> // CHECK: } else { // CHECK: scf.if %[[VAL_5]] { -// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref<32xf64> -// CHECK: %[[VAL_27:.*]] = subf %[[VAL_7]], %[[VAL_26]] : f64 -// CHECK: memref.store %[[VAL_27]], %[[VAL_12]]{{\[}}%[[VAL_20]]] : memref<32xf64> +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xf64> +// CHECK: %[[VAL_26:.*]] = negf %[[VAL_25]] : f64 +// CHECK: memref.store %[[VAL_26]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xf64> // CHECK: } else { // CHECK: } // CHECK: } -// CHECK: %[[VAL_28:.*]] = cmpi eq, %[[VAL_21]], %[[VAL_20]] : index -// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index -// CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_19]] : index -// CHECK: %[[VAL_31:.*]] = addi %[[VAL_20]], %[[VAL_6]] : index -// CHECK: scf.yield %[[VAL_30]], %[[VAL_31]] : index, index +// CHECK: %[[VAL_27:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_28:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index +// CHECK: %[[VAL_29:.*]] = select %[[VAL_27]], %[[VAL_28]], %[[VAL_18]] : index +// CHECK: %[[VAL_30:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_29]], %[[VAL_30]] : index, index // CHECK: } -// CHECK: scf.for %[[VAL_32:.*]] = %[[VAL_33:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] { -// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_32]]] : memref<32xf64> -// CHECK: %[[VAL_35:.*]] = subf %[[VAL_7]], %[[VAL_34]] : f64 -// CHECK: memref.store %[[VAL_35]], %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<32xf64> +// CHECK: scf.for %[[VAL_31:.*]] = %[[VAL_32:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref<32xf64> +// CHECK: %[[VAL_34:.*]] = negf %[[VAL_33]] : f64 +// CHECK: memref.store %[[VAL_34]], %[[VAL_11]]{{\[}}%[[VAL_31]]] : memref<32xf64> // CHECK: } -// CHECK: %[[VAL_36:.*]] = memref.tensor_load %[[VAL_12]] : memref<32xf64> -// CHECK: return %[[VAL_36]] : tensor<32xf64> +// CHECK: %[[VAL_35:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xf64> +// CHECK: return %[[VAL_35]] : tensor<32xf64> // CHECK: } func @sub(%arga: tensor<32xf64, #SV>, %argb: tensor<32xf64>, diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir @@ -345,3 +345,62 @@ return %0 : tensor<32xi64> } +// CHECK-LABEL: func @xor( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant true +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64> +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index +// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index): +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index +// CHECK: scf.if %[[VAL_21]] { +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: %[[VAL_24:.*]] = xor %[[VAL_22]], %[[VAL_23]] : i64 +// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: } else { +// CHECK: scf.if %[[VAL_5]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index +// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xi64> +// CHECK: memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi64> +// CHECK: return %[[VAL_33]] : tensor<32xi64> +// CHECK: } +func @xor(%arga: tensor<32xi64, #SV>, + %argb: tensor<32xi64>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %0 = linalg.generic #trait2 + ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %b: i64, %x: i64): + %0 = xor %a, %b : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +} diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -145,14 +145,24 @@ switch (tensorExp.kind) { case Kind::kTensor: return tensorExp.tensor == pattern->tensorNum; - case Kind::kZero: - return true; + case Kind::kAbsF: + case Kind::kCeilF: + case Kind::kFloorF: + case Kind::kNegF: + case Kind::kNegI: + return compareExpression(tensorExp.children.e0, pattern->e0); case Kind::kMulF: case Kind::kMulI: + case Kind::kDivF: + case Kind::kDivS: + case Kind::kDivU: case Kind::kAddF: case Kind::kAddI: case Kind::kSubF: case Kind::kSubI: + case Kind::kAndI: + case Kind::kOrI: + case Kind::kXorI: return compareExpression(tensorExp.children.e0, pattern->e0) && compareExpression(tensorExp.children.e1, pattern->e1); default: