diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -83,6 +83,7 @@ kShrU, // unsigned kShlI, kBinary, // semiring binary op + kReduce, // semiring reduction op }; /// Children subexpressions of tensor operations. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -50,7 +50,7 @@ }; // Reduction kinds. -enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; +enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom }; // Code generation. struct CodeGen { @@ -376,6 +376,7 @@ static vector::CombiningKind getCombiningKind(Reduction kind) { switch (kind) { case kNoReduc: + case kCustom: break; case kSum: return vector::CombiningKind::ADD; @@ -391,6 +392,27 @@ llvm_unreachable("unknown reduction kind"); } +static bool isValidReduction(Kind kind) { + switch (kind) { + case Kind::kAddF: + case Kind::kAddC: + case Kind::kAddI: + case Kind::kSubF: + case Kind::kSubC: + case Kind::kSubI: + case Kind::kMulF: + case Kind::kMulC: + case Kind::kMulI: + case Kind::kAndI: + case Kind::kOrI: + case Kind::kXorI: + case Kind::kReduce: + return true; + default: + return false; + } +} + /// Maps operation to reduction. static Reduction getReduction(Kind kind) { switch (kind) { @@ -411,11 +433,37 @@ return kOr; case Kind::kXorI: return kXor; + case Kind::kReduce: + return kCustom; default: llvm_unreachable("unexpected reduction operator"); } } +/// Generates the reduction identity value based on the kind of operation. +/// The identity is meant to have no impact on the final reduction value +/// (i.e. x op identity == x). +static Value genReductionIdentity(Merger &merger, CodeGen &codegen, + OpBuilder &builder, linalg::GenericOp op, + Type tp) { + Location loc = op.getLoc(); + Kind kind = merger.exp(codegen.redExp).kind; + switch (kind) { + case kMulF: + return builder.create(loc, tp, + builder.getFloatAttr(tp, 1.0)); + case kMulI: + case kAndI: + return builder.create(loc, tp, + builder.getIntegerAttr(tp, 1)); + case kReduce: + return dyn_cast(merger.exp(codegen.redExp).op) + .identity(); + default: + return builder.create(loc, tp, builder.getZeroAttr(tp)); + } +} + /// Generates an initial value for a vector reduction, following the scheme /// given in Chapter 5 of "The Software Vectorization Handbook", where the /// initial scalar value is correctly embedded in the vector reduction value, @@ -425,6 +473,7 @@ Value r = codegen.redVal; switch (codegen.redKind) { case kNoReduc: + case kCustom: break; case kSum: case kXor: @@ -729,6 +778,36 @@ return builder.create(loc, codegen.expValues, index); } +/// Generates insertion code to implement dynamic tensor load for reduction. +static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen, + OpBuilder &builder, linalg::GenericOp op, + OpOperand *t) { + Location loc = op.getLoc(); + Type tp = getElementTypeOrSelf(t->get().getType()); + Value identity = genReductionIdentity(merger, codegen, builder, op, tp); + // Direct lexicographic index order, tensor loads as identity. + if (!codegen.expValues) { + return identity; + } + // Load from expanded access pattern if filled, identity otherwise. + Value index = genIndex(codegen, op, t); + Value isFilled = + builder.create(loc, codegen.expFilled, index); + scf::IfOp ifIsFilled = + builder.create(loc, tp, isFilled, /*else=*/true); + // True branch + builder.setInsertionPointToStart(ifIsFilled.thenBlock()); + Value valAtIndex = + builder.create(loc, codegen.expValues, index); + builder.create(loc, valAtIndex); + // False branch + builder.setInsertionPointToStart(ifIsFilled.elseBlock()); + builder.create(loc, identity); + builder.setInsertionPointAfter(ifIsFilled); + // End if + return ifIsFilled.getResult(0); +} + /// Generates insertion code to implement dynamic tensor store. static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, OpOperand *t, Value rhs) { @@ -783,8 +862,12 @@ } // Load during insertion. OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; - if (t == codegen.sparseOut) - return genInsertionLoad(codegen, builder, op, t); + if (t == codegen.sparseOut) { + if (codegen.redKind == kNoReduc) + return genInsertionLoad(codegen, builder, op, t); + else + return genInsertionLoadReduce(merger, codegen, builder, op, t); + } // Actual load. SmallVector args; Value ptr = genSubscript(codegen, builder, op, t, args); @@ -946,24 +1029,41 @@ /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, - linalg::GenericOp op, unsigned exp, unsigned ldx) { + linalg::GenericOp op, unsigned exp, unsigned ldx, + unsigned last = 0) { Location loc = op.getLoc(); if (exp == -1u) return Value(); - if (merger.exp(exp).kind == Kind::kTensor) - return genTensorLoad(merger, codegen, rewriter, op, exp); + if (merger.exp(exp).kind == Kind::kTensor) { + // Handle reductions for access pattern expansion. The trigger is when the + // output tensor is also an operand parameter, although this can also apply + // to inplace updates which are not reductions. To avoid errors, check that + // the operation is a permitted reduction. + bool validRed = isValidReduction(merger.exp(last).kind); + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; + OpOperand *lhs = op.getOutputOperand(0); + if (validRed && lhs == t) { + codegen.redKind = getReduction(merger.exp(last).kind); + codegen.redExp = last; // handling for reduction identity + } + Value redVal = genTensorLoad(merger, codegen, rewriter, op, exp); + if (validRed && lhs == t) + codegen.redExp = exp; + return redVal; + } if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); if (merger.exp(exp).kind == Kind::kIndex) return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); - Value v0 = - genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); - Value v1 = - genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); + Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, + ldx, exp); + Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, + ldx, exp); Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); if (ee && (merger.exp(exp).kind == Kind::kUnary || merger.exp(exp).kind == Kind::kBinary || - merger.exp(exp).kind == Kind::kBinaryBranch)) + merger.exp(exp).kind == Kind::kBinaryBranch || + merger.exp(exp).kind == Kind::kReduce)) ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); return ee; } @@ -992,7 +1092,7 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned ldx, - bool atStart, Kind last = Kind::kTensor) { + bool atStart, unsigned last = 0) { if (exp == -1u) return; if (merger.exp(exp).kind == Kind::kTensor) { @@ -1013,8 +1113,9 @@ if (lhs == t) { // Start or end a scalarized reduction if (atStart) { + codegen.redKind = getReduction(merger.exp(last).kind); + codegen.redExp = last; // handling for reduction identity Value load = genTensorLoad(merger, codegen, builder, op, exp); - codegen.redKind = getReduction(last); codegen.redExp = exp; updateReduc(merger, codegen, load); } else { @@ -1034,11 +1135,10 @@ // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - Kind last = merger.exp(exp).kind; unsigned e0 = merger.exp(exp).children.e0; unsigned e1 = merger.exp(exp).children.e1; - genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); - genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); + genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp); + genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp); } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -113,6 +113,7 @@ children.e1 = y; break; case kBinary: + case kReduce: assert(x != -1u && y != -1u && !v && o); children.e0 = x; children.e1 = y; @@ -375,6 +376,7 @@ case kOrI: case kXorI: case kBinary: + case kReduce: return false; } llvm_unreachable("unexpected kind"); @@ -474,6 +476,8 @@ return "<<"; case kBinary: return "binary"; + case kReduce: + return "reduce"; } llvm_unreachable("unexpected kind for symbol"); } @@ -551,6 +555,7 @@ case kShrU: case kShlI: case kBinary: + case kReduce: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; @@ -793,6 +798,11 @@ kBinaryBranch, leftYield, includeRight, kBinaryBranch, rightYield); } + case kReduce: + // A custom reduce operation. + return takeConj(kind, buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), + tensorExps[e].op); } llvm_unreachable("unexpected expression kind"); } @@ -962,7 +972,7 @@ } // Construct binary operations if subexpressions can be built. // See buildLattices() for an explanation of rejecting certain - // division and shift operations + // division and shift operations. if (def->getNumOperands() == 2) { auto x = buildTensorExp(op, def->getOperand(0)); auto y = buildTensorExp(op, def->getOperand(1)); @@ -1017,6 +1027,21 @@ } } } + // Construct ternary operations if subexpressions can be built. + if (def->getNumOperands() == 3) { + auto x = buildTensorExp(op, def->getOperand(0)); + auto y = buildTensorExp(op, def->getOperand(1)); + auto z = buildTensorExp(op, def->getOperand(2)); + if (x.hasValue() && y.hasValue() && z.hasValue()) { + unsigned e0 = x.getValue(); + unsigned e1 = y.getValue(); + // unsigned e2 = z.getValue(); + if (auto redop = dyn_cast(def)) { + if (isAdmissableBranch(redop, redop.region())) + return addExp(kReduce, e0, e1, Value(), def); + } + } + } // Cannot build. return None; } @@ -1066,6 +1091,13 @@ return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); } +static Value buildReduce(RewriterBase &rewriter, Location loc, Operation *op, + Value v0, Value v1) { + ReduceOp redop = cast(op); + Region &formula = redop.region(); + return insertYieldOp(rewriter, loc, formula, {v0, v1}); +} + Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1) { switch (tensorExps[e].kind) { @@ -1194,6 +1226,8 @@ return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); case kBinary: return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); + case kReduce: + return buildReduce(rewriter, loc, tensorExps[e].op, v0, v1); } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -289,4 +289,4 @@ sparse_tensor.yield %x : f64 } return %r : f64 -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir @@ -55,37 +55,38 @@ // // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR. // -// CHECK-LABEL: func @matmul2( +// CHECK-LABEL: func.func @matmul2( // CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[VAL_6:.*]] = arith.constant false // CHECK-DAG: %[[VAL_7:.*]] = arith.constant true // CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_19:.*]] = memref.alloca(%[[VAL_5]]) : memref -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref -// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_4]] { +// CHECK: %[[VAL_19:.*]] = memref.alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_3]] { // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref -// CHECK: memref.store %[[VAL_23]], %[[VAL_19]]{{\[}}%[[VAL_3]]] : memref +// CHECK: memref.store %[[VAL_23]], %[[VAL_19]]{{\[}}%[[VAL_2]]] : memref // CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_8]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref, memref, memref, index // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref -// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_22]], %[[VAL_4]] : index +// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref // CHECK: %[[VAL_33:.*]]:3 = scf.while (%[[VAL_34:.*]] = %[[VAL_28]], %[[VAL_35:.*]] = %[[VAL_31]], %[[VAL_36:.*]] = %[[VAL_27]]) : (index, index, index) -> (index, index, index) { // CHECK: %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_30]] : index // CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_32]] : index @@ -103,43 +104,49 @@ // CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (index) { // CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref // CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_41]], %[[VAL_4]] : index +// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index // CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_53]]] : memref -// CHECK: %[[VAL_55:.*]] = scf.for %[[VAL_56:.*]] = %[[VAL_52]] to %[[VAL_54]] step %[[VAL_4]] iter_args(%[[VAL_57:.*]] = %[[VAL_42]]) -> (index) { +// CHECK: %[[VAL_55:.*]] = scf.for %[[VAL_56:.*]] = %[[VAL_52]] to %[[VAL_54]] step %[[VAL_3]] iter_args(%[[VAL_57:.*]] = %[[VAL_42]]) -> (index) { // CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_56]]] : memref -// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_56]]] : memref -// CHECK: %[[VAL_61:.*]] = arith.mulf %[[VAL_51]], %[[VAL_60]] : f64 -// CHECK: %[[VAL_62:.*]] = arith.addf %[[VAL_59]], %[[VAL_61]] : f64 -// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_64:.*]] = arith.cmpi eq, %[[VAL_63]], %[[VAL_6]] : i1 -// CHECK: %[[VAL_65:.*]] = scf.if %[[VAL_64]] -> (index) { +// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_60:.*]] = scf.if %[[VAL_59]] -> (f64) { +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref +// CHECK: scf.yield %[[VAL_61]] : f64 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_5]] : f64 +// CHECK: } +// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_56]]] : memref +// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_51]], %[[VAL_62]] : f64 +// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_65:.*]], %[[VAL_63]] : f64 +// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_67:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_6]] : i1 +// CHECK: %[[VAL_68:.*]] = scf.if %[[VAL_67]] -> (index) { // CHECK: memref.store %[[VAL_7]], %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref // CHECK: memref.store %[[VAL_58]], %[[VAL_26]]{{\[}}%[[VAL_57]]] : memref -// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_57]], %[[VAL_4]] : index -// CHECK: scf.yield %[[VAL_66]] : index +// CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_57]], %[[VAL_3]] : index +// CHECK: scf.yield %[[VAL_69]] : index // CHECK: } else { // CHECK: scf.yield %[[VAL_57]] : index // CHECK: } -// CHECK: memref.store %[[VAL_62]], %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref -// CHECK: scf.yield %[[VAL_67:.*]] : index +// CHECK: memref.store %[[VAL_64]], %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref +// CHECK: scf.yield %[[VAL_70:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_68:.*]] : index +// CHECK: scf.yield %[[VAL_71:.*]] : index // CHECK: } else { // CHECK: scf.yield %[[VAL_42]] : index // CHECK: } -// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index -// CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_40]], %[[VAL_4]] : index -// CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_69]], %[[VAL_70]], %[[VAL_40]] : index -// CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index -// CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_41]], %[[VAL_4]] : index -// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_72]], %[[VAL_73]], %[[VAL_41]] : index -// CHECK: scf.yield %[[VAL_71]], %[[VAL_74]], %[[VAL_75:.*]] : index, index, index +// CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index +// CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_40]], %[[VAL_3]] : index +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_72]], %[[VAL_73]], %[[VAL_40]] : index +// CHECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index +// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index +// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_41]] : index +// CHECK: scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]] : index, index, index // CHECK: } -// CHECK: sparse_tensor.compress %[[VAL_8]], %[[VAL_19]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_76:.*]]#2 : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref, memref, memref, memref, index +// CHECK: sparse_tensor.compress %[[VAL_8]], %[[VAL_19]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_79:.*]]#2 : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref, memref, memref, memref, index // CHECK: } -// CHECK: %[[VAL_77:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: return %[[VAL_77]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_80:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[VAL_80]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } func.func @matmul2(%A: tensor<4x8xf64, #DCSR>, %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -312,93 +312,100 @@ doc = "C(i,j) = SUM_k A(i,k) * B(k,j)" } -// CHECK-LABEL: func @matmat( +// CHECK-LABEL: func.func @matmat( // CHECK-SAME: %[[VAL_0:.*]]: tensor>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor> { +// CHECK-SAME: %[[VAL_1:.*]]: tensor>) -> tensor> { // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant true -// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> -// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor(%[[VAL_7]], %[[VAL_8]]) : tensor> -// CHECK: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref -// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_1]] : tensor> to memref -// CHECK: %[[VAL_20:.*]] = memref.alloca(%[[VAL_4]]) : memref -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref -// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_21]] to %[[VAL_22]] step %[[VAL_3]] { -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref -// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]], %[[VAL_28:.*]] = sparse_tensor.expand %[[VAL_9]] : tensor> to memref, memref, memref, index -// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : index -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_30]]] : memref -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_34:.*]]:3 = scf.while (%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]]) : (index, index, index) -> (index, index, index) { -// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_31]] : index -// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_33]] : index -// CHECK: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1 -// CHECK: scf.condition(%[[VAL_40]]) %[[VAL_35]], %[[VAL_36]], %[[VAL_37]] : index, index, index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> +// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor> +// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor(%[[VAL_8]], %[[VAL_9]]) : tensor> +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref +// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_18:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_19:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_20:.*]] = sparse_tensor.values %[[VAL_1]] : tensor> to memref +// CHECK: %[[VAL_21:.*]] = memref.alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_22]] to %[[VAL_23]] step %[[VAL_3]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref +// CHECK: memref.store %[[VAL_25]], %[[VAL_21]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]], %[[VAL_28:.*]], %[[VAL_29:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor> to memref, memref, memref, index +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_24]], %[[VAL_3]] : index +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_35:.*]]:3 = scf.while (%[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_33]], %[[VAL_38:.*]] = %[[VAL_29]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_32]] : index +// CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_34]] : index +// CHECK: %[[VAL_41:.*]] = arith.andi %[[VAL_39]], %[[VAL_40]] : i1 +// CHECK: scf.condition(%[[VAL_41]]) %[[VAL_36]], %[[VAL_37]], %[[VAL_38]] : index, index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index): -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_42]]] : memref -// CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_44]] : index -// CHECK: %[[VAL_47:.*]] = arith.select %[[VAL_46]], %[[VAL_45]], %[[VAL_44]] : index -// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index -// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index -// CHECK: %[[VAL_50:.*]] = arith.andi %[[VAL_48]], %[[VAL_49]] : i1 -// CHECK: %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (index) { -// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_42]]] : memref -// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index -// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_54]]] : memref -// CHECK: %[[VAL_56:.*]] = scf.for %[[VAL_57:.*]] = %[[VAL_53]] to %[[VAL_55]] step %[[VAL_3]] iter_args(%[[VAL_58:.*]] = %[[VAL_43]]) -> (index) { -// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_57]]] : memref -// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref -// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref -// CHECK: %[[VAL_62:.*]] = arith.mulf %[[VAL_52]], %[[VAL_61]] : f32 -// CHECK: %[[VAL_63:.*]] = arith.addf %[[VAL_60]], %[[VAL_62]] : f32 -// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref -// CHECK: %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_64]], %[[VAL_5]] : i1 -// CHECK: %[[VAL_66:.*]] = scf.if %[[VAL_65]] -> (index) { -// CHECK: memref.store %[[VAL_6]], %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref -// CHECK: memref.store %[[VAL_59]], %[[VAL_27]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index -// CHECK: scf.yield %[[VAL_67]] : index +// CHECK: ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, %[[VAL_44:.*]]: index): +// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_42]]] : memref +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_47:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_45]] : index +// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[VAL_46]], %[[VAL_45]] : index +// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_48]] : index +// CHECK: %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_48]] : index +// CHECK: %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1 +// CHECK: %[[VAL_52:.*]] = scf.if %[[VAL_51]] -> (index) { +// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_42]]] : memref +// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_3]] iter_args(%[[VAL_59:.*]] = %[[VAL_44]]) -> (index) { +// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_60]]] : memref +// CHECK: %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (f32) { +// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_60]]] : memref +// CHECK: scf.yield %[[VAL_63]] : f32 // CHECK: } else { -// CHECK: scf.yield %[[VAL_58]] : index +// CHECK: scf.yield %[[VAL_5]] : f32 // CHECK: } -// CHECK: memref.store %[[VAL_63]], %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref -// CHECK: scf.yield %[[VAL_68:.*]] : index +// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_65:.*]] = arith.mulf %[[VAL_53]], %[[VAL_64]] : f32 +// CHECK: %[[VAL_66:.*]] = arith.addf %[[VAL_67:.*]], %[[VAL_65]] : f32 +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_60]]] : memref +// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_68]], %[[VAL_6]] : i1 +// CHECK: %[[VAL_70:.*]] = scf.if %[[VAL_69]] -> (index) { +// CHECK: memref.store %[[VAL_7]], %[[VAL_27]]{{\[}}%[[VAL_60]]] : memref +// CHECK: memref.store %[[VAL_60]], %[[VAL_28]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_59]], %[[VAL_3]] : index +// CHECK: scf.yield %[[VAL_71]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_59]] : index +// CHECK: } +// CHECK: memref.store %[[VAL_66]], %[[VAL_26]]{{\[}}%[[VAL_60]]] : memref +// CHECK: scf.yield %[[VAL_72:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_69:.*]] : index +// CHECK: scf.yield %[[VAL_73:.*]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_43]] : index +// CHECK: scf.yield %[[VAL_44]] : index // CHECK: } -// CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index -// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index -// CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_70]], %[[VAL_71]], %[[VAL_41]] : index -// CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index -// CHECK: %[[VAL_74:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index -// CHECK: %[[VAL_75:.*]] = arith.select %[[VAL_73]], %[[VAL_74]], %[[VAL_42]] : index -// CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]] : index, index, index +// CHECK: %[[VAL_74:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_48]] : index +// CHECK: %[[VAL_75:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index +// CHECK: %[[VAL_76:.*]] = arith.select %[[VAL_74]], %[[VAL_75]], %[[VAL_42]] : index +// CHECK: %[[VAL_77:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_48]] : index +// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index +// CHECK: %[[VAL_79:.*]] = arith.select %[[VAL_77]], %[[VAL_78]], %[[VAL_43]] : index +// CHECK: scf.yield %[[VAL_76]], %[[VAL_79]], %[[VAL_80:.*]] : index, index, index // CHECK: } -// CHECK: sparse_tensor.compress %[[VAL_9]], %[[VAL_20]], %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_77:.*]]#2 : tensor>, memref, memref, memref, memref, index +// CHECK: sparse_tensor.compress %[[VAL_10]], %[[VAL_21]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]], %[[VAL_81:.*]]#2 : tensor>, memref, memref, memref, memref, index // CHECK: } -// CHECK: %[[VAL_78:.*]] = sparse_tensor.load %[[VAL_9]] hasInserts : tensor> -// CHECK: return %[[VAL_78]] : tensor> +// CHECK: %[[VAL_82:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor> +// CHECK: return %[[VAL_82]] : tensor> // CHECK: } func.func @matmat(%arga: tensor, %argb: tensor) -> tensor { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir @@ -4,10 +4,19 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed" ] +}> + #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> +#SparseCSCMatrix = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + #SparseTensor = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }> @@ -22,6 +31,15 @@ doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)" } +#redprod = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i)> // X (out) + ], + iterator_types = ["parallel", "reduction"], + doc = "X(i) = PROD_j A(i,j)" +} + module { func.func @redsum(%arga: tensor, %argb: tensor) @@ -43,6 +61,49 @@ return %0 : tensor } + func.func @redprod(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arga, %c0 : tensor + %xinit = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #redprod + ins(%arga: tensor) + outs(%xinit: tensor) { + ^bb(%a: i32, %x: i32): + %0 = arith.muli %x, %a : i32 + linalg.yield %0 : i32 + } -> tensor + return %0 : tensor + } + + func.func @redprod2(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arga, %c0 : tensor + %xinit = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #redprod + ins(%arga: tensor) + outs(%xinit: tensor) { + ^bb(%a: i32, %x: i32): + %0 = arith.muli %x, %a : i32 + linalg.yield %0 : i32 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector. + func.func @dumpvec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1 : i32 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<4xi32> + vector.print %1 : vector<4xi32> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dv[%c0], %d0: tensor, vector<4xi32> + vector.print %2 : vector<4xi32> + return + } + // Driver method to call and verify tensor kernel. func.func @entry() { %c0 = arith.constant 0 : index @@ -60,16 +121,33 @@ %st2 = sparse_tensor.convert %t2 : tensor<3x3x4xi32> to tensor + // Setup sparse 2-d tensors. + %m1 = arith.constant sparse< + [ [0, 3], [0, 4], [2, 3], [3, 0], [3, 2], [3, 4] ], [ 1, 2, 3, 4, 5, 6 ] + >: tensor<4x5xi32> + %smr = sparse_tensor.convert %m1 + : tensor<4x5xi32> to tensor + %smc = sparse_tensor.convert %m1 + : tensor<4x5xi32> to tensor + // Call kernel. %0 = call @redsum(%st1, %st2) : (tensor, tensor) -> tensor + %1 = call @redprod(%smr) + : (tensor) -> tensor + %2 = call @redprod2(%smc) + : (tensor) -> tensor // // Verify results. Only two entries stored in result. Correct structure. // // CHECK: ( 7, 69, -1, -1 ) // CHECK-NEXT: ( ( 0, 0, 0 ), ( 0, 7, 0 ), ( 0, 0, 69 ) ) + // CHECK-NEXT: ( 2, 3, 120, -1 ) + // CHECK-NEXT: ( 2, 0, 3, 120 ) + // CHECK-NEXT: ( 2, 3, 120, -1 ) + // CHECK-NEXT: ( 2, 0, 3, 120 ) // %val = sparse_tensor.values %0 : tensor to memref @@ -79,11 +157,15 @@ : tensor to tensor %vm = vector.transfer_read %dm[%c0, %c0], %i0: tensor, vector<3x3xi32> vector.print %vm : vector<3x3xi32> + call @dumpvec(%1) : (tensor) -> () + call @dumpvec(%2) : (tensor) -> () // Release the resources. sparse_tensor.release %st1 : tensor sparse_tensor.release %st2 : tensor sparse_tensor.release %0 : tensor + sparse_tensor.release %1 : tensor + sparse_tensor.release %2 : tensor return } } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir @@ -0,0 +1,180 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +// +// Traits for tensor operations. +// +#trait_matmul = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,k)>, // A + affine_map<(i,j,k) -> (k,j)>, // B + affine_map<(i,j,k) -> (i,j)> // C (out) + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "C(i,j) = SUM_k A(i,k) * B(k,j)" +} + +#trait_mat_reduce = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i)> // X (out) + ], + iterator_types = ["parallel", "reduce"] +} + +module { + // Creates a new sparse vector using the minimum values from two input sparse vectors. + // When there is no overlap, include the present value in the output. + func.func @min_plus_csrcsr(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %maxf = arith.constant 1.0e999 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %argb, %c1 : tensor + %xm = bufferization.alloc_tensor(%d0, %d1) : tensor + %0 = linalg.generic #trait_matmul + ins(%arga, %argb: tensor, tensor) + outs(%xm: tensor) { + ^bb(%a: f64, %b: f64, %output: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap = { + ^bb0(%x: f64, %y: f64): + %3 = arith.addf %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + left={} + right={} + %2 = sparse_tensor.reduce %1, %output, %maxf : f64 { + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpf "olt", %x, %y : f64 + %3 = arith.select %cmp, %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + func.func @min_plus_csrcsc(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %maxf = arith.constant 1.0e999 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %argb, %c1 : tensor + %xm = bufferization.alloc_tensor(%d0, %d1) : tensor + %0 = linalg.generic #trait_matmul + ins(%arga, %argb: tensor, tensor) + outs(%xm: tensor) { + ^bb(%a: f64, %b: f64, %output: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap = { + ^bb0(%x: f64, %y: f64): + %3 = arith.addf %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + left={} + right={} + %2 = sparse_tensor.reduce %1, %output, %maxf : f64 { + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpf "olt", %x, %y : f64 + %3 = arith.select %cmp, %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func.func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dv[%c0], %d0: tensor, vector<32xf64> + vector.print %2 : vector<32xf64> + return + } + + // Dump a sparse matrix. + func.func @dump_mat(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor, vector<5x5xf64> + vector.print %2 : vector<5x5xf64> + return + } + + // Driver method to call and verify vector kernels. + func.func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse matrices. + %m1 = arith.constant sparse< + [ [0,0], [0,1], [1,0], [2,2], [2,3], [2,4], [3,0], [3,2], [3,3] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<4x5xf64> + %m2 = arith.constant sparse< + [ [0,0], [1,3], [2,0], [2,3], [3,1], [4,1] ], + [6.0, 5.0, 4.0, 3.0, 2.0, 11.0 ] + > : tensor<5x4xf64> + %sm1 = sparse_tensor.convert %m1 : tensor<4x5xf64> to tensor + %sm2r = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor + %sm2c = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor + + // Call sparse matrix kernels. + %5 = call @min_plus_csrcsr(%sm1, %sm2r) + : (tensor, tensor) -> tensor + // COM: This is broken because the lex-insert version of matmul always inserts the + // COM: identity value, even if there is nothing to accumulate, resulting in a + // COM: dense output. + // COM: %6 = call @min_plus_csrcsc(%sm1, %sm2c) + // COM: : (tensor, tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 1, 2, 0, 0, 0 ), ( 3, 0, 0, 0, 0 ), ( 0, 0, 4, 5, 6 ), ( 7, 0, 8, 9, 0 ), ( -1, -1, -1, -1, -1 ) ) + // CHECK-NEXT: ( 6, 5, 4, 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 6, 0, 0, 0, -1 ), ( 0, 0, 0, 5, -1 ), ( 4, 0, 0, 3, -1 ), ( 0, 2, 0, 0, -1 ), ( 0, 11, 0, 0, -1 ) ) + // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 7, 0, 0, 7, -1 ), ( 9, 0, 0, 0, -1 ), ( 8, 7, 0, 7, -1 ), ( 12, 11, 0, 11, -1 ), ( -1, -1, -1, -1, -1 ) ) + // COM: CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, -1, -1, -1, -1, -1, -1, -1 ) + // COM: CHECK-NEXT: ( ( 7, 0, 0, 7, -1 ), ( 9, 0, 0, 0, -1 ), ( 8, 7, 0, 7, -1 ), ( 12, 11, 0, 11, -1 ), ( -1, -1, -1, -1, -1 ) ) + // + call @dump_mat(%sm1) : (tensor) -> () + call @dump_mat(%sm2r) : (tensor) -> () + call @dump_mat(%5) : (tensor) -> () + // COM: call @dump_mat(%6) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sm1 : tensor + sparse_tensor.release %sm2r : tensor + sparse_tensor.release %sm2c : tensor + sparse_tensor.release %5 : tensor + // COM: sparse_tensor.release %6 : tensor + return + } +} diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -262,6 +262,7 @@ case kUnary: case kShlI: case kBinary: + case kReduce: return compareExpression(tensorExp.children.e0, pattern->e0); // Binary operations. case kMulF: