diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -54,13 +54,16 @@ enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI }; /// Tensor expression. Represents a MLIR expression in tensor index notation. -/// For tensors and invariants, e0 denotes the tensor index. For all binary -/// operations, e0 and e1 denote the index of the children tensor expressions. +/// For tensors, e0 denotes the tensor index. For invariants, the IR value is +/// stored directly. For binary operations, e0 and e1 denote the index of the +/// children tensor expressions. struct TensorExp { - TensorExp(Kind k, unsigned x, unsigned y) : kind(k), e0(x), e1(y) {} + TensorExp(Kind k, unsigned x, unsigned y, Value v) + : kind(k), e0(x), e1(y), val(v) {} Kind kind; unsigned e0; unsigned e1; + Value val; }; /// Lattice point. Each lattice point consist of a conjunction of tensor @@ -85,11 +88,12 @@ : numTensors(t), numLoops(l), isSparse(t, std::vector(l, false)) {} /// Adds a tensor expression. Returns its index. - unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u) { + unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) { unsigned e = tensorExps.size(); - tensorExps.push_back(TensorExp(k, e0, e1)); + tensorExps.push_back(TensorExp(k, e0, e1, v)); return e; } + unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); } /// Adds an iteration lattice point. Returns its index. unsigned addLat(unsigned t, unsigned i, unsigned e) { @@ -339,7 +343,6 @@ /// building (compared to using the SSA representation everywhere). static Optional buildTensorExp(Merger &merger, linalg::GenericOp op, Value val) { - Operation *def = val.getDefiningOp(); if (auto arg = val.dyn_cast()) { unsigned argN = arg.getArgNumber(); if (arg.getOwner()->getParentOp() == op) { @@ -348,10 +351,16 @@ auto map = op.getIndexingMap(argN); if (map.isProjectedPermutation()) return merger.addExp(Kind::kTensor, argN); - } else { - // Any parameter of a higher op is invariant in the tensor expression. - return merger.addExp(Kind::kInvariant, argN); + // Cannot handle (yet). + return None; } + // Any parameter of a higher op is invariant. + return merger.addExp(Kind::kInvariant, val); + } + Operation *def = val.getDefiningOp(); + if (def->getBlock() != &op.region().front()) { + // Something defined outside is invariant. + return merger.addExp(Kind::kInvariant, val); } else if (def->getNumOperands() == 2) { // Construct binary operations if subexpressions could be built. auto x = buildTensorExp(merger, op, def->getOperand(0)); @@ -380,9 +389,12 @@ Kind kind = merger.exp(exp).kind; if (kind == Kind::kTensor || kind == Kind::kInvariant) { // Either the index is really used in the tensor expression, or it it - // set to the "non-existing dense index" in that dimension. + // set to the "non-existing dense index" in that dimension. Invariant + // expressions borrow the output tensor indices. unsigned s = merger.addSet(); - merger.set(s).push_back(merger.addLat(merger.exp(exp).e0, idx, exp)); + unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0 + : op.getNumInputsAndOutputs() - 1; + merger.set(s).push_back(merger.addLat(t, idx, exp)); return s; } unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx); @@ -502,7 +514,7 @@ if (merger.exp(exp).kind == Kind::kTensor) return genTensorLoad(merger, codegen, rewriter, op, merger.exp(exp).e0); else if (merger.exp(exp).kind == Kind::kInvariant) - return op.getParentRegion()->front().getArgument(merger.exp(exp).e0); + return merger.exp(exp).val; Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1); switch (merger.exp(exp).kind) { diff --git a/mlir/test/Dialect/Linalg/sparse_2d.mlir b/mlir/test/Dialect/Linalg/sparse_2d.mlir --- a/mlir/test/Dialect/Linalg/sparse_2d.mlir +++ b/mlir/test/Dialect/Linalg/sparse_2d.mlir @@ -1106,6 +1106,56 @@ return %0 : tensor } +#trait_scale = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (i,j)> // X (out) + ], + sparse = [ + [ "D", "S" ], // A + [ "D", "D" ] // X + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) = A(i,j) * SCALE" +} + +// CHECK-LABEL: func @scale( +// CHECK-SAME: %[[VAL_0:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_1:.*]] = constant 2.000000e+00 : f64 +// CHECK: %[[VAL_2:.*]] = constant 999 : index +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = dim %[[VAL_0]], %[[VAL_3]] : tensor +// CHECK: %[[VAL_6:.*]] = alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_7:.*]] = alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_8:.*]] = dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_9:.*]] = alloca(%[[VAL_2]]) : memref +// CHECK: %[[VAL_10:.*]] = alloca(%[[VAL_5]], %[[VAL_8]]) : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_3]] to %[[VAL_5]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = addi %[[VAL_11]], %[[VAL_4]] : index +// CHECK: %[[VAL_14:.*]] = load %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref +// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_14]] step %[[VAL_4]] { +// CHECK: %[[VAL_16:.*]] = load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_17:.*]] = load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref +// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_1]] : f64 +// CHECK: store %[[VAL_18]], %[[VAL_10]]{{\[}}%[[VAL_11]], %[[VAL_16]]] : memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_19:.*]] = tensor_load %[[VAL_10]] : memref +// CHECK: return %[[VAL_19]] : tensor +// CHECK: } +func @scale(%arga: tensor) -> tensor { + %0 = constant 2.0 : f64 + %1 = linalg.generic #trait_scale + ins(%arga: tensor) { + ^bb(%a: f64): + %2 = mulf %a, %0 : f64 + linalg.yield %2 : f64 + } -> tensor + return %1 : tensor +} + #trait_sampled_dense_dense = { indexing_maps = [ affine_map<(i,j,k) -> (i,j)>, // S