diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -22,33 +22,36 @@ return result; } +namespace { +size_t getNumTensorResults(Operation *op) { + size_t numTensorResults = 0; + for (auto t : op->getResultTypes()) { + if (t.isa()) { + ++numTensorResults; + } + } + return numTensorResults; +} +} // namespace + LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { DestinationStyleOpInterface dstStyleOp = cast(op); - SmallVector outputBufferOperands, outputTensorOperands; + SmallVector outputTensorOperands; for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) { Type type = operand->get().getType(); - if (type.isa()) { - outputBufferOperands.push_back(operand); - } else if (type.isa()) { + if (type.isa()) { outputTensorOperands.push_back(operand); - } else { + } else if (!type.isa()) { return op->emitOpError("expected that operand #") << operand->getOperandNumber() << " is a ranked tensor or a ranked memref"; } } - // Expect at least one output operand. - int64_t numInputs = dstStyleOp.getNumDpsInputs(); - int64_t numInits = dstStyleOp.getNumDpsInits(); - if (numInits == 0) - return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != outputTensorOperands.size()) + // Verify the number of tensor results matches the number of output tensors. + if (getNumTensorResults(op) != outputTensorOperands.size()) return op->emitOpError("expected the number of results (") << op->getNumResults() << ") to be equal to the number of output tensors ("