diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -83,8 +83,9 @@ return llvm::None; if (OpOperand *operand = opView.dyn_cast()) return owner.getTiedIndexingMap(operand); - return owner.getTiedIndexingMap(owner.getOutputOperand( - opView.get().cast().getResultNumber())); + return owner.getTiedIndexingMap( + cast(owner).getOutputOperand( + opView.get().cast().getResultNumber())); } // Return the operand number if the `opView` is an OpOperand *. Otherwise // return llvm::None. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -286,736 +286,829 @@ getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, //===------------------------------------------------------------------===// - // Num input/output arguments handling. + // Input and Output arguments handling. //===------------------------------------------------------------------===// - // `inputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the input shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"inputs", - /*args=*/(ins) - >, - // These special methods rely on `inputs` and `outputs` being defined by - // each op that wants to implement the LinalgStructuredInterface. InterfaceMethod< /*desc=*/[{ - Return the number of inputs. + Return true if the payload uses the value loaded from `opOperand`. This + is useful to avoid loading from "write-only" memory that may be + uninitialized, as well as properly cloning "read-write" operands. }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"payloadUsesValueFromOperand", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getInputs().size(); + unsigned bbArgNumber = opOperand->getOperandNumber(); + // Init tensors have uses. + return !getBlock()->getArgument(bbArgNumber).use_empty(); }] >, - // `outputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the output shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"outputs", - /*args=*/(ins) - >, InterfaceMethod< /*desc=*/[{ - Return the number of outputs. + Return true if `opOperand` is an init tensor. This is true when it is + an output tensor operand whose value is used in the payload region. }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isInitTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.outputs().size(); + if (!$_op.isOutputTensor(opOperand)) + return false; + return payloadUsesValueFromOperand(opOperand); }] >, InterfaceMethod< /*desc=*/[{ - Return the number of inputs and outputs. + Return the `opOperand` rank or zero for scalars. }], /*retTy=*/"int64_t", - /*methodName=*/"getNumInputsAndOutputs", - /*args=*/(ins), + /*methodName=*/"getRank", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return this->getOperation()->getNumOperands(); + assert(opOperand->getOwner() == this->getOperation()); + if (auto shapedType = + opOperand->get().getType().template dyn_cast()) + return shapedType.getRank(); + return 0; }] >, - //===------------------------------------------------------------------===// - // Input operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the input operands. + Return the output block arguments of the region. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", + /*retTy=*/"Block::BlockArgListType", + /*methodName=*/"getRegionOutputArgs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numInputs = getNumInputs(); - OpOperandVector result; - result.reserve(numInputs); - llvm::transform( - this->getOperation()->getOpOperands().take_front(numInputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + return getBlock()->getArguments().take_back( + cast(*this->getOperation()).getNumOutputs()); }] >, InterfaceMethod< /*desc=*/[{ - Return the `i`-th input operand. + Return the `opOperand` shape or an empty vector for scalars. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), + /*retTy=*/"ArrayRef", + /*methodName=*/"getShape", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - return &this->getOperation()->getOpOperand(i); + assert(opOperand->getOwner() == this->getOperation()); + if (auto shapedType = + opOperand->get().getType().template dyn_cast()) + return shapedType.getShape(); + return {}; }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of input operands that are of buffer type. + Return the block argument for an `opOperand`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputBufferOperands", - /*args=*/(ins), + /*retTy=*/"BlockArgument", + /*methodName=*/"getTiedBlockArgument", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + assert(opOperand->getOwner() == this->getOperation()); + return getBlock()->getArgument(opOperand->getOperandNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of input operands that are of tensor type. + Return the operand for a `blockArgument`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputTensorOperands", - /*args=*/(ins), + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "BlockArgument":$blockArgument), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + assert(blockArgument.getOwner() == getBlock()); + return &this->getOperation()->getOpOperand( + blockArgument.getArgNumber()); }] >, - //===------------------------------------------------------------------===// - // Output operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the output operands. + Return the input or output indexing map for `opOperand`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputOperands", - /*args=*/(ins), + /*retTy=*/"AffineMap", + /*methodName=*/"getTiedIndexingMap", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numOutputs = getNumOutputs(); - OpOperandVector result; - result.reserve(numOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .take_back(numOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + assert(opOperand->getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + return *(indexingMaps.begin() + opOperand->getOperandNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Return the `i`-th output operand. + Return the indexing map for a `result`. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getOutputOperand", - /*args=*/(ins "int64_t":$i), + /*retTy=*/"AffineMap", + /*methodName=*/"getTiedIndexingMapForResult", + /*args=*/(ins "OpResult":$result), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - return &this->getOperation()->getOpOperand(getNumInputs() + i); + assert(result.getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + return *(indexingMaps.begin() + + cast(*this->getOperation()).getNumInputs() + + result.getResultNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Set the `i`-th output operand. + Return the value yielded by the region corresponding to an output + `opOperand`. }], - /*retTy=*/"void", - /*methodName=*/"setOutputOperand", - /*args=*/(ins "int64_t":$i, "Value":$value), + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedYieldValue", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - this->getOperation()->setOperand(getNumInputs() + i, value); + assert(opOperand->getOwner() == this->getOperation()); + int64_t resultIndex = opOperand->getOperandNumber() - + cast(*this->getOperation()).getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults()); + Operation *yieldOp = getBlock()->getTerminator(); + return &yieldOp->getOpOperand(resultIndex); }] >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the subset of output operands that are of buffer type. + Return the single block constituting the body of the operation by + calling the getBody method on the concrete operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputBufferOperands", + /*retTy=*/"Block*", + /*methodName=*/"getBlock", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + // Assume the concrete operation implements the + // SingleBlockImplicitTerminator trait. + return $_op.getBody(); }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of output operands that are of tensor type. + Return the iterator types attribute within the current operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputTensorOperands", + /*retTy=*/"ArrayAttr", + /*methodName=*/"iterator_types", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + return $_op.iterator_types(); }] >, InterfaceMethod< /*desc=*/[{ - Return the types of the subset of output operands that are of buffer type. + Return true if the indexing map is depending on the current op instance. + This means that the indexing map is dynamically synthesized by using the + op instance's concrete attributes, instead of being static for all + instances of the same op kind. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputBufferTypes", + /*retTy=*/"bool", + /*methodName=*/"hasDynamicIndexingMaps", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputBufferOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] + /*defaultImplementation=*/[{ return false; }] >, InterfaceMethod< /*desc=*/[{ - Return the types of the subset of output operands that are of tensor type. + Verify all attributes used by indexing maps are valid. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputTensorTypes", + /*retTy=*/"LogicalResult", + /*methodName=*/"verifyIndexingMapRequiredAttributes", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputTensorOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] + /*defaultImplementation=*/[{ return success(); }] >, - //===------------------------------------------------------------------===// - // Input and Output arguments handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the range over input and output operands. + Return the indexing maps attribute within the current operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputAndOutputOperands", + /*retTy=*/"ArrayAttr", + /*methodName=*/"getIndexingMaps" + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing maps within the current operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getIndexingMapsArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numInputsAndOutputs = getNumInputsAndOutputs(); - OpOperandVector result; - result.reserve(numInputsAndOutputs); - llvm::transform( - this->getOperation()->getOpOperands(), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + auto range = $_op.getIndexingMaps() + .template getAsValueRange(); + return {range.begin(), range.end()}; }] >, InterfaceMethod< /*desc=*/[{ - Return true if the payload uses the value loaded from `opOperand`. This - is useful to avoid loading from "write-only" memory that may be - uninitialized, as well as properly cloning "read-write" operands. + Return true if any of the operands has a dynamic shape. }], /*retTy=*/"bool", - /*methodName=*/"payloadUsesValueFromOperand", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"hasDynamicShape", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - unsigned bbArgNumber = opOperand->getOperandNumber(); - // Init tensors have uses. - return !getBlock()->getArgument(bbArgNumber).use_empty(); + return llvm::any_of(getStaticShape(), ShapedType::isDynamic); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an input tensor. + Return the name registered for this op when lowering to an external + library call. }], - /*retTy=*/"bool", - /*methodName=*/"isInputTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*retTy=*/"std::string", + /*methodName=*/"getLibraryCallName", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() < $_op.getNumInputs()) - return true; - return false; + return $_op.getLibraryCallName(); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an output tensor. + Return whether the op accesses the iteration indices. }], /*retTy=*/"bool", - /*methodName=*/"isOutputTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"hasIndexSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + //===------------------------------------------------------------------===// + // Linalg generalization hooks. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Hook to provide a custom AffineMap used to compute all the operand + subshapes given loop bounds. This is used to answer the question: "given + an iteration space over the codomain, what are the subshapes of the + operands involved in the computation". + The default behavior is to just concatenate all the indexing maps. + A custom AffineMap allows providing a map that can be used to + compute subshapes even in cases where the concatenation of indexing maps + (i.e. the data traversal order) is not a simple permutation of the loop + traversal order. It is then possible to define ops with skewed data + traversal order for which we can still easily compute hyperrectangular + loop bounds and subviews. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getLoopsToShapesMap", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() >= $_op.getNumInputs()) - return true; - return false; + auto maps = $_op.getIndexingMapsArray(); + return concatAffineMaps(maps); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an init tensor. This is true when it is - an output tensor operand whose value is used in the payload region. + Hook to provide a custom AffineMap used to construct the + hyperrectangular loop iteration space given all the operand subshapes. + This is used to answer the question: + "Given a list of operand ranges, what is the subportion of the iteration + space involved in the computation". + This is the inverse problem of `getLoopsToShapesMap`. + Return the empty AffineMap when such an AffineMap cannot be constructed. + The default behavior is based on a very simple inference procedure that + only works with permutation affine maps. + A more advanced Tensor-Comprehension like inference is possible but has + proven to be ambiguous in unfavorable case. + A safer and more robust alternative is to allow each op to define + its own AffineMap. }], - /*retTy=*/"bool", - /*methodName=*/"isInitTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*retTy=*/"AffineMap", + /*methodName=*/"getShapesToLoopsMap", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!$_op.isOutputTensor(opOperand)) - return false; - return payloadUsesValueFromOperand(opOperand); + return inversePermutation(getLoopsToShapesMap()); }] >, InterfaceMethod< /*desc=*/[{ - Return the `opOperand` rank or zero for scalars. + Checks if the given operands can be dropped, and the remaining + operands can still compute the bounds of the op. }], - /*retTy=*/"int64_t", - /*methodName=*/"getRank", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"bool", + /*methodName=*/"canOpOperandsBeDropped", + /*args=*/(ins "ArrayRef":$droppedOperands), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getRank(); - return 0; + return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); }] >, InterfaceMethod< /*desc=*/[{ - Return the output block arguments of the region. + Like `getShape`, but only returns statically-known information, without + generating any new IR. For each shape dimension, returns >=0 if that + dimension is statically known, or ShapeType::kDynamicSize otherwise. }], - /*retTy=*/"Block::BlockArgListType", - /*methodName=*/"getRegionOutputArgs", + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticShape", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return getBlock()->getArguments().take_back(this->getNumOutputs()); + SmallVector res; + auto iface = cast(*this->getOperation()); + for (OpOperand *opOperand : iface.getInputAndOutputOperands()) + llvm::append_range(res, getShape(opOperand)); + return res; }] >, InterfaceMethod< /*desc=*/[{ - Return the `opOperand` shape or an empty vector for scalars. + Returns the statically-known loop ranges. Composes + `getShapesToLoopsMap()` with the result of `getStaticShape`. + Returns ShapeType::kDynamicSize for non-statically-known loop ranges. + This is expected to be called by a valid Linalg op }], - /*retTy=*/"ArrayRef", - /*methodName=*/"getShape", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticLoopRanges", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getShape(); - return {}; + SmallVector viewSizes = getStaticShape(); + AffineMap invertedMap = getShapesToLoopsMap(); + assert(invertedMap && "expected a valid Linalg op to call the method"); + return invertedMap.compose(viewSizes); }] >, + //===------------------------------------------------------------------===// + // Other static interface methods. + //===------------------------------------------------------------------===// + StaticInterfaceMethod< + /*desc=*/[{ + Returns the region builder for constructing the body for linalg.generic. + Returns a null function if this named op does not define a region + builder. + }], + /*retTy=*/"std::function)>", + /*methodName=*/"getRegionBuilder", + (ins), + [{ return ConcreteOp::getRegionBuilder(); }] + >, InterfaceMethod< /*desc=*/[{ - Return true if the `opOperand` is a scalar value. + Return true if all the indexing maps are projected permutations. + Otherwise return false. }], /*retTy=*/"bool", - /*methodName=*/"isScalar", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return !opOperand->get().getType().template isa(); + /*methodName=*/"hasOnlyProjectedPermutations", + (ins), + [{ + return llvm::all_of($_op.getIndexingMapsArray(), + [](AffineMap map) { return map.isProjectedPermutation(); }); }] + > + ]; + + let extraClassDeclaration = [{ + /// Return the flat list of all operand dimension sizes in the order they + /// appear in the operands. + SmallVector createFlatListOfOperandDims(OpBuilder &, Location); + + /// Return the flat list of all operands' static dimension sizes in the + /// order they appear in the operands. All operand dimension sizes have to + /// be statically known. + SmallVector createFlatListOfOperandStaticDims(); + + /// Create the loop ranges to materialize the computation over the current + /// operands. This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandDims`. + SmallVector createLoopRanges(OpBuilder &b, Location loc); + + /// Compute the static loop sizes necessary to vectorize the computation. + /// This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandStaticDims`. + SmallVector computeStaticLoopSizes(); + + /// Returns the value that expresses the shape of the output in terms of + /// shape of the input operands where possible + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes); + + // TODO: Remove once prefixing is flipped. + ArrayAttr getIteratorTypes() { return iterator_types(); } + + //========================================================================// + // Helper functions to mutate the `operand_segment_sizes` attribute. + // These are useful when cloning and changing operand types. + //========================================================================// + void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } + void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } + + private: + void setOperandSegmentAt(unsigned idx, unsigned val) { + auto attr = (*this)->getAttr("operand_segment_sizes") + .cast(); + unsigned i = 0; + auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), + [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); + getOperation()->setAttr("operand_segment_sizes", newAttr); + } + }]; + + let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; + let verifyWithRegions = 1; +} + +// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. +def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { + let cppNamespace = "::mlir::linalg"; + let methods = [ + //===------------------------------------------------------------------===// + // Num input/output arguments handling. + //===------------------------------------------------------------------===// + // `inputs` must be defined by each op that wants to implement the + // LinalgStructuredInterface. + InterfaceMethod< + /*desc=*/[{ + Return the input shape operands. + }], + /*retTy=*/"ValueRange", + /*methodName=*/"inputs", + /*args=*/(ins) >, + // These special methods rely on `inputs` and `outputs` being defined by + // each op that wants to implement the LinalgStructuredInterface. InterfaceMethod< /*desc=*/[{ - Return the block argument for an `opOperand`. + Return the number of inputs. }], - /*retTy=*/"BlockArgument", - /*methodName=*/"getTiedBlockArgument", - /*args=*/(ins "OpOperand *":$opOperand), + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return getBlock()->getArgument(opOperand->getOperandNumber()); + return $_op.getInputs().size(); }] >, + // `outputs` must be defined by each op that wants to implement the + // LinalgStructuredInterface. InterfaceMethod< /*desc=*/[{ - Return the operand for a `blockArgument`. + Return the output shape operands. }], - /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedOpOperand", - /*args=*/(ins "BlockArgument":$blockArgument), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(blockArgument.getOwner() == getBlock()); - return &this->getOperation()->getOpOperand( - blockArgument.getArgNumber()); - }] + /*retTy=*/"ValueRange", + /*methodName=*/"outputs", + /*args=*/(ins) >, InterfaceMethod< /*desc=*/[{ - Return the input or output indexing map for `opOperand`. + Return the number of outputs. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMap", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"int64_t", + /*methodName=*/"getNumOutputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + opOperand->getOperandNumber()); + return $_op.outputs().size(); }] >, InterfaceMethod< /*desc=*/[{ - Return the indexing map for a `result`. + Return the number of inputs and outputs. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMapForResult", - /*args=*/(ins "OpResult":$result), + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputsAndOutputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(result.getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + getNumInputs() + - result.getResultNumber()); + return this->getOperation()->getNumOperands(); }] >, + //===------------------------------------------------------------------===// + // Input operands handling. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the result tied to `opOperand`. + Return the input operands. }], - /*retTy=*/"OpResult", - /*methodName=*/"getTiedOpResult", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputOperands", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); - assert(resultIndex >= 0 && - resultIndex < this->getOperation()->getNumResults() ); - return this->getOperation()->getResult(resultIndex); + int64_t numInputs = getNumInputs(); + OpOperandVector result; + result.reserve(numInputs); + llvm::transform( + this->getOperation()->getOpOperands().take_front(numInputs), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the value yielded by the region corresponding to an output - `opOperand`. + Return the `i`-th input operand. }], - /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedYieldValue", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperand*", + /*methodName=*/"getInputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); - assert(resultIndex >= 0 && - resultIndex < this->getOperation()->getNumResults()); - Operation *yieldOp = getBlock()->getTerminator(); - return &yieldOp->getOpOperand(resultIndex); + assert(i >= 0 && i < getNumInputs()); + return &this->getOperation()->getOpOperand(i); }] >, - //===------------------------------------------------------------------===// - // Other interface methods. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the single block constituting the body of the operation by - calling the getBody method on the concrete operation. + Return the subset of input operands that are of buffer type. }], - /*retTy=*/"Block*", - /*methodName=*/"getBlock", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputBufferOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - // Assume the concrete operation implements the - // SingleBlockImplicitTerminator trait. - return $_op.getBody(); + OpOperandVector result; + result.reserve(getNumInputs()); + llvm::copy_if(getInputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the iterator types attribute within the current operation. + Return the subset of input operands that are of tensor type. }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"iterator_types", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputTensorOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.iterator_types(); + OpOperandVector result; + result.reserve(getNumInputs()); + llvm::copy_if(getInputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, + //===------------------------------------------------------------------===// + // Output operands handling. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return true if the indexing map is depending on the current op instance. - This means that the indexing map is dynamically synthesized by using the - op instance's concrete attributes, instead of being static for all - instances of the same op kind. + Return the output operands. }], - /*retTy=*/"bool", - /*methodName=*/"hasDynamicIndexingMaps", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputOperands", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ return false; }] + /*defaultImplementation=*/[{ + int64_t numOutputs = getNumOutputs(); + OpOperandVector result; + result.reserve(numOutputs); + llvm::transform( + this->getOperation()->getOpOperands() + .take_back(numOutputs), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; + }] >, InterfaceMethod< /*desc=*/[{ - Verify all attributes used by indexing maps are valid. + Return the `i`-th output operand. }], - /*retTy=*/"LogicalResult", - /*methodName=*/"verifyIndexingMapRequiredAttributes", - /*args=*/(ins), + /*retTy=*/"OpOperand*", + /*methodName=*/"getOutputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", - /*defaultImplementation=*/[{ return success(); }] + /*defaultImplementation=*/[{ + assert(i >= 0 && i < getNumOutputs()); + return &this->getOperation()->getOpOperand(getNumInputs() + i); + }] >, InterfaceMethod< /*desc=*/[{ - Return the indexing maps attribute within the current operation. + Set the `i`-th output operand. }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"getIndexingMaps" + /*retTy=*/"void", + /*methodName=*/"setOutputOperand", + /*args=*/(ins "int64_t":$i, "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i >= 0 && i < getNumOutputs()); + this->getOperation()->setOperand(getNumInputs() + i, value); + }] >, InterfaceMethod< /*desc=*/[{ - Return the indexing maps within the current operation. + Return the subset of output operands that are of buffer type. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getIndexingMapsArray", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputBufferOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto range = $_op.getIndexingMaps() - .template getAsValueRange(); - return {range.begin(), range.end()}; + OpOperandVector result; + result.reserve(getNumOutputs()); + llvm::copy_if(getOutputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return true if any of the operands has a dynamic shape. + Return the subset of output operands that are of tensor type. }], - /*retTy=*/"bool", - /*methodName=*/"hasDynamicShape", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputTensorOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::any_of(getStaticShape(), ShapedType::isDynamic); + OpOperandVector result; + result.reserve(getNumOutputs()); + llvm::copy_if(getOutputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op has only MemRef input and outputs. - }], - /*retTy=*/"bool", - /*methodName=*/"hasBufferSemantics", + Return the types of the subset of output operands that are of buffer type. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputBufferTypes", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(this->getOperation()->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); - }); + SmallVector result; + result.reserve(getNumOutputs()); + llvm::transform(getOutputBufferOperands(), + std::back_inserter(result), + [](OpOperand *opOperands) { + return opOperands->get().getType().cast(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op has only RankedTensor input and outputs. + Return the types of the subset of output operands that are of tensor type. }], - /*retTy=*/"bool", - /*methodName=*/"hasTensorSemantics", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputTensorTypes", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::all_of(this->getOperation()->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); + SmallVector result; + result.reserve(getNumOutputs()); + llvm::transform(getOutputTensorOperands(), + std::back_inserter(result), + [](OpOperand *opOperands) { + return opOperands->get().getType().cast(); }); + return result; }] >, + //===------------------------------------------------------------------===// + // Input and Output arguments handling. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the name registered for this op when lowering to an external - library call. + Return the range over input and output operands. }], - /*retTy=*/"std::string", - /*methodName=*/"getLibraryCallName", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputAndOutputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getLibraryCallName(); + int64_t numInputsAndOutputs = getNumInputsAndOutputs(); + OpOperandVector result; + result.reserve(numInputsAndOutputs); + llvm::transform( + this->getOperation()->getOpOperands(), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op accesses the iteration indices. + Return true if `opOperand` is an input tensor. }], /*retTy=*/"bool", - /*methodName=*/"hasIndexSemantics", - /*args=*/(ins), + /*methodName=*/"isInputTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", - /*defaultImplementation=*/"" + /*defaultImplementation=*/[{ + if (!opOperand->get().getType().template isa()) + return false; + if (opOperand->getOperandNumber() < $_op.getNumInputs()) + return true; + return false; + }] >, - //===------------------------------------------------------------------===// - // Linalg generalization hooks. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Hook to provide a custom AffineMap used to compute all the operand - subshapes given loop bounds. This is used to answer the question: "given - an iteration space over the codomain, what are the subshapes of the - operands involved in the computation". - The default behavior is to just concatenate all the indexing maps. - A custom AffineMap allows providing a map that can be used to - compute subshapes even in cases where the concatenation of indexing maps - (i.e. the data traversal order) is not a simple permutation of the loop - traversal order. It is then possible to define ops with skewed data - traversal order for which we can still easily compute hyperrectangular - loop bounds and subviews. + Return true if `opOperand` is an output tensor. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getLoopsToShapesMap", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isOutputTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto maps = $_op.getIndexingMapsArray(); - return concatAffineMaps(maps); + if (!opOperand->get().getType().template isa()) + return false; + if (opOperand->getOperandNumber() >= $_op.getNumInputs()) + return true; + return false; }] >, InterfaceMethod< /*desc=*/[{ - Hook to provide a custom AffineMap used to construct the - hyperrectangular loop iteration space given all the operand subshapes. - This is used to answer the question: - "Given a list of operand ranges, what is the subportion of the iteration - space involved in the computation". - This is the inverse problem of `getLoopsToShapesMap`. - Return the empty AffineMap when such an AffineMap cannot be constructed. - The default behavior is based on a very simple inference procedure that - only works with permutation affine maps. - A more advanced Tensor-Comprehension like inference is possible but has - proven to be ambiguous in unfavorable case. - A safer and more robust alternative is to allow each op to define - its own AffineMap. + Return true if the `opOperand` is a scalar value. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getShapesToLoopsMap", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isScalar", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return inversePermutation(getLoopsToShapesMap()); + assert(opOperand->getOwner() == this->getOperation()); + return !opOperand->get().getType().template isa(); }] >, InterfaceMethod< /*desc=*/[{ - Checks if the given operands can be dropped, and the remaining - operands can still compute the bounds of the op. + Return the result tied to `opOperand`. }], - /*retTy=*/"bool", - /*methodName=*/"canOpOperandsBeDropped", - /*args=*/(ins "ArrayRef":$droppedOperands), + /*retTy=*/"OpResult", + /*methodName=*/"getTiedOpResult", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); + assert(opOperand->getOwner() == this->getOperation()); + int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults() ); + return this->getOperation()->getResult(resultIndex); }] >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Like `getShape`, but only returns statically-known information, without - generating any new IR. For each shape dimension, returns >=0 if that - dimension is statically known, or ShapeType::kDynamicSize otherwise. + Return whether the op has only MemRef input and outputs. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticShape", + /*retTy=*/"bool", + /*methodName=*/"hasBufferSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector res; - for (OpOperand *opOperand : getInputAndOutputOperands()) - llvm::append_range(res, getShape(opOperand)); - return res; + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(this->getOperation()->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); }] >, InterfaceMethod< /*desc=*/[{ - Returns the statically-known loop ranges. Composes - `getShapesToLoopsMap()` with the result of `getStaticShape`. - Returns ShapeType::kDynamicSize for non-statically-known loop ranges. - This is expected to be called by a valid Linalg op + Return whether the op has only RankedTensor input and outputs. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticLoopRanges", + /*retTy=*/"bool", + /*methodName=*/"hasTensorSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector viewSizes = getStaticShape(); - AffineMap invertedMap = getShapesToLoopsMap(); - assert(invertedMap && "expected a valid Linalg op to call the method"); - return invertedMap.compose(viewSizes); + return llvm::all_of(this->getOperation()->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); }] >, //===------------------------------------------------------------------===// @@ -1042,27 +1135,6 @@ return b.create(state); }] >, - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location, operands - and BlockAndValueMapping. This is used to abstract away the - optional underlying region creation. This does not change the - balance between input, output_buffer and init_tensors - operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"cloneWithMapper", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands, "BlockAndValueMapping &":$bvm), - [{ - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (Region &r : $_op->getRegions()) - r.cloneInto(state.addRegion(), bvm); - return b.create(state); - }] - >, InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location, operands @@ -1083,79 +1155,9 @@ state.addRegion(); return b.create(state); }] - >, - StaticInterfaceMethod< - /*desc=*/[{ - Returns the region builder for constructing the body for linalg.generic. - Returns a null function if this named op does not define a region - builder. - }], - /*retTy=*/"std::function)>", - /*methodName=*/"getRegionBuilder", - (ins), - [{ return ConcreteOp::getRegionBuilder(); }] - >, - InterfaceMethod< - /*desc=*/[{ - Return true if all the indexing maps are projected permutations. - Otherwise return false. - }], - /*retTy=*/"bool", - /*methodName=*/"hasOnlyProjectedPermutations", - (ins), - [{ - return llvm::all_of($_op.getIndexingMapsArray(), - [](AffineMap map) { return map.isProjectedPermutation(); }); - }] > ]; - let extraClassDeclaration = [{ - /// Return the flat list of all operand dimension sizes in the order they - /// appear in the operands. - SmallVector createFlatListOfOperandDims(OpBuilder &, Location); - - /// Return the flat list of all operands' static dimension sizes in the - /// order they appear in the operands. All operand dimension sizes have to - /// be statically known. - SmallVector createFlatListOfOperandStaticDims(); - - /// Create the loop ranges to materialize the computation over the current - /// operands. This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandDims`. - SmallVector createLoopRanges(OpBuilder &b, Location loc); - - /// Compute the static loop sizes necessary to vectorize the computation. - /// This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandStaticDims`. - SmallVector computeStaticLoopSizes(); - - /// Returns the value that expresses the shape of the output in terms of - /// shape of the input operands where possible - LogicalResult reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes); - - // TODO: Remove once prefixing is flipped. - ArrayAttr getIteratorTypes() { return iterator_types(); } - - //========================================================================// - // Helper functions to mutate the `operand_segment_sizes` attribute. - // These are useful when cloning and changing operand types. - //========================================================================// - void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } - void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } - - private: - void setOperandSegmentAt(unsigned idx, unsigned val) { - auto attr = (*this)->getAttr("operand_segment_sizes") - .cast(); - unsigned i = 0; - auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), - [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); - getOperation()->setAttr("operand_segment_sizes", newAttr); - } - }]; - let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; let verifyWithRegions = 1; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -28,6 +28,7 @@ : Op, DeclareOpInterfaceMethods, + DestinationStyleOpInterface, LinalgStructuredInterface, RegionBranchOpInterface, ReifyRankedShapedTypeOpInterface], props)> { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -29,7 +29,8 @@ SmallVector argTypes; SmallVector argLocs; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : cast(linalgOp) + .getInputAndOutputOperands()) { argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); argLocs.push_back(opOperand->get().getLoc()); } diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -164,15 +164,18 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() << " and " << *dst.getOperation() << "\n"); - if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { - for (OpOperand *dstOpOperand : dst.getInputOperands()) { + if (cast(src).hasTensorSemantics() && + cast(dst).hasTensorSemantics()) { + for (OpOperand *dstOpOperand : + cast(dst).getInputOperands()) { // Check if the operand is defined by the src. auto definingOp = dstOpOperand->get().getDefiningOp(); if (definingOp && definingOp == src) addDependenceElem(DependenceType::RAW, dstOpOperand->get(), dstOpOperand); } - for (OpOperand *dstOpOperand : dst.getOutputOperands()) { + for (OpOperand *dstOpOperand : + cast(dst).getOutputOperands()) { // Check if the operand is defined by the src. auto definingOp = dstOpOperand->get().getDefiningOp(); if (definingOp && definingOp == src) { @@ -186,25 +189,32 @@ } return; } - assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && + assert(cast(src).hasBufferSemantics() && + cast(dst).hasBufferSemantics() && "unhandled dependence tracking for mixed buffer/tensor operations"); - for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W + for (OpOperand *srcOpOperand : + cast(src).getOutputBufferOperands()) { // W // RAW graph - for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R + for (OpOperand *dstOpOperand : + cast(dst).getInputBufferOperands()) // R if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); // WAW graph - for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W + for (OpOperand *dstOpOperand : + cast(dst).getOutputBufferOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); } - for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R + for (OpOperand *srcOpOperand : + cast(src).getInputBufferOperands()) { // R // RAR graph - for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R + for (OpOperand *dstOpOperand : + cast(dst).getInputBufferOperands()) // R if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); // WAR graph - for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W + for (OpOperand *dstOpOperand : + cast(dst).getOutputBufferOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -31,7 +31,9 @@ bool linalg::detail::canOpOperandsBeDroppedImpl( linalg::LinalgOp linalgOp, ArrayRef droppedOperands) { SmallVector indexingMaps; - for (auto *opOperand : linalgOp.getInputAndOutputOperands()) { + for (auto *opOperand : + cast(*linalgOp.getOperation()) + .getInputAndOutputOperands()) { if (llvm::is_contained(droppedOperands, opOperand)) continue; indexingMaps.push_back(linalgOp.getTiedIndexingMap(opOperand)); @@ -119,7 +121,10 @@ auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchContractionResult::NotLinalgOp; - if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) + if (cast(*linalgOp.getOperation()) + .getNumInputs() != 2 || + cast(*linalgOp.getOperation()) + .getNumOutputs() != 1) return MatchContractionResult::WrongNumOperands; auto mapRange = linalgOp.getIndexingMapsArray(); if (linalgOp.getNumReductionLoops() == 0) @@ -278,7 +283,10 @@ auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchConvolutionResult::NotLinalgOp; - if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1) + if (cast(*linalgOp.getOperation()) + .getNumInputs() < 2 || + cast(*linalgOp.getOperation()) + .getNumOutputs() != 1) return MatchConvolutionResult::WrongNumOperands; auto indexingMaps = linalgOp.getIndexingMapsArray(); @@ -443,11 +451,16 @@ auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchFillResult::NotLinalgOp; - if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1) + if (cast(*linalgOp.getOperation()) + .getNumInputs() != 1 || + cast(*linalgOp.getOperation()) + .getNumOutputs() != 1) return MatchFillResult::WrongNumOperands; - OpOperand *value = linalgOp.getInputOperand(0); - if (!linalgOp.isScalar(value)) + OpOperand *value = cast(*linalgOp.getOperation()) + .getInputOperand(0); + if (!cast(*linalgOp.getOperation()) + .isScalar(value)) return MatchFillResult::NotScalarInput; return MatchFillResult::Success; @@ -498,7 +511,9 @@ SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; - for (OpOperand *opOperand : getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(*this->getOperation()) + .getInputAndOutputOperands()) { for (int64_t i = 0, e = getRank(opOperand); i < e; ++i) res.push_back(createFoldedDimOp(b, loc, opOperand->get(), i)); } @@ -508,7 +523,9 @@ SmallVector LinalgOp::createFlatListOfOperandStaticDims() { SmallVector res; assert(!hasDynamicShape() && "expected operands to have static shapes"); - for (OpOperand *opOperand : getInputAndOutputOperands()) + for (OpOperand *opOperand : + cast(*this->getOperation()) + .getInputAndOutputOperands()) llvm::append_range(res, getShape(opOperand)); return res; } @@ -570,9 +587,10 @@ getResultsPositionInLoopsToShapeMap(LinalgOp &op) { int64_t inputRankSum = 0; int64_t outputRankSum = 0; - for (OpOperand *input : op.getInputOperands()) + auto iface = cast(*op.getOperation()); + for (OpOperand *input : iface.getInputOperands()) inputRankSum += op.getRank(input); - for (OpOperand *output : op.getOutputOperands()) + for (OpOperand *output : iface.getOutputOperands()) outputRankSum += op.getRank(output); return {inputRankSum, inputRankSum + outputRankSum}; } @@ -616,7 +634,9 @@ createFlatListOfOperandDims(b, loc)); int64_t pos = 0; ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); - for (OpOperand *opOperand : getOutputOperands()) { + for (OpOperand *opOperand : + cast(*this->getOperation()) + .getOutputOperands()) { SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { if (checkDimExpr.visit(shapeExprs[pos])) @@ -637,18 +657,28 @@ // This means an op that constructs a tensor out of indices cannot be a // LinalgOp at the moment. For now this will have to be a special op until we // have output shape operands that are not tensors. - int64_t numInputs = linalgOp.getNumInputs(); - int64_t numOutputs = linalgOp.getNumOutputs(); + int64_t numInputs = + cast(*linalgOp.getOperation()) + .getNumInputs(); + int64_t numOutputs = + cast(*linalgOp.getOperation()) + .getNumOutputs(); if (numOutputs == 0) return op->emitOpError("expected at least one output operand"); if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) return failure(); // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != linalgOp.getOutputTensorOperands().size()) + if (op->getNumResults() != + cast(*linalgOp.getOperation()) + .getOutputTensorOperands() + .size()) return op->emitOpError("expected the number of results (") << op->getNumResults() << ") to be equal to the number of output tensors (" - << linalgOp.getOutputTensorOperands().size() << ")"; + << cast(*linalgOp.getOperation()) + .getOutputTensorOperands() + .size() + << ")"; // Check all iterator types are known. auto iteratorTypesRange = @@ -667,13 +697,18 @@ // All input/output operands must be indexed. if (static_cast(linalgOp.getIndexingMapsArray().size()) != - linalgOp.getNumInputsAndOutputs()) + cast(*linalgOp.getOperation()) + .getNumInputsAndOutputs()) return op->emitOpError("expected the number of indexing_map (") << linalgOp.getIndexingMapsArray().size() << ") to be equal to the number of input/output operands (" - << linalgOp.getNumInputsAndOutputs() << ")"; + << cast(*linalgOp.getOperation()) + .getNumInputsAndOutputs() + << ")"; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(*linalgOp.getOperation()) + .getInputAndOutputOperands()) { AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); // Symbols disallowed. @@ -703,14 +738,22 @@ // This allows simpler verification of output operands vs result types // without premature tracking of which operand is what in mixed-mode. // TODO: relax when mixed-mode needs to pass verification. - if (!linalgOp.getOutputBufferOperands().empty() && - !linalgOp.getOutputTensorOperands().empty()) + if (!cast(*linalgOp.getOperation()) + .getOutputBufferOperands() + .empty() && + !cast(*linalgOp.getOperation()) + .getOutputTensorOperands() + .empty()) return op->emitOpError( "expected output operands to all have tensor type or " "all have buffer type"); - for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) { - OpResult result = linalgOp.getTiedOpResult(opOperand); + for (OpOperand *opOperand : + cast(*linalgOp.getOperation()) + .getOutputTensorOperands()) { + OpResult result = + cast(*linalgOp.getOperation()) + .getTiedOpResult(opOperand); if (result.getType() != opOperand->get().getType()) return op->emitOpError("expected type of operand #") << opOperand->getOperandNumber() << " (" @@ -720,7 +763,9 @@ } // Output tensor indexing map may not depend on reduction indices. - for (OpOperand *opOperand : linalgOp.getOutputOperands()) { + for (OpOperand *opOperand : + cast(*linalgOp.getOperation()) + .getOutputOperands()) { AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); for (AffineExpr expr : indexingMap.getResults()) { for (unsigned pos : redDims) { @@ -732,7 +777,9 @@ } return op->emitOpError( "unexpected output tensor expression in indexing map #") - << (opOperand->getOperandNumber() - linalgOp.getNumInputs()) + << (opOperand->getOperandNumber() - + cast(*linalgOp.getOperation()) + .getNumInputs()) << " a.k.a '" << exprStr << "' is function of reduction iterator 'd" << pos << "'"; } @@ -756,11 +803,14 @@ // not used). Block &block = linalgOp->getRegion(0).front(); - if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) + if (cast(*linalgOp.getOperation()) + .getNumInputsAndOutputs() != block.getNumArguments()) return op->emitOpError("expected as many non-induction variable region " "arguments as the number of input/output operands"); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(*linalgOp.getOperation()) + .getInputAndOutputOperands()) { Type elementType = getElementTypeOrSelf(opOperand->get()); Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); if (elementType != argType) @@ -779,7 +829,9 @@ if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { for (int64_t &range : endLoopRangeValues) range -= 1; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(*linalgOp.getOperation()) + .getInputAndOutputOperands()) { AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); SmallVector startIndices = indexingMap.compose(startLoopRangeValues); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1485,15 +1485,19 @@ // Check the operand number and types must match the element types of the // LinalgOp interface's shaped operands. static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { - if (op.getNumOperands() != linalgOp.getNumOutputs()) + if (op.getNumOperands() != + cast(*linalgOp.getOperation()) + .getNumOutputs()) return op.emitOpError("expected number of yield values (") - << linalgOp.getNumOutputs() + << cast(*linalgOp.getOperation()) + .getNumOutputs() << ") to match the number of operands of the enclosing " << "LinalgOp (" << op.getNumOperands() << ")"; for (OpOperand &opOperand : op->getOpOperands()) { OpOperand *outputOperand = - linalgOp.getOutputOperand(opOperand.getOperandNumber()); + cast(*linalgOp.getOperation()) + .getOutputOperand(opOperand.getOperandNumber()); Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); if (opOperand.get().getType() != elementType) return op.emitOpError("type of yield operand ") @@ -1630,7 +1634,9 @@ LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(*op.getOperation()) + .getInputAndOutputOperands()) { // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. @@ -1653,12 +1659,15 @@ PatternRewriter &rewriter) const override { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = - llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - if (opOperand->get().isa()) - return false; - auto castOp = opOperand->get().getDefiningOp(); - return castOp && canFoldIntoConsumerOp(castOp); - }); + llvm::any_of(cast(*op.getOperation()) + .getInputAndOutputOperands(), + [&](OpOperand *opOperand) { + if (opOperand->get().isa()) + return false; + auto castOp = + opOperand->get().getDefiningOp(); + return castOp && canFoldIntoConsumerOp(castOp); + }); if (!hasTensorCastOperand) return failure(); @@ -1667,14 +1676,18 @@ SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. - for (OpOperand *opOperand : op.getInputOperands()) { + for (OpOperand *opOperand : + cast(*op.getOperation()) + .getInputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.getSource() : opOperand->get()); } // Init tensors may fold, in which case the resultType must also change. - for (OpOperand *opOperand : op.getOutputOperands()) { + for (OpOperand *opOperand : + cast(*op.getOperation()) + .getOutputOperands()) { auto tensorCastOp = opOperand->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() @@ -1683,7 +1696,8 @@ } // Clone op. Operation *newOp = - op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); + cast(*op.getOperation()) + .clone(rewriter, op->getLoc(), newResultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { @@ -1734,18 +1748,26 @@ // for this cast, i.e. producer of the out operand, is also an operation // that folds with tensor.cast consumer (like this pattern), the cast will // continue to propagate as far up the stack as it can go. - OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); + OpOperand *outOperand = + cast(*linalgOp.getOperation()) + .getOutputOperand(resultNumber); Value newOperand = rewriter.create(loc, resultType, outOperand->get()); - SmallVector newOperands = linalgOp.getInputOperands(); - SmallVector outputOperands = linalgOp.getOutputOperands(); + SmallVector newOperands = + cast(*linalgOp.getOperation()) + .getInputOperands(); + SmallVector outputOperands = + cast(*linalgOp.getOperation()) + .getOutputOperands(); outputOperands[resultNumber] = newOperand; newOperands.append(outputOperands.begin(), outputOperands.end()); SmallVector resultTypes(linalgOp->result_type_begin(), linalgOp->result_type_end()); resultTypes[resultNumber] = resultType; - Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands); + Operation *newOp = + cast(*linalgOp.getOperation()) + .clone(rewriter, loc, resultTypes, newOperands); // Create a tensor.cast operation back to the original type. Value castBack = rewriter.create( @@ -1764,7 +1786,8 @@ static void populateMap(LinalgOp linalgOp, ArrayRef operands, llvm::DenseMap &affineExprToSize) { for (OpOperand *opOperand : operands) { - if (linalgOp.isScalar(opOperand)) + if (cast(*linalgOp.getOperation()) + .isScalar(opOperand)) continue; Value src = opOperand->get(); auto sourceType = src.getType().cast(); @@ -1807,11 +1830,14 @@ bool &changeNeeded) { Value src = opOperand->get(); newOperands.push_back(src); - if (linalgOp.isScalar(opOperand)) + if (cast(*linalgOp.getOperation()) + .isScalar(opOperand)) return; auto sourceType = src.getType().cast(); Type resultType = sourceType; - if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) { + if (sourceType.hasStaticShape() && + cast(*linalgOp.getOperation()) + .isOutputTensor(opOperand)) { resultTypes.push_back(resultType); return; } @@ -1844,7 +1870,8 @@ unsigned index = opOperand->getOperandNumber(); newOperands[index] = newOperand; } - if (linalgOp.isOutputTensor(opOperand)) + if (cast(*linalgOp.getOperation()) + .isOutputTensor(opOperand)) resultTypes.push_back(resultType); } @@ -1856,7 +1883,8 @@ LogicalResult matchAndRewrite(LinalgOp linalgOp, PatternRewriter &rewriter) const override { - if (!linalgOp.hasTensorSemantics()) + if (!cast(*linalgOp.getOperation()) + .hasTensorSemantics()) return failure(); // Maps must be projected permutations. @@ -1871,7 +1899,9 @@ // For each of the affine dim expression, check if the size is known. If // known add that in the map. - populateMap(linalgOp, linalgOp.getInputAndOutputOperands(), + populateMap(linalgOp, + cast(*linalgOp.getOperation()) + .getInputAndOutputOperands(), affineExprToSize); SmallVector newOperands; @@ -1880,11 +1910,17 @@ // `changeNeeded` is `false` if the operands of `linalgOp` require no // change in their types. bool changeNeeded = false; - newOperands.reserve(linalgOp.getNumInputsAndOutputs()); - resultTypes.reserve(linalgOp.getNumOutputs()); + newOperands.reserve( + cast(*linalgOp.getOperation()) + .getNumInputsAndOutputs()); + resultTypes.reserve( + cast(*linalgOp.getOperation()) + .getNumOutputs()); // Iterate over all the operands and update the static sizes. - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(*linalgOp.getOperation()) + .getInputAndOutputOperands()) { createNewOperandWithStaticSizes(loc, rewriter, opOperand, affineExprToSize, linalgOp, newOperands, resultTypes, changeNeeded); @@ -1897,7 +1933,8 @@ // Clone op. Operation *newOp = - linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands); + cast(*linalgOp.getOperation()) + .clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -64,12 +64,12 @@ "expected single use of linalg op"); } - if (linalgOp.getNumOutputs() != 1) { + if (cast(linalgOp).getNumOutputs() != 1) { return rewriter.notifyMatchFailure(sliceOp, "expected single output of linalg op"); } - if (!linalgOp.hasTensorSemantics()) { + if (!cast(linalgOp).hasTensorSemantics()) { return rewriter.notifyMatchFailure(sliceOp, "expected tensor of linalg op"); } @@ -81,7 +81,8 @@ return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction"); } - OpOperand *outOperand = linalgOp.getOutputOperand(0); + OpOperand *outOperand = + cast(linalgOp).getOutputOperand(0); AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand); if (!indexingMap.isProjectedPermutation()) { return rewriter.notifyMatchFailure( @@ -113,19 +114,21 @@ tileSizes[position] = sliceOp.getMixedSizes()[result.index()]; } - SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector valuesToTile = + cast(linalgOp).getInputAndOutputOperands(); SmallVector tiledOperands = makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes, sizeBounds, /*omitPartialTileCheck=*/true); SmallVector resultTensorTypes; - for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) + for (OpOperand *opOperand : + cast(linalgOp).getOutputTensorOperands()) resultTensorTypes.push_back( tiledOperands[opOperand->getOperandNumber()].getType()); - Operation *newOp = - linalgOp.clone(rewriter, linalgLoc, resultTensorTypes, tiledOperands); + Operation *newOp = cast(linalgOp).clone( + rewriter, linalgLoc, resultTensorTypes, tiledOperands); rewriter.replaceOp(sliceOp, newOp->getResults()); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -21,7 +21,8 @@ namespace { /// Generic conversion for any LinalgOp on tensors. -static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op, +static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, + DestinationStyleOpInterface op, const BufferizationOptions &options) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); @@ -105,7 +106,7 @@ SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { - auto genericOp = cast(op); + auto genericOp = cast(op); // The i-th OpResult may alias with the i-th "out" tensor. return {genericOp.getOutputOperand(opResult.getResultNumber())}; @@ -113,7 +114,7 @@ SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - auto genericOp = cast(op); + auto genericOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. if (genericOp.isOutputTensor(&opOperand)) @@ -128,7 +129,8 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - return bufferizeLinalgOp(rewriter, cast(op), options); + return bufferizeLinalgOp(rewriter, cast(op), + options); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -69,7 +69,8 @@ bool fromSubViewOpOnly = false) { // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(op).getInputAndOutputOperands()) { // The method `getRangeFromOperandShape` requires using SubViewOp or // ExtractSliceOps. If the value isn't defined from there continue. // todo: The method should be adapted to get the values from @@ -105,7 +106,8 @@ } static SmallVector getTiledOperands(LinalgOp producer) { - return producer.getInputAndOutputOperands(); + return cast(producer) + .getInputAndOutputOperands(); } /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` @@ -138,7 +140,8 @@ } SmallVector clonedShapes; - clonedShapes.reserve(producer.getNumInputsAndOutputs()); + clonedShapes.reserve( + cast(producer).getNumInputsAndOutputs()); // Compute subranges for all tensor input/output operands. clonedShapes.append(makeTiledShapes( @@ -151,7 +154,8 @@ // fully dynamic at construction time. SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); - for (RankedTensorType t : producer.getOutputTensorTypes()) { + for (RankedTensorType t : + cast(producer).getOutputTensorTypes()) { unsigned rank = t.getRank(); SmallVector staticOffsetsVector( rank, ShapedType::kDynamicStrideOrOffset); @@ -163,7 +167,8 @@ staticStridesVector)); } - Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes); + Operation *clonedOp = cast(producer).clone( + b, loc, resultTypes, clonedShapes); // Shift all IndexOp results by the tile offset. SmallVector allIvs = llvm::to_vector( @@ -205,11 +210,11 @@ // Some of these will be lifted in the future with better analysis. static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, LinalgOp consumer) { - assert(producer.hasBufferSemantics() && + assert(cast(producer).hasBufferSemantics() && "expected linalg op with buffer semantics"); - assert(consumer.hasBufferSemantics() && + assert(cast(consumer).hasBufferSemantics() && "expected linalg op with buffer semantics"); - if (producer.getNumOutputs() != 1) { + if (cast(producer).getNumOutputs() != 1) { LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); return false; } @@ -274,11 +279,12 @@ static FailureOr findFusableProducer(OpOperand &consumerOpOperand, const LinalgDependenceGraph &dependenceGraph) { - LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " - << consumerOpOperand.get() << " @" - << consumerOpOperand.getOperandNumber() << " in " - << *consumerOpOperand.getOwner() << "\n"); - LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); + LLVM_DEBUG( + llvm::dbgs() + << "findFusableProducer for: " << consumerOpOperand.get() << " @" + << cast(consumerOpOperand).getOperandNumber() + << " in " << *consumerOpOperand.getOwner() << "\n"); + auto consumerOp = dyn_cast(consumerOpOperand.getOwner()); if (!consumerOp) return failure(); @@ -314,11 +320,13 @@ // If the producer and consumer have tensor semantics, the only dependence // between them is through a RAW dependence and they are fusable by // construction. For buffer semantics need additional checks. - if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && + if (cast(producer).hasBufferSemantics() && + cast(consumerOp).hasBufferSemantics() && isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), producer)) return dependence; - if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { + if (cast(producer).hasTensorSemantics() && + cast(consumerOp).hasTensorSemantics()) { assert(dependence.dependenceType == LinalgDependenceGraph::DependenceType::RAW); return dependence; @@ -441,7 +449,8 @@ b.setInsertionPoint(consumerOp); LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); OpOperand *opOperand = - producerOp.getOutputOperand(producerOpResult.getResultNumber()); + cast(producerOp) + .getOutputOperand(producerOpResult.getResultNumber()); LinalgOp fusedProducer = fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand), consumerOpOperand); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -70,7 +70,8 @@ // Get the indexing map of the `producerOp` output operand that matches // ´producerResult´. AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( - producerOp.getOutputOperand(producerResult.getResultNumber())); + cast(producerOp) + .getOutputOperand(producerResult.getResultNumber())); // Keep only the tiled result slice dimensions of `producerIndexingMap`. AffineMap tiledProducerIndexingSubMap = @@ -162,7 +163,8 @@ allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; } erase_value(tileIvs, OpFoldResult()); - SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); + SmallVector tiledOperands = + cast(producerOp).getInputAndOutputOperands(); tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, tileSizes, producerLoopBounds, /**omitPartialTileCheck=*/false); @@ -174,16 +176,20 @@ // output operand. if (iterArg) { OpOperand *outputOperand = - producerOp.getOutputOperand(producerResult.getResultNumber()); + cast(producerOp) + .getOutputOperand(producerResult.getResultNumber()); iterArg->set(outputOperand->get()); tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); } // Clone the producer using the tiled producer operands. - TypeRange resultTypes = ValueRange(tiledOperands) - .take_back(producerOp.getNumOutputs()) - .getTypes(); - LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); + TypeRange resultTypes = + ValueRange(tiledOperands) + .take_back( + cast(producerOp).getNumOutputs()) + .getTypes(); + LinalgOp clonedOp = cast(producerOp) + .clone(b, loc, resultTypes, tiledOperands); // Shift all IndexOp results by the tile offset. offsetIndices(b, clonedOp, allIvs); @@ -450,7 +456,8 @@ tileLoopNest.fuseProducer(b, candidates.pop_back_val()); if (failed(fusedProducer)) continue; - candidates.append(fusedProducer->getInputAndOutputOperands()); + candidates.append(cast(*fusedProducer) + .getInputAndOutputOperands()); } }; @@ -469,13 +476,17 @@ if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, tileDistribution))) return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); + fuseProducersGreedily( + cast(*tileLoopNest.getRootOp()) + .getOutputOperands()); // Tile the remaining loops and fuse the input operands. if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, tileDistribution))) return failure(); - fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); + fuseProducersGreedily( + cast(*tileLoopNest.getRootOp()) + .getInputOperands()); // Exit if the tile loop nest is empty since all tile sizes are zero. if (tileLoopNest.isEmpty()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -44,12 +44,15 @@ if (failed(generalizeNamedOpPrecondition(linalgOp))) return rewriter.notifyMatchFailure(linalgOp, "preconditions not met"); - SmallVector inputOperands = linalgOp.getInputOperands(); - SmallVector outputOperands = linalgOp.getOutputOperands(); + SmallVector inputOperands = + cast(linalgOp).getInputOperands(); + SmallVector outputOperands = + cast(linalgOp).getOutputOperands(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector iterators = llvm::to_vector<4>( linalgOp.iterator_types().getAsValueRange()); - SmallVector resultTypes = linalgOp.getOutputTensorTypes(); + SmallVector resultTypes = + cast(linalgOp).getOutputTensorTypes(); SmallVector types(resultTypes.begin(), resultTypes.end()); // All named ops have a region attached that can be inlined. diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -110,7 +110,7 @@ /// LinalgOp. static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) { for (OpOperand &use : padOp.getResult().getUses()) { - auto linalgUser = dyn_cast(use.getOwner()); + auto linalgUser = dyn_cast(use.getOwner()); if (!linalgUser || !linalgUser.isInputTensor(&use)) { LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp) << "\nthat is not an input tensor of a LinalgOp, " diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -122,15 +122,17 @@ assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); SmallVector indexedValues; - indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + indexedValues.reserve( + cast(linalgOp).getNumInputsAndOutputs()); auto allIvsPlusDims = SmallVector(allIvs.begin(), allIvs.end()); // TODO: Avoid the loads if the corresponding argument of the // region has no uses. // 1.a. Emit load from input operand or for scalars access the operand itself. - for (OpOperand *inputOperand : linalgOp.getInputOperands()) { - if (linalgOp.isScalar(inputOperand)) { + for (OpOperand *inputOperand : + cast(linalgOp).getInputOperands()) { + if (cast(linalgOp).isScalar(inputOperand)) { indexedValues.push_back(inputOperand->get()); continue; } @@ -140,7 +142,8 @@ b.create(loc, inputOperand->get(), indexing)); } // 1.b. Emit load from output views. - for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { + for (OpOperand *outputOperand : + cast(linalgOp).getOutputOperands()) { SmallVector indexing = makeCanonicalAffineApplies( b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims); indexedValues.push_back( @@ -152,7 +155,8 @@ // 3. Emit store. SmallVector, 8> indexing; SmallVector outputBuffers; - for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { + for (OpOperand *outputOperand : + cast(linalgOp).getOutputBufferOperands()) { indexing.push_back(makeCanonicalAffineApplies( b, loc, linalgOp.getTiedIndexingMap(outputOperand), allIvsPlusDims)); outputBuffers.push_back(outputOperand->get()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -31,7 +31,7 @@ Value iZp, Value kZp, Value init, Attribute stride, Attribute dilation, PatternRewriter &rewriter) { Location loc = operation->getLoc(); - auto linalgOp = dyn_cast(operation); + auto linalgOp = dyn_cast(operation); // Exit out on the memref version of this operation. if (!linalgOp || !linalgOp.hasTensorSemantics()) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -143,13 +143,16 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( LinalgOp linalgOp, const LinalgPromotionOptions &options) : subViews(), alignment(options.alignment) { - assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand"); + assert(cast(linalgOp).hasBufferSemantics() && + "revisit usage of shaped operand"); auto vUseFullTileBuffers = options.useFullTileBuffers.value_or(llvm::SmallBitVector()); - vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(), - options.useFullTileBuffersDefault); + vUseFullTileBuffers.resize( + cast(linalgOp).getNumInputsAndOutputs(), + options.useFullTileBuffersDefault); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : cast(linalgOp) + .getInputAndOutputOperands()) { int64_t operandNumber = opOperand->getOperandNumber(); if (options.operandsToPromote && !options.operandsToPromote->count(operandNumber)) @@ -314,7 +317,8 @@ static FailureOr promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op, LinalgOpInstancePromotionOptions options, DataLayout &layout) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + assert(cast(op).hasBufferSemantics() && + "expected linalg op with buffer semantics"); // 1. Promote the specified views and use them in the new op. auto promotedBuffersAndViews = promoteSubViews(b, options, layout); @@ -326,10 +330,12 @@ // operands are not views. This is to support cases such as FillOp taking // extra scalars etc. Keep a reference to output buffers; SmallVector opViews; - opViews.reserve(op.getNumInputsAndOutputs()); + opViews.reserve( + cast(op).getNumInputsAndOutputs()); SmallVector, 8> writebackViews; writebackViews.reserve(promotedBuffersAndViews->size()); - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + for (OpOperand *opOperand : + cast(op).getInputAndOutputOperands()) { int64_t operandNumber = opOperand->getOperandNumber(); if (options.subViews.count(operandNumber) != 0) { if (options.useFullTileBuffers[opOperand->get()]) @@ -338,7 +344,7 @@ else opViews.push_back( (*promotedBuffersAndViews)[operandNumber].partialLocalView); - if (operandNumber >= op.getNumInputs()) + if (operandNumber >= cast(op).getNumInputs()) writebackViews.emplace_back(std::make_pair( opOperand->get(), (*promotedBuffersAndViews)[operandNumber].partialLocalView)); @@ -366,7 +372,7 @@ LogicalResult mlir::linalg::promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options) { - LinalgOp linalgOp = dyn_cast(op); + auto linalgOp = dyn_cast(op); // Transformation applies to buffers only. if (!linalgOp || !linalgOp.hasBufferSemantics()) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -62,8 +62,10 @@ PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &filter, bool useAlloc) { - if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() || - op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 || + if (failed(filter.checkAndNotify(b, op)) || + !cast(op).hasTensorSemantics() || + op.getNumReductionLoops() != 1 || + cast(op).getNumOutputs() != 1 || !op.hasOnlyProjectedPermutations()) return b.notifyMatchFailure(op, "precondition not met"); @@ -116,7 +118,8 @@ SmallVector newInputs; SmallVector newMaps; // Calculate the new shapes and indexing maps of the input operands. - for (OpOperand *operand : op.getInputOperands()) { + for (OpOperand *operand : + cast(op).getInputOperands()) { AffineMap map = op.getTiedIndexingMap(operand); SmallVector newShape; SmallVector exprs; @@ -156,8 +159,10 @@ // Calculate the new output map and shape, we insert the new dimension based // on the index returned by `controlSplitReductionFn`. SmallVector newOutputShape; - AffineMap oldOutputMap = op.getTiedIndexingMap(op.getOutputOperand(0)); - ArrayRef oldShape = op.getShape(op.getOutputOperand(0)); + AffineMap oldOutputMap = op.getTiedIndexingMap( + cast(op).getOutputOperand(0)); + ArrayRef oldShape = + op.getShape(cast(op).getOutputOperand(0)); SmallVector outputExpr; for (unsigned idx : llvm::seq(0, oldOutputMap.getNumResults() + 1)) { @@ -208,7 +213,8 @@ // from the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector outputOperands = op.getOutputOperands(); + SmallVector outputOperands = + cast(op).getOutputOperands(); SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { @@ -307,7 +313,8 @@ return b.notifyMatchFailure(op, "unknown reduction neutral"); // TODO: relax this when multi-reduction support is available. - if (op.getNumOutputs() != static_cast(neutralElements.size())) + if (cast(op).getNumOutputs() != + static_cast(neutralElements.size())) return b.notifyMatchFailure(op, "expect one reduction per output"); // Rewrite part. @@ -327,11 +334,12 @@ // For now assume outputs are 1-1 with reduction neutralElements. // TODO: generalize when multi-reduction support is available. SmallVector newOutputs; - newOutputs.reserve(op.getNumOutputs()); + newOutputs.reserve(cast(op).getNumOutputs()); SmallVector initOrAllocTensorOps; SmallVector fillOps; - fillOps.reserve(op.getNumOutputs()); - for (auto it : llvm::zip(op.outputs(), neutralElements)) { + fillOps.reserve(cast(op).getNumOutputs()); + for (auto it : llvm::zip(cast(op).outputs(), + neutralElements)) { Value rankedTensor = std::get<0>(it); auto t = rankedTensor.getType().cast(); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( @@ -356,8 +364,9 @@ // Step 2. Reindex / expand indexing maps. // Reindex existing input indexings: k -> k * splitFactor + k'. SmallVector newMaps; - newMaps.reserve(op.getNumInputsAndOutputs() + 1); - for (OpOperand *o : op.getInputOperands()) + newMaps.reserve( + cast(op).getNumInputsAndOutputs() + 1); + for (OpOperand *o : cast(op).getInputOperands()) newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); // Provision a new indexing for the shape-only tensor. auto nDims = op.getNumLoops() + 1; @@ -368,13 +377,14 @@ // TODO: a subset of these may not reduce along reducePos and should be // reindexed: k -> k * splitFactor + k', when multi-reduction support is // available. - for (OpOperand *o : op.getOutputOperands()) + for (OpOperand *o : cast(op).getOutputOperands()) newMaps.push_back(insertParallelDim(op, *o, reductionDimPos, reductionDimSize / splitFactor)); // Step 3. Handle operands. // Compute the new input tensors. - auto newInputs = llvm::to_vector<4>(op.inputs()); + auto newInputs = + llvm::to_vector<4>(cast(op).inputs()); // Add a single shape-only tensor to carry the dimensions without resorting to // more complex inversions. newInputs.push_back(b.create( @@ -404,8 +414,9 @@ // TODO: all results can be handled in a single GenericOp, when // multi-reduction support is available. SmallVector results; - for (auto it : - llvm::zip(genericOp->getResults(), op.outputs(), combinerOps)) { + for (auto it : llvm::zip(genericOp->getResults(), + cast(op).outputs(), + combinerOps)) { Value reindexedOutput = std::get<0>(it); Value originalOutput = std::get<1>(it); auto originalOutputType = originalOutput.getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -483,7 +483,8 @@ SmallVector resultTensorTypes = getTensorOutputTypes(op, tiledOperands); - res = op.clone(b, loc, resultTensorTypes, tiledOperands); + res = cast(op).clone(b, loc, resultTensorTypes, + tiledOperands); tensorResults = insertSlicesBack(builder, loc, op, tiledOperands, res->getResults()); return scf::ValueVector(tensorResults.begin(), tensorResults.end()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -60,11 +60,15 @@ Location loc = terminator->getLoc(); for (auto operand : llvm::enumerate(terminator->getOperands())) { Value toStore = map.lookupOrDefault(operand.value()); - OpOperand *storeInto = linalgOp.getOutputOperand(operand.index()); + OpOperand *storeInto = + cast(linalgOp).getOutputOperand( + operand.index()); auto indices = getIndicesForAccess( b, loc, linalgOp.getTiedIndexingMap(storeInto), ivs); b.create(loc, toStore, - linalgOp.getOutputOperand(operand.index())->get(), + cast(linalgOp) + .getOutputOperand(operand.index()) + ->get(), indices); } return success(); @@ -86,7 +90,7 @@ LinalgOpTy> { /// Return the destination operands. SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { - return cast(op).getOutputOperands(); + return cast(op).getOutputOperands(); } /// Return the loop iterator type. @@ -126,17 +130,19 @@ // specified could lead to out of bounds accesses. Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); - SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector valuesToTile = + cast(linalgOp).getInputAndOutputOperands(); SmallVector tiledOperands = makeTiledShapes( b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( - linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { + cast(linalgOp).getOutputTensorOperands(), + [&](OpOperand *opOperand) { return tiledOperands[opOperand->getOperandNumber()].getType(); })); - Operation *tiledOp = - linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); + Operation *tiledOp = cast(linalgOp).clone( + b, loc, resultTensorTypes, tiledOperands); offsetIndices(b, cast(tiledOp), offsets); return {tiledOp}; @@ -160,7 +166,9 @@ return makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr); })); - OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); + OpOperand *outOperand = + cast(linalgOp).getOutputOperand( + resultNumber); Value sliceOpResult = makeTiledShape(b, loc, outOperand->get(), sizes, linalgOp.getTiedIndexingMap(outOperand), offsets, @@ -224,20 +232,22 @@ Location loc, ValueRange ivs) const { auto linalgOp = cast(op); - if (!linalgOp.hasBufferSemantics()) + if (!cast(linalgOp).hasBufferSemantics()) return op->emitOpError("expected operation to have buffer semantics"); SmallVector indexedValues; - indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + indexedValues.reserve( + cast(linalgOp).getNumInputsAndOutputs()); Location linalgOpLoc = op->getLoc(); /// Load the data corresponding to the block arguments that /// represent input operands. - for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *operand : cast(linalgOp) + .getInputAndOutputOperands()) { if (!linalgOp.payloadUsesValueFromOperand(operand)) { indexedValues.push_back(nullptr); continue; } - if (linalgOp.isScalar(operand)) { + if (cast(linalgOp).isScalar(operand)) { indexedValues.push_back(operand->get()); continue; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -206,7 +206,9 @@ // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. OpOperand *currOpOperand = opOperand; - while (auto linalgOp = currOpOperand->get().getDefiningOp()) { + while ( + auto linalgOp = + currOpOperand->get().getDefiningOp()) { OpResult result = currOpOperand->get().cast(); currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); } @@ -265,7 +267,7 @@ Location loc = opToPad->getLoc(); // TODO: there are cases where we may still want to pad to larger sizes. - assert(opToPad.hasTensorSemantics() && + assert(cast(opToPad).hasTensorSemantics() && "expected operation to have tensor semantics"); OpBuilder::InsertionGuard g(b); @@ -273,8 +275,10 @@ b.setInsertionPointAfter(opToPad); // Make a copy of the shaped operands and update it. SmallVector newOperands; - newOperands.reserve(opToPad.getNumInputsAndOutputs()); - for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { + newOperands.reserve( + cast(opToPad).getNumInputsAndOutputs()); + for (OpOperand *opOperand : + cast(opToPad).getInputAndOutputOperands()) { FailureOr paddedOperand = padOperandToSmallestStaticBoundingBox( b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings); // Exit if `paddingDimensions` cannot be bounded statically. @@ -292,8 +296,11 @@ // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = - ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); - paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); + ValueRange(newOperands) + .take_back(cast(opToPad).getNumOutputs()) + .getTypes(); + paddedOp = cast(opToPad).clone( + b, loc, resultTensorTypes, newOperands); // Recover the slice out of the new static results. This keeps the original // linalg op around because it uses the dims of the original results. @@ -411,7 +418,7 @@ FailureOr mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( LinalgOp linalgOp, PatternRewriter &rewriter) const { - if (!linalgOp.hasTensorSemantics()) + if (!cast(linalgOp).hasTensorSemantics()) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); @@ -426,7 +433,8 @@ // Hoist the padding. for (const auto &en : enumerate(options.hoistPaddings)) { - if (static_cast(en.index()) >= paddedOp.getNumInputsAndOutputs()) + if (static_cast(en.index()) >= + cast(paddedOp).getNumInputsAndOutputs()) break; OpOperand *opOperand = &paddedOp->getOpOperand(en.index()); auto padOp = opOperand->get().getDefiningOp(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -144,7 +144,8 @@ static Operation *matchLinalgReduction(OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); unsigned outputPos = - outputOperand->getOperandNumber() - linalgOp.getNumInputs(); + outputOperand->getOperandNumber() - + cast(linalgOp).getNumInputs(); // Only single combiner operations are supported for now. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || @@ -257,7 +258,9 @@ // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getOutputOperand(outputs.index())); + b, vectorValue, + cast(linalgOp).getOutputOperand( + outputs.index())); if (newResult) newResults.push_back(newResult); } @@ -366,12 +369,15 @@ SmallVector> reductionOperands; for (Value operand : op->getOperands()) { auto arg = operand.dyn_cast(); - if (!arg || arg.getArgNumber() < linalgOp.getNumInputs()) + if (!arg || arg.getArgNumber() < + cast(linalgOp).getNumInputs()) continue; SmallVector reductionOps; Value reduceValue = matchReduction( linalgOp.getRegionOutputArgs(), - arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps); + arg.getArgNumber() - + cast(linalgOp).getNumInputs(), + reductionOps); if (!reduceValue) continue; reductionOperands.push_back(std::make_pair(reduceValue, operand)); @@ -448,7 +454,7 @@ mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); - if (linalgOp.getNumOutputs() == 0) + if (cast(linalgOp).getNumOutputs() == 0) return failure(); // TODO: the common vector shape is equal to the static loop sizes only when @@ -459,9 +465,10 @@ // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); Value zero = b.create(loc, 0); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : cast(linalgOp) + .getInputAndOutputOperands()) { BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber()); - if (linalgOp.isScalar(opOperand)) { + if (cast(linalgOp).isScalar(opOperand)) { bvm.map(bbarg, opOperand->get()); continue; } @@ -471,7 +478,8 @@ // if (linalgOp.getShape(opOperand).empty()) { // readType = VectorType::get({}, bbarg.getType()); // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { + if (opOperand->getOperandNumber() < + cast(linalgOp).getNumInputs()) { map = inverseAndBroadcastProjectedPermutation( linalgOp.getTiedIndexingMap(opOperand)); readType = VectorType::get(commonVectorShape, @@ -538,7 +546,8 @@ LDBG("reduction precondition failed: no reduction iterator"); return failure(); } - for (OpOperand *opOperand : op.getOutputOperands()) { + for (OpOperand *opOperand : + cast(op).getOutputOperands()) { Operation *reduceOp = matchLinalgReduction(opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { LDBG("reduction precondition failed: reduction detection failed"); @@ -1319,11 +1328,12 @@ : StructuredGenerator(builder, linalgOp), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator - if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) + if (cast(linalgOp).getNumInputs() != 2 || + cast(linalgOp).getNumOutputs() != 1) return; - lhsShaped = linalgOp.inputs()[0]; - rhsShaped = linalgOp.inputs()[1]; - resShaped = linalgOp.outputs()[0]; + lhsShaped = cast(linalgOp).inputs()[0]; + rhsShaped = cast(linalgOp).inputs()[1]; + resShaped = cast(linalgOp).outputs()[0]; lhsShapedType = lhsShaped.getType().dyn_cast(); rhsShapedType = rhsShaped.getType().dyn_cast(); resShapedType = resShaped.getType().dyn_cast(); @@ -1335,7 +1345,8 @@ return; // Check for reduction `add` preceded by `mul`. - Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0)); + Operation *reduceOp = matchLinalgReduction( + cast(linalgOp).getOutputOperand(0)); if (!reduceOp) return; llvm::Optional maybeKind; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -177,7 +177,8 @@ return false; // TODO: relax the restrictions on indexing map. - for (OpOperand *opOperand : op.getOutputOperands()) { + for (OpOperand *opOperand : + cast(op).getOutputOperands()) { if (!op.getTiedIndexingMap(opOperand).isPermutation()) return false; } @@ -402,7 +403,9 @@ if (!linalgOp) break; OpResult opResult = current.cast(); - current = linalgOp.getOutputOperand(opResult.getResultNumber())->get(); + current = cast(linalgOp) + .getOutputOperand(opResult.getResultNumber()) + ->get(); } auto padOp = current ? current.getDefiningOp() : nullptr; @@ -520,7 +523,8 @@ bodyBuilderFn, Optional distributionOptions, ArrayRef distributionTypes) { - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = + cast(linalgOp).getOutputTensorOperands(); // Create procInfo so it dominates loops, if appropriate. SmallVector procInfo; SmallVector distributionMethod; @@ -543,12 +547,16 @@ LoopNest loopNest = mlir::scf::buildLoopNest( b, loc, lbs, ubs, steps, iterArgInitValues, [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { - assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() && + assert(iterArgs.size() == cast(linalgOp) + .getOutputTensorOperands() + .size() && "expect the number of output tensors and iter args to match"); SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); + cast(linalgOp) + .getInputAndOutputOperands(); if (!iterArgs.empty()) { - operandValuesToUse = linalgOp.getInputOperands(); + operandValuesToUse = + cast(linalgOp).getInputOperands(); operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); } return bodyBuilderFn(b, loc, ivs, operandValuesToUse); @@ -579,7 +587,8 @@ ValueRange)> bodyBuilderFn, Optional, ArrayRef) { - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = + cast(linalgOp).getOutputTensorOperands(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); @@ -596,7 +605,8 @@ mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, [&](OpBuilder &b, Location loc, ValueRange ivs) { SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); + cast(linalgOp) + .getInputAndOutputOperands(); bodyBuilderFn(b, loc, ivs, operandValuesToUse); }); } @@ -745,7 +755,8 @@ bodyBuilderFn, Optional distributionOptions, ArrayRef distributionTypes) { - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = + cast(linalgOp).getOutputTensorOperands(); assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && @@ -794,7 +805,8 @@ b, loc, lbs, ubs, steps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange ivs) { SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); + cast(linalgOp) + .getInputAndOutputOperands(); bodyBuilderFn(b, loc, ivs, operandValuesToUse); }, ivs, distributionMethod); @@ -975,8 +987,9 @@ SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands) { // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. - return llvm::to_vector( - llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) { + return llvm::to_vector(llvm::map_range( + cast(op).getOutputTensorOperands(), + [&](OpOperand *opOperand) { return operands[opOperand->getOperandNumber()].getType(); })); } @@ -988,7 +1001,8 @@ tensorResults.reserve(results.size()); // Insert a insert_slice for each output tensor. unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorOperands()) { + for (OpOperand *opOperand : + cast(op).getOutputTensorOperands()) { // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. Value outputTensor = operands[opOperand->getOperandNumber()]; @@ -1047,7 +1061,8 @@ "expected one value to tile for every operand"); SmallVector> allSliceParams; allSliceParams.reserve(valuesToTile.size()); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : cast(linalgOp) + .getInputAndOutputOperands()) { Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); AffineMap map = linalgOp.getTiedIndexingMap(opOperand); @@ -1056,7 +1071,9 @@ // transformations such as padding and bufferization since the // extract/insert slice pairs make the accessed iteration argument // subdomains explicit. - if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) { + if (!isTiled(map, tileSizes) && + !cast(linalgOp).isOutputTensor( + opOperand)) { allSliceParams.push_back(llvm::None); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " << opOperand->get().getType() << "\n"); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -26,7 +26,9 @@ return; TypeSwitch(op) .Case([&](linalg::LinalgOp linalgOp) { - SmallVector inputOperands = linalgOp.getInputOperands(); + SmallVector inputOperands = + cast(linalgOp) + .getInputOperands(); operandSet.insert(inputOperands.begin(), inputOperands.end()); }) .Default([&](Operation *operation) { @@ -136,7 +138,8 @@ dyn_cast(consumer.getOwner())) { if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); - auto linalgOp = dyn_cast(use.getOwner()); + auto linalgOp = dyn_cast( + use.getOwner()); if (linalgOp && linalgOp.isOutputTensor(&use)) return true; } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -38,7 +38,8 @@ // Tile and Fuse for tensors inputs (TODO: all tensor operands). bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : cast(linalgOp) + .getInputAndOutputOperands()) { if (opOperand->get().getType().isa()) { // TODO: LinalgDependenceGraph should be able to update itself. // The current naive and expensive reconstruction of the graph should be @@ -56,7 +57,8 @@ changed = true; } else if (opOperand->get().getType().isa()) { // Tile and Fuse tensor input. - if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) + if (opOperand->getOperandNumber() >= + cast(linalgOp).getNumInputs()) continue; auto info = fuseProducerOfTensor(b, *opOperand); if (failed(info)) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2741,7 +2741,8 @@ def TestLinalgConvOp : TEST_Op<"linalg_conv_op", [AttrSizedOperandSegments, SingleBlock, - LinalgStructuredInterface, LinalgConvolutionOpInterface]> { + DestinationStyleOpInterface, LinalgStructuredInterface, + LinalgConvolutionOpInterface]> { let arguments = (ins Variadic:$inputs, Variadic:$outputs); @@ -2799,7 +2800,8 @@ def TestLinalgFillOp : TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock, - LinalgStructuredInterface, LinalgFillOpInterface]> { + DestinationStyleOpInterface, LinalgStructuredInterface, + LinalgFillOpInterface]> { let arguments = (ins Variadic:$inputs, Variadic:$outputs);