diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -641,6 +641,19 @@ return !opOperand->get().getType().template isa(); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the block argumen for `opOperand`. + }], + /*retTy=*/"BlockArgument", + /*methodName=*/"getTiedBlockArgument", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + return getBlock()->getArgument(opOperand->getOperandNumber()); + }] + >, InterfaceMethod< /*desc=*/[{ Return the input or output indexing map for `opOperand`. @@ -672,6 +685,23 @@ return this->getOperation()->getResult(resultIndex); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the value yielded by the operation corresponding to an `opOperand`. + }], + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedYieldValue", + /*args=*/(ins "OpOperand*":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults()); + Operation *yieldOp = getBlock()->getTerminator(); + return &yieldOp->getOpOperand(resultIndex); + }] + >, //===------------------------------------------------------------------===// // Other interface methods. //===------------------------------------------------------------------===//