diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -143,35 +143,30 @@ } /// Hook to transform the call arguments before using them to replace the - /// callee arguments. It returns the transformation result or `argument` - /// itself if the hook did not change anything. The type of the returned value - /// has to match `targetType`, and the `argumentAttrs` dictionary is non-null - /// even if no attribute is present. The hook is called after converting the + /// callee arguments. Returns a value of the same type or the `argument` + /// itself if nothing changed. The `argumentAttrs` dictionary is non-null even + /// if no attribute is present. The hook is called after converting the /// callsite argument types using the materializeCallConversion callback, and /// right before inlining the callee region. Any operations created using the - /// provided `builder` are inserted right before the inlined callee region. - /// Example use cases are the insertion of copies for by value arguments, or - /// integer conversions that require signedness information. + /// provided `builder` are inserted right before the inlined callee region. An + /// example use case is the insertion of copies for by value arguments. virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, - Type targetType, DictionaryAttr argumentAttrs) const { return argument; } /// Hook to transform the callee results before using them to replace the call - /// results. It returns the transformation result or the `result` itself if - /// the hook did not change anything. The type of the returned values has to - /// match `targetType`, and the `resultAttrs` dictionary is non-null even if - /// no attribute is present. The hook is called right before handling + /// results. Returns a value of the same type or the `result` itself if + /// nothing changed. The `resultAttrs` dictionary is non-null even if no + /// attribute is present. The hook is called right before handling /// terminators, and obtains the callee result before converting its type /// using the `materializeCallConversion` callback. Any operations created /// using the provided `builder` are inserted right after the inlined callee - /// region. Example use cases are the insertion of copies for by value results - /// or integer conversions that require signedness information. - /// NOTE: This hook is invoked after inlining the `callable` region. + /// region. An example use case is the insertion of copies for by value + /// results. NOTE: This hook is invoked after inlining the `callable` region. virtual Value handleResult(OpBuilder &builder, Operation *call, - Operation *callable, Value result, Type targetType, + Operation *callable, Value result, DictionaryAttr resultAttrs) const { return result; } @@ -221,10 +216,9 @@ virtual Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, - Type targetType, DictionaryAttr argumentAttrs) const; virtual Value handleResult(OpBuilder &builder, Operation *call, - Operation *callable, Value result, Type targetType, + Operation *callable, Value result, DictionaryAttr resultAttrs) const; virtual void processInlinedCallBlocks( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -348,7 +348,7 @@ } Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, - Value argument, Type targetType, + Value argument, DictionaryAttr argumentAttrs) const final { if (std::optional attr = argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) { diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -105,22 +105,19 @@ Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, - Type targetType, DictionaryAttr argumentAttrs) const { auto *handler = getInterfaceFor(callable); assert(handler && "expected valid dialect handler"); - return handler->handleArgument(builder, call, callable, argument, targetType, + return handler->handleArgument(builder, call, callable, argument, argumentAttrs); } Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call, Operation *callable, Value result, - Type targetType, DictionaryAttr resultAttrs) const { auto *handler = getInterfaceFor(callable); assert(handler && "expected valid dialect handler"); - return handler->handleResult(builder, call, callable, result, targetType, - resultAttrs); + return handler->handleResult(builder, call, callable, result, resultAttrs); } void InlinerInterface::processInlinedCallBlocks( @@ -178,11 +175,10 @@ // Run the argument attribute handler for the given argument and attribute. for (auto [blockArg, argAttr] : llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { - Value newArgument = interface.handleArgument(builder, call, callable, - mapper.lookup(blockArg), - blockArg.getType(), argAttr); - assert(newArgument.getType() == blockArg.getType() && - "expected the handled argument type to match the target type"); + Value newArgument = interface.handleArgument( + builder, call, callable, mapper.lookup(blockArg), argAttr); + assert(newArgument.getType() == mapper.lookup(blockArg).getType() && + "expected the argument type to not change"); // Update the mapping to point the new argument returned by the handler. mapper.map(blockArg, newArgument); @@ -209,15 +205,10 @@ for (Operation *user : result.getUsers()) resultUsers.insert(user); - // TODO: Use the type of the call result to replace once the hook can be - // used for type conversions. At the moment, all type conversions have to be - // done using materializeCallConversion. - Type targetType = result.getType(); - - Value newResult = interface.handleResult(builder, call, callable, result, - targetType, resAttr); - assert(newResult.getType() == targetType && - "expected the handled result type to match the target type"); + Value newResult = + interface.handleResult(builder, call, callable, result, resAttr); + assert(newResult.getType() == result.getType() && + "expected the result type to not change"); // Replace the result uses except for the ones introduce by the handler. result.replaceUsesWithIf(newResult, [&](OpOperand &operand) { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -356,20 +356,19 @@ } Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, - Value argument, Type targetType, + Value argument, DictionaryAttr argumentAttrs) const final { if (!argumentAttrs.contains("test.handle_argument")) return argument; - return builder.create(call->getLoc(), targetType, + return builder.create(call->getLoc(), argument.getType(), argument); } Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, - Value result, Type targetType, - DictionaryAttr resultAttrs) const final { + Value result, DictionaryAttr resultAttrs) const final { if (!resultAttrs.contains("test.handle_result")) return result; - return builder.create(call->getLoc(), targetType, + return builder.create(call->getLoc(), result.getType(), result); }