diff --git a/mlir/include/mlir/Transforms/InliningUtils.h b/mlir/include/mlir/Transforms/InliningUtils.h --- a/mlir/include/mlir/Transforms/InliningUtils.h +++ b/mlir/include/mlir/Transforms/InliningUtils.h @@ -172,13 +172,15 @@ /// remapped operands that are used within the region, and *must* include /// remappings for the entry arguments to the region. 'resultsToReplace' /// corresponds to any results that should be replaced by terminators within the -/// inlined region. 'inlineLoc' is an optional Location that, if provided, will -/// be used to update the inlined operations' location information. -/// 'shouldCloneInlinedRegion' corresponds to whether the source region should -/// be cloned into the 'inlinePoint' or spliced directly. +/// inlined region. 'regionResultTypes' specifies the expected return types of +/// the terminators in the region. 'inlineLoc' is an optional Location that, if +/// provided, will be used to update the inlined operations' location +/// information. 'shouldCloneInlinedRegion' corresponds to whether the source +/// region should be cloned into the 'inlinePoint' or spliced directly. LogicalResult inlineRegion(InlinerInterface &interface, Region *src, Operation *inlinePoint, BlockAndValueMapping &mapper, ArrayRef resultsToReplace, + ArrayRef regionResultTypes, Optional inlineLoc = llvm::None, bool shouldCloneInlinedRegion = true); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -21,12 +21,17 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Support/Parallel.h" #define DEBUG_TYPE "inlining" +using llvm::DebugCounter; using namespace mlir; +DEBUG_COUNTER(InlineFunction, "inliner-inline-function", + "Controls which inlinings we do."); + //===----------------------------------------------------------------------===// // Symbol Use Tracking //===----------------------------------------------------------------------===// @@ -456,11 +461,16 @@ bool inlinedAnyCalls = false; for (unsigned i = 0; i != calls.size(); ++i) { ResolvedCall it = calls[i]; + bool doInline = + shouldInline(it) && DebugCounter::shouldExecute(InlineFunction); LLVM_DEBUG({ - llvm::dbgs() << "* Considering inlining call: "; + if (doInline) + llvm::dbgs() << "* Inlining call: "; + else + llvm::dbgs() << "* Not inlining call: "; it.call.dump(); }); - if (!shouldInline(it)) + if (!doInline) continue; CallOpInterface call = it.call; Region *targetRegion = it.targetNode->getCallableRegion(); diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -129,8 +129,10 @@ Operation *inlinePoint, BlockAndValueMapping &mapper, ArrayRef resultsToReplace, + ArrayRef regionResultTypes, Optional inlineLoc, bool shouldCloneInlinedRegion) { + assert(resultsToReplace.size() == regionResultTypes.size()); // We expect the region to have at least one block. if (src->empty()) return failure(); @@ -198,9 +200,9 @@ } else { // Otherwise, there were multiple blocks inlined. Add arguments to the post // insertion block to represent the results to replace. - for (Value resultToRepl : resultsToReplace) { - resultToRepl.replaceAllUsesWith( - postInsertBlock->addArgument(resultToRepl.getType())); + for (auto resultToRepl : llvm::enumerate(resultsToReplace)) { + resultToRepl.value().replaceAllUsesWith(postInsertBlock->addArgument( + regionResultTypes[resultToRepl.index()])); } /// Handle the terminators for each of the new blocks. @@ -243,9 +245,11 @@ mapper.map(regionArg, inlinedOperands[i]); } + auto regionResultTypes = llvm::to_vector<6>( + llvm::map_range(resultsToReplace, [](Value v) { return v.getType(); })); // Call into the main region inliner function. return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, - inlineLoc, shouldCloneInlinedRegion); + regionResultTypes, inlineLoc, shouldCloneInlinedRegion); } /// Utility function used to generate a cast operation from the given interface, @@ -350,7 +354,8 @@ // Attempt to inline the call. if (failed(inlineRegion(interface, src, call, mapper, callResults, - call.getLoc(), shouldCloneInlinedRegion))) + callableResultTypes, call.getLoc(), + shouldCloneInlinedRegion))) return cleanupState(); return success(); } diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -131,6 +131,17 @@ return %res : i16 } +func @convert_callee_fn_multiblock() -> i32 { + br ^bb0 +^bb0: + %0 = constant 0 : i32 + return %0 : i32 +} +func @inline_convert_result_multiblock() -> i16 { + %res = "test.conversion_call_op"() { callee=@convert_callee_fn_multiblock } : () -> (i16) + return %res : i16 +} + // CHECK-LABEL: func @no_inline_convert_call func @no_inline_convert_call() { // CHECK: "test.conversion_call_op"