diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -731,6 +731,8 @@ * `CallableOpInterface` - Used to represent the target callee of call. - `Region * getCallableRegion()` - `ArrayRef getCallableResults()` + - `ArrayAttr getCallableArgAttrs()` + - `ArrayAttr getCallableResAttrs()` ##### RegionKindInterfaces diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -299,6 +299,18 @@ /// executed. ArrayRef getCallableResults() { return getFunctionType().getResults(); } + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or + /// null if there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } + //===------------------------------------------------------------------===// // FunctionOpInterface Methods //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -84,6 +84,28 @@ }], "::llvm::ArrayRef<::mlir::Type>", "getCallableResults" >, + InterfaceMethod< + /*desc=*/[{ + Returns the argument attributes for all callable region arguments or + null if there are none. + }], + /*retType=*/"::mlir::ArrayAttr", + /*methodName=*/"getCallableArgAttrs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return {}; }] + >, + InterfaceMethod< + /*desc=*/[{ + Returns the result attributes for all callable region results or null + if there are none. + }], + /*retType=*/"::mlir::ArrayAttr", + /*methodName=*/"getCallableResAttrs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ return {}; }] + >, ]; } diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -13,6 +13,7 @@ #ifndef MLIR_TRANSFORMS_INLININGUTILS_H #define MLIR_TRANSFORMS_INLININGUTILS_H +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectInterface.h" #include "mlir/IR/Location.h" #include "mlir/IR/Region.h" @@ -141,6 +142,40 @@ return nullptr; } + /// Hook to transform the call arguments before using them to replace the + /// callee arguments. It returns the transformation result or `argument` + /// itself if the hook did not change anything. The type of the returned value + /// has to match `targetType`, and the `argumentAttrs` dictionary is non-null + /// even if no attribute is present. The hook is called after converting the + /// callsite argument types using the materializeCallConversion callback, and + /// right before inlining the callee region. Any operations created using the + /// provided `builder` are inserted right before the inlined callee region. + /// Example use cases are the insertion of copies for by value arguments, or + /// integer conversions that require signedness information. + virtual Value handleArgument(OpBuilder &builder, Operation *call, + Operation *callable, Value argument, + Type targetType, + DictionaryAttr argumentAttrs) const { + return argument; + } + + /// Hook to transform the callee results before using them to replace the call + /// results. It returns the transformation result or the `result` itself if + /// the hook did not change anything. The type of the returned values has to + /// match `targetType`, and the `resultAttrs` dictionary is non-null even if + /// no attribute is present. The hook is called right before handling + /// terminators, and obtains the callee result before converting its type + /// using the `materializeCallConversion` callback. Any operations created + /// using the provided `builder` are inserted right after the inlined callee + /// region. Example use cases are the insertion of copies for by value results + /// or integer conversions that require signedness information. + /// NOTE: This hook is invoked after inlining the `callable` region. + virtual Value handleResult(OpBuilder &builder, Operation *call, + Operation *callable, Value result, Type targetType, + DictionaryAttr resultAttrs) const { + return result; + } + /// Process a set of blocks that have been inlined for a call. This callback /// is invoked before inlined terminator operations have been processed. virtual void processInlinedCallBlocks( @@ -183,6 +218,15 @@ virtual void handleTerminator(Operation *op, Block *newDest) const; virtual void handleTerminator(Operation *op, ArrayRef valuesToRepl) const; + + virtual Value handleArgument(OpBuilder &builder, Operation *call, + Operation *callable, Value argument, + Type targetType, + DictionaryAttr argumentAttrs) const; + virtual Value handleResult(OpBuilder &builder, Operation *call, + Operation *callable, Value result, Type targetType, + DictionaryAttr resultAttrs) const; + virtual void processInlinedCallBlocks( Operation *call, iterator_range inlinedBlocks) const; }; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -103,6 +103,26 @@ handler->handleTerminator(op, valuesToRepl); } +Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call, + Operation *callable, Value argument, + Type targetType, + DictionaryAttr argumentAttrs) const { + auto *handler = getInterfaceFor(call); + assert(handler && "expected valid dialect handler"); + return handler->handleArgument(builder, call, callable, argument, targetType, + argumentAttrs); +} + +Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call, + Operation *callable, Value result, + Type targetType, + DictionaryAttr resultAttrs) const { + auto *handler = getInterfaceFor(call); + assert(handler && "expected valid dialect handler"); + return handler->handleResult(builder, call, callable, result, targetType, + resultAttrs); +} + void InlinerInterface::processInlinedCallBlocks( Operation *call, iterator_range inlinedBlocks) const { auto *handler = getInterfaceFor(call); @@ -141,6 +161,77 @@ // Inline Methods //===----------------------------------------------------------------------===// +static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, + CallOpInterface call, + CallableOpInterface callable, + IRMapping &mapper) { + if (!call || !callable) + return; + + // Unpack the argument attributes if there are any. + SmallVector argAttrs( + callable.getCallableRegion()->getNumArguments(), + builder.getDictionaryAttr({})); + if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) { + assert(arrayAttr.size() == argAttrs.size()); + for (auto [idx, attr] : llvm::enumerate(arrayAttr)) + argAttrs[idx] = dyn_cast(attr); + } + + // Run the argument attribute handler for the given argument and attribute. + for (auto [blockArg, argAttr] : + llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { + Value newArgument = interface.handleArgument(builder, call, callable, + mapper.lookup(blockArg), + blockArg.getType(), argAttr); + assert(newArgument.getType() == blockArg.getType() && + "expected the handled argument type to match the target type"); + + // Update the mapping to point the new argument returned by the handler. + mapper.map(blockArg, newArgument); + } +} + +static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, + CallOpInterface call, CallableOpInterface callable, + ValueRange results) { + if (!call || !callable) + return; + + // Unpack the result attributes if there are any. + SmallVector resAttrs(results.size(), + builder.getDictionaryAttr({})); + if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) { + assert(arrayAttr.size() == resAttrs.size()); + for (auto [idx, attr] : llvm::enumerate(arrayAttr)) + resAttrs[idx] = dyn_cast(attr); + } + + // Run the result attribute handler for the given result and attribute. + SmallVector resultAttributes; + for (auto [result, resAttr] : llvm::zip(results, resAttrs)) { + // Store the original result users before running the handler. + DenseSet resultUsers; + for (Operation *user : result.getUsers()) + resultUsers.insert(user); + + // TODO: Use the type of the call result to replace once the hook can be + // used for type conversions. At the moment, all type conversions have to be + // done using materializeCallConversion. + Type targetType = result.getType(); + + Value newResult = interface.handleResult(builder, call, callable, result, + targetType, resAttr); + assert(newResult.getType() == targetType && + "expected the handled result type to match the target type"); + + // Replace the result uses except for the ones introduce by the handler. + result.replaceUsesWithIf(newResult, [&](OpOperand &operand) { + return resultUsers.count(operand.getOwner()); + }); + } +} + static LogicalResult inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, Block::iterator inlinePoint, IRMapping &mapper, @@ -166,6 +257,11 @@ mapper)) return failure(); + // Run the argument attribute handler before inlining the callable region. + OpBuilder builder(inlineBlock, inlinePoint); + auto callable = dyn_cast(src->getParentOp()); + handleArgumentImpl(interface, builder, call, callable, mapper); + // Check to see if the region is being cloned, or moved inline. In either // case, move the new blocks after the 'insertBlock' to improve IR // readability. @@ -199,8 +295,13 @@ // Handle the case where only a single block was inlined. if (std::next(newBlocks.begin()) == newBlocks.end()) { + // Run the result attribute handler on the terminator operands. + Operation *firstBlockTerminator = firstNewBlock->getTerminator(); + builder.setInsertionPoint(firstBlockTerminator); + handleResultImpl(interface, builder, call, callable, + firstBlockTerminator->getOperands()); + // Have the interface handle the terminator of this block. - auto *firstBlockTerminator = firstNewBlock->getTerminator(); interface.handleTerminator(firstBlockTerminator, llvm::to_vector<6>(resultsToReplace)); firstBlockTerminator->erase(); @@ -218,6 +319,11 @@ resultToRepl.value().getLoc())); } + // Run the result attribute handler on the post insertion block arguments. + builder.setInsertionPointToStart(postInsertBlock); + handleResultImpl(interface, builder, call, callable, + postInsertBlock->getArguments()); + /// Handle the terminators for each of the new blocks. for (auto &newBlock : newBlocks) interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -226,3 +226,40 @@ call @func_with_block_args_location(%arg0) : (i32) -> () return } + +// Check that we can handle argument and result attributes. +func.func @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) { + %0 = arith.addi %arg0, %arg1 : i16 + %1 = arith.subi %arg0, %arg1 : i16 + return %0, %1 : i16, i16 +} +func.func @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) { + return %arg0 : i32 +} + +// CHECK-LABEL: func @inline_handle_attr_call +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +func.func @inline_handle_attr_call(%arg0 : i16, %arg1 : i16) -> (i16, i16) { + + // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[ARG1]]) : (i16) -> i16 + // CHECK: %[[SUM:.*]] = arith.addi %[[ARG0]], %[[CHANGE_INPUT]] + // CHECK: %[[DIFF:.*]] = arith.subi %[[ARG0]], %[[CHANGE_INPUT]] + // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[SUM]]) : (i16) -> i16 + // CHECK-NEXT: return %[[CHANGE_RESULT]], %[[DIFF]] + %res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16) + return %res0, %res1 : i16, i16 +} + +// CHECK-LABEL: func @inline_convert_and_handle_attr_call +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) { + + // CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32 + // CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32 + // CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32 + // CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16 + // CHECK: return %[[CAST_RESULT]] + %res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16) + return %res : i16 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -245,6 +245,24 @@ return builder.create(conversionLoc, resultType, input); } + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (!argumentAttrs.contains("test.handle_argument")) + return argument; + return builder.create(call->getLoc(), targetType, + argument); + } + + Value handleResult(OpBuilder &builder, Operation *call, Operation *callable, + Value result, Type targetType, + DictionaryAttr resultAttrs) const final { + if (!resultAttrs.contains("test.handle_result")) + return result; + return builder.create(call->getLoc(), targetType, + result); + } + void processInlinedCallBlocks( Operation *call, iterator_range inlinedBlocks) const final {