diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -27,7 +27,9 @@ class OpBuilder; class Operation; class Region; +class TypeRange; class Value; +class ValueRange; //===----------------------------------------------------------------------===// // InlinerInterface @@ -172,13 +174,15 @@ /// remapped operands that are used within the region, and *must* include /// remappings for the entry arguments to the region. 'resultsToReplace' /// corresponds to any results that should be replaced by terminators within the -/// inlined region. 'inlineLoc' is an optional Location that, if provided, will -/// be used to update the inlined operations' location information. -/// 'shouldCloneInlinedRegion' corresponds to whether the source region should -/// be cloned into the 'inlinePoint' or spliced directly. +/// inlined region. 'regionResultTypes' specifies the expected return types of +/// the terminators in the region. 'inlineLoc' is an optional Location that, if +/// provided, will be used to update the inlined operations' location +/// information. 'shouldCloneInlinedRegion' corresponds to whether the source +/// region should be cloned into the 'inlinePoint' or spliced directly. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, - ArrayRef resultsToReplace, + ValueRange resultsToReplace, + TypeRange regionResultTypes, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); @@ -187,8 +191,8 @@ /// in-favor of the region arguments when inlining. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, - ArrayRef inlinedOperands, - ArrayRef resultsToReplace, + ValueRange inlinedOperands, + ValueRange resultsToReplace, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -456,11 +456,15 @@ bool inlinedAnyCalls = false; for (unsigned i = 0; i != calls.size(); ++i) { ResolvedCall it = calls[i]; + bool doInline = shouldInline(it); LLVM_DEBUG({ - llvm::dbgs() << "* Considering inlining call: "; + if (doInline) + llvm::dbgs() << "* Inlining call: "; + else + llvm::dbgs() << "* Not inlining call: "; it.call.dump(); }); - if (!shouldInline(it)) + if (!doInline) continue; CallOpInterface call = it.call; Region *targetRegion = it.targetNode->getCallableRegion(); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -128,9 +128,11 @@ LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, - ArrayRef resultsToReplace, + ValueRange resultsToReplace, + TypeRange regionResultTypes, Optional inlineLoc, bool shouldCloneInlinedRegion) { + assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. if (src->empty()) return failure(); @@ -188,7 +190,8 @@ if (std::next(newBlocks.begin()) == newBlocks.end()) { // Have the interface handle the terminator of this block. auto *firstBlockTerminator = firstNewBlock->getTerminator(); - interface.handleTerminator(firstBlockTerminator, resultsToReplace); + interface.handleTerminator(firstBlockTerminator, + llvm::to_vector<6>(resultsToReplace)); firstBlockTerminator->erase(); // Merge the post insert block into the cloned entry block. @@ -198,9 +201,9 @@ } else { // Otherwise, there were multiple blocks inlined. Add arguments to the post // insertion block to represent the results to replace. - for (Value resultToRepl : resultsToReplace) { - resultToRepl.replaceAllUsesWith( - postInsertBlock->addArgument(resultToRepl.getType())); + for (auto resultToRepl : llvm::enumerate(resultsToReplace)) { + resultToRepl.value().replaceAllUsesWith(postInsertBlock->addArgument( + regionResultTypes[resultToRepl.index()])); } /// Handle the terminators for each of the new blocks. @@ -220,8 +223,8 @@ /// in-favor of the region arguments when inlining. LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, - ArrayRef inlinedOperands, - ArrayRef resultsToReplace, + ValueRange inlinedOperands, + ValueRange resultsToReplace, Optional inlineLoc, bool shouldCloneInlinedRegion) { // We expect the region to have at least one block. @@ -245,7 +248,8 @@ // Call into the main region inliner function. return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, - inlineLoc, shouldCloneInlinedRegion); + resultsToReplace.getTypes(), inlineLoc, + shouldCloneInlinedRegion); } /// Utility function used to generate a cast operation from the given interface, @@ -350,7 +354,8 @@ // Attempt to inline the call. if (failed(inlineRegion(interface, src, call, mapper, callResults, - call.getLoc(), shouldCloneInlinedRegion))) + callableResultTypes, call.getLoc(), + shouldCloneInlinedRegion))) return cleanupState(); return success(); } diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -131,6 +131,27 @@ return %res : i16 } +func @convert_callee_fn_multiblock() -> i32 { + br ^bb0 +^bb0: + %0 = constant 0 : i32 + return %0 : i32 +} + +// CHECK-LABEL: func @inline_convert_result_multiblock +func @inline_convert_result_multiblock() -> i16 { +// CHECK: br ^bb1 +// CHECK: ^bb1: +// CHECK: %[[C:.+]] = constant 0 : i32 +// CHECK: br ^bb2(%[[C]] : i32) +// CHECK: ^bb2(%[[BBARG:.+]]: i32): +// CHECK: %[[CAST_RESULT:.+]] = "test.cast"(%[[BBARG]]) : (i32) -> i16 +// CHECK: return %[[CAST_RESULT]] : i16 + + %res = "test.conversion_call_op"() { callee=@convert_callee_fn_multiblock } : () -> (i16) + return %res : i16 +} + // CHECK-LABEL: func @no_inline_convert_call func @no_inline_convert_call() { // CHECK: "test.conversion_call_op" diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -44,8 +44,7 @@ // Inline the functional region operation, but only clone the internal // region if there is more than one use. if (failed(inlineRegion( - interface, &callee.body(), caller, - llvm::to_vector<8>(caller.getArgOperands()), + interface, &callee.body(), caller, caller.getArgOperands(), SmallVector(caller.getResults()), caller.getLoc(), /*shouldCloneInlinedRegion=*/!callee.getResult().hasOneUse()))) continue;