diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -458,11 +458,17 @@ Value val) { if (auto arg = val.dyn_cast()) { unsigned argN = arg.getArgNumber(); - // Any parameter of the generic op is considered a tensor, - // indexed by the implicit loop bounds. - if (arg.getOwner()->getParentOp() == op) - return merger.addExp(Kind::kTensor, argN); - // Any parameter of a higher op is invariant. + // Any argument of the generic op that is not marked as a scalar + // argument is considered a tensor, indexed by the implicit loop + // bounds. This includes rank-0 tensor arguments. + if (arg.getOwner()->getParentOp() == op) { + OpOperand *t = op.getInputAndOutputOperands()[argN]; + if (!op.isScalar(t)) + return merger.addExp(Kind::kTensor, argN); + val = t->get(); // get scalar value + } + // Any other argument (marked as scalar argument for the generic op + // or belonging to an enveloping op) is considered invariant. return merger.addExp(Kind::kInvariant, val); } Operation *def = val.getDefiningOp(); @@ -719,9 +725,7 @@ } // Actual load. SmallVector args; - OpOperand *t = merger.exp(exp).e0 < op.getNumInputs() - ? op.getInputOperand(merger.exp(exp).e0) - : op.getOutputOperand(0); + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; unsigned tensor = t->getOperandNumber(); auto map = op.getTiedIndexingMap(t); auto enc = getSparseTensorEncoding(t->get().getType()); @@ -919,11 +923,9 @@ if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand *tensor = merger.exp(exp).e0 < op.getNumInputs() - ? op.getInputOperand(merger.exp(exp).e0) - : op.getOutputOperand(0); - auto map = op.getTiedIndexingMap(tensor); - auto enc = getSparseTensorEncoding(tensor->get().getType()); + OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).e0]; + auto map = op.getTiedIndexingMap(t); + auto enc = getSparseTensorEncoding(t->get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { unsigned idx = map.getDimPosition(perm(enc, d)); if (!codegen.loops[idx]) @@ -933,7 +935,7 @@ } // All exhausted at this level (atLevel denotes exactly at this level). OpOperand *lhs = op.getOutputOperand(0); - if (lhs == tensor) { + if (lhs == t) { codegen.redExp = hoist ? exp : -1u; } else if (atLevel) { merger.exp(exp).val = @@ -1413,8 +1415,6 @@ // Detects sparse annotations and translate the per-dimension sparsity // information for all tensors to loop indices in the kernel. assert(op.getNumOutputs() == 1); - assert(llvm::none_of(op.getInputAndOutputOperands(), - [&](OpOperand *t) { return op.isScalar(t); })); unsigned numTensors = op.getNumInputsAndOutputs(); unsigned numLoops = op.iterator_types().getValue().size(); Merger merger(numTensors, numLoops); diff --git a/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_scalars.mlir @@ -0,0 +1,83 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +// A contrived example that demonstrates the many different ways +// in which scalar values can be involved in a sparse kernel +// through the linalg generic op. + +#trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A (sparse tensor) + affine_map<(i,j) -> ()>, // p (scalar tensor) + affine_map<(i,j) -> ()>, // q (true scalar) + affine_map<(i,j) -> (i,j)> // X (dense tensor out) + ], + iterator_types = ["parallel", "parallel"], + doc = "X(i,j) += A(i,j) * p * q * r * s * 2.2" +} + +// CHECK-LABEL: func @mul( +// CHECK-SAME: %[[VAL_0:.*0]]: tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor, +// CHECK-SAME: %[[VAL_2:.*2]]: f32, +// CHECK-SAME: %[[VAL_3:.*3]]: f32, +// CHECK-SAME: %[[VAL_4:.*4]]: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> { +// CHECK: %[[VAL_5:.*]] = constant 2.200000e+00 : f32 +// CHECK: %[[VAL_6:.*]] = constant 0 : index +// CHECK: %[[VAL_7:.*]] = constant 1 : index +// CHECK: %[[VAL_8:.*]] = addf %[[VAL_2]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_9:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_10:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_6]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_7]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16xf32, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_14:.*]] = memref.buffer_cast %[[VAL_1]] : memref +// CHECK: %[[VAL_15:.*]] = memref.buffer_cast %[[VAL_4]] : memref<32x16xf32> +// CHECK: %[[VAL_16:.*]] = memref.load %[[VAL_14]][] : memref +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_7]] { +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_22:.*]] = addi %[[VAL_19]], %[[VAL_7]] : index +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_7]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_26]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_27]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_28]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_29]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_30]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32> +// CHECK: %[[VAL_33:.*]] = addf %[[VAL_31]], %[[VAL_32]] : f32 +// CHECK: memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<32x16xf32> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_34:.*]] = memref.tensor_load %[[VAL_15]] : memref<32x16xf32> +// CHECK: return %[[VAL_34]] : tensor<32x16xf32> +// CHECK: } +func @mul(%arga: tensor<32x16xf32, #SparseMatrix>, + %argp: tensor, + %argq: f32, + %argr: f32, + %argx: tensor<32x16xf32> {linalg.inplaceable = true}) -> tensor<32x16xf32> { + %s = addf %argq, %argr : f32 + %c = constant 2.2 : f32 + %0 = linalg.generic #trait + ins(%arga, %argp, %argq: tensor<32x16xf32, #SparseMatrix>, tensor, f32) + outs(%argx: tensor<32x16xf32>) { + ^bb(%a: f32, %p: f32, %q: f32, %x: f32): + %0 = mulf %a, %p : f32 // scalar tensor argument + %1 = mulf %0, %q : f32 // scalar argument + %2 = mulf %1, %argr : f32 // scalar argument from outside block + %3 = mulf %2, %s : f32 // scalar value from outside block + %4 = mulf %3, %c : f32 // direct constant from outside block + %5 = addf %4, %x : f32 + linalg.yield %5 : f32 + } -> tensor<32x16xf32> + + return %0 : tensor<32x16xf32> +}