diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -101,7 +101,9 @@ } /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi], -/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. +/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note +/// that the sparse compiler can only generate indirect loads in +/// the last index, i.e. back(). static Value genVectorLoad(PatternRewriter &rewriter, Location loc, VL vl, Value ptr, ArrayRef idxs, Value vmask) { VectorType vtp = vectorType(vl, ptr); @@ -118,7 +120,9 @@ } /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs -/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. +/// where 'lo' denotes the current index and 'hi = lo + vl - 1'. Note +/// that the sparse compiler can only generate indirect stores in +/// the last index, i.e. back(). static void genVectorStore(PatternRewriter &rewriter, Location loc, Value ptr, ArrayRef idxs, Value vmask, Value rhs) { if (idxs.back().getType().isa()) { @@ -132,32 +136,52 @@ rewriter.create(loc, ptr, idxs, vmask, rhs); } -/// Maps operation to combining kind for reduction. -static vector::CombiningKind getCombiningKind(Operation *def) { - if (isa(def) || isa(def) || - isa(def) || isa(def)) - return vector::CombiningKind::ADD; - if (isa(def) || isa(def)) - return vector::CombiningKind::MUL; - if (isa(def)) - return vector::CombiningKind::AND; - if (isa(def)) - return vector::CombiningKind::OR; - if (isa(def)) - return vector::CombiningKind::XOR; - llvm_unreachable("unknown reduction kind"); +/// Detects a vectorizable reduction operations and returns the +/// combining kind of reduction on success in `kind`. +static bool isVectorizableReduction(Value red, Value iter, + vector::CombiningKind &kind) { + if (auto addf = red.getDefiningOp()) { + kind = vector::CombiningKind::ADD; + return addf->getOperand(0) == iter || addf->getOperand(1) == iter; + } else if (auto addi = red.getDefiningOp()) { + kind = vector::CombiningKind::ADD; + return addi->getOperand(0) == iter || addi->getOperand(1) == iter; + } else if (auto subf = red.getDefiningOp()) { + kind = vector::CombiningKind::ADD; + return subf->getOperand(0) == iter; + } else if (auto subi = red.getDefiningOp()) { + kind = vector::CombiningKind::ADD; + return subi->getOperand(0) == iter; + } else if (auto mulf = red.getDefiningOp()) { + kind = vector::CombiningKind::MUL; + return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter; + } else if (auto muli = red.getDefiningOp()) { + kind = vector::CombiningKind::MUL; + return muli->getOperand(0) == iter || muli->getOperand(1) == iter; + } else if (auto andi = red.getDefiningOp()) { + kind = vector::CombiningKind::AND; + return andi->getOperand(0) == iter || andi->getOperand(1) == iter; + } else if (auto ori = red.getDefiningOp()) { + kind = vector::CombiningKind::OR; + return ori->getOperand(0) == iter || ori->getOperand(1) == iter; + } else if (auto xori = red.getDefiningOp()) { + kind = vector::CombiningKind::XOR; + return xori->getOperand(0) == iter || xori->getOperand(1) == iter; + } + return false; } /// Generates an initial value for a vector reduction, following the scheme /// given in Chapter 5 of "The Software Vectorization Handbook", where the /// initial scalar value is correctly embedded in the vector reduction value, /// and a straightforward horizontal reduction will complete the operation. -/// The value 'r' denotes the initial value of the accumulator. Value 'rd' -/// denotes the accumulation operation, which is solely used here to determine -/// the kind of combining reduction (viz. addf -> sum-accumulation). +/// Value 'r' denotes the initial value of the reduction outside the loop. static Value genVectorReducInit(PatternRewriter &rewriter, Location loc, - VectorType vtp, Value r, Value rd) { - vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp()); + Value red, Value iter, Value r, + VectorType vtp) { + vector::CombiningKind kind; + if (!isVectorizableReduction(red, iter, kind)) + llvm_unreachable("unknown reduction"); switch (kind) { case vector::CombiningKind::ADD: case vector::CombiningKind::XOR: @@ -180,13 +204,6 @@ llvm_unreachable("unknown reduction kind"); } -/// Generates final value for a vector reduction. -static Value genVectorReducEnd(PatternRewriter &rewriter, Location loc, - Value vexp, Value rd) { - vector::CombiningKind kind = getCombiningKind(rd.getDefiningOp()); - return rewriter.create(loc, kind, vexp); -} - /// This method is called twice to analyze and rewrite the given subscripts. /// The first call (!codegen) does the analysis. Then, on success, the second /// call (codegen) yields the proper vector form in the output parameter @@ -379,10 +396,14 @@ if (!yield.getResults().empty()) { Value init = forOp.getInitArgs()[0]; VectorType vtp = vectorType(vl, init.getType()); - Value vinit = - genVectorReducInit(rewriter, loc, vtp, init, yield->getOperand(0)); + Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), + forOp.getRegionIterArg(0), init, vtp); forOpNew = rewriter.create( loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit); + forOpNew->setAttr( + SparseTensorLoopEmitter::getLoopEmitterLoopAttrName(), + forOp->getAttr( + SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())); rewriter.setInsertionPointToStart(forOpNew.getBody()); } else { forOp.setStep(step); @@ -395,20 +416,22 @@ // Sparse for-loops either are terminated by a non-empty yield operation // (reduction loop) or otherwise by a store operation (pararallel loop). if (!yield.getResults().empty()) { + // Analyze/vectorize reduction. if (yield->getNumOperands() != 1) return false; - Value redOp = yield->getOperand(0); - // Analyze/vectorize reduction. - // TODO: use linalg utils to verify the actual reduction? + Value red = yield->getOperand(0); + Value iter = forOp.getRegionIterArg(0); + vector::CombiningKind kind; Value vrhs; - if (vectorizeExpr(rewriter, forOp, vl, redOp, codegen, vmask, vrhs)) { + if (isVectorizableReduction(red, iter, kind) && + vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) { if (codegen) { - Value vpass = - genVectorInvariantValue(rewriter, vl, forOp.getRegionIterArg(0)); + Value partial = forOpNew.getResult(0); + Value vpass = genVectorInvariantValue(rewriter, vl, iter); Value vred = rewriter.create(loc, vmask, vrhs, vpass); rewriter.create(loc, vred); rewriter.setInsertionPointAfter(forOpNew); - Value vres = genVectorReducEnd(rewriter, loc, forOpNew.getResult(0), redOp); + Value vres = rewriter.create(loc, kind, partial); // Now do some relinking (last one is not completely type safe // but all bad ones are removed right away). This also folds away // nop broadcast operations. @@ -469,6 +492,32 @@ const VL vl; }; +/// Reduction chain cleanup. +/// v = for { } +/// s = vsum(v) v = for { } +/// u = expand(s) -> for (v) { } +/// for (u) { } +template +struct ReducChainRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(VectorOp op, + PatternRewriter &rewriter) const override { + Value inp = op.getSource(); + if (auto redOp = inp.getDefiningOp()) { + if (auto forOp = redOp.getVector().getDefiningOp()) { + if (forOp->hasAttr( + SparseTensorLoopEmitter::getLoopEmitterLoopAttrName())) { + rewriter.replaceOp(op, redOp.getVector()); + return success(); + } + } + } + return failure(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -482,4 +531,6 @@ bool enableSIMDIndex32) { patterns.add(patterns.getContext(), vectorLength, enableVLAVectorization, enableSIMDIndex32); + patterns.add, + ReducChainRewriter>(patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -0,0 +1,122 @@ +// RUN: mlir-opt %s -sparsification -cse -sparse-vectorization="vl=8" -cse | \ +// RUN: FileCheck %s + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["dense","compressed"]}> + +#trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // a (in) + affine_map<(i,j) -> (i,j)>, // b (in) + affine_map<(i,j) -> ()> // x (out) + ], + iterator_types = ["reduction", "reduction"] +} + +// +// Verifies that the SIMD reductions in the two for-loops after the +// while-loop are chained before horizontally reducing these back to scalar. +// +// CHECK-LABEL: func.func @sparse_matrix_sum( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>) -> tensor { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant dense<0.000000e+00> : vector<8xf64> +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = sparse_tensor.pointers %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.indices %[[VAL_1]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_2]] {dimension = 1 : index} : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<64x32xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> to memref +// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_0]] : memref +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_14]][] : memref +// CHECK: %[[VAL_16:.*]] = scf.for %[[VAL_17:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] iter_args(%[[VAL_18:.*]] = %[[VAL_15]]) -> (f64) { +// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_20]]] : memref +// CHECK: %[[VAL_24:.*]]:3 = scf.while (%[[VAL_25:.*]] = %[[VAL_19]], %[[VAL_26:.*]] = %[[VAL_22]], %[[VAL_27:.*]] = %[[VAL_18]]) : (index, index, f64) -> (index, index, f64) { +// CHECK: %[[VAL_28:.*]] = arith.cmpi ult, %[[VAL_25]], %[[VAL_21]] : index +// CHECK: %[[VAL_29:.*]] = arith.cmpi ult, %[[VAL_26]], %[[VAL_23]] : index +// CHECK: %[[VAL_30:.*]] = arith.andi %[[VAL_28]], %[[VAL_29]] : i1 +// CHECK: scf.condition(%[[VAL_30]]) %[[VAL_25]], %[[VAL_26]], %[[VAL_27]] : index, index, f64 +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_31:.*]]: index, %[[VAL_32:.*]]: index, %[[VAL_33:.*]]: f64): +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_35:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref +// CHECK: %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_34]] : index +// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_36]], %[[VAL_35]], %[[VAL_34]] : index +// CHECK: %[[VAL_38:.*]] = arith.cmpi eq, %[[VAL_34]], %[[VAL_37]] : index +// CHECK: %[[VAL_39:.*]] = arith.cmpi eq, %[[VAL_35]], %[[VAL_37]] : index +// CHECK: %[[VAL_40:.*]] = arith.andi %[[VAL_38]], %[[VAL_39]] : i1 +// CHECK: %[[VAL_41:.*]] = scf.if %[[VAL_40]] -> (f64) { +// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_43:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref +// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f64 +// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_33]], %[[VAL_44]] : f64 +// CHECK: scf.yield %[[VAL_45]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_46:.*]] = scf.if %[[VAL_38]] -> (f64) { +// CHECK: %[[VAL_47:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_31]]] : memref +// CHECK: %[[VAL_48:.*]] = arith.addf %[[VAL_33]], %[[VAL_47]] : f64 +// CHECK: scf.yield %[[VAL_48]] : f64 +// CHECK: } else { +// CHECK: %[[VAL_49:.*]] = scf.if %[[VAL_39]] -> (f64) { +// CHECK: %[[VAL_50:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_32]]] : memref +// CHECK: %[[VAL_51:.*]] = arith.addf %[[VAL_33]], %[[VAL_50]] : f64 +// CHECK: scf.yield %[[VAL_51]] : f64 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_33]] : f64 +// CHECK: } +// CHECK: scf.yield %[[VAL_52:.*]] : f64 +// CHECK: } +// CHECK: scf.yield %[[VAL_53:.*]] : f64 +// CHECK: } +// CHECK: %[[VAL_54:.*]] = arith.addi %[[VAL_31]], %[[VAL_7]] : index +// CHECK: %[[VAL_55:.*]] = arith.select %[[VAL_38]], %[[VAL_54]], %[[VAL_31]] : index +// CHECK: %[[VAL_56:.*]] = arith.addi %[[VAL_32]], %[[VAL_7]] : index +// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index +// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64 +// CHECK: } attributes {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64> +// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) { +// CHECK: %[[VAL_64:.*]] = affine.min #map2(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]] +// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1> +// CHECK: %[[VAL_66:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_62]]], %[[VAL_65]], %[[VAL_4]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> +// CHECK: %[[VAL_67:.*]] = arith.addf %[[VAL_63]], %[[VAL_66]] : vector<8xf64> +// CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_65]], %[[VAL_67]], %[[VAL_63]] : vector<8xi1>, vector<8xf64> +// CHECK: scf.yield %[[VAL_68]] : vector<8xf64> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_60]]#1 to %[[VAL_23]] step %[[VAL_3]] iter_args(%[[VAL_71:.*]] = %[[VAL_61]]) -> (vector<8xf64>) { +// CHECK: %[[VAL_73:.*]] = affine.min #map2(%[[VAL_23]], %[[VAL_70]]){{\[}}%[[VAL_3]]] +// CHECK: %[[VAL_74:.*]] = vector.create_mask %[[VAL_73]] : vector<8xi1> +// CHECK: %[[VAL_75:.*]] = vector.maskedload %[[VAL_13]]{{\[}}%[[VAL_70]]], %[[VAL_74]], %[[VAL_4]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> +// CHECK: %[[VAL_76:.*]] = arith.addf %[[VAL_71]], %[[VAL_75]] : vector<8xf64> +// CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_74]], %[[VAL_76]], %[[VAL_71]] : vector<8xi1>, vector<8xf64> +// CHECK: scf.yield %[[VAL_77]] : vector<8xf64> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_78:.*]] = vector.reduction , %[[VAL_69]] : vector<8xf64> into f64 +// CHECK: scf.yield %[[VAL_78]] : f64 +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: memref.store %[[VAL_80:.*]], %[[VAL_14]][] : memref +// CHECK: %[[VAL_81:.*]] = bufferization.to_tensor %[[VAL_14]] : memref +// CHECK: return %[[VAL_81]] : tensor +// CHECK: } +func.func @sparse_matrix_sum(%argx: tensor, + %arga: tensor<64x32xf64, #SparseMatrix>, + %argb: tensor<64x32xf64, #SparseMatrix>) -> tensor { + %0 = linalg.generic #trait + ins(%arga, %argb: tensor<64x32xf64, #SparseMatrix>, + tensor<64x32xf64, #SparseMatrix>) + outs(%argx: tensor) { + ^bb(%a: f64, %b: f64, %x: f64): + %m = arith.addf %a, %b : f64 + %t = arith.addf %x, %m : f64 + linalg.yield %t : f64 + } -> tensor + return %0 : tensor +}