diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2347,6 +2347,12 @@ return {arg_operand_begin() + 1, arg_operand_end()}; } + mlir::MutableOperandRange getArgOperandsMutable() { + if ((*this)->getAttrOfType(getCalleeAttrName())) + return getArgsMutable(); + return mlir::MutableOperandRange(*this, 1, getArgs().size() - 1); + } + operand_iterator arg_operand_begin() { return operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -348,6 +348,12 @@ /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -348,6 +348,12 @@ /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -348,6 +348,12 @@ /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -377,6 +377,12 @@ /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } +/// Get the argument operands to the called function as a mutable range, this is +/// required by the call interface. +MutableOperandRange GenericCallOp::getArgOperandsMutable() { + return getInputsMutable(); +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -264,6 +264,10 @@ return {arg_operand_begin(), arg_operand_end()}; } + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + operand_iterator arg_operand_begin() { return operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -83,6 +83,10 @@ return {arg_operand_begin(), arg_operand_end()}; } + MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + operand_iterator arg_operand_begin() { return operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } @@ -152,6 +156,10 @@ return {arg_operand_begin(), arg_operand_end()}; } + MutableOperandRange getArgOperandsMutable() { + return getCalleeOperandsMutable(); + } + operand_iterator arg_operand_begin() { return ++operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -616,7 +616,7 @@ }]; dag args = (ins OptionalAttr:$callee, - Variadic, + Variadic:$callee_operands, DefaultValuedAttr:$fastmathFlags, OptionalAttr:$branch_weights); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -632,6 +632,10 @@ ::mlir::Operation::operand_range getArgOperands() { return getOperands(); } + + ::mlir::MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } }]; } diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -55,6 +55,11 @@ }], "::mlir::Operation::operand_range", "getArgOperands" >, + InterfaceMethod<[{ + Returns the operands within this call that are used as arguments to the + callee as a mutable range. + }], + "::mlir::MutableOperandRange", "getArgOperandsMutable">, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1003,6 +1003,11 @@ return getOperands().drop_front(getCallee().has_value() ? 0 : 1); } +MutableOperandRange CallOp::getArgOperandsMutable() { + return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, + getCalleeOperands().size()); +} + LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { if (getNumResults() > 1) return emitOpError("must have 0 or 1 result"); @@ -1237,6 +1242,11 @@ return getOperands().drop_front(getCallee().has_value() ? 0 : 1); } +MutableOperandRange InvokeOp::getArgOperandsMutable() { + return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1, + getCalleeOperands().size()); +} + LogicalResult InvokeOp::verify() { if (getNumResults() > 1) return emitOpError("must have 0 or 1 result"); diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -208,6 +208,10 @@ return getArguments(); } +MutableOperandRange FunctionCallOp::getArgOperandsMutable() { + return getArgumentsMutable(); +} + //===----------------------------------------------------------------------===// // spirv.mlir.loop //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1263,6 +1263,10 @@ return getCalleeOperands(); } +MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { + return getCalleeOperandsMutable(); +} + void TestStoreWithARegion::getSuccessorRegions( std::optional index, ArrayRef operands, SmallVectorImpl ®ions) {