diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -47,6 +47,9 @@ kAndI, kOrI, kXorI, + kShrS, // signed + kShrU, // unsigned + kShlI, }; /// Children subexpressions of tensor operations. @@ -215,7 +218,8 @@ Value v1); private: - bool maybeZero(unsigned e); + bool maybeZero(unsigned e) const; + bool isInvariant(unsigned e) const; /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. Optional buildTensorExp(linalg::GenericOp op, Value v); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -208,13 +208,16 @@ case kFloorF: case kNegF: case kNegI: + case Kind::kDivF: // note: x / c only + case Kind::kDivS: + case Kind::kDivU: + case Kind::kShrS: // note: x >> inv only + case Kind::kShrU: + case Kind::kShlI: return isConjunction(t, tensorExps[e].children.e0); case Kind::kMulF: case Kind::kMulI: case Kind::kAndI: - case Kind::kDivF: // note: x / c only - case Kind::kDivS: - case Kind::kDivU: return isConjunction(t, tensorExps[e].children.e0) || isConjunction(t, tensorExps[e].children.e1); default: @@ -228,9 +231,9 @@ // Print methods (for debugging). // -static const char *kOpSymbols[] = {"", "", "abs", "ceil", "floor", "-", - "-", "*", "*", "/", "/", "+", - "+", "-", "-", "&", "|", "^"}; +static const char *kOpSymbols[] = { + "", "", "abs", "ceil", "floor", "-", "-", "*", "*", "/", "/", + "+", "+", "-", "-", "&", "|", "^", "a>>", ">>", "<<"}; void Merger::dumpExp(unsigned e) const { switch (tensorExps[e].kind) { @@ -383,6 +386,15 @@ return takeDisj(kind, // take binary disjunction buildLattices(tensorExps[e].children.e0, i), buildLattices(tensorExps[e].children.e1, i)); + case Kind::kShrS: + case Kind::kShrU: + case Kind::kShlI: + // A shift operation by an invariant amount (viz. tensor expressions + // can only occur at the left-hand-side of the operator) can be handled + // with the conjuction rule. + return takeConj(kind, // take binary conjunction + buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i)); } llvm_unreachable("unexpected expression kind"); } @@ -392,7 +404,7 @@ return buildTensorExp(op, yield->getOperand(0)); } -bool Merger::maybeZero(unsigned e) { +bool Merger::maybeZero(unsigned e) const { if (tensorExps[e].kind == Kind::kInvariant) { if (auto c = tensorExps[e].val.getDefiningOp()) return c.getValue() == 0; @@ -402,6 +414,10 @@ return true; } +bool Merger::isInvariant(unsigned e) const { + return tensorExps[e].kind == Kind::kInvariant; +} + Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { if (auto arg = v.dyn_cast()) { unsigned argN = arg.getArgNumber(); @@ -470,6 +486,12 @@ return addExp(Kind::kOrI, e0, e1); if (isa(def)) return addExp(Kind::kXorI, e0, e1); + if (isa(def) && isInvariant(e1)) + return addExp(Kind::kShrS, e0, e1); + if (isa(def) && isInvariant(e1)) + return addExp(Kind::kShrU, e0, e1); + if (isa(def) && isInvariant(e1)) + return addExp(Kind::kShlI, e0, e1); } } // Cannot build. @@ -517,6 +539,12 @@ return rewriter.create(loc, v0, v1); case Kind::kXorI: return rewriter.create(loc, v0, v1); + case Kind::kShrS: + return rewriter.create(loc, v0, v1); + case Kind::kShrU: + return rewriter.create(loc, v0, v1); + case Kind::kShlI: + return rewriter.create(loc, v0, v1); } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir @@ -404,3 +404,106 @@ } -> tensor<32xi64> return %0 : tensor<32xi64> } + +// CHECK-LABEL: func @ashrbyc( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_2:.*]] = constant 2 : i64 +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_14:.*]] = shift_right_signed %[[VAL_13]], %[[VAL_2]] : i64 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64> +// CHECK: return %[[VAL_15]] : tensor<32xi64> +// CHECK: } +func @ashrbyc(%arga: tensor<32xi64, #SV>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %c = constant 2 : i64 + %0 = linalg.generic #traitc + ins(%arga: tensor<32xi64, #SV>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %x: i64): + %0 = shift_right_signed %a, %c : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +} + +// CHECK-LABEL: func @lsrbyc( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_2:.*]] = constant 2 : i64 +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_14:.*]] = shift_right_unsigned %[[VAL_13]], %[[VAL_2]] : i64 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64> +// CHECK: return %[[VAL_15]] : tensor<32xi64> +// CHECK: } +func @lsrbyc(%arga: tensor<32xi64, #SV>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %c = constant 2 : i64 + %0 = linalg.generic #traitc + ins(%arga: tensor<32xi64, #SV>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %x: i64): + %0 = shift_right_unsigned %a, %c : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +} + +// CHECK-LABEL: func @lslbyc( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_2:.*]] = constant 2 : i64 +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_14:.*]] = shift_left %[[VAL_13]], %[[VAL_2]] : i64 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64> +// CHECK: return %[[VAL_15]] : tensor<32xi64> +// CHECK: } +func @lslbyc(%arga: tensor<32xi64, #SV>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %c = constant 2 : i64 + %0 = linalg.generic #traitc + ins(%arga: tensor<32xi64, #SV>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %x: i64): + %0 = shift_left %a, %c : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +} +