diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -57,6 +57,9 @@ /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); +/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface +LogicalResult verifyDestinationStyleOpInterface(Operation *op); + } // namespace detail } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -286,736 +286,986 @@ getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, //===------------------------------------------------------------------===// - // Num input/output arguments handling. + // Input and Output arguments handling. //===------------------------------------------------------------------===// - // `inputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the input shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"inputs", - /*args=*/(ins) - >, - // These special methods rely on `inputs` and `outputs` being defined by - // each op that wants to implement the LinalgStructuredInterface. InterfaceMethod< /*desc=*/[{ - Return the number of inputs. + Return true if the payload uses the value loaded from `opOperand`. This + is useful to avoid loading from "write-only" memory that may be + uninitialized, as well as properly cloning "read-write" operands. }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"payloadUsesValueFromOperand", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getInputs().size(); + unsigned bbArgNumber = opOperand->getOperandNumber(); + // Init tensors have uses. + return !getBlock()->getArgument(bbArgNumber).use_empty(); }] >, - // `outputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the output shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"outputs", - /*args=*/(ins) - >, InterfaceMethod< /*desc=*/[{ - Return the number of outputs. + Return true if `opOperand` is an init tensor. This is true when it is + an output tensor operand whose value is used in the payload region. }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isInitTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.outputs().size(); + if (!$_op.isOutputTensor(opOperand)) + return false; + return payloadUsesValueFromOperand(opOperand); }] >, InterfaceMethod< /*desc=*/[{ - Return the number of inputs and outputs. + Return the `opOperand` rank or zero for scalars. }], /*retTy=*/"int64_t", - /*methodName=*/"getNumInputsAndOutputs", - /*args=*/(ins), + /*methodName=*/"getRank", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return this->getOperation()->getNumOperands(); + assert(opOperand->getOwner() == this->getOperation()); + if (auto shapedType = + opOperand->get().getType().template dyn_cast()) + return shapedType.getRank(); + return 0; }] >, - //===------------------------------------------------------------------===// - // Input operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the input operands. + Return the output block arguments of the region. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", + /*retTy=*/"Block::BlockArgListType", + /*methodName=*/"getRegionOutputArgs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numInputs = getNumInputs(); - OpOperandVector result; - result.reserve(numInputs); - llvm::transform( - this->getOperation()->getOpOperands().take_front(numInputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with LinalgStructuredOpInterface + // must implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + return getBlock()->getArguments().take_back( + cast(*this->getOperation()) + .getNumOutputs()); }] >, InterfaceMethod< /*desc=*/[{ - Return the `i`-th input operand. + Return the `opOperand` shape or an empty vector for scalars. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), + /*retTy=*/"ArrayRef", + /*methodName=*/"getShape", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - return &this->getOperation()->getOpOperand(i); + assert(opOperand->getOwner() == this->getOperation()); + if (auto shapedType = + opOperand->get().getType().template dyn_cast()) + return shapedType.getShape(); + return {}; }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of input operands that are of buffer type. + Return the block argument for an `opOperand`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputBufferOperands", - /*args=*/(ins), + /*retTy=*/"BlockArgument", + /*methodName=*/"getTiedBlockArgument", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + assert(opOperand->getOwner() == this->getOperation()); + return getBlock()->getArgument(opOperand->getOperandNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of input operands that are of tensor type. + Return the operand for a `blockArgument`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputTensorOperands", - /*args=*/(ins), + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "BlockArgument":$blockArgument), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + assert(blockArgument.getOwner() == getBlock()); + return &this->getOperation()->getOpOperand( + blockArgument.getArgNumber()); }] >, - //===------------------------------------------------------------------===// - // Output operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the output operands. + Return the input or output indexing map for `opOperand`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputOperands", - /*args=*/(ins), + /*retTy=*/"AffineMap", + /*methodName=*/"getTiedIndexingMap", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numOutputs = getNumOutputs(); - OpOperandVector result; - result.reserve(numOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .take_back(numOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + assert(opOperand->getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + return *(indexingMaps.begin() + opOperand->getOperandNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Return the `i`-th output operand. + Return the indexing map for a `result`. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getOutputOperand", - /*args=*/(ins "int64_t":$i), + /*retTy=*/"AffineMap", + /*methodName=*/"getTiedIndexingMapForResult", + /*args=*/(ins "OpResult":$result), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - return &this->getOperation()->getOpOperand(getNumInputs() + i); + assert(result.getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with LinalgStructuredOpInterface + // must implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + return *(indexingMaps.begin() + + cast(*this->getOperation()) + .getNumInputs() + + result.getResultNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Set the `i`-th output operand. + Return the value yielded by the region corresponding to an output + `opOperand`. }], - /*retTy=*/"void", - /*methodName=*/"setOutputOperand", - /*args=*/(ins "int64_t":$i, "Value":$value), + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedYieldValue", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - this->getOperation()->setOperand(getNumInputs() + i, value); + assert(opOperand->getOwner() == this->getOperation()); + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with LinalgStructuredOpInterface + // must implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + int64_t resultIndex = + opOperand->getOperandNumber() - + cast(*this->getOperation()) + .getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults()); + Operation *yieldOp = getBlock()->getTerminator(); + return &yieldOp->getOpOperand(resultIndex); }] >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the subset of output operands that are of buffer type. + Return the single block constituting the body of the operation by + calling the getBody method on the concrete operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputBufferOperands", + /*retTy=*/"Block*", + /*methodName=*/"getBlock", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + // Assume the concrete operation implements the + // SingleBlockImplicitTerminator trait. + return $_op.getBody(); }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of output operands that are of tensor type. + Return the iterator types attribute within the current operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputTensorOperands", + /*retTy=*/"ArrayAttr", + /*methodName=*/"iterator_types", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + return $_op.iterator_types(); }] >, InterfaceMethod< /*desc=*/[{ - Return the types of the subset of output operands that are of buffer type. + Return true if the indexing map is depending on the current op instance. + This means that the indexing map is dynamically synthesized by using the + op instance's concrete attributes, instead of being static for all + instances of the same op kind. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputBufferTypes", + /*retTy=*/"bool", + /*methodName=*/"hasDynamicIndexingMaps", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputBufferOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] + /*defaultImplementation=*/[{ return false; }] >, InterfaceMethod< /*desc=*/[{ - Return the types of the subset of output operands that are of tensor type. + Verify all attributes used by indexing maps are valid. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputTensorTypes", + /*retTy=*/"LogicalResult", + /*methodName=*/"verifyIndexingMapRequiredAttributes", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputTensorOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] + /*defaultImplementation=*/[{ return success(); }] >, - //===------------------------------------------------------------------===// - // Input and Output arguments handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the range over input and output operands. + Return the indexing maps attribute within the current operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputAndOutputOperands", + /*retTy=*/"ArrayAttr", + /*methodName=*/"getIndexingMaps" + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing maps within the current operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getIndexingMapsArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numInputsAndOutputs = getNumInputsAndOutputs(); - OpOperandVector result; - result.reserve(numInputsAndOutputs); - llvm::transform( - this->getOperation()->getOpOperands(), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + auto range = $_op.getIndexingMaps() + .template getAsValueRange(); + return {range.begin(), range.end()}; }] >, InterfaceMethod< /*desc=*/[{ - Return true if the payload uses the value loaded from `opOperand`. This - is useful to avoid loading from "write-only" memory that may be - uninitialized, as well as properly cloning "read-write" operands. + Return true if any of the operands has a dynamic shape. }], /*retTy=*/"bool", - /*methodName=*/"payloadUsesValueFromOperand", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"hasDynamicShape", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - unsigned bbArgNumber = opOperand->getOperandNumber(); - // Init tensors have uses. - return !getBlock()->getArgument(bbArgNumber).use_empty(); + return llvm::any_of(getStaticShape(), ShapedType::isDynamic); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an input tensor. + Return the name registered for this op when lowering to an external + library call. }], - /*retTy=*/"bool", - /*methodName=*/"isInputTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*retTy=*/"std::string", + /*methodName=*/"getLibraryCallName", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() < $_op.getNumInputs()) - return true; - return false; + return $_op.getLibraryCallName(); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an output tensor. + Return whether the op accesses the iteration indices. }], /*retTy=*/"bool", - /*methodName=*/"isOutputTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"hasIndexSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + //===------------------------------------------------------------------===// + // Linalg generalization hooks. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Hook to provide a custom AffineMap used to compute all the operand + subshapes given loop bounds. This is used to answer the question: "given + an iteration space over the codomain, what are the subshapes of the + operands involved in the computation". + The default behavior is to just concatenate all the indexing maps. + A custom AffineMap allows providing a map that can be used to + compute subshapes even in cases where the concatenation of indexing maps + (i.e. the data traversal order) is not a simple permutation of the loop + traversal order. It is then possible to define ops with skewed data + traversal order for which we can still easily compute hyperrectangular + loop bounds and subviews. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getLoopsToShapesMap", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() >= $_op.getNumInputs()) - return true; - return false; + auto maps = $_op.getIndexingMapsArray(); + return concatAffineMaps(maps); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an init tensor. This is true when it is - an output tensor operand whose value is used in the payload region. + Hook to provide a custom AffineMap used to construct the + hyperrectangular loop iteration space given all the operand subshapes. + This is used to answer the question: + "Given a list of operand ranges, what is the subportion of the iteration + space involved in the computation". + This is the inverse problem of `getLoopsToShapesMap`. + Return the empty AffineMap when such an AffineMap cannot be constructed. + The default behavior is based on a very simple inference procedure that + only works with permutation affine maps. + A more advanced Tensor-Comprehension like inference is possible but has + proven to be ambiguous in unfavorable case. + A safer and more robust alternative is to allow each op to define + its own AffineMap. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getShapesToLoopsMap", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return inversePermutation(getLoopsToShapesMap()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Checks if the given operands can be dropped, and the remaining + operands can still compute the bounds of the op. }], /*retTy=*/"bool", - /*methodName=*/"isInitTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"canOpOperandsBeDropped", + /*args=*/(ins "ArrayRef":$droppedOperands), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!$_op.isOutputTensor(opOperand)) - return false; - return payloadUsesValueFromOperand(opOperand); + return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); }] >, InterfaceMethod< /*desc=*/[{ - Return the `opOperand` rank or zero for scalars. + Like `getShape`, but only returns statically-known information, without + generating any new IR. For each shape dimension, returns >=0 if that + dimension is statically known, or ShapeType::kDynamicSize otherwise. }], - /*retTy=*/"int64_t", - /*methodName=*/"getRank", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticShape", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getRank(); - return 0; + SmallVector res; + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with LinalgStructuredOpInterface + // must implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + auto iface = cast(*this->getOperation()); + for (OpOperand *opOperand : iface.getInputAndOutputOperands()) + llvm::append_range(res, getShape(opOperand)); + return res; }] >, InterfaceMethod< /*desc=*/[{ - Return the output block arguments of the region. + Returns the statically-known loop ranges. Composes + `getShapesToLoopsMap()` with the result of `getStaticShape`. + Returns ShapeType::kDynamicSize for non-statically-known loop ranges. + This is expected to be called by a valid Linalg op + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticLoopRanges", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector viewSizes = getStaticShape(); + AffineMap invertedMap = getShapesToLoopsMap(); + assert(invertedMap && "expected a valid Linalg op to call the method"); + return invertedMap.compose(viewSizes); + }] + >, + //===------------------------------------------------------------------===// + // Other static interface methods. + //===------------------------------------------------------------------===// + StaticInterfaceMethod< + /*desc=*/[{ + Returns the region builder for constructing the body for linalg.generic. + Returns a null function if this named op does not define a region + builder. + }], + /*retTy=*/"std::function)>", + /*methodName=*/"getRegionBuilder", + (ins), + [{ return ConcreteOp::getRegionBuilder(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if all the indexing maps are projected permutations. + Otherwise return false. + }], + /*retTy=*/"bool", + /*methodName=*/"hasOnlyProjectedPermutations", + (ins), + [{ + return llvm::all_of($_op.getIndexingMapsArray(), + [](AffineMap map) { return map.isProjectedPermutation(); }); + }] + > + ]; + + let extraClassDeclaration = [{ + /// Return the flat list of all operand dimension sizes in the order they + /// appear in the operands. + SmallVector createFlatListOfOperandDims(OpBuilder &, Location); + + /// Return the flat list of all operands' static dimension sizes in the + /// order they appear in the operands. All operand dimension sizes have to + /// be statically known. + SmallVector createFlatListOfOperandStaticDims(); + + /// Create the loop ranges to materialize the computation over the current + /// operands. This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandDims`. + SmallVector createLoopRanges(OpBuilder &b, Location loc); + + /// Compute the static loop sizes necessary to vectorize the computation. + /// This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandStaticDims`. + SmallVector computeStaticLoopSizes(); + + /// Returns the value that expresses the shape of the output in terms of + /// shape of the input operands where possible + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes); + + // TODO: Remove once prefixing is flipped. + ArrayAttr getIteratorTypes() { return iterator_types(); } + + //========================================================================// + // Forwarding functions to access interface methods from the + // DestinationStyleOpInterface. + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with LinalgStructuredOpInterface + // must implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + //========================================================================// + + ValueRange inputs() { + return cast(*this->getOperation()).inputs(); + } + + int64_t getNumInputs() { + return cast(*this->getOperation()) + .getNumInputs(); + } + + ValueRange outputs() { + return cast(*this->getOperation()).outputs(); + } + + int64_t getNumOutputs() { + return cast(*this->getOperation()) + .getNumOutputs(); + } + + int64_t getNumInputsAndOutputs() { + return cast(*this->getOperation()) + .getNumInputsAndOutputs(); + } + + OpOperandVector getInputOperands() { + return cast(*this->getOperation()) + .getInputOperands(); + } + + OpOperand *getInputOperand(int64_t i) { + return cast(*this->getOperation()) + .getInputOperand(i); + } + + OpOperandVector getInputBufferOperands() { + return cast(*this->getOperation()) + .getInputBufferOperands(); + } + + OpOperandVector getInputTensorOperands() { + return cast(*this->getOperation()) + .getInputTensorOperands(); + } + + OpOperandVector getOutputOperands() { + return cast(*this->getOperation()) + .getOutputOperands(); + } + + OpOperand *getOutputOperand(int64_t i) { + return cast(*this->getOperation()) + .getOutputOperand(i); + } + + void setOutputOperand(int64_t i, Value value) { + return cast(*this->getOperation()) + .setOutputOperand(i, value); + } + + OpOperandVector getOutputBufferOperands() { + return cast(*this->getOperation()) + .getOutputBufferOperands(); + } + + OpOperandVector getOutputTensorOperands() { + return cast(*this->getOperation()) + .getOutputTensorOperands(); + } + + SmallVector getOutputBufferTypes() { + return cast(*this->getOperation()) + .getOutputBufferTypes(); + } + + SmallVector getOutputTensorTypes() { + return cast(*this->getOperation()) + .getOutputTensorTypes(); + } + + OpOperandVector getInputAndOutputOperands() { + return cast(*this->getOperation()) + .getInputAndOutputOperands(); + } + + bool isInputTensor(OpOperand *opOperand) { + return cast(*this->getOperation()) + .isInputTensor(opOperand); + } + + bool isOutputTensor(OpOperand *opOperand) { + return cast(*this->getOperation()) + .isOutputTensor(opOperand); + } + + bool isScalar(OpOperand *opOperand) { + return cast(*this->getOperation()) + .isScalar(opOperand); + } + + OpResult getTiedOpResult(OpOperand *opOperand) { + return cast(*this->getOperation()) + .getTiedOpResult(opOperand); + } + + bool hasBufferSemantics() { + return cast(*this->getOperation()) + .hasBufferSemantics(); + } + + bool hasTensorSemantics() { + return cast(*this->getOperation()) + .hasTensorSemantics(); + } + + Operation *clone(OpBuilder & b, Location loc, TypeRange resultTypes, + ValueRange operands) { + return cast(*this->getOperation()) + .clone(b, loc, resultTypes, operands); + } + + Operation *cloneWithoutRegions(OpBuilder & b, Location loc, + TypeRange resultTypes, ValueRange operands) { + return cast(*this->getOperation()) + .cloneWithoutRegions(b, loc, resultTypes, operands); + } + + //========================================================================// + // Helper functions to mutate the `operand_segment_sizes` attribute. + // These are useful when cloning and changing operand types. + //========================================================================// + void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } + void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } + + private: + void setOperandSegmentAt(unsigned idx, unsigned val) { + auto attr = (*this)->getAttr("operand_segment_sizes") + .cast(); + unsigned i = 0; + auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), + [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); + getOperation()->setAttr("operand_segment_sizes", newAttr); + } + }]; + + let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; + let verifyWithRegions = 1; +} + +// The 'DestinationStyleOpInterface' provides access to the methods relevant +// for destination-style ops. A destination-style operation has 'n' input +// arguments and 'm' output arguments. Each op that wants to implement +// DestinationStyleOpInterface needs to define inputs() and outputs() methods. +def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { + let cppNamespace = "::mlir::linalg"; + let methods = [ + //===------------------------------------------------------------------===// + // Num input/output arguments handling. + //===------------------------------------------------------------------===// + // `inputs` must be defined by each op that wants to implement the + // DestinationStyleOpInterface. + InterfaceMethod< + /*desc=*/[{ + Return the input shape operands. }], - /*retTy=*/"Block::BlockArgListType", - /*methodName=*/"getRegionOutputArgs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return getBlock()->getArguments().take_back(this->getNumOutputs()); - }] + /*retTy=*/"ValueRange", + /*methodName=*/"inputs", + /*args=*/(ins) >, + // These special methods rely on `inputs` and `outputs` being defined by + // each op that wants to implement the DestinationStyleOpInterface. InterfaceMethod< /*desc=*/[{ - Return the `opOperand` shape or an empty vector for scalars. + Return the number of inputs. }], - /*retTy=*/"ArrayRef", - /*methodName=*/"getShape", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getShape(); - return {}; + return $_op.getInputs().size(); }] >, + // `outputs` must be defined by each op that wants to implement the + // DestinationStyleOpInterface. InterfaceMethod< /*desc=*/[{ - Return true if the `opOperand` is a scalar value. + Return the output shape operands. }], - /*retTy=*/"bool", - /*methodName=*/"isScalar", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return !opOperand->get().getType().template isa(); - }] + /*retTy=*/"ValueRange", + /*methodName=*/"outputs", + /*args=*/(ins) >, InterfaceMethod< /*desc=*/[{ - Return the block argument for an `opOperand`. + Return the number of outputs. }], - /*retTy=*/"BlockArgument", - /*methodName=*/"getTiedBlockArgument", - /*args=*/(ins "OpOperand *":$opOperand), + /*retTy=*/"int64_t", + /*methodName=*/"getNumOutputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return getBlock()->getArgument(opOperand->getOperandNumber()); + return $_op.outputs().size(); }] >, InterfaceMethod< /*desc=*/[{ - Return the operand for a `blockArgument`. + Return the number of inputs and outputs. }], - /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedOpOperand", - /*args=*/(ins "BlockArgument":$blockArgument), + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputsAndOutputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(blockArgument.getOwner() == getBlock()); - return &this->getOperation()->getOpOperand( - blockArgument.getArgNumber()); + return this->getOperation()->getNumOperands(); }] >, + //===------------------------------------------------------------------===// + // Input operands handling. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the input or output indexing map for `opOperand`. + Return the input operands. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMap", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputOperands", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + opOperand->getOperandNumber()); + int64_t numInputs = getNumInputs(); + OpOperandVector result; + result.reserve(numInputs); + llvm::transform( + this->getOperation()->getOpOperands().take_front(numInputs), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the indexing map for a `result`. + Return the `i`-th input operand. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMapForResult", - /*args=*/(ins "OpResult":$result), + /*retTy=*/"OpOperand*", + /*methodName=*/"getInputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(result.getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + getNumInputs() + - result.getResultNumber()); + assert(i >= 0 && i < getNumInputs()); + return &this->getOperation()->getOpOperand(i); }] >, InterfaceMethod< /*desc=*/[{ - Return the result tied to `opOperand`. + Return the subset of input operands that are of buffer type. }], - /*retTy=*/"OpResult", - /*methodName=*/"getTiedOpResult", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputBufferOperands", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); - assert(resultIndex >= 0 && - resultIndex < this->getOperation()->getNumResults() ); - return this->getOperation()->getResult(resultIndex); + OpOperandVector result; + result.reserve(getNumInputs()); + llvm::copy_if(getInputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the value yielded by the region corresponding to an output - `opOperand`. + Return the subset of input operands that are of tensor type. }], - /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedYieldValue", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputTensorOperands", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); - assert(resultIndex >= 0 && - resultIndex < this->getOperation()->getNumResults()); - Operation *yieldOp = getBlock()->getTerminator(); - return &yieldOp->getOpOperand(resultIndex); + OpOperandVector result; + result.reserve(getNumInputs()); + llvm::copy_if(getInputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, //===------------------------------------------------------------------===// - // Other interface methods. + // Output operands handling. //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the single block constituting the body of the operation by - calling the getBody method on the concrete operation. + Return the output operands. }], - /*retTy=*/"Block*", - /*methodName=*/"getBlock", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - // Assume the concrete operation implements the - // SingleBlockImplicitTerminator trait. - return $_op.getBody(); + int64_t numOutputs = getNumOutputs(); + OpOperandVector result; + result.reserve(numOutputs); + llvm::transform( + this->getOperation()->getOpOperands() + .take_back(numOutputs), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the iterator types attribute within the current operation. + Return the `i`-th output operand. }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"iterator_types", - /*args=*/(ins), + /*retTy=*/"OpOperand*", + /*methodName=*/"getOutputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.iterator_types(); + assert(i >= 0 && i < getNumOutputs()); + return &this->getOperation()->getOpOperand(getNumInputs() + i); }] >, InterfaceMethod< /*desc=*/[{ - Return true if the indexing map is depending on the current op instance. - This means that the indexing map is dynamically synthesized by using the - op instance's concrete attributes, instead of being static for all - instances of the same op kind. - }], - /*retTy=*/"bool", - /*methodName=*/"hasDynamicIndexingMaps", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ return false; }] - >, - InterfaceMethod< - /*desc=*/[{ - Verify all attributes used by indexing maps are valid. + Set the `i`-th output operand. }], - /*retTy=*/"LogicalResult", - /*methodName=*/"verifyIndexingMapRequiredAttributes", - /*args=*/(ins), + /*retTy=*/"void", + /*methodName=*/"setOutputOperand", + /*args=*/(ins "int64_t":$i, "Value":$value), /*methodBody=*/"", - /*defaultImplementation=*/[{ return success(); }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the indexing maps attribute within the current operation. - }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"getIndexingMaps" + /*defaultImplementation=*/[{ + assert(i >= 0 && i < getNumOutputs()); + this->getOperation()->setOperand(getNumInputs() + i, value); + }] >, InterfaceMethod< /*desc=*/[{ - Return the indexing maps within the current operation. + Return the subset of output operands that are of buffer type. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getIndexingMapsArray", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputBufferOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto range = $_op.getIndexingMaps() - .template getAsValueRange(); - return {range.begin(), range.end()}; + OpOperandVector result; + result.reserve(getNumOutputs()); + llvm::copy_if(getOutputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return true if any of the operands has a dynamic shape. + Return the subset of output operands that are of tensor type. }], - /*retTy=*/"bool", - /*methodName=*/"hasDynamicShape", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputTensorOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::any_of(getStaticShape(), ShapedType::isDynamic); + OpOperandVector result; + result.reserve(getNumOutputs()); + llvm::copy_if(getOutputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op has only MemRef input and outputs. + Return the types of the subset of output operands that are of buffer type. }], - /*retTy=*/"bool", - /*methodName=*/"hasBufferSemantics", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputBufferTypes", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(this->getOperation()->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); - }); + SmallVector result; + result.reserve(getNumOutputs()); + llvm::transform(getOutputBufferOperands(), + std::back_inserter(result), + [](OpOperand *opOperands) { + return opOperands->get().getType().cast(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op has only RankedTensor input and outputs. + Return the types of the subset of output operands that are of tensor type. }], - /*retTy=*/"bool", - /*methodName=*/"hasTensorSemantics", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputTensorTypes", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::all_of(this->getOperation()->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); + SmallVector result; + result.reserve(getNumOutputs()); + llvm::transform(getOutputTensorOperands(), + std::back_inserter(result), + [](OpOperand *opOperands) { + return opOperands->get().getType().cast(); }); + return result; }] >, + //===------------------------------------------------------------------===// + // Input and Output arguments handling. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the name registered for this op when lowering to an external - library call. + Return the range over input and output operands. }], - /*retTy=*/"std::string", - /*methodName=*/"getLibraryCallName", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputAndOutputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getLibraryCallName(); + int64_t numInputsAndOutputs = getNumInputsAndOutputs(); + OpOperandVector result; + result.reserve(numInputsAndOutputs); + llvm::transform( + this->getOperation()->getOpOperands(), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op accesses the iteration indices. + Return true if `opOperand` is an input tensor. }], /*retTy=*/"bool", - /*methodName=*/"hasIndexSemantics", - /*args=*/(ins), + /*methodName=*/"isInputTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", - /*defaultImplementation=*/"" + /*defaultImplementation=*/[{ + if (!opOperand->get().getType().template isa()) + return false; + if (opOperand->getOperandNumber() < $_op.getNumInputs()) + return true; + return false; + }] >, - //===------------------------------------------------------------------===// - // Linalg generalization hooks. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Hook to provide a custom AffineMap used to compute all the operand - subshapes given loop bounds. This is used to answer the question: "given - an iteration space over the codomain, what are the subshapes of the - operands involved in the computation". - The default behavior is to just concatenate all the indexing maps. - A custom AffineMap allows providing a map that can be used to - compute subshapes even in cases where the concatenation of indexing maps - (i.e. the data traversal order) is not a simple permutation of the loop - traversal order. It is then possible to define ops with skewed data - traversal order for which we can still easily compute hyperrectangular - loop bounds and subviews. + Return true if `opOperand` is an output tensor. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getLoopsToShapesMap", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isOutputTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto maps = $_op.getIndexingMapsArray(); - return concatAffineMaps(maps); + if (!opOperand->get().getType().template isa()) + return false; + if (opOperand->getOperandNumber() >= $_op.getNumInputs()) + return true; + return false; }] >, InterfaceMethod< /*desc=*/[{ - Hook to provide a custom AffineMap used to construct the - hyperrectangular loop iteration space given all the operand subshapes. - This is used to answer the question: - "Given a list of operand ranges, what is the subportion of the iteration - space involved in the computation". - This is the inverse problem of `getLoopsToShapesMap`. - Return the empty AffineMap when such an AffineMap cannot be constructed. - The default behavior is based on a very simple inference procedure that - only works with permutation affine maps. - A more advanced Tensor-Comprehension like inference is possible but has - proven to be ambiguous in unfavorable case. - A safer and more robust alternative is to allow each op to define - its own AffineMap. + Return true if the `opOperand` is a scalar value. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getShapesToLoopsMap", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isScalar", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return inversePermutation(getLoopsToShapesMap()); + assert(opOperand->getOwner() == this->getOperation()); + return !opOperand->get().getType().template isa(); }] >, InterfaceMethod< /*desc=*/[{ - Checks if the given operands can be dropped, and the remaining - operands can still compute the bounds of the op. + Return the result tied to `opOperand`. }], - /*retTy=*/"bool", - /*methodName=*/"canOpOperandsBeDropped", - /*args=*/(ins "ArrayRef":$droppedOperands), + /*retTy=*/"OpResult", + /*methodName=*/"getTiedOpResult", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); + assert(opOperand->getOwner() == this->getOperation()); + int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults() ); + return this->getOperation()->getResult(resultIndex); }] >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Like `getShape`, but only returns statically-known information, without - generating any new IR. For each shape dimension, returns >=0 if that - dimension is statically known, or ShapeType::kDynamicSize otherwise. + Return whether the op has only MemRef input and outputs. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticShape", + /*retTy=*/"bool", + /*methodName=*/"hasBufferSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector res; - for (OpOperand *opOperand : getInputAndOutputOperands()) - llvm::append_range(res, getShape(opOperand)); - return res; + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(this->getOperation()->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); }] >, InterfaceMethod< /*desc=*/[{ - Returns the statically-known loop ranges. Composes - `getShapesToLoopsMap()` with the result of `getStaticShape`. - Returns ShapeType::kDynamicSize for non-statically-known loop ranges. - This is expected to be called by a valid Linalg op + Return whether the op has only RankedTensor input and outputs. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticLoopRanges", + /*retTy=*/"bool", + /*methodName=*/"hasTensorSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector viewSizes = getStaticShape(); - AffineMap invertedMap = getShapesToLoopsMap(); - assert(invertedMap && "expected a valid Linalg op to call the method"); - return invertedMap.compose(viewSizes); + return llvm::all_of(this->getOperation()->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); }] >, //===------------------------------------------------------------------===// @@ -1042,27 +1292,6 @@ return b.create(state); }] >, - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location, operands - and BlockAndValueMapping. This is used to abstract away the - optional underlying region creation. This does not change the - balance between input, output_buffer and init_tensors - operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"cloneWithMapper", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands, "BlockAndValueMapping &":$bvm), - [{ - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (Region &r : $_op->getRegions()) - r.cloneInto(state.addRegion(), bvm); - return b.create(state); - }] - >, InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location, operands @@ -1083,80 +1312,10 @@ state.addRegion(); return b.create(state); }] - >, - StaticInterfaceMethod< - /*desc=*/[{ - Returns the region builder for constructing the body for linalg.generic. - Returns a null function if this named op does not define a region - builder. - }], - /*retTy=*/"std::function)>", - /*methodName=*/"getRegionBuilder", - (ins), - [{ return ConcreteOp::getRegionBuilder(); }] - >, - InterfaceMethod< - /*desc=*/[{ - Return true if all the indexing maps are projected permutations. - Otherwise return false. - }], - /*retTy=*/"bool", - /*methodName=*/"hasOnlyProjectedPermutations", - (ins), - [{ - return llvm::all_of($_op.getIndexingMapsArray(), - [](AffineMap map) { return map.isProjectedPermutation(); }); - }] > ]; - let extraClassDeclaration = [{ - /// Return the flat list of all operand dimension sizes in the order they - /// appear in the operands. - SmallVector createFlatListOfOperandDims(OpBuilder &, Location); - - /// Return the flat list of all operands' static dimension sizes in the - /// order they appear in the operands. All operand dimension sizes have to - /// be statically known. - SmallVector createFlatListOfOperandStaticDims(); - - /// Create the loop ranges to materialize the computation over the current - /// operands. This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandDims`. - SmallVector createLoopRanges(OpBuilder &b, Location loc); - - /// Compute the static loop sizes necessary to vectorize the computation. - /// This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandStaticDims`. - SmallVector computeStaticLoopSizes(); - - /// Returns the value that expresses the shape of the output in terms of - /// shape of the input operands where possible - LogicalResult reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes); - - // TODO: Remove once prefixing is flipped. - ArrayAttr getIteratorTypes() { return iterator_types(); } - - //========================================================================// - // Helper functions to mutate the `operand_segment_sizes` attribute. - // These are useful when cloning and changing operand types. - //========================================================================// - void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } - void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } - - private: - void setOperandSegmentAt(unsigned idx, unsigned val) { - auto attr = (*this)->getAttr("operand_segment_sizes") - .cast(); - unsigned i = 0; - auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), - [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); - getOperation()->setAttr("operand_segment_sizes", newAttr); - } - }]; - - let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; + let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }]; let verifyWithRegions = 1; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -28,6 +28,7 @@ : Op, DeclareOpInterfaceMethods, + DestinationStyleOpInterface, LinalgStructuredInterface, RegionBranchOpInterface, ReifyRankedShapedTypeOpInterface], props)> { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -633,22 +633,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); - // Expect at least one output operand. - // This means an op that constructs a tensor out of indices cannot be a - // LinalgOp at the moment. For now this will have to be a special op until we - // have output shape operands that are not tensors. - int64_t numInputs = linalgOp.getNumInputs(); - int64_t numOutputs = linalgOp.getNumOutputs(); - if (numOutputs == 0) - return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != linalgOp.getOutputTensorOperands().size()) - return op->emitOpError("expected the number of results (") - << op->getNumResults() - << ") to be equal to the number of output tensors (" - << linalgOp.getOutputTensorOperands().size() << ")"; // Check all iterator types are known. auto iteratorTypesRange = @@ -699,26 +683,6 @@ SmallVector redDims; linalgOp.getReductionDims(redDims); - // Simplifying assumption: either full tensor or full buffer mode. - // This allows simpler verification of output operands vs result types - // without premature tracking of which operand is what in mixed-mode. - // TODO: relax when mixed-mode needs to pass verification. - if (!linalgOp.getOutputBufferOperands().empty() && - !linalgOp.getOutputTensorOperands().empty()) - return op->emitOpError( - "expected output operands to all have tensor type or " - "all have buffer type"); - - for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) { - OpResult result = linalgOp.getTiedOpResult(opOperand); - if (result.getType() != opOperand->get().getType()) - return op->emitOpError("expected type of operand #") - << opOperand->getOperandNumber() << " (" - << opOperand->get().getType() << ")" - << " to match type of corresponding result (" << result.getType() - << ")"; - } - // Output tensor indexing map may not depend on reduction indices. for (OpOperand *opOperand : linalgOp.getOutputOperands()) { AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); @@ -740,36 +704,9 @@ } } - // Check the region has exactly one block. - if (linalgOp->getNumRegions() != 1 || - !llvm::hasSingleElement(linalgOp->getRegion(0))) - return op->emitOpError("expects to have 1 region with 1 block"); - if (!linalgOp.getShapesToLoopsMap()) return op->emitOpError("expected the shape-to-loops map to be non-null"); - // Simplifying assumption: bbargs match 1-1 with shape operands elemental - // types. - // TODO: once ranked shape types are plugged in, we may want to drop the - // corresponding bbargs, that can never be read from. This will be subject to - // consistency discussions (i.e. what to do with output tensors whose bbarg is - // not used). - Block &block = linalgOp->getRegion(0).front(); - - if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) - return op->emitOpError("expected as many non-induction variable region " - "arguments as the number of input/output operands"); - - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - Type elementType = getElementTypeOrSelf(opOperand->get()); - Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); - if (elementType != argType) - return op->emitOpError("expected type of bb argument #") - << opOperand->getOperandNumber() << " (" << argType << ")" - << " to match element or self type of the corresponding operand (" - << elementType << ")"; - } - // Check if given shapes match to inferred shapes. SmallVector endLoopRangeValues = linalgOp.getStaticLoopRanges(); SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); @@ -835,3 +772,75 @@ return success(); } + +LogicalResult +mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) { + DestinationStyleOpInterface dstStyleOp = + cast(op); + + // Expect at least one output operand. + // This means an op that constructs a tensor out of indices cannot be a + // LinalgOp at the moment. For now this will have to be a special op until we + // have output shape operands that are not tensors. + int64_t numInputs = dstStyleOp.getNumInputs(); + int64_t numOutputs = dstStyleOp.getNumOutputs(); + if (numOutputs == 0) + return op->emitOpError("expected at least one output operand"); + if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) + return failure(); + // Verify the number of results matches the number of output tensors. + if (op->getNumResults() != dstStyleOp.getOutputTensorOperands().size()) + return op->emitOpError("expected the number of results (") + << op->getNumResults() + << ") to be equal to the number of output tensors (" + << dstStyleOp.getOutputTensorOperands().size() << ")"; + + // Simplifying assumption: either full tensor or full buffer mode. + // This allows simpler verification of output operands vs result types + // without premature tracking of which operand is what in mixed-mode. + // TODO: relax when mixed-mode needs to pass verification. + if (!dstStyleOp.getOutputBufferOperands().empty() && + !dstStyleOp.getOutputTensorOperands().empty()) + return op->emitOpError( + "expected output operands to all have tensor type or " + "all have buffer type"); + + for (OpOperand *opOperand : dstStyleOp.getOutputTensorOperands()) { + OpResult result = dstStyleOp.getTiedOpResult(opOperand); + if (result.getType() != opOperand->get().getType()) + return op->emitOpError("expected type of operand #") + << opOperand->getOperandNumber() << " (" + << opOperand->get().getType() << ")" + << " to match type of corresponding result (" << result.getType() + << ")"; + } + + // Check the region has exactly one block. + if (dstStyleOp->getNumRegions() != 1 || + !llvm::hasSingleElement(dstStyleOp->getRegion(0))) + return op->emitOpError("expects to have 1 region with 1 block"); + + // Simplifying assumption: bbargs match 1-1 with shape operands elemental + // types. + // TODO: once ranked shape types are plugged in, we may want to drop the + // corresponding bbargs, that can never be read from. This will be subject to + // consistency discussions (i.e. what to do with output tensors whose bbarg is + // not used). + Block &block = dstStyleOp->getRegion(0).front(); + + if (dstStyleOp.getNumInputsAndOutputs() != block.getNumArguments()) + return op->emitOpError("expected as many non-induction variable region " + "arguments as the number of input/output operands"); + + for (OpOperand *opOperand : dstStyleOp.getInputAndOutputOperands()) { + Type elementType = getElementTypeOrSelf(opOperand->get()); + Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); + if (elementType != argType) + return op->emitOpError("expected type of bb argument #") + << opOperand->getOperandNumber() << " (" << argType << ")" + << " to match element or self type of the corresponding operand (" + << elementType << ")"; + } + + return success(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2741,7 +2741,8 @@ def TestLinalgConvOp : TEST_Op<"linalg_conv_op", [AttrSizedOperandSegments, SingleBlock, - LinalgStructuredInterface, LinalgConvolutionOpInterface]> { + DestinationStyleOpInterface, LinalgStructuredInterface, + LinalgConvolutionOpInterface]> { let arguments = (ins Variadic:$inputs, Variadic:$outputs); @@ -2799,7 +2800,8 @@ def TestLinalgFillOp : TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock, - LinalgStructuredInterface, LinalgFillOpInterface]> { + DestinationStyleOpInterface, LinalgStructuredInterface, + LinalgFillOpInterface]> { let arguments = (ins Variadic:$inputs, Variadic:$outputs);