diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -419,6 +419,8 @@ // Verify consistency of sparse annotations. if (!op.hasTensorSemantics()) return op.emitOpError("expected sparse annotations on tensors only"); + if (op.getNumOutputs() != 1) + return op.emitOpError("expected single output tensor"); unsigned numTensors = op.getNumInputsAndOutputs(); if (sparseAttr.size() != numTensors) return op.emitOpError("expected one sparse annotation for each tensor"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp @@ -830,22 +830,16 @@ LogicalResult matchAndRewrite(linalg::GenericOp op, PatternRewriter &rewriter) const override { - unsigned numTensors = op.getNumInputsAndOutputs(); - unsigned numLoops = op.iterator_types().getValue().size(); - Merger merger(numTensors, numLoops); - // Detects sparse annotations and translate the per-dimension sparsity // information for all tensors to loop indices in the kernel. if (!op.hasSparseSemantics()) return failure(); + assert(op.getNumOutputs() == 1); + unsigned numTensors = op.getNumInputsAndOutputs(); + unsigned numLoops = op.iterator_types().getValue().size(); + Merger merger(numTensors, numLoops); findSparseAnnotations(op, merger.sparse()); - // Accept only single, dense result. - if (op.getNumOutputs() != 1 || - std::any_of(merger.sparse().back().begin(), - merger.sparse().back().end(), [](bool b) { return b; })) - return failure(); - // Computes a topologically sorted iteration graph to ensure // tensors are visited in natural index order. Fails on cycles. // This assumes that higher-level passes have already put the @@ -858,10 +852,7 @@ // Finds the terminating yield statement and builds the tensor // expression for the Linalg operation in SSA form. - auto ®ion = op.region(); - if (!llvm::hasSingleElement(region)) - return failure(); // single block only - Operation *yield = region.front().getTerminator(); + Operation *yield = op.region().front().getTerminator(); Optional exp = buildTensorExp(merger, op, yield->getOperand(0)); if (!exp.hasValue()) return failure(); // build failure diff --git a/mlir/test/Dialect/Linalg/sparse_invalid.mlir b/mlir/test/Dialect/Linalg/sparse_invalid.mlir --- a/mlir/test/Dialect/Linalg/sparse_invalid.mlir +++ b/mlir/test/Dialect/Linalg/sparse_invalid.mlir @@ -13,7 +13,7 @@ } func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> { - // expected-error@+1 {{'linalg.generic' op expected sparse annotations on tensors only}} + // expected-error@+1 {{'linalg.generic' op expected sparse annotations on tensors only}} %0 = linalg.generic #trait_memref ins(%arga: memref<32xf32>) { ^bb(%a: f32): @@ -25,6 +25,79 @@ // ----- +#trait_two_out = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)>, // x (out) + affine_map<(i) -> (i)> // y (out) + ], + sparse = [ + [ "S" ], // a + [ "D" ], // x + [ "D" ] // y + ], + iterator_types = ["parallel"] +} + +func @invalid_two_out(%arga: tensor<32xf32>) -> tensor<32xf32> { + // expected-error@+1 {{'linalg.generic' op expected single output tensor}} + %0, %1 = linalg.generic #trait_two_out + ins(%arga: tensor<32xf32>) { + ^bb(%a: f32): + %0 = addf %a, %a : f32 + linalg.yield %a, %0 : f32, f32 + } -> tensor<32xf32>, tensor<32xf32> + return %1 : tensor<32xf32> +} + +// ----- + +#trait_two_blocks = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)> // x (out) + ], + sparse = [ + [ "S" ], // a + [ "D" ] // x + ], + iterator_types = ["parallel"] +} + +func @invalid_two_blocks(%arga: tensor<32xf32>) -> tensor<32xf32> { + // expected-error@+1 {{'linalg.generic' op expects region #0 to have 0 or 1 blocks}} + %0 = linalg.generic #trait_two_blocks + ins(%arga: tensor<32xf32>) { + ^bb1(%a: f32): + %0 = addf %a, %a : f32 + ^bb2: + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + +// ----- + +#trait_no_block = { + indexing_maps = [ + affine_map<(i) -> (i)> // a + ], + sparse = [ + [ "S" ] // a + ], + iterator_types = ["parallel"] +} + +func @invalid_no_block(%arga: tensor<32xf32>) { + // expected-error@+1 {{'linalg.generic' op expected region with 1 block}} + linalg.generic #trait_no_block + ins(%arga: tensor<32xf32>) { + } + return +} + +// ----- + #trait_too_many = { indexing_maps = [ affine_map<(i) -> (i)>, // a @@ -39,7 +112,7 @@ } func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> { - // expected-error@+1 {{'linalg.generic' op expected one sparse annotation for each tensor}} + // expected-error@+1 {{'linalg.generic' op expected one sparse annotation for each tensor}} %0 = linalg.generic #trait_too_many ins(%arga: tensor<32xf32>) { ^bb(%a: f32): @@ -61,7 +134,7 @@ } func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> { - // expected-error@+1 {{'linalg.generic' op expected sparse annotation array for tensor 0}} + // expected-error@+1 {{'linalg.generic' op expected sparse annotation array for tensor 0}} %0 = linalg.generic #trait_no_array ins(%arga: tensor<32xf32>) { ^bb(%a: f32): @@ -86,7 +159,7 @@ } func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> { - // expected-error@+1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}} + // expected-error@+1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}} %0 = linalg.generic #trait_wrong_rank ins(%arga: tensor<32xf32>) { ^bb(%a: f32): @@ -111,7 +184,7 @@ } func @invalid_no_string(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16xf32> { - // expected-error@+1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}} + // expected-error@+1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}} %0 = linalg.generic #trait_no_string ins(%arga: tensor<32x16xf32>) { ^bb(%a: f32):