diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -570,6 +570,21 @@ /*methodBody=*/"", /*defaultImplementation=*/"" >, + InterfaceMethod< + /*desc=*/[{ + Return op operands that have a corresponding argument in the basic block. + By default, the block should have an argument for each operand, but there + are expection. For example, in `map` output operand isn't used in + the block. + }], + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOpOperandsMatchingBBargs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getInputAndOutputOperands(); + }] + >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -762,11 +762,11 @@ // not used). Block &block = linalgOp->getRegion(0).front(); - if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) + if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) return op->emitOpError("expected as many non-induction variable region " "arguments as the number of input/output operands"); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { Type elementType = getElementTypeOrSelf(opOperand->get()); Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); if (elementType != argType)