diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2357,6 +2357,14 @@ return calling; return getOperand(0); } + + /// Set the callee for this operation. + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + if (auto calling = + (*this)->getAttrOfType(getCalleeAttrName())) + (*this)->setAttr(getCalleeAttrName(), callee.get()); + setOperand(0, callee.get()); + } }]; } diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -728,6 +728,7 @@ * `CallOpInterface` - Used to represent operations like 'call' - `CallInterfaceCallable getCallableForCallee()` + - `void setCalleeFromCallable(CallInterfaceCallable)` * `CallableOpInterface` - Used to represent the target callee of call. - `Region * getCallableRegion()` - `ArrayRef getCallableResults()` diff --git a/mlir/docs/Tutorials/Toy/Ch-4.md b/mlir/docs/Tutorials/Toy/Ch-4.md --- a/mlir/docs/Tutorials/Toy/Ch-4.md +++ b/mlir/docs/Tutorials/Toy/Ch-4.md @@ -189,6 +189,12 @@ return getAttrOfType("callee"); } +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + /// Get the argument operands to the called function, this is required by the /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -338,6 +338,12 @@ return (*this)->getAttrOfType("callee"); } +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + /// Get the argument operands to the called function, this is required by the /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -338,6 +338,12 @@ return (*this)->getAttrOfType("callee"); } +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + /// Get the argument operands to the called function, this is required by the /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -338,6 +338,12 @@ return (*this)->getAttrOfType("callee"); } +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + /// Get the argument operands to the called function, this is required by the /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -367,6 +367,12 @@ return (*this)->getAttrOfType("callee"); } +/// Set the callee for the generic call operation, this is required by the call +/// interface. +void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); +} + /// Get the argument operands to the called function, this is required by the /// call interface. Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -271,6 +271,11 @@ CallInterfaceCallable getCallableForCallee() { return (*this)->getAttrOfType("callee"); } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } }]; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -91,6 +91,11 @@ CallInterfaceCallable getCallableForCallee() { return (*this)->getAttrOfType("callee"); } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } }]; let assemblyFormat = [{ @@ -153,6 +158,11 @@ /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getCallee(); } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + setOperand(0, callee.get()); + } }]; let hasCanonicalizeMethod = 1; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -372,6 +372,10 @@ return getTarget(); } + void setCalleeFromCallable(::mlir::CallInterfaceCallable callee) { + setTargetAttr(callee.get()); + } + ::mlir::Operation::operand_range getArgOperands() { return getOperands(); } diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -40,6 +40,15 @@ }], "::mlir::CallInterfaceCallable", "getCallableForCallee" >, + InterfaceMethod<[{ + Sets the callee of this call-like operation. A `callee` is either a + reference to a symbol, via SymbolRefAttr, or a reference to a defined + SSA value. The type of the `callee` is expected to be the same as the + return type of `getCallableForCallee`, e.g., `callee` should be + SymbolRefAttr for `func.call`. + }], + "void", "setCalleeFromCallable", (ins "::mlir::CallInterfaceCallable":$callee) + >, InterfaceMethod<[{ Returns the operands within this call that are used as arguments to the callee. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -933,6 +933,16 @@ return getOperand(0); } +void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) { + // Direct call. + if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { + auto symRef = callee.get(); + return setCalleeAttr(cast(symRef)); + } + // Indirect call, callee Value is the first operand. + return setOperand(0, callee.get()); +} + Operation::operand_range CallOp::getArgOperands() { return getOperands().drop_front(getCallee().has_value() ? 0 : 1); } @@ -1157,6 +1167,16 @@ return getOperand(0); } +void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) { + // Direct call. + if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) { + auto symRef = callee.get(); + return setCalleeAttr(cast(symRef)); + } + // Indirect call, callee Value is the first operand. + return setOperand(0, callee.get()); +} + Operation::operand_range InvokeOp::getArgOperands() { return getOperands().drop_front(getCallee().has_value() ? 0 : 1); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2576,6 +2576,11 @@ return (*this)->getAttrOfType(kCallee); } +void spirv::FunctionCallOp::setCalleeFromCallable( + CallInterfaceCallable callee) { + (*this)->setAttr(kCallee, callee.get()); +} + Operation::operand_range spirv::FunctionCallOp::getArgOperands() { return getArguments(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -495,11 +495,18 @@ let extraClassDeclaration = [{ /// Return the callee of this operation. ::mlir::CallInterfaceCallable getCallableForCallee(); + + /// Set the callee for this operation. + void setCalleeFromCallable(::mlir::CallInterfaceCallable); }]; let extraClassDefinition = [{ ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() { return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee"); } + + void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } }]; }