diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -59,14 +59,21 @@ /// children tensor expressions. struct TensorExp { TensorExp(Kind k, unsigned x, unsigned y, Value v) - : kind(k), e0(x), e1(y), val(v) {} + : kind(k), e0(x), e1(y), val(v) { + assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) || + (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) || + (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val)); + } Kind kind; + /// Indices of children expression(s). unsigned e0; unsigned e1; + /// Direct link to IR for an invariant. During code generation, + /// field is used to cache "hoisted" loop invariant tensor loads. Value val; }; -/// Lattice point. Each lattice point consist of a conjunction of tensor +/// Lattice point. Each lattice point consists of a conjunction of tensor /// loop indices (encoded in a bitvector) and the index of the corresponding /// tensor expression. struct LatPoint { @@ -74,7 +81,9 @@ bits.set(b); } LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {} + /// Conjunction of tensor loop indices as bitvector. llvm::BitVector bits; + /// Index of the tensor expresssion. unsigned exp; }; @@ -502,8 +511,16 @@ /// Generates a load on a dense or sparse tensor. static Value genTensorLoad(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - unsigned tensor) { + unsigned exp) { + // Test if the load was hoisted to a higher loop nest. + Value val = merger.exp(exp).val; + if (val) { + merger.exp(exp).val = Value(); // reset + return val; + } + // Actual load. SmallVector args; + unsigned tensor = merger.exp(exp).e0; auto map = op.getIndexingMap(tensor); bool sparse = false; for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { @@ -515,7 +532,9 @@ args.push_back(codegen.pidxs[tensor][idx]); // position index } } - return rewriter.create(op.getLoc(), codegen.buffers[tensor], args); + Location loc = op.getLoc(); + Value ptr = codegen.buffers[tensor]; + return rewriter.create(loc, ptr, args); } /// Generates a store on a dense tensor. @@ -528,25 +547,33 @@ unsigned idx = map.getDimPosition(i); args.push_back(codegen.loops[idx]); // universal dense index } - rewriter.create(op.getLoc(), rhs, codegen.buffers[tensor], args); + Location loc = op.getLoc(); + Value ptr = codegen.buffers[tensor]; + rewriter.create(loc, rhs, ptr, args); } /// Generates a pointer/index load from the sparse storage scheme. -static Value genIntLoad(PatternRewriter &rewriter, Location loc, Value ptr, - Value s) { +static Value genLoad(PatternRewriter &rewriter, Location loc, Value ptr, + Value s) { Value load = rewriter.create(loc, ptr, s); return load.getType().isa() ? load : rewriter.create(loc, load, rewriter.getIndexType()); } +/// Generates an invariant value. +static Value genInvariantValue(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, unsigned exp) { + return merger.exp(exp).val; +} + /// Recursively generates tensor expression. static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, unsigned exp) { if (merger.exp(exp).kind == Kind::kTensor) - return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0); + return genTensorLoad(merger, codegen, rewriter, op, exp); else if (merger.exp(exp).kind == Kind::kInvariant) - return merger.exp(exp).val; + return genInvariantValue(merger, codegen, rewriter, exp); Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); switch (merger.exp(exp).kind) { @@ -564,6 +591,33 @@ } } +/// Hoists loop invariant tensor loads for which indices have been exhausted. +static void genInvariants(Merger &merger, CodeGen &codegen, + PatternRewriter &rewriter, linalg::GenericOp op, + unsigned exp) { + if (merger.exp(exp).kind == Kind::kTensor) { + unsigned lhs = op.getNumInputsAndOutputs() - 1; + unsigned tensor = merger.exp(exp).e0; + if (tensor == lhs) + return; // TODO: scalarize reduction as well (using scf.yield) + auto map = op.getIndexingMap(tensor); + for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) { + unsigned idx = map.getDimPosition(i); + if (!codegen.loops[idx]) + return; // still in play + } + // All exhausted at this level. + merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp); + + } else if (merger.exp(exp).kind != Kind::kInvariant) { + // Traverse into the binary operations. Note that we only hoist + // tensor loads, since subsequent MLIR/LLVM passes know how to + // deal with all other kinds of derived loop invariants. + genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0); + genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1); + } +} + /// Generates initialization code for the subsequent loop sequence at /// current index level. Returns true if the loop sequence needs to /// maintain the universal index. @@ -590,9 +644,9 @@ Value one = rewriter.create(loc, 1); Value p0 = (pat == 0) ? rewriter.create(loc, 0) : codegen.pidxs[tensor][topSort[pat - 1]]; - codegen.pidxs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p0); + codegen.pidxs[tensor][idx] = genLoad(rewriter, loc, ptr, p0); Value p1 = rewriter.create(loc, p0, one); - codegen.highs[tensor][idx] = genIntLoad(rewriter, loc, ptr, p1); + codegen.highs[tensor][idx] = genLoad(rewriter, loc, ptr, p1); } else { // Dense index still in play. needsUniv = true; @@ -608,7 +662,8 @@ /// Generates a for-loop on a single index. static Operation *genFor(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - bool isOuter, unsigned idx, llvm::BitVector &indices) { + bool isOuter, bool isInner, unsigned idx, + llvm::BitVector &indices) { unsigned fb = indices.find_first(); unsigned tensor = merger.tensor(fb); assert(idx == merger.index(fb)); @@ -725,10 +780,15 @@ /// singleton iteration or co-iteration over the given conjunction. static Operation *genLoop(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, - bool isOuter, unsigned idx, bool needsUniv, - llvm::BitVector &indices) { - if (indices.count() == 1) - return genFor(merger, codegen, rewriter, op, isOuter, idx, indices); + std::vector &topSort, unsigned at, + bool needsUniv, llvm::BitVector &indices) { + unsigned idx = topSort[at]; + if (indices.count() == 1) { + bool isOuter = at == 0; + bool isInner = at == topSort.size() - 1; + return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx, + indices); + } return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices); } @@ -749,7 +809,7 @@ assert(idx == merger.index(b)); Value ptr = codegen.indices[tensor][idx]; Value s = codegen.pidxs[tensor][idx]; - Value load = genIntLoad(rewriter, loc, ptr, s); + Value load = genLoad(rewriter, loc, ptr, s); codegen.idxs[tensor][idx] = load; if (!needsUniv) { if (min) { @@ -886,6 +946,7 @@ assert(lsize != 0); unsigned l0 = merger.set(lts)[0]; LatPoint lat0 = merger.lat(l0); + genInvariants(merger, codegen, rewriter, op, exp); bool needsUniv = genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) && lsize > 1; @@ -897,9 +958,8 @@ // Emit loop. llvm::BitVector indices = lati.bits; optimizeIndices(merger, lsize, indices); - bool isOuter = at == 0; - Operation *loop = genLoop(merger, codegen, rewriter, op, isOuter, idx, - needsUniv, indices); + Operation *loop = + genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices); genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits); // Visit all lattices points with Li >= Lj to generate the @@ -931,6 +991,7 @@ } rewriter.setInsertionPointAfter(loop); } + codegen.loops[idx] = Value(); } namespace { diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir --- a/mlir/test/Dialect/Linalg/sparse_2d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir @@ -1071,8 +1071,8 @@ } // CHECK-LABEL: func @sum_reduction( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<10x20xf32>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor) -> tensor { // CHECK: %[[VAL_2:.*]] = constant 999 : index // CHECK: %[[VAL_3:.*]] = constant 10 : index // CHECK: %[[VAL_4:.*]] = constant 0 : index @@ -1200,19 +1200,19 @@ // CHECK: scf.for %[[VAL_23:.*]] = %[[VAL_21]] to %[[VAL_22]] step %[[VAL_6]] { // CHECK: %[[VAL_24:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_23]]] : memref // CHECK: scf.for %[[VAL_25:.*]] = %[[VAL_5]] to %[[VAL_15]] step %[[VAL_6]] { -// CHECK: %[[VAL_26:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref -// CHECK: %[[VAL_27:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index -// CHECK: %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_27]]] : memref -// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_26]] to %[[VAL_28]] step %[[VAL_6]] { -// CHECK: %[[VAL_30:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_29]]] : memref -// CHECK: %[[VAL_31:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref -// CHECK: %[[VAL_32:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref -// CHECK: %[[VAL_33:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref -// CHECK: %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_30]]] : memref -// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_33]], %[[VAL_34]] : f32 -// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_32]], %[[VAL_35]] : f32 -// CHECK: %[[VAL_37:.*]] = addf %[[VAL_31]], %[[VAL_36]] : f32 -// CHECK: store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_30]]] : memref +// CHECK: %[[VAL_26:.*]] = load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_25]]] : memref +// CHECK: %[[VAL_27:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_28:.*]] = addi %[[VAL_23]], %[[VAL_6]] : index +// CHECK: %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_28]]] : memref +// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_27]] to %[[VAL_29]] step %[[VAL_6]] { +// CHECK: %[[VAL_31:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref +// CHECK: %[[VAL_32:.*]] = load %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref +// CHECK: %[[VAL_33:.*]] = load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref +// CHECK: %[[VAL_34:.*]] = load %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_31]]] : memref +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_26]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_33]], %[[VAL_35]] : f32 +// CHECK: %[[VAL_37:.*]] = addf %[[VAL_32]], %[[VAL_36]] : f32 +// CHECK: store %[[VAL_37]], %[[VAL_20]]{{\[}}%[[VAL_24]], %[[VAL_31]]] : memref // CHECK: } // CHECK: } // CHECK: } diff --git a/mlir/test/Dialect/Linalg/sparse_3d.mlir b/mlir/test/Dialect/Linalg/sparse_3d.mlir --- a/mlir/test/Dialect/Linalg/sparse_3d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_3d.mlir @@ -1192,15 +1192,15 @@ // CHECK: %[[VAL_25:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref // CHECK: scf.for %[[VAL_26:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_6]] { // CHECK: %[[VAL_27:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_26]]] : memref -// CHECK: scf.for %[[VAL_28:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] { -// CHECK: %[[VAL_29:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref -// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_28]]] : memref -// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_29]], %[[VAL_30]] : f32 -// CHECK: %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_28]]] : memref +// CHECK: %[[VAL_28:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_26]]] : memref +// CHECK: scf.for %[[VAL_29:.*]] = %[[VAL_5]] to %[[VAL_17]] step %[[VAL_6]] { +// CHECK: %[[VAL_30:.*]] = load %[[VAL_12]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref +// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_28]], %[[VAL_30]] : f32 +// CHECK: %[[VAL_32:.*]] = load %[[VAL_15]]{{\[}}%[[VAL_27]], %[[VAL_29]]] : memref // CHECK: %[[VAL_33:.*]] = mulf %[[VAL_31]], %[[VAL_32]] : f32 -// CHECK: %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref +// CHECK: %[[VAL_34:.*]] = load %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref // CHECK: %[[VAL_35:.*]] = addf %[[VAL_33]], %[[VAL_34]] : f32 -// CHECK: store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : memref +// CHECK: store %[[VAL_35]], %[[VAL_18]]{{\[}}%[[VAL_19]], %[[VAL_29]]] : memref // CHECK: } // CHECK: } // CHECK: } @@ -1281,3 +1281,61 @@ } -> tensor return %0 : tensor } + +#trait_invariants = { + indexing_maps = [ + affine_map<(i,j,k) -> (i)>, // a + affine_map<(i,j,k) -> (j)>, // b + affine_map<(i,j,k) -> (k)>, // c + affine_map<(i,j,k) -> (i,j,k)> // x + ], + sparse = [ + [ "D" ], // a + [ "D" ], // b + [ "D" ], // c + [ "D", "D", "D" ] // x + ], + iterator_types = ["parallel", "parallel", "parallel"], + doc = "x(i,j,k) = a(i) * b(j) * c(k)" +} + +// CHECK-LABEL: func @invariants( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<20xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<30xf32>) -> tensor<10x20x30xf32> { +// CHECK: %[[VAL_3:.*]] = constant 10 : index +// CHECK: %[[VAL_4:.*]] = constant 20 : index +// CHECK: %[[VAL_5:.*]] = constant 30 : index +// CHECK: %[[VAL_6:.*]] = constant 0 : index +// CHECK: %[[VAL_7:.*]] = constant 1 : index +// CHECK: %[[VAL_8:.*]] = alloca() : memref<10xf32> +// CHECK: %[[VAL_9:.*]] = alloca() : memref<20xf32> +// CHECK: %[[VAL_10:.*]] = alloca() : memref<30xf32> +// CHECK: %[[VAL_11:.*]] = alloca() : memref<10x20x30xf32> +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: %[[VAL_13:.*]] = load %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<10xf32> +// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: %[[VAL_15:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<20xf32> +// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] { +// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_13]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_18:.*]] = load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<30xf32> +// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: store %[[VAL_19]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_14]], %[[VAL_16]]] : memref<10x20x30xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_20:.*]] = tensor_load %[[VAL_11]] : memref<10x20x30xf32> +// CHECK: return %[[VAL_20]] : tensor<10x20x30xf32> +// CHECK: } +func @invariants(%arga: tensor<10xf32>, + %argb: tensor<20xf32>, + %argc: tensor<30xf32>) -> tensor<10x20x30xf32> { + %0 = linalg.generic #trait_invariants + ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>) { + ^bb(%a : f32, %b : f32, %c : f32): + %0 = mulf %a, %b : f32 + %1 = mulf %0, %c : f32 + linalg.yield %1: f32 + } -> tensor<10x20x30xf32> + return %0 : tensor<10x20x30xf32> +}