diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -584,9 +584,33 @@ ]; let extraClassDeclaration = [{ + /// Number of operands controlling the loop: lbs, ubs, steps + unsigned getNumControlOperands() { return 3 * getNumLoops(); } + ValueRange getInductionVars() { return getBody()->getArguments(); } + + /// Result that corresponds to the `outputs` argument of tensor type. + OpResult getTiedOpResult(OpOperand& opOperand) { + // No result can correspond to a memref argument. + if (opOperand.get().getType().isa()) return OpResult(); + + // Check whether the operand index is in bounds of `outputs()` arg. + int operandIndex = opOperand.getOperandNumber(); + int outputIndexStart = + getNumControlOperands() + inputs().size(); + int outputIndexEnd = outputIndexStart + outputs().size(); + if (operandIndex < outputIndexStart || operandIndex >= outputIndexEnd) + return OpResult(); + + // Count tensor arguments in `outputs` to compute the result index. + int tensorId = -1; + for (int i = outputIndexStart; i <= operandIndex; ++i) + tensorId += getOperand(i).getType().isa(); + return getOperation()->getResult(tensorId); + } + unsigned getNumLoops() { return step().size(); } }];