diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -22,35 +22,38 @@ return result; } +namespace { +size_t getNumTensorResults(Operation *op) { + size_t numTensorResults = 0; + for (auto t : op->getResultTypes()) { + if (isa(t)) { + ++numTensorResults; + } + } + return numTensorResults; +} +} // namespace + LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { DestinationStyleOpInterface dstStyleOp = cast(op); - SmallVector outputBufferOperands, outputTensorOperands; + SmallVector outputTensorOperands; for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) { Type type = operand->get().getType(); - if (isa(type)) { - outputBufferOperands.push_back(operand); - } else if (isa(type)) { + if (isa(type)) { outputTensorOperands.push_back(operand); - } else { + } else if (!isa(type)) { return op->emitOpError("expected that operand #") << operand->getOperandNumber() << " is a ranked tensor or a ranked memref"; } } - // Expect at least one output operand. - int64_t numInputs = dstStyleOp.getNumDpsInputs(); - int64_t numInits = dstStyleOp.getNumDpsInits(); - if (numInits == 0) - return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != outputTensorOperands.size()) - return op->emitOpError("expected the number of results (") - << op->getNumResults() + // Verify the number of tensor results matches the number of output tensors. + if (getNumTensorResults(op) != outputTensorOperands.size()) + return op->emitOpError("expected the number of tensor results (") + << getNumTensorResults(op) << ") to be equal to the number of output tensors (" << outputTensorOperands.size() << ")"; diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -326,7 +326,7 @@ func.func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32) { %0 = tensor.empty(%arg0, %arg1) : tensor - // expected-error @+1 {{expected the number of results (0) to be equal to the number of output tensors (1)}} + // expected-error @+1 {{expected the number of tensor results (0) to be equal to the number of output tensors (1)}} linalg.fill ins(%arg2 : f32) outs(%0 : tensor) } @@ -335,7 +335,7 @@ func.func @illegal_fill_memref_with_tensor_return (%arg0 : memref, %arg1 : f32) -> tensor { - // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}} + // expected-error @+1 {{expected the number of tensor results (1) to be equal to the number of output tensors (0)}} %0 = linalg.fill ins(%arg1 : f32) outs(%arg0 : memref) -> tensor return %0 : tensor } diff --git a/mlir/test/Interfaces/DestinationStyleOpInterface/verify-destination-style-op-interface.mlir b/mlir/test/Interfaces/DestinationStyleOpInterface/verify-destination-style-op-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/DestinationStyleOpInterface/verify-destination-style-op-interface.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +func.func @ins_1_index_outs_none_results_1_index(%arg0 : index) -> index { + %0 = test.destination_style_op ins(%arg0 : index) -> index + func.return %0 : index +} + +// ----- + +func.func @ins_1_index_outs_1_tensor_results_1_index(%arg0 : index, %arg1 : tensor<2x2xf32>) -> index { + // expected-error @+1 {{op expected the number of tensor results (0) to be equal to the number of output tensors (1)}} + %0 = test.destination_style_op ins(%arg0 : index) outs(%arg1 : tensor<2x2xf32>) -> index + func.return %0 : index +} + +// ----- + +func.func @ins_1_tensor_outs_none_results_1_index(%arg0 :tensor<2x2xf32>) -> index { + %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) -> index + func.return %0 : index +} + +// ----- + +func.func @ins_1_tensor_outs_1_tensor_results_1_index(%arg0 :tensor<2x2xf32>, %arg1 : tensor<2x2xf32>) -> index { + // expected-error @+1 {{op expected the number of tensor results (0) to be equal to the number of output tensors (1)}} + %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) outs(%arg1 : tensor<2x2xf32>) -> index + func.return %0 : index +} + +// ----- + +func.func @ins_1_index_outs_none_results_1_tensor(%arg0 : index) -> tensor<2x2xf32> { + // expected-error @+1 {{op expected the number of tensor results (1) to be equal to the number of output tensors (0)}} + %0 = test.destination_style_op ins(%arg0 : index) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @ins_1_index_outs_1_tensor_results_1_tensor(%arg0 : index, %arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = test.destination_style_op ins(%arg0 : index) outs(%arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @ins_1_tensor_outs_none_results_1_tensor(%arg0 :tensor<2x2xf32>) -> tensor<2x2xf32> { + // expected-error @+1 {{op expected the number of tensor results (1) to be equal to the number of output tensors (0)}} + %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @ins_1_tensor_outs_1_tensor_results_1_tensor(%arg0 :tensor<2x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = test.destination_style_op ins(%arg0 : tensor<2x2xf32>) outs(%arg1 : tensor<2x2xf32>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2908,6 +2908,34 @@ def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>; +//===----------------------------------------------------------------------===// +// Test DestinationStyleOpInterface. +//===----------------------------------------------------------------------===// + +def TestDestinationStyleOp : + TEST_Op<"destination_style_op", [ + DestinationStyleOpInterface, + AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs, + Variadic:$other_operands); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)? + (`outs` `(` $outputs^ `:` type($outputs) `)`)? + (`(` $other_operands^ `:` type($other_operands) `)`)? + (`->` type($results)^)? + }]; + + let extraClassDeclaration = [{ + std::pair getDpsInitsPositionRange() { + int64_t numOperands = this->getNumOperands(); + return {numOperands - getOutputs().size(), numOperands}; + } + }]; +} + //===----------------------------------------------------------------------===// // Test LinalgConvolutionOpInterface. //===----------------------------------------------------------------------===//