diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -20,6 +20,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -26,11 +27,6 @@ namespace linalg { class LinalgOp; -/// OpOperand vector that implicitly converts to a Value vector. -struct OpOperandVector : public SmallVector { - operator SmallVector(); -}; - namespace detail { /// Implementation of the method that that check if given operands /// can be dropped, i.e. the remaining operands can compute the loop @@ -57,9 +53,6 @@ /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); -/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface -LogicalResult verifyDestinationStyleOpInterface(Operation *op); - } // namespace detail } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -879,291 +879,4 @@ let verifyWithRegions = 1; } -// Ops that are in destination style have designated output operands, which act -// as initial tensor values for the results of the operation or the output -// buffers to which the results of the op will be written. -// -// Output operands must be tensors or memrefs. Input operands can have any -// type. All non-output operands are inputs. - -// It is assumed that the output operands of the op are the operands at -// position [start, end). The positions are defined by getOutputsPositionRange -// method. All non-output operands are "inputs" of the DPS op. - -// If the op has "tensor semantics", then the input operands are either scalars -// or tensors. The output operands are tensors and every tensor output is tied -// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output -// tensor is tied to the i-th OpResult. The op may not have any additional -// OpResults. Output operands and their tied OpResults have the same type. -// -// If the op has "buffer semantics", then the input operands are either memrefs -// or other non-tensor types, e.g. scalar types. Furthermore, the output -// operands are memrefs and the op has no results. -// -// Destination-passing style abstraction makes certain transformations easier. -// For example, tiling implementation can extract/insert slices from/into the -// destination of an op and use the resulting shaped value as an iter_arg in -// the surrounding loop structure. As another example, bufferization does not -// have to allocate new buffers for destinations (in case of in-place -// bufferization) and can directly reuse the existing destination buffer. -// -// Example of a destination style op: `%r = tensor.insert_slice %t into %d`, -// where `%t` is the single input and `%d` is the single output. `%d` is tied -// to `%r`. -// -// Example of an op that is not in destination style: `%r = tensor.pad %t`. -// This op is not in destination style because `%r` and `%t` have different -// shape. -// -// Each op that wants to implement DestinationStyleOpInterface needs to define -// the getOutputsPositionRange() method. -def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { - let cppNamespace = "::mlir::linalg"; - let methods = [ - // This method has to be defined for every DPS op. - InterfaceMethod< - /*desc=*/"Return start and end indices of the output operands range.", - /*retTy=*/"std::pair", - /*methodName=*/"getOutputsPositionRange", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/"" - >, - //===------------------------------------------------------------------===// - // Operands handling. - //===------------------------------------------------------------------===// - // The operand list is assumed to start with the input operands and end - // with the output operands. Therefore, all methods to access the inputs - // and outputs can be expressed if the number of output operands is know. - InterfaceMethod< - /*desc=*/"Return the number of outputs.", - /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - return end - start; - }] - >, - InterfaceMethod< - /*desc=*/"Return the output operands.", - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - - OpOperandVector result; - result.reserve(end - start); - for (int i = start; i < end; ++i) - result.push_back(&$_op->getOpOperand(i)); - return result; - }] - >, - InterfaceMethod< - /*desc=*/"Return the `i`-th output operand.", - /*retTy=*/"OpOperand*", - /*methodName=*/"getOutputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < $_op.getNumOutputs()); - auto [start, end] = $_op.getOutputsPositionRange(); - return &$_op->getOpOperand(start + i); - }] - >, - InterfaceMethod< - /*desc=*/"Set the `i`-th output operand.", - /*retTy=*/"void", - /*methodName=*/"setOutputOperand", - /*args=*/(ins "int64_t":$i, "Value":$value), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < $_op.getNumOutputs()); - auto [start, end] = $_op.getOutputsPositionRange(); - $_op->setOperand(start + i, value); - }] - >, - InterfaceMethod< - /*desc=*/"Return the number of inputs.", - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getNumOperands() - $_op.getNumOutputs(); - }] - >, - InterfaceMethod< - /*desc=*/"Return the input operands.", - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - int64_t numOutputs = end - start; - int64_t numOperands = $_op.getNumOperands(); - - OpOperandVector result; - result.reserve(numOperands - numOutputs); - for (int i = 0; i < start; ++i) - result.push_back(&$_op->getOpOperand(i)); - for (int i = end; i < numOperands; ++i) - result.push_back(&$_op->getOpOperand(end + i)); - - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ Return the `i`-th input operand. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - auto [start, end] = $_op.getOutputsPositionRange(); - return &$_op->getOpOperand(i < start ? i : i + end - start) ; - }] - >, - //===------------------------------------------------------------------===// - // Input and Output arguments handling. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/"Return true if `opOperand` is an input.", - /*retTy=*/"bool", - /*methodName=*/"isInput", - /*args=*/(ins "OpOperand *":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - auto operandNumber = opOperand->getOperandNumber(); - return operandNumber < start || operandNumber >= end; - }] - >, - InterfaceMethod< - /*desc=*/"Return true if `opOperand` is an output.", - /*retTy=*/"bool", - /*methodName=*/"isOutput", - /*args=*/(ins "OpOperand *":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - auto operandNumber = opOperand->getOperandNumber(); - return operandNumber >= start && operandNumber < end; - }] - >, - InterfaceMethod< - /*desc=*/"Return true if the `opOperand` is a scalar value.", - /*retTy=*/"bool", - /*methodName=*/"isScalar", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return !opOperand->get().getType().template isa(); - }] - >, - InterfaceMethod< - /*desc=*/"Return the result tied to `opOperand`.", - /*retTy=*/"OpResult", - /*methodName=*/"getTiedOpResult", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - - auto [start, end] = $_op.getOutputsPositionRange(); - int64_t resultIndex = opOperand->getOperandNumber() - start; - assert(resultIndex >= 0 && - resultIndex < $_op->getNumResults() ); - return $_op->getResult(resultIndex); - }] - >, - //===------------------------------------------------------------------===// - // Other interface methods. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/"Return whether the op has only MemRef input and outputs.", - /*retTy=*/"bool", - /*methodName=*/"hasBufferSemantics", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op->getNumResults() == 0 && - llvm::all_of($_op->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); - }); - }] - >, - InterfaceMethod< - /*desc=*/"Return whether the op has only RankedTensor input and outputs.", - /*retTy=*/"bool", - /*methodName=*/"hasTensorSemantics", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::all_of($_op->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); - }); - }] - >, - //===------------------------------------------------------------------===// - // Other static interface methods. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location and operands. This - is used to abstract away the optional underlying region creation. This - does not change the balance between input, output_buffer and - init_tensors operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"clone", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands), - [{ - BlockAndValueMapping bvm; - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (Region &r : $_op->getRegions()) - r.cloneInto(state.addRegion(), bvm); - return b.create(state); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location, operands - and BlockAndValueMapping but leave the regions empty. This is - used to abstract away the optional underlying region creation. - This does not change the balance between input, output_buffer - and init_tensors operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"cloneWithoutRegions", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands), - [{ - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt) - state.addRegion(); - return b.create(state); - }] - > - ]; - - let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }]; - let verifyWithRegions = 1; -} - #endif // LINALG_IR_LINALGINTERFACES diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -279,7 +280,7 @@ int64_t getNumOperands = this->getNumOperands(); return {getNumOperands - 1, getNumOperands}; } - linalg::OpOperandVector getOpOperandsMatchingBBargs() { + OpOperandVector getOpOperandsMatchingBBargs() { return getInputOperands(); } diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) +add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h @@ -0,0 +1,34 @@ +//===- DestinationStyleOpInterface.h ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_ +#define MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_ + +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +/// OpOperand vector that implicitly converts to a Value vector. +struct OpOperandVector : public llvm::SmallVector { + operator SmallVector(); +}; + +namespace detail { +/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface +LogicalResult verifyDestinationStyleOpInterface(Operation *op); +} // namespace detail +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/DestinationStyleOpInterface.h.inc" + +#endif // MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -0,0 +1,306 @@ +//===- DestinationStyleOpInterface.td ----------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DESTINATIONSTYLEOPINTERFACE +#define MLIR_DESTINATIONSTYLEOPINTERFACE + +include "mlir/IR/OpBase.td" + +def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { + let description = [{ + Ops that are in destination style have designated output operands, which act + as initial tensor values for the results of the operation or the output + buffers to which the results of the op will be written. + + Output operands must be tensors or memrefs. Input operands can have any + type. All non-output operands are inputs. + + It is assumed that the output operands of the op are the operands at + position [start, end). The positions are defined by getOutputsPositionRange + method. All non-output operands are "inputs" of the DPS op. + + If the op has "tensor semantics", then the input operands are either scalars + or tensors. The output operands are tensors and every tensor output is tied + to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output + tensor is tied to the i-th OpResult. The op may not have any additional + OpResults. Output operands and their tied OpResults have the same type. + + If the op has "buffer semantics", then the input operands are either memrefs + or other non-tensor types, e.g. scalar types. Furthermore, the output + operands are memrefs and the op has no results. + + Destination-passing style abstraction makes certain transformations easier. + For example, tiling implementation can extract/insert slices from/into the + destination of an op and use the resulting shaped value as an iter_arg in + the surrounding loop structure. As another example, bufferization does not + have to allocate new buffers for destinations (in case of in-place + bufferization) and can directly reuse the existing destination buffer. + + Example of a destination style op: `%r = tensor.insert_slice %t into %d`, + where `%t` is the single input and `%d` is the single output. `%d` is tied + to `%r`. + + Example of an op that is not in destination style: `%r = tensor.pad %t`. + This op is not in destination style because `%r` and `%t` have different + shape. + + Each op that wants to implement DestinationStyleOpInterface needs to define + the getOutputsPositionRange() method. + }]; + + let cppNamespace = "::mlir"; + + let methods = [ + // This method has to be defined for every DPS op. + InterfaceMethod< + /*desc=*/"Return start and end indices of the output operands range.", + /*retTy=*/"std::pair", + /*methodName=*/"getOutputsPositionRange", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + //===------------------------------------------------------------------===// + // Operands handling. + //===------------------------------------------------------------------===// + // The operand list is assumed to start with the input operands and end + // with the output operands. Therefore, all methods to access the inputs + // and outputs can be expressed if the number of output operands is know. + InterfaceMethod< + /*desc=*/"Return the number of outputs.", + /*retTy=*/"int64_t", + /*methodName=*/"getNumOutputs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + return end - start; + }] + >, + InterfaceMethod< + /*desc=*/"Return the output operands.", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + + OpOperandVector result; + result.reserve(end - start); + for (int i = start; i < end; ++i) + result.push_back(&$_op->getOpOperand(i)); + return result; + }] + >, + InterfaceMethod< + /*desc=*/"Return the `i`-th output operand.", + /*retTy=*/"OpOperand *", + /*methodName=*/"getOutputOperand", + /*args=*/(ins "int64_t":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i >= 0 && i < $_op.getNumOutputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + return &$_op->getOpOperand(start + i); + }] + >, + InterfaceMethod< + /*desc=*/"Set the `i`-th output operand.", + /*retTy=*/"void", + /*methodName=*/"setOutputOperand", + /*args=*/(ins "int64_t":$i, "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i >= 0 && i < $_op.getNumOutputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + $_op->setOperand(start + i, value); + }] + >, + InterfaceMethod< + /*desc=*/"Return the number of inputs.", + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getNumOperands() - $_op.getNumOutputs(); + }] + >, + InterfaceMethod< + /*desc=*/"Return the input operands.", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + int64_t numOutputs = end - start; + int64_t numOperands = $_op.getNumOperands(); + + OpOperandVector result; + result.reserve(numOperands - numOutputs); + for (int i = 0; i < start; ++i) + result.push_back(&$_op->getOpOperand(i)); + for (int i = end; i < numOperands; ++i) + result.push_back(&$_op->getOpOperand(end + i)); + + return result; + }] + >, + InterfaceMethod< + /*desc=*/[{ Return the `i`-th input operand. }], + /*retTy=*/"OpOperand *", + /*methodName=*/"getInputOperand", + /*args=*/(ins "int64_t":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i >= 0 && i < getNumInputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + return &$_op->getOpOperand(i < start ? i : i + end - start) ; + }] + >, + //===------------------------------------------------------------------===// + // Input and Output arguments handling. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/"Return true if `opOperand` is an input.", + /*retTy=*/"bool", + /*methodName=*/"isInput", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + auto operandNumber = opOperand->getOperandNumber(); + return operandNumber < start || operandNumber >= end; + }] + >, + InterfaceMethod< + /*desc=*/"Return true if `opOperand` is an output.", + /*retTy=*/"bool", + /*methodName=*/"isOutput", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + auto operandNumber = opOperand->getOperandNumber(); + return operandNumber >= start && operandNumber < end; + }] + >, + InterfaceMethod< + /*desc=*/"Return true if the `opOperand` is a scalar value.", + /*retTy=*/"bool", + /*methodName=*/"isScalar", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + return !opOperand->get().getType().template isa(); + }] + >, + InterfaceMethod< + /*desc=*/"Return the result tied to `opOperand`.", + /*retTy=*/"OpResult", + /*methodName=*/"getTiedOpResult", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + + auto [start, end] = $_op.getOutputsPositionRange(); + int64_t resultIndex = opOperand->getOperandNumber() - start; + assert(resultIndex >= 0 && + resultIndex < $_op->getNumResults() ); + return $_op->getResult(resultIndex); + }] + >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/"Return whether the op has only MemRef input and outputs.", + /*retTy=*/"bool", + /*methodName=*/"hasBufferSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op->getNumResults() == 0 && + llvm::all_of($_op->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); + }] + >, + InterfaceMethod< + /*desc=*/"Return whether the op has only RankedTensor input and outputs.", + /*retTy=*/"bool", + /*methodName=*/"hasTensorSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::all_of($_op->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); + }] + >, + //===------------------------------------------------------------------===// + // Other static interface methods. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Clone the current operation with the given location and operands. This + is used to abstract away the optional underlying region creation. This + does not change the balance between input, output_buffer and + init_tensors operands. + }], + /*retTy=*/"Operation *", + /*methodName=*/"clone", + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ + BlockAndValueMapping bvm; + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (Region &r : $_op->getRegions()) + r.cloneInto(state.addRegion(), bvm); + return b.create(state); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Clone the current operation with the given location, operands + and BlockAndValueMapping but leave the regions empty. This is + used to abstract away the optional underlying region creation. + This does not change the balance between input, output_buffer + and init_tensors operands. + }], + /*retTy=*/"Operation *", + /*methodName=*/"cloneWithoutRegions", + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt) + state.addRegion(); + return b.create(state); + }] + > + ]; + + let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }]; + let verifyWithRegions = 1; +} + + +#endif // MLIR_DESTINATIONSTYLEOPINTERFACE diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -18,6 +18,7 @@ MLIRArithDialect MLIRArithUtils MLIRBufferizationDialect + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -462,14 +462,6 @@ // StructuredOpInterface implementation //===----------------------------------------------------------------------===// -OpOperandVector::operator SmallVector() { - SmallVector result; - result.reserve(this->size()); - llvm::transform(*this, std::back_inserter(result), - [](OpOperand *opOperand) { return opOperand->get(); }); - return result; -} - /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, @@ -770,55 +762,3 @@ return success(); } - -LogicalResult -mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) { - DestinationStyleOpInterface dstStyleOp = - cast(op); - - SmallVector outputBufferOperands, outputTensorOperands; - for (OpOperand *operand : dstStyleOp.getOutputOperands()) { - Type type = operand->get().getType(); - if (type.isa()) - outputBufferOperands.push_back(operand); - if (type.isa()) - outputTensorOperands.push_back(operand); - } - - // Expect at least one output operand. - // This means an op that constructs a tensor out of indices cannot be a - // LinalgOp at the moment. For now this will have to be a special op until we - // have output shape operands that are not tensors. - int64_t numInputs = dstStyleOp.getNumInputs(); - int64_t numOutputs = dstStyleOp.getNumOutputs(); - if (numOutputs == 0) - return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != outputTensorOperands.size()) - return op->emitOpError("expected the number of results (") - << op->getNumResults() - << ") to be equal to the number of output tensors (" - << outputTensorOperands.size() << ")"; - - // Simplifying assumption: either full tensor or full buffer mode. - // This allows simpler verification of output operands vs result types - // without premature tracking of which operand is what in mixed-mode. - // TODO: relax when mixed-mode needs to pass verification. - if (!outputBufferOperands.empty() && !outputTensorOperands.empty()) - return op->emitOpError( - "expected output operands to all have tensor type or " - "all have buffer type"); - - for (OpOperand *opOperand : outputTensorOperands) { - OpResult result = dstStyleOp.getTiedOpResult(opOperand); - if (result.getType() != opOperand->get().getType()) - return op->emitOpError("expected type of operand #") - << opOperand->getOperandNumber() << " (" - << opOperand->get().getType() << ")" - << " to match type of corresponding result (" << result.getType() - << ")"; - } - return success(); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" using namespace mlir; using namespace linalg; @@ -115,7 +116,7 @@ SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - auto genericOp = cast(op); + auto genericOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. if (genericOp.isOutput(&opOperand)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -43,6 +43,7 @@ MLIRBufferizationDialect MLIRBufferizationTransforms MLIRComplexDialect + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRFuncDialect MLIRFuncToLLVM diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ CopyOpInterface.cpp DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp + DestinationStyleOpInterface.cpp InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp @@ -38,6 +39,7 @@ add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) +add_mlir_interface_library(DestinationStyleOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -0,0 +1,71 @@ +//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +using namespace mlir; + +namespace mlir { +#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc" +} // namespace mlir + +OpOperandVector::operator SmallVector() { + SmallVector result; + result.reserve(this->size()); + llvm::transform(*this, std::back_inserter(result), + [](OpOperand *opOperand) { return opOperand->get(); }); + return result; +} + +LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { + DestinationStyleOpInterface dstStyleOp = + cast(op); + + SmallVector outputBufferOperands, outputTensorOperands; + for (OpOperand *operand : dstStyleOp.getOutputOperands()) { + Type type = operand->get().getType(); + if (type.isa()) + outputBufferOperands.push_back(operand); + if (type.isa()) + outputTensorOperands.push_back(operand); + } + + // Expect at least one output operand. + int64_t numInputs = dstStyleOp.getNumInputs(); + int64_t numOutputs = dstStyleOp.getNumOutputs(); + if (numOutputs == 0) + return op->emitOpError("expected at least one output operand"); + if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) + return failure(); + // Verify the number of results matches the number of output tensors. + if (op->getNumResults() != outputTensorOperands.size()) + return op->emitOpError("expected the number of results (") + << op->getNumResults() + << ") to be equal to the number of output tensors (" + << outputTensorOperands.size() << ")"; + + // Simplifying assumption: either full tensor or full buffer mode. + // This allows simpler verification of output operands vs result types + // without premature tracking of which operand is what in mixed-mode. + // TODO: relax when mixed-mode needs to pass verification. + if (!outputBufferOperands.empty() && !outputTensorOperands.empty()) + return op->emitOpError( + "expected output operands to all have tensor type or " + "all have buffer type"); + + for (OpOperand *opOperand : outputTensorOperands) { + OpResult result = dstStyleOp.getTiedOpResult(opOperand); + if (result.getType() != opOperand->get().getType()) + return op->emitOpError("expected type of operand #") + << opOperand->getOperandNumber() << " (" + << opOperand->get().getType() << ")" + << " to match type of corresponding result (" << result.getType() + << ")"; + } + return success(); +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -54,6 +54,7 @@ MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRDerivedAttributeOpInterface + MLIRDestinationStyleOpInterface MLIRDialect MLIRDLTIDialect MLIRFuncDialect diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -995,6 +995,13 @@ deps = [":OpBaseTdFiles"], ) +td_library( + name = "DestinationStyleOpInterfaceTdFiles", + srcs = ["include/mlir/Interfaces/DestinationStyleOpInterface.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + td_library( name = "InferIntRangeInterfaceTdFiles", srcs = ["include/mlir/Interfaces/InferIntRangeInterface.td"], @@ -5321,6 +5328,36 @@ ], ) +gentbl_cc_library( + name = "DestinationStyleOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/DestinationStyleOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/DestinationStyleOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/DestinationStyleOpInterface.td", + deps = [":DestinationStyleOpInterfaceTdFiles"], +) + +cc_library( + name = "DestinationStyleOpInterface", + srcs = ["lib/Interfaces/DestinationStyleOpInterface.cpp"], + hdrs = ["include/mlir/Interfaces/DestinationStyleOpInterface.h"], + includes = ["include"], + deps = [ + ":DestinationStyleOpInterfaceIncGen", + ":IR", + "//llvm:Support", + ], +) + gentbl_cc_library( name = "InferIntRangeInterfaceIncGen", strip_include_prefix = "include", @@ -7437,6 +7474,7 @@ includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":DestinationStyleOpInterfaceTdFiles", ":DialectUtilsTdFiles", ":InferTypeOpInterfaceTdFiles", ":LoopLikeInterfaceTdFiles", @@ -7571,6 +7609,7 @@ includes = ["include"], deps = [ ":CopyOpInterfaceTdFiles", + ":DestinationStyleOpInterface", ":LinalgOpsTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", @@ -7768,6 +7807,7 @@ ":ComplexDialect", ":ControlFlowInterfaces", ":CopyOpInterface", + ":DestinationStyleOpInterface", ":DialectUtils", ":FuncDialect", ":IR", @@ -7925,6 +7965,7 @@ ":BufferizationTransforms", ":ComplexDialect", ":ControlFlowDialect", + ":DestinationStyleOpInterface", ":DialectUtils", ":FuncDialect", ":FuncTransforms", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -94,6 +94,7 @@ "//mlir:CopyOpInterfaceTdFiles", "//mlir:DLTIDialectTdFiles", "//mlir:DataLayoutInterfacesTdFiles", + "//mlir:DestinationStyleOpInterfaceTdFiles", "//mlir:InferIntRangeInterfaceTdFiles", "//mlir:InferTypeOpInterfaceTdFiles", "//mlir:LinalgStructuredOpsTdFiles", @@ -325,6 +326,7 @@ "//mlir:DLTIDialect", "//mlir:DataLayoutInterfaces", "//mlir:DerivedAttributeOpInterface", + "//mlir:DestinationStyleOpInterface", "//mlir:Dialect", "//mlir:FuncDialect", "//mlir:FuncTransforms",