diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -338,10 +338,12 @@ return op->emitOpError("expected at least one output operand"); if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) return failure(); - // Should have at least one output tensor per result tensor. - // Can also have outbut buffers that do not correspond to results. - if (op->getNumResults() > linalgOp.getOutputTensorOperands().size()) - return op->emitOpError("unexpected #results > #outputs"); + // Verify the number of results matches the number of output tensors. + if (op->getNumResults() != linalgOp.getOutputTensorOperands().size()) + return op->emitOpError("expected the number of results (") + << op->getNumResults() + << ") to be equal to the number of output tensors (" + << linalgOp.getOutputTensorOperands().size() << ")"; // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. @@ -394,10 +396,6 @@ "all have buffer type"); for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) { - // TODO: Enforce one output tensor per result? - if (opOperand->getOperandNumber() - linalgOp.getNumInputs() >= - linalgOp->getNumResults()) - continue; OpResult result = linalgOp.getTiedOpResult(opOperand); if (result.getType() != opOperand->get().getType()) return op->emitOpError("expected type of operand #") diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -458,10 +458,6 @@ Type fillType = op.value().getType(); if (getElementTypeOrSelf(output->get()) != fillType) return op.emitOpError("expects fill type to match view elemental type"); - if (!op.getNumResults() && !output->get().getType().isa()) { - return op.emitOpError( - "expected fill op with no result value to use memref type"); - } return success(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -640,7 +640,7 @@ func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32) { %0 = linalg.init_tensor [%arg0, %arg1] : tensor - // expected-error @+1 {{expected fill op with no result value to use memref type}} + // expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}} linalg.fill(%arg2, %0) : f32, tensor } @@ -648,7 +648,7 @@ func @illegal_fill_memref_with_return(%arg0 : memref, %arg1 : f32) -> memref { - // expected-error @+1 {{unexpected #results > #outputs}} + // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}} %0 = linalg.fill(%arg1, %arg0) : f32, memref -> memref return %0 : memref } @@ -658,7 +658,7 @@ func @illegal_fill_memref_with_tensor_return (%arg0 : memref, %arg1 : f32) -> tensor { - // expected-error @+1 {{unexpected #results > #outputs}} + // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}} %0 = linalg.fill(%arg1, %arg0) : f32, memref -> tensor return %0 : tensor }