diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -544,6 +544,46 @@ let hasVerifier = 1; } +def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [NoSideEffect, SameOperandsAndResultType]>, + Arguments<(ins AnyType:$x, AnyType:$y, AnyType:$identity)>, + Results<(outs AnyType:$output)> { + let summary = "Custom reduction operation utilized within linalg.generic"; + let description = [{ + Defines a computation with a `linalg.generic` operation that takes two + operands and an identity value and reduces all values down to a single + result based on the computation in the region. + + The region must contain exactly one block taking two arguments. The block + must end with a sparse_tensor.yield and the output must match the input + argument types. + + Note that this operation is only required for custom reductions beyond the + standard operations (add, mul, and, or, etc). The `linalg.generic` + `iterator_types` defines which indices are being reduced. When the associated + operands are used in an operation, a reduction will occur. The use of this + explicit `reduce` operation is not required in most cases. + + Example: + + ```mlir + %cf1 = arith.constant 1.0 : f64 + %result = sparse_tensor.reduce %x, %y, %cf1 : f64 { + ^bb0(%a: f64, %b: f64): + %ret1 = arith.mulf %a, %b : f64 + %ret2 = arith.addf %ret1, %cf1 : f64 + sparse_tensor.yield %ret2 : f64 + } + ``` + }]; + + let regions = (region SizedRegion<1>:$region); + + let assemblyFormat = [{ + $x `,` $y `,` $identity attr-dict `:` type($output) $region + }]; + let hasVerifier = 1; +} + def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, Arguments<(ins AnyType:$result)> { let summary = "Yield from sparse_tensor set-like operations"; diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -83,6 +83,7 @@ kShrU, // unsigned kShlI, kBinary, // semiring binary op + kReduce, // semiring reduction op }; /// Children subexpressions of tensor operations. @@ -273,6 +274,12 @@ /// Returns index of the root expression. unsigned buildLattices(unsigned e, unsigned i); + /// Returns the reduction identity value based on the kind of operation. + /// The identity is meant to have no impact on the final reduction value + /// (i.e. x op identity == x). + Value getReductionIdentity(OpBuilder &builder, Location loc, unsigned e, + Type tp); + /// Builds a tensor expression from the given Linalg operation. /// Returns index of the root expression on success. Optional buildTensorExpFromLinalg(linalg::GenericOp op); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -357,15 +357,31 @@ return success(); } +LogicalResult ReduceOp::verify() { + Type inputType = x().getType(); + LogicalResult regionResult = success(); + + // Check correct number of block arguments and return type. + Region &formula = region(); + if (!formula.empty()) { + regionResult = verifyNumBlockArgs( + this, formula, "reduce", TypeRange{inputType, inputType}, inputType); + if (failed(regionResult)) + return regionResult; + } + + return success(); +} + LogicalResult YieldOp::verify() { // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); - if (auto binaryOp = dyn_cast(parentOp)) - return success(); - if (auto unaryOp = dyn_cast(parentOp)) + if (isa(parentOp) || isa(parentOp) || + isa(parentOp)) return success(); - return emitOpError("expected parent op to be sparse_tensor binary or unary"); + return emitOpError( + "expected parent op to be sparse_tensor unary, binary, or reduce"); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -44,7 +44,7 @@ enum SortMask { kSparseOnly = 0x0, kIncludeDense = 0x1, kIncludeUndef = 0x2 }; // Reduction kinds. -enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor }; +enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom }; // Code generation. struct CodeGen { @@ -360,6 +360,7 @@ static vector::CombiningKind getCombiningKind(Reduction kind) { switch (kind) { case kNoReduc: + case kCustom: break; case kSum: return vector::CombiningKind::ADD; @@ -375,6 +376,27 @@ llvm_unreachable("unknown reduction kind"); } +static bool isValidReduction(Kind kind) { + switch (kind) { + case Kind::kAddF: + case Kind::kAddC: + case Kind::kAddI: + case Kind::kSubF: + case Kind::kSubC: + case Kind::kSubI: + case Kind::kMulF: + case Kind::kMulC: + case Kind::kMulI: + case Kind::kAndI: + case Kind::kOrI: + case Kind::kXorI: + case Kind::kReduce: + return true; + default: + return false; + } +} + /// Maps operation to reduction. static Reduction getReduction(Kind kind) { switch (kind) { @@ -395,6 +417,8 @@ return kOr; case Kind::kXorI: return kXor; + case Kind::kReduce: + return kCustom; default: llvm_unreachable("unexpected reduction operator"); } @@ -409,6 +433,7 @@ Value r = codegen.redVal; switch (codegen.redKind) { case kNoReduc: + case kCustom: break; case kSum: case kXor: @@ -707,17 +732,39 @@ } /// Generates insertion code to implement dynamic tensor load. -static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, - linalg::GenericOp op, OpOperand *t) { +static Value genInsertionLoad(Merger &merger, CodeGen &codegen, + OpBuilder &builder, linalg::GenericOp op, + OpOperand *t) { Location loc = op.getLoc(); - // Direct lexicographic index order, tensor loads as zero. + Type tp = getElementTypeOrSelf(t->get().getType()); + // Direct lexicographic index order, tensor loads as identity. if (!codegen.expValues) { - Type tp = getElementTypeOrSelf(t->get().getType()); - return constantZero(builder, loc, tp); + if (codegen.redKind == kNoReduc) + return constantZero(builder, loc, tp); + else + return merger.getReductionIdentity(builder, loc, codegen.redExp, tp); } - // Load from expanded access pattern. + // Load from expanded access pattern if filled, identity otherwise. Value index = genIndex(codegen, op, t); - return builder.create(loc, codegen.expValues, index); + if (codegen.redKind == kNoReduc) + return builder.create(loc, codegen.expValues, index); + Value isFilled = + builder.create(loc, codegen.expFilled, index); + scf::IfOp ifIsFilled = + builder.create(loc, tp, isFilled, /*else=*/true); + // True branch + builder.setInsertionPointToStart(ifIsFilled.thenBlock()); + Value valAtIndex = + builder.create(loc, codegen.expValues, index); + builder.create(loc, valAtIndex); + // False branch + builder.setInsertionPointToStart(ifIsFilled.elseBlock()); + Value identity = + merger.getReductionIdentity(builder, loc, codegen.redExp, tp); + builder.create(loc, identity); + builder.setInsertionPointAfter(ifIsFilled); + // End if + return ifIsFilled.getResult(0); } /// Generates insertion code to implement dynamic tensor store. @@ -775,7 +822,7 @@ // Load during insertion. OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; if (t == codegen.sparseOut) - return genInsertionLoad(codegen, builder, op, t); + return genInsertionLoad(merger, codegen, builder, op, t); // Actual load. SmallVector args; Value ptr = genSubscript(codegen, builder, op, t, args); @@ -937,24 +984,41 @@ /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, - linalg::GenericOp op, unsigned exp, unsigned ldx) { + linalg::GenericOp op, unsigned exp, unsigned ldx, + unsigned last = 0) { Location loc = op.getLoc(); if (exp == -1u) return Value(); - if (merger.exp(exp).kind == Kind::kTensor) - return genTensorLoad(merger, codegen, rewriter, op, exp); + if (merger.exp(exp).kind == Kind::kTensor) { + // Handle reductions for access pattern expansion. The trigger is when the + // output tensor is also an operand parameter, although this can also apply + // to inplace updates which are not reductions. To avoid errors, check that + // the operation is a permitted reduction. + bool validRed = isValidReduction(merger.exp(last).kind); + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; + OpOperand *lhs = op.getOutputOperand(0); + if (validRed && lhs == t) { + codegen.redKind = getReduction(merger.exp(last).kind); + codegen.redExp = last; // handling for reduction identity + } + Value redVal = genTensorLoad(merger, codegen, rewriter, op, exp); + if (validRed && lhs == t) + codegen.redExp = exp; + return redVal; + } if (merger.exp(exp).kind == Kind::kInvariant) return genInvariantValue(merger, codegen, rewriter, exp); if (merger.exp(exp).kind == Kind::kIndex) return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); Value v0 = - genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); + genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx, exp); Value v1 = - genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); + genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx, exp); Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); if (ee && (merger.exp(exp).kind == Kind::kUnary || merger.exp(exp).kind == Kind::kBinary || - merger.exp(exp).kind == Kind::kBinaryBranch)) + merger.exp(exp).kind == Kind::kBinaryBranch || + merger.exp(exp).kind == Kind::kReduce)) ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); return ee; } @@ -983,7 +1047,7 @@ /// Hoists loop invariant tensor loads for which indices have been exhausted. static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op, unsigned exp, unsigned ldx, - bool atStart, Kind last = Kind::kTensor) { + bool atStart, unsigned last = 0) { if (exp == -1u) return; if (merger.exp(exp).kind == Kind::kTensor) { @@ -1004,8 +1068,9 @@ if (lhs == t) { // Start or end a scalarized reduction if (atStart) { + codegen.redKind = getReduction(merger.exp(last).kind); + codegen.redExp = last; // handling for reduction identity Value load = genTensorLoad(merger, codegen, builder, op, exp); - codegen.redKind = getReduction(last); codegen.redExp = exp; updateReduc(merger, codegen, load); } else { @@ -1025,11 +1090,10 @@ // Traverse into the binary operations. Note that we only hoist // tensor loads, since subsequent MLIR/LLVM passes know how to // deal with all other kinds of derived loop invariants. - Kind last = merger.exp(exp).kind; unsigned e0 = merger.exp(exp).children.e0; unsigned e1 = merger.exp(exp).children.e1; - genInvariants(merger, codegen, builder, op, e0, ldx, atStart, last); - genInvariants(merger, codegen, builder, op, e1, ldx, atStart, last); + genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp); + genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp); } } diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -113,6 +113,7 @@ children.e1 = y; break; case kBinary: + case kReduce: assert(x != -1u && y != -1u && !v && o); children.e0 = x; children.e1 = y; @@ -375,6 +376,7 @@ case kOrI: case kXorI: case kBinary: + case kReduce: return false; } llvm_unreachable("unexpected kind"); @@ -474,6 +476,8 @@ return "<<"; case kBinary: return "binary"; + case kReduce: + return "reduce"; } llvm_unreachable("unexpected kind for symbol"); } @@ -551,6 +555,7 @@ case kShrU: case kShlI: case kBinary: + case kReduce: llvm::dbgs() << "("; dumpExp(tensorExps[e].children.e0); llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; @@ -793,10 +798,33 @@ kBinaryBranch, leftYield, includeRight, kBinaryBranch, rightYield); } + case kReduce: + // A custom reduce operation. + return takeConj(kind, buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i), + tensorExps[e].op); } llvm_unreachable("unexpected expression kind"); } +Value Merger::getReductionIdentity(OpBuilder &builder, Location loc, unsigned e, + Type tp) { + Kind kind = tensorExps[e].kind; + switch (kind) { + case kMulF: + return builder.create(loc, tp, + builder.getFloatAttr(tp, 1.0)); + case kMulI: + case kAndI: + return builder.create(loc, tp, + builder.getIntegerAttr(tp, 1)); + case kReduce: + return dyn_cast(tensorExps[e].op).identity(); + default: + return builder.create(loc, tp, builder.getZeroAttr(tp)); + } +} + Optional Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { // Build the linalg semantics backward from yield. Operation *yield = op.region().front().getTerminator(); @@ -962,7 +990,7 @@ } // Construct binary operations if subexpressions can be built. // See buildLattices() for an explanation of rejecting certain - // division and shift operations + // division and shift operations. if (def->getNumOperands() == 2) { auto x = buildTensorExp(op, def->getOperand(0)); auto y = buildTensorExp(op, def->getOperand(1)); @@ -1017,6 +1045,21 @@ } } } + // Construct ternary operations if subexpressions can be built. + if (def->getNumOperands() == 3) { + auto x = buildTensorExp(op, def->getOperand(0)); + auto y = buildTensorExp(op, def->getOperand(1)); + auto z = buildTensorExp(op, def->getOperand(2)); + if (x.hasValue() && y.hasValue() && z.hasValue()) { + unsigned e0 = x.getValue(); + unsigned e1 = y.getValue(); + // unsigned e2 = z.getValue(); + if (auto redop = dyn_cast(def)) { + if (isAdmissableBranch(redop, redop.region())) + return addExp(kReduce, e0, e1, Value(), def); + } + } + } // Cannot build. return None; } @@ -1066,6 +1109,13 @@ return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); } +static Value buildReduce(RewriterBase &rewriter, Location loc, Operation *op, + Value v0, Value v1) { + ReduceOp redop = cast(op); + Region &formula = redop.region(); + return insertYieldOp(rewriter, loc, formula, {v0, v1}); +} + Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, Value v0, Value v1) { switch (tensorExps[e].kind) { @@ -1194,6 +1244,8 @@ return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); case kBinary: return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); + case kReduce: + return buildReduce(rewriter, loc, tensorExps[e].op, v0, v1); } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -253,6 +253,20 @@ // ----- +func.func @invalid_binary_wrong_yield(%arg0: f64, %arg1: f64) -> f64 { + // expected-error@+1 {{left region must end with sparse_tensor.yield}} + %0 = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64 + overlap={} + left={ + ^bb0(%x: f64): + tensor.yield %x : f64 + } + right=identity + return %0 : f64 +} + +// ----- + func.func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 { // expected-error@+1 {{present region argument 1 type mismatch}} %r = sparse_tensor.unary %arg0 : f64 to f64 @@ -290,3 +304,67 @@ absent={} return %0 : f64 } + +// ----- + +func.func @invalid_unary_wrong_yield(%arg0: f64) -> f64 { + // expected-error@+1 {{present region must end with sparse_tensor.yield}} + %0 = sparse_tensor.unary %arg0 : f64 to f64 + present={ + ^bb0(%x: f64): + tensor.yield %x : f64 + } + absent={} + return %0 : f64 +} + +// ----- + +func.func @invalid_reduce_num_args_mismatch(%arg0: f64, %arg1: f64) -> f64 { + %cf1 = arith.constant 1.0 : f64 + // expected-error@+1 {{reduce region must have exactly 2 arguments}} + %r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func.func @invalid_reduce_block_arg_type_mismatch(%arg0: i64, %arg1: i64) -> i64 { + %ci1 = arith.constant 1 : i64 + // expected-error@+1 {{reduce region argument 1 type mismatch}} + %r = sparse_tensor.reduce %arg0, %arg1, %ci1 : i64 { + ^bb0(%x: f64, %y: f64): + %cst = arith.constant 2 : i64 + sparse_tensor.yield %cst : i64 + } + return %r : i64 +} + +// ----- + +func.func @invalid_reduce_return_type_mismatch(%arg0: f64, %arg1: f64) -> f64 { + %cf1 = arith.constant 1.0 : f64 + // expected-error@+1 {{reduce region yield type mismatch}} + %r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 { + ^bb0(%x: f64, %y: f64): + %cst = arith.constant 2 : i64 + sparse_tensor.yield %cst : i64 + } + return %r : f64 +} + +// ----- + +func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 { + %cf1 = arith.constant 1.0 : f64 + // expected-error@+1 {{reduce region must end with sparse_tensor.yield}} + %r = sparse_tensor.reduce %arg0, %arg1, %cf1 : f64 { + ^bb0(%x: f64, %y: f64): + %cst = arith.constant 2 : i64 + tensor.yield %cst : i64 + } + return %r : f64 +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -268,3 +268,25 @@ absent={} return %r : i64 } + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_reduce_2d_to_1d( +// CHECK-SAME: %[[A:.*]]: f64, %[[B:.*]]: f64) -> f64 { +// CHECK: %[[Z:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[C1:.*]] = sparse_tensor.reduce %[[A]], %[[B]], %[[Z]] : f64 { +// CHECK: ^bb0(%[[A1:.*]]: f64, %[[B1:.*]]: f64): +// CHECK: sparse_tensor.yield %[[A1]] : f64 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 { + %cf0 = arith.constant 0.0 : f64 + %r = sparse_tensor.reduce %arg0, %arg1, %cf0 : f64 { + ^bb0(%x: f64, %y: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_kernels.mlir @@ -57,37 +57,38 @@ // // Computes C = A x B with all matrices sparse (SpMSpM) in DCSR. // -// CHECK-LABEL: func @matmul2( +// CHECK-LABEL: func.func @matmul2( // CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2 : index +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>>) -> tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> { +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f64 // CHECK-DAG: %[[VAL_6:.*]] = arith.constant false // CHECK-DAG: %[[VAL_7:.*]] = arith.constant true // CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_4]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_4]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_14:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_15:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref // CHECK: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<8x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref -// CHECK: %[[VAL_19:.*]] = memref.alloca(%[[VAL_5]]) : memref -// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref -// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_4]] { +// CHECK: %[[VAL_19:.*]] = memref.alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_3]] { // CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_22]]] : memref -// CHECK: memref.store %[[VAL_23]], %[[VAL_19]]{{\[}}%[[VAL_3]]] : memref +// CHECK: memref.store %[[VAL_23]], %[[VAL_19]]{{\[}}%[[VAL_2]]] : memref // CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_8]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref, memref, memref, index // CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref -// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_22]], %[[VAL_4]] : index +// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_22]], %[[VAL_3]] : index // CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_3]]] : memref // CHECK: %[[VAL_33:.*]]:3 = scf.while (%[[VAL_34:.*]] = %[[VAL_28]], %[[VAL_35:.*]] = %[[VAL_31]], %[[VAL_36:.*]] = %[[VAL_27]]) : (index, index, index) -> (index, index, index) { // CHECK: %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_30]] : index // CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_32]] : index @@ -105,43 +106,49 @@ // CHECK: %[[VAL_50:.*]] = scf.if %[[VAL_49]] -> (index) { // CHECK: %[[VAL_51:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref // CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_41]], %[[VAL_4]] : index +// CHECK: %[[VAL_53:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index // CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_53]]] : memref -// CHECK: %[[VAL_55:.*]] = scf.for %[[VAL_56:.*]] = %[[VAL_52]] to %[[VAL_54]] step %[[VAL_4]] iter_args(%[[VAL_57:.*]] = %[[VAL_42]]) -> (index) { +// CHECK: %[[VAL_55:.*]] = scf.for %[[VAL_56:.*]] = %[[VAL_52]] to %[[VAL_54]] step %[[VAL_3]] iter_args(%[[VAL_57:.*]] = %[[VAL_42]]) -> (index) { // CHECK: %[[VAL_58:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_56]]] : memref -// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_56]]] : memref -// CHECK: %[[VAL_61:.*]] = arith.mulf %[[VAL_51]], %[[VAL_60]] : f64 -// CHECK: %[[VAL_62:.*]] = arith.addf %[[VAL_59]], %[[VAL_61]] : f64 -// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_64:.*]] = arith.cmpi eq, %[[VAL_63]], %[[VAL_6]] : i1 -// CHECK: %[[VAL_65:.*]] = scf.if %[[VAL_64]] -> (index) { +// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_60:.*]] = scf.if %[[VAL_59]] -> (f64) { +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref +// CHECK: scf.yield %[[VAL_61]] : f64 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_5]] : f64 +// CHECK: } +// CHECK: %[[VAL_62:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_56]]] : memref +// CHECK: %[[VAL_63:.*]] = arith.mulf %[[VAL_51]], %[[VAL_62]] : f64 +// CHECK: %[[VAL_64:.*]] = arith.addf %[[VAL_65:.*]], %[[VAL_63]] : f64 +// CHECK: %[[VAL_66:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_67:.*]] = arith.cmpi eq, %[[VAL_66]], %[[VAL_6]] : i1 +// CHECK: %[[VAL_68:.*]] = scf.if %[[VAL_67]] -> (index) { // CHECK: memref.store %[[VAL_7]], %[[VAL_25]]{{\[}}%[[VAL_58]]] : memref // CHECK: memref.store %[[VAL_58]], %[[VAL_26]]{{\[}}%[[VAL_57]]] : memref -// CHECK: %[[VAL_66:.*]] = arith.addi %[[VAL_57]], %[[VAL_4]] : index -// CHECK: scf.yield %[[VAL_66]] : index +// CHECK: %[[VAL_69:.*]] = arith.addi %[[VAL_57]], %[[VAL_3]] : index +// CHECK: scf.yield %[[VAL_69]] : index // CHECK: } else { // CHECK: scf.yield %[[VAL_57]] : index // CHECK: } -// CHECK: memref.store %[[VAL_62]], %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref -// CHECK: scf.yield %[[VAL_67:.*]] : index +// CHECK: memref.store %[[VAL_64]], %[[VAL_24]]{{\[}}%[[VAL_58]]] : memref +// CHECK: scf.yield %[[VAL_70:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_68:.*]] : index +// CHECK: scf.yield %[[VAL_71:.*]] : index // CHECK: } else { // CHECK: scf.yield %[[VAL_42]] : index // CHECK: } -// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index -// CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_40]], %[[VAL_4]] : index -// CHECK: %[[VAL_71:.*]] = arith.select %[[VAL_69]], %[[VAL_70]], %[[VAL_40]] : index -// CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index -// CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_41]], %[[VAL_4]] : index -// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_72]], %[[VAL_73]], %[[VAL_41]] : index -// CHECK: scf.yield %[[VAL_71]], %[[VAL_74]], %[[VAL_75:.*]] : index, index, index +// CHECK: %[[VAL_72:.*]] = arith.cmpi eq, %[[VAL_43]], %[[VAL_46]] : index +// CHECK: %[[VAL_73:.*]] = arith.addi %[[VAL_40]], %[[VAL_3]] : index +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_72]], %[[VAL_73]], %[[VAL_40]] : index +// CHECK: %[[VAL_75:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_46]] : index +// CHECK: %[[VAL_76:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index +// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_75]], %[[VAL_76]], %[[VAL_41]] : index +// CHECK: scf.yield %[[VAL_74]], %[[VAL_77]], %[[VAL_78:.*]] : index, index, index // CHECK: } -// CHECK: sparse_tensor.compress %[[VAL_8]], %[[VAL_19]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_76:.*]]#2 : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref, memref, memref, memref, index +// CHECK: sparse_tensor.compress %[[VAL_8]], %[[VAL_19]], %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_79:.*]]#2 : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, memref, memref, memref, memref, index // CHECK: } -// CHECK: %[[VAL_77:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> -// CHECK: return %[[VAL_77]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_80:.*]] = sparse_tensor.load %[[VAL_8]] hasInserts : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: return %[[VAL_80]] : tensor<4x4xf64, #sparse_tensor.encoding<{{{.*}}}>> // CHECK: } func.func @matmul2(%A: tensor<4x8xf64, #DCSR>, %B: tensor<8x4xf64, #DCSR>) -> tensor<4x4xf64, #DCSR> { diff --git a/mlir/test/Dialect/SparseTensor/sparse_out.mlir b/mlir/test/Dialect/SparseTensor/sparse_out.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_out.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_out.mlir @@ -312,93 +312,100 @@ doc = "C(i,j) = SUM_k A(i,k) * B(k,j)" } -// CHECK-LABEL: func @matmat( +// CHECK-LABEL: func.func @matmat( // CHECK-SAME: %[[VAL_0:.*]]: tensor>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor> { +// CHECK-SAME: %[[VAL_1:.*]]: tensor>) -> tensor> { // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant false -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant true -// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> -// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor> -// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor(%[[VAL_7]], %[[VAL_8]]) : tensor> -// CHECK: %[[VAL_10:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_11:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_12:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_13:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref -// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor> to memref -// CHECK: %[[VAL_17:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_18:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor> to memref -// CHECK: %[[VAL_19:.*]] = sparse_tensor.values %[[VAL_1]] : tensor> to memref -// CHECK: %[[VAL_20:.*]] = memref.alloca(%[[VAL_4]]) : memref -// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_3]]] : memref -// CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_21]] to %[[VAL_22]] step %[[VAL_3]] { -// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref -// CHECK: memref.store %[[VAL_24]], %[[VAL_20]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]], %[[VAL_28:.*]] = sparse_tensor.expand %[[VAL_9]] : tensor> to memref, memref, memref, index -// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_23]]] : memref -// CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_23]], %[[VAL_3]] : index -// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_30]]] : memref -// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_2]]] : memref -// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_3]]] : memref -// CHECK: %[[VAL_34:.*]]:3 = scf.while (%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_32]], %[[VAL_37:.*]] = %[[VAL_28]]) : (index, index, index) -> (index, index, index) { -// CHECK: %[[VAL_38:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_31]] : index -// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_33]] : index -// CHECK: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1 -// CHECK: scf.condition(%[[VAL_40]]) %[[VAL_35]], %[[VAL_36]], %[[VAL_37]] : index, index, index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor> +// CHECK: %[[VAL_9:.*]] = tensor.dim %[[VAL_1]], %[[VAL_3]] : tensor> +// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor(%[[VAL_8]], %[[VAL_9]]) : tensor> +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor> to memref +// CHECK: %[[VAL_16:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_17:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_2]] : tensor> to memref +// CHECK: %[[VAL_18:.*]] = sparse_tensor.pointers %[[VAL_1]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_19:.*]] = sparse_tensor.indices %[[VAL_1]], %[[VAL_3]] : tensor> to memref +// CHECK: %[[VAL_20:.*]] = sparse_tensor.values %[[VAL_1]] : tensor> to memref +// CHECK: %[[VAL_21:.*]] = memref.alloca(%[[VAL_4]]) : memref +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_3]]] : memref +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_22]] to %[[VAL_23]] step %[[VAL_3]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref +// CHECK: memref.store %[[VAL_25]], %[[VAL_21]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_26:.*]], %[[VAL_27:.*]], %[[VAL_28:.*]], %[[VAL_29:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor> to memref, memref, memref, index +// CHECK: %[[VAL_30:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_24]], %[[VAL_3]] : index +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_2]]] : memref +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_35:.*]]:3 = scf.while (%[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_33]], %[[VAL_38:.*]] = %[[VAL_29]]) : (index, index, index) -> (index, index, index) { +// CHECK: %[[VAL_39:.*]] = arith.cmpi ult, %[[VAL_36]], %[[VAL_32]] : index +// CHECK: %[[VAL_40:.*]] = arith.cmpi ult, %[[VAL_37]], %[[VAL_34]] : index +// CHECK: %[[VAL_41:.*]] = arith.andi %[[VAL_39]], %[[VAL_40]] : i1 +// CHECK: scf.condition(%[[VAL_41]]) %[[VAL_36]], %[[VAL_37]], %[[VAL_38]] : index, index, index // CHECK: } do { -// CHECK: ^bb0(%[[VAL_41:.*]]: index, %[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index): -// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_42]]] : memref -// CHECK: %[[VAL_46:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_44]] : index -// CHECK: %[[VAL_47:.*]] = arith.select %[[VAL_46]], %[[VAL_45]], %[[VAL_44]] : index -// CHECK: %[[VAL_48:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index -// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index -// CHECK: %[[VAL_50:.*]] = arith.andi %[[VAL_48]], %[[VAL_49]] : i1 -// CHECK: %[[VAL_51:.*]] = scf.if %[[VAL_50]] -> (index) { -// CHECK: %[[VAL_52:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_41]]] : memref -// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_42]]] : memref -// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index -// CHECK: %[[VAL_55:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_54]]] : memref -// CHECK: %[[VAL_56:.*]] = scf.for %[[VAL_57:.*]] = %[[VAL_53]] to %[[VAL_55]] step %[[VAL_3]] iter_args(%[[VAL_58:.*]] = %[[VAL_43]]) -> (index) { -// CHECK: %[[VAL_59:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_57]]] : memref -// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref -// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_57]]] : memref -// CHECK: %[[VAL_62:.*]] = arith.mulf %[[VAL_52]], %[[VAL_61]] : f32 -// CHECK: %[[VAL_63:.*]] = arith.addf %[[VAL_60]], %[[VAL_62]] : f32 -// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref -// CHECK: %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_64]], %[[VAL_5]] : i1 -// CHECK: %[[VAL_66:.*]] = scf.if %[[VAL_65]] -> (index) { -// CHECK: memref.store %[[VAL_6]], %[[VAL_26]]{{\[}}%[[VAL_59]]] : memref -// CHECK: memref.store %[[VAL_59]], %[[VAL_27]]{{\[}}%[[VAL_58]]] : memref -// CHECK: %[[VAL_67:.*]] = arith.addi %[[VAL_58]], %[[VAL_3]] : index -// CHECK: scf.yield %[[VAL_67]] : index +// CHECK: ^bb0(%[[VAL_42:.*]]: index, %[[VAL_43:.*]]: index, %[[VAL_44:.*]]: index): +// CHECK: %[[VAL_45:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_42]]] : memref +// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_47:.*]] = arith.cmpi ult, %[[VAL_46]], %[[VAL_45]] : index +// CHECK: %[[VAL_48:.*]] = arith.select %[[VAL_47]], %[[VAL_46]], %[[VAL_45]] : index +// CHECK: %[[VAL_49:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_48]] : index +// CHECK: %[[VAL_50:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_48]] : index +// CHECK: %[[VAL_51:.*]] = arith.andi %[[VAL_49]], %[[VAL_50]] : i1 +// CHECK: %[[VAL_52:.*]] = scf.if %[[VAL_51]] -> (index) { +// CHECK: %[[VAL_53:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_42]]] : memref +// CHECK: %[[VAL_54:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_43]]] : memref +// CHECK: %[[VAL_55:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index +// CHECK: %[[VAL_56:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_55]]] : memref +// CHECK: %[[VAL_57:.*]] = scf.for %[[VAL_58:.*]] = %[[VAL_54]] to %[[VAL_56]] step %[[VAL_3]] iter_args(%[[VAL_59:.*]] = %[[VAL_44]]) -> (index) { +// CHECK: %[[VAL_60:.*]] = memref.load %[[VAL_19]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_61:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_60]]] : memref +// CHECK: %[[VAL_62:.*]] = scf.if %[[VAL_61]] -> (f32) { +// CHECK: %[[VAL_63:.*]] = memref.load %[[VAL_26]]{{\[}}%[[VAL_60]]] : memref +// CHECK: scf.yield %[[VAL_63]] : f32 // CHECK: } else { -// CHECK: scf.yield %[[VAL_58]] : index +// CHECK: scf.yield %[[VAL_5]] : f32 // CHECK: } -// CHECK: memref.store %[[VAL_63]], %[[VAL_25]]{{\[}}%[[VAL_59]]] : memref -// CHECK: scf.yield %[[VAL_68:.*]] : index +// CHECK: %[[VAL_64:.*]] = memref.load %[[VAL_20]]{{\[}}%[[VAL_58]]] : memref +// CHECK: %[[VAL_65:.*]] = arith.mulf %[[VAL_53]], %[[VAL_64]] : f32 +// CHECK: %[[VAL_66:.*]] = arith.addf %[[VAL_67:.*]], %[[VAL_65]] : f32 +// CHECK: %[[VAL_68:.*]] = memref.load %[[VAL_27]]{{\[}}%[[VAL_60]]] : memref +// CHECK: %[[VAL_69:.*]] = arith.cmpi eq, %[[VAL_68]], %[[VAL_6]] : i1 +// CHECK: %[[VAL_70:.*]] = scf.if %[[VAL_69]] -> (index) { +// CHECK: memref.store %[[VAL_7]], %[[VAL_27]]{{\[}}%[[VAL_60]]] : memref +// CHECK: memref.store %[[VAL_60]], %[[VAL_28]]{{\[}}%[[VAL_59]]] : memref +// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_59]], %[[VAL_3]] : index +// CHECK: scf.yield %[[VAL_71]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_59]] : index +// CHECK: } +// CHECK: memref.store %[[VAL_66]], %[[VAL_26]]{{\[}}%[[VAL_60]]] : memref +// CHECK: scf.yield %[[VAL_72:.*]] : index // CHECK: } -// CHECK: scf.yield %[[VAL_69:.*]] : index +// CHECK: scf.yield %[[VAL_73:.*]] : index // CHECK: } else { -// CHECK: scf.yield %[[VAL_43]] : index +// CHECK: scf.yield %[[VAL_44]] : index // CHECK: } -// CHECK: %[[VAL_70:.*]] = arith.cmpi eq, %[[VAL_44]], %[[VAL_47]] : index -// CHECK: %[[VAL_71:.*]] = arith.addi %[[VAL_41]], %[[VAL_3]] : index -// CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_70]], %[[VAL_71]], %[[VAL_41]] : index -// CHECK: %[[VAL_73:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_47]] : index -// CHECK: %[[VAL_74:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index -// CHECK: %[[VAL_75:.*]] = arith.select %[[VAL_73]], %[[VAL_74]], %[[VAL_42]] : index -// CHECK: scf.yield %[[VAL_72]], %[[VAL_75]], %[[VAL_76:.*]] : index, index, index +// CHECK: %[[VAL_74:.*]] = arith.cmpi eq, %[[VAL_45]], %[[VAL_48]] : index +// CHECK: %[[VAL_75:.*]] = arith.addi %[[VAL_42]], %[[VAL_3]] : index +// CHECK: %[[VAL_76:.*]] = arith.select %[[VAL_74]], %[[VAL_75]], %[[VAL_42]] : index +// CHECK: %[[VAL_77:.*]] = arith.cmpi eq, %[[VAL_46]], %[[VAL_48]] : index +// CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_43]], %[[VAL_3]] : index +// CHECK: %[[VAL_79:.*]] = arith.select %[[VAL_77]], %[[VAL_78]], %[[VAL_43]] : index +// CHECK: scf.yield %[[VAL_76]], %[[VAL_79]], %[[VAL_80:.*]] : index, index, index // CHECK: } -// CHECK: sparse_tensor.compress %[[VAL_9]], %[[VAL_20]], %[[VAL_25]], %[[VAL_26]], %[[VAL_27]], %[[VAL_77:.*]]#2 : tensor>, memref, memref, memref, memref, index +// CHECK: sparse_tensor.compress %[[VAL_10]], %[[VAL_21]], %[[VAL_26]], %[[VAL_27]], %[[VAL_28]], %[[VAL_81:.*]]#2 : tensor>, memref, memref, memref, memref, index // CHECK: } -// CHECK: %[[VAL_78:.*]] = sparse_tensor.load %[[VAL_9]] hasInserts : tensor> -// CHECK: return %[[VAL_78]] : tensor> +// CHECK: %[[VAL_82:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor> +// CHECK: return %[[VAL_82]] : tensor> // CHECK: } func.func @matmat(%arga: tensor, %argb: tensor) -> tensor { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_out_reduction.mlir @@ -4,10 +4,19 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s +#SparseVector = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed" ] +}> + #SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> +#SparseCSCMatrix = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + #SparseTensor = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }> @@ -22,6 +31,15 @@ doc = "X(i,j) = SUM_k A(i,j,k) * B(i,j,k)" } +#redprod = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i)> // X (out) + ], + iterator_types = ["parallel", "reduction"], + doc = "X(i) = PROD_j A(i,j)" +} + module { func.func @redsum(%arga: tensor, %argb: tensor) @@ -43,6 +61,51 @@ return %0 : tensor } + func.func @redprod(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arga, %c0 : tensor + %xinit = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #redprod + ins(%arga: tensor) + outs(%xinit: tensor) { + ^bb(%a: i32, %x: i32): + %0 = arith.muli %x, %a : i32 + linalg.yield %0 : i32 + } -> tensor + return %0 : tensor + } + + func.func @redprod2(%arga: tensor) -> tensor { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arga, %c0 : tensor + %xinit = bufferization.alloc_tensor(%d0): tensor + %0 = linalg.generic #redprod + ins(%arga: tensor) + outs(%xinit: tensor) { + ^bb(%a: i32, %x: i32): + %0 = arith.muli %x, %a : i32 + linalg.yield %0 : i32 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector. + func.func @dumpvec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1 : i32 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<4xi32> + vector.print %1 : vector<4xi32> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<4xi32> + vector.print %3 : vector<4xi32> + memref.dealloc %2 : memref + return + } + // Driver method to call and verify tensor kernel. func.func @entry() { %c0 = arith.constant 0 : index @@ -60,16 +123,33 @@ %st2 = sparse_tensor.convert %t2 : tensor<3x3x4xi32> to tensor + // Setup sparse 2-d tensors. + %m1 = arith.constant sparse< + [ [0, 3], [0, 4], [2, 3], [3, 0], [3, 2], [3, 4] ], [ 1, 2, 3, 4, 5, 6 ] + >: tensor<4x5xi32> + %smr = sparse_tensor.convert %m1 + : tensor<4x5xi32> to tensor + %smc = sparse_tensor.convert %m1 + : tensor<4x5xi32> to tensor + // Call kernel. %0 = call @redsum(%st1, %st2) : (tensor, tensor) -> tensor + %1 = call @redprod(%smr) + : (tensor) -> tensor + %2 = call @redprod2(%smc) + : (tensor) -> tensor // // Verify results. Only two entries stored in result. Correct structure. // // CHECK: ( 7, 69, -1, -1 ) // CHECK-NEXT: ( ( 0, 0, 0 ), ( 0, 7, 0 ), ( 0, 0, 69 ) ) + // CHECK-NEXT: ( 2, 3, 120, -1 ) + // CHECK-NEXT: ( 2, 0, 3, 120 ) + // CHECK-NEXT: ( 2, 3, 120, -1 ) + // CHECK-NEXT: ( 2, 0, 3, 120 ) // %val = sparse_tensor.values %0 : tensor to memref @@ -80,11 +160,15 @@ %db = bufferization.to_memref %dm : memref %vm = vector.transfer_read %db[%c0, %c0], %i0: memref, vector<3x3xi32> vector.print %vm : vector<3x3xi32> + call @dumpvec(%1) : (tensor) -> () + call @dumpvec(%2) : (tensor) -> () // Release the resources. sparse_tensor.release %st1 : tensor sparse_tensor.release %st2 : tensor sparse_tensor.release %0 : tensor + sparse_tensor.release %1 : tensor + sparse_tensor.release %2 : tensor memref.dealloc %db : memref return } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reduce_custom.mlir @@ -0,0 +1,184 @@ +// RUN: mlir-opt %s --sparse-compiler | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> +#CSC = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + dimOrdering = affine_map<(i,j) -> (j,i)> +}> + +// +// Traits for tensor operations. +// +#trait_matmul = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,k)>, // A + affine_map<(i,j,k) -> (k,j)>, // B + affine_map<(i,j,k) -> (i,j)> // C (out) + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "C(i,j) = SUM_k A(i,k) * B(k,j)" +} + +#trait_mat_reduce = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (in) + affine_map<(i,j) -> (i)> // X (out) + ], + iterator_types = ["parallel", "reduce"] +} + +module { + // Creates a new sparse vector using the minimum values from two input sparse vectors. + // When there is no overlap, include the present value in the output. + func.func @min_plus_csrcsr(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %maxf = arith.constant 1.0e999 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %argb, %c1 : tensor + %xm = bufferization.alloc_tensor(%d0, %d1) : tensor + %0 = linalg.generic #trait_matmul + ins(%arga, %argb: tensor, tensor) + outs(%xm: tensor) { + ^bb(%a: f64, %b: f64, %output: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap = { + ^bb0(%x: f64, %y: f64): + %3 = arith.addf %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + left={} + right={} + %2 = sparse_tensor.reduce %1, %output, %maxf : f64 { + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpf "olt", %x, %y : f64 + %3 = arith.select %cmp, %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + func.func @min_plus_csrcsc(%arga: tensor, + %argb: tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %maxf = arith.constant 1.0e999 : f64 + %d0 = tensor.dim %arga, %c0 : tensor + %d1 = tensor.dim %argb, %c1 : tensor + %xm = bufferization.alloc_tensor(%d0, %d1) : tensor + %0 = linalg.generic #trait_matmul + ins(%arga, %argb: tensor, tensor) + outs(%xm: tensor) { + ^bb(%a: f64, %b: f64, %output: f64): + %1 = sparse_tensor.binary %a, %b : f64, f64 to f64 + overlap = { + ^bb0(%x: f64, %y: f64): + %3 = arith.addf %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + left={} + right={} + %2 = sparse_tensor.reduce %1, %output, %maxf : f64 { + ^bb0(%x: f64, %y: f64): + %cmp = arith.cmpf "olt", %x, %y : f64 + %3 = arith.select %cmp, %x, %y : f64 + sparse_tensor.yield %3 : f64 + } + linalg.yield %2 : f64 + } -> tensor + return %0 : tensor + } + + // Dumps a sparse vector of type f64. + func.func @dump_vec(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + // Dump the dense vector to verify structure is correct. + %dv = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dv : memref + %3 = vector.transfer_read %2[%c0], %d0: memref, vector<32xf64> + vector.print %3 : vector<32xf64> + memref.dealloc %2 : memref + return + } + + // Dump a sparse matrix. + func.func @dump_mat(%arg0: tensor) { + // Dump the values array to verify only sparse contents are stored. + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = sparse_tensor.values %arg0 : tensor to memref + %1 = vector.transfer_read %0[%c0], %d0: memref, vector<16xf64> + vector.print %1 : vector<16xf64> + %dm = sparse_tensor.convert %arg0 : tensor to tensor + %2 = bufferization.to_memref %dm : memref + %3 = vector.transfer_read %2[%c0, %c0], %d0: memref, vector<5x5xf64> + vector.print %3 : vector<5x5xf64> + memref.dealloc %2 : memref + return + } + + // Driver method to call and verify vector kernels. + func.func @entry() { + %c0 = arith.constant 0 : index + + // Setup sparse matrices. + %m1 = arith.constant sparse< + [ [0,0], [0,1], [1,0], [2,2], [2,3], [2,4], [3,0], [3,2], [3,3] ], + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 ] + > : tensor<4x5xf64> + %m2 = arith.constant sparse< + [ [0,0], [1,3], [2,0], [2,3], [3,1], [4,1] ], + [6.0, 5.0, 4.0, 3.0, 2.0, 11.0 ] + > : tensor<5x4xf64> + %sm1 = sparse_tensor.convert %m1 : tensor<4x5xf64> to tensor + %sm2r = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor + %sm2c = sparse_tensor.convert %m2 : tensor<5x4xf64> to tensor + + // Call sparse matrix kernels. + %5 = call @min_plus_csrcsr(%sm1, %sm2r) + : (tensor, tensor) -> tensor + // COM: This is broken because the lex-insert version of matmul always inserts the + // COM: identity value, even if there is nothing to accumulate, resulting in a + // COM: dense output. + // COM: %6 = call @min_plus_csrcsc(%sm1, %sm2c) + // COM: : (tensor, tensor) -> tensor + + // + // Verify the results. + // + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 1, 2, 0, 0, 0 ), ( 3, 0, 0, 0, 0 ), ( 0, 0, 4, 5, 6 ), ( 7, 0, 8, 9, 0 ), ( -1, -1, -1, -1, -1 ) ) + // CHECK-NEXT: ( 6, 5, 4, 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 6, 0, 0, 0, -1 ), ( 0, 0, 0, 5, -1 ), ( 4, 0, 0, 3, -1 ), ( 0, 2, 0, 0, -1 ), ( 0, 11, 0, 0, -1 ) ) + // CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, -1, -1, -1, -1, -1, -1, -1 ) + // CHECK-NEXT: ( ( 7, 0, 0, 7, -1 ), ( 9, 0, 0, 0, -1 ), ( 8, 7, 0, 7, -1 ), ( 12, 11, 0, 11, -1 ), ( -1, -1, -1, -1, -1 ) ) + // COM: CHECK-NEXT: ( 7, 7, 9, 8, 7, 7, 12, 11, 11, -1, -1, -1, -1, -1, -1, -1 ) + // COM: CHECK-NEXT: ( ( 7, 0, 0, 7, -1 ), ( 9, 0, 0, 0, -1 ), ( 8, 7, 0, 7, -1 ), ( 12, 11, 0, 11, -1 ), ( -1, -1, -1, -1, -1 ) ) + // + call @dump_mat(%sm1) : (tensor) -> () + call @dump_mat(%sm2r) : (tensor) -> () + call @dump_mat(%5) : (tensor) -> () + // COM: call @dump_mat(%6) : (tensor) -> () + + // Release the resources. + sparse_tensor.release %sm1 : tensor + sparse_tensor.release %sm2r : tensor + sparse_tensor.release %sm2c : tensor + sparse_tensor.release %5 : tensor + // COM: sparse_tensor.release %6 : tensor + return + } +} diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -184,6 +184,7 @@ case kUnary: case kShlI: case kBinary: + case kReduce: return compareExpression(tensorExp.children.e0, pattern->e0); // Binary operations. case kMulF: