diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -13,6 +13,8 @@
 #ifndef MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
 #define MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
 
+#include "mlir/Support/LLVM.h"
+
 namespace mlir {
 
 // Forward declarations.
@@ -32,8 +34,15 @@
 /// operands that have been legalized by the conversion framework. This can only
 /// be done if the branch operation implements the BranchOpInterface. Only
 /// needed for partial conversions.
-void populateBranchOpInterfaceTypeConversionPattern(RewritePatternSet &patterns,
-                                                    TypeConverter &converter);
+///
+/// If for some branch ops, we need to convert/legalize only a sub-set of the
+/// op's operands, such filtering behavior can be specified in
+/// branchOpsOperandConversionFilter where an op is mapped to the subset of its
+/// operands that need to be converted.
+void populateBranchOpInterfaceTypeConversionPattern(
+    RewritePatternSet &patterns, TypeConverter &converter,
+    const DenseMap<Operation *, DenseSet<int>>
+        *branchOpsOperandConversionFilter = nullptr);
 
 /// Return true if op is a BranchOpInterface op whose operands are all legal
 /// according to converter.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -498,8 +498,13 @@
   /// Convert the types of block arguments within the given region except for
   /// the entry region. This replaces each non-entry block with a new block
   /// containing the updated signature.
-  LogicalResult convertNonEntryRegionTypes(Region *region,
-                                           TypeConverter &converter);
+  ///
+  /// If special conversion behavior is needed for the non-entry blocks (for
+  /// example, we need to convert only a subset of a BB arguments), such
+  /// behavior can be specified in blockConversions.
+  LogicalResult convertNonEntryRegionTypes(
+      Region *region, TypeConverter &converter,
+      SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions);
 
   /// Replace all the uses of the block argument `from` with value `to`.
   void replaceUsesOfBlockArgument(BlockArgument from, Value to);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -15,6 +15,8 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include <deque>
 #include <iterator>
 #include <memory>
 
@@ -39,13 +41,21 @@
 /// Defines the criteria a TensorType must follow in order to be considered
 /// "detensorable".
 ///
-/// NOTE: For now, only 0-D are supported.
+/// NOTE: For now, only 0-D tensors are supported.
 ///
 /// Returns true if tensorType can be detensored.
 bool canBeDetensored(TensorType tensorType) {
   return tensorType.hasRank() && tensorType.getRank() == 0;
 }
 
+bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
+  GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
+  return genericOp && llvm::all_of(genericOp.getShapedOperandTypes(),
+                                   [&](ShapedType shapedType) {
+                                     return !typeConverter.isLegal(shapedType);
+                                   });
+}
+
 /// A conversion patttern for detensoring `linalg.generic` ops.
 class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
 public:
@@ -81,17 +91,36 @@
 /// A conversion pattern for detensoring internal (non-entry) blocks within a
 /// function.
 struct FunctionNonEntryBlockConversion : public ConversionPattern {
-  FunctionNonEntryBlockConversion(StringRef functionLikeOpName,
-                                  MLIRContext *ctx, TypeConverter &converter)
-      : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
+  FunctionNonEntryBlockConversion(
+      StringRef functionLikeOpName, MLIRContext *ctx, TypeConverter &converter,
+      DenseMap<Block *, DenseSet<int>> blockArgumentDetensoring)
+      : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx),
+        blockArgumentDetensoring(blockArgumentDetensoring) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     rewriter.startRootUpdate(op);
+    Region &region = mlir::impl::getFunctionBody(op);
+    SmallVector<TypeConverter::SignatureConversion, 2> conversions;
+
+    for (Block &block : llvm::drop_begin(region, 1)) {
+      conversions.emplace_back(block.getNumArguments());
+      TypeConverter::SignatureConversion &back = conversions.back();
+      DenseSet<int> blockArgumentDetensoringFilter =
+          blockArgumentDetensoring.lookup(&block);
+
+      for (unsigned int idx = 0; idx < block.getNumArguments(); ++idx) {
+        if (blockArgumentDetensoringFilter.count(idx))
+          back.addInputs(idx, {getTypeConverter()->convertType(
+                                  block.getArgumentTypes()[idx])});
+        else
+          back.addInputs(idx, {block.getArgumentTypes()[idx]});
+      }
+    }
 
-    if (failed(rewriter.convertNonEntryRegionTypes(
-            &mlir::impl::getFunctionBody(op), *typeConverter))) {
+    if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
+                                                   &conversions))) {
       rewriter.cancelRootUpdate(op);
       return failure();
     }
@@ -99,6 +128,9 @@
     rewriter.finalizeRootUpdate(op);
     return success();
   }
+
+private:
+  const DenseMap<Block *, DenseSet<int>> blockArgumentDetensoring;
 };
 
 class DetensorizeTypeConverter : public TypeConverter {
@@ -160,46 +192,291 @@
 
 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
+  LinalgDetensorize() = default;
+  LinalgDetensorize(const LinalgDetensorize &pass) {}
+
+  class CostModel {
+  public:
+    virtual ~CostModel() = default;
+
+    /// A cost model algorithm computes the following outputs:
+    ///
+    /// - detensorableLinalgOps: the list of linalg ops that should be
+    /// detensored.
+    ///
+    /// - detensorableBranchOps: a map whose keys are branch ops and whose
+    /// values are operand indices for such keys. The set of operand indices
+    /// corresponding to a branch op specify which sub-set of the branch's
+    /// operands should be detensored (i.e. converted by typeConverter).
+    ///
+    /// - blockArgumentDetensoring: since the operands and results of detensored
+    /// lingal ops can cross the BB boundary (e.g. a linalg op's input can come
+    /// from a BB argument and a linalg op's output can be passed to successor
+    /// BBs), we need to maintain the sub-set of arguments that should be
+    /// detensored (i.e. converted by typeConverter) for each affected BB.
+    ///
+    /// Example:
+    ///
+    /// For the following snippet:
+    /// ...
+    /// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
+    ///   %7 = linalg.init_tensor [] : tensor<i32>
+    ///   %8 = linalg.generic #attrs
+    ///     ins(%6, %6 : tensor<i32>, tensor<i32>)
+    ///     outs(%7 : tensor<i32>) {
+    ///     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
+    ///       %9 = addi %arg0, %arg1 : i32
+    ///       linalg.yield %9 : i32
+    ///   } -> tensor<i32>
+    ///   %10 = "some.op"(%9)
+    ///   br ^bb2(%8 : tensor<i32>)
+    /// ...
+    ///
+    /// if the cost model decides that the linalg.generic op should be
+    /// detensored, then:
+    /// - detensorableLinalgOps should be = {linalg.generic{add}}.
+    /// - detensorableBranchOps should be = {bb2 -> {0}}.
+    /// - blockArgumentDetensoring should be = {bb1 -> {0}, bb2 -> {0}}.
+    virtual void
+    compute(FuncOp func, DetensorizeTypeConverter typeConverter,
+            DenseSet<Operation *> &detensorableLinalgOps,
+            DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps,
+            DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) = 0;
+  };
+
+  class PureControlFlowDetectionModel : public CostModel {
+    void compute(
+        FuncOp func, DetensorizeTypeConverter typeConverter,
+        DenseSet<Operation *> &detensorableLinalgOps,
+        DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps,
+        DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) override {
+      // TODO The following code is implemented with loops in mind. We might
+      // need to add support for if conditions later on.
+
+      DenseSet<Operation *> workList;
+      // 1. Find which detensorable ops are involved in control-flow (i.e.
+      // they produce tensors that are then used in a cond_br's condition).
+      func.walk([&](CondBranchOp condBr) {
+        auto *chainOp = condBr.condition().getDefiningOp();
+
+        while (chainOp && !dyn_cast<GenericOp>(chainOp)) {
+          if (chainOp->getNumOperands() != 1)
+            break;
+
+          chainOp = chainOp->getOperand(0).getDefiningOp();
+        }
+
+        if (!shouldBeDetensored(chainOp, typeConverter))
+          return;
+
+        workList.insert(chainOp);
+      });
+
+      // 2. Discover other detensorable ops by walking the def-use chain
+      // backwards starting from the detensorable ops currently on the
+      // workList.
+      while (!workList.empty()) {
+        GenericOp detensorableOp = cast<GenericOp>(*workList.begin());
+        detensorableLinalgOps.insert(detensorableOp);
+        workList.erase(workList.begin());
+
+        // Discover where the detensorableOp's operands come from.
+        for (Value operand : detensorableOp.inputs())
+          if (!discoverDetensorableComponent(
+                  operand, typeConverter, workList, detensorableLinalgOps,
+                  detensorableBranchOps, blockArgumentDetensoring)) {
+            // TODO For now we assume there is one opportunity for detensoring
+            // in a function. This can be extended to support multiple separate
+            // components in a single function.
+            detensorableLinalgOps.clear();
+            detensorableBranchOps.clear();
+            blockArgumentDetensoring.clear();
+            return;
+          }
+      }
+    }
+
+  private:
+    bool discoverDetensorableComponent(
+        Value operand, TypeConverter typeConverter,
+        DenseSet<Operation *> &workList,
+        const DenseSet<Operation *> &detensorableLinalgOps,
+        DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps,
+        DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) {
+      auto *definingOp = operand.getDefiningOp();
+
+      if (definingOp) {
+        if (comesFromElements(definingOp))
+          return true;
+
+        if (!shouldBeDetensored(definingOp, typeConverter))
+          return false;
+
+        if (!workList.count(definingOp) &&
+            !detensorableLinalgOps.count(definingOp))
+          workList.insert(definingOp);
+
+        return true;
+      }
+
+      BlockArgument blockArgument = operand.cast<BlockArgument>();
+      Block *ownerBlock = blockArgument.getOwner();
+
+      if (&*ownerBlock->getParent()->begin() == ownerBlock)
+        return true;
+
+      blockArgumentDetensoring[ownerBlock].insert(blockArgument.getArgNumber());
+
+      for (PredecessorIterator pred = ownerBlock->pred_begin();
+           pred != ownerBlock->pred_end(); ++pred) {
+        BranchOpInterface terminator =
+            dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+        auto ownerBlockOperands =
+            terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+        // TODO Add a test where the same operand is passed more than once to
+        // the same block.
+        if (!ownerBlockOperands || ownerBlockOperands->empty())
+          continue;
+
+        auto operand =
+            ownerBlockOperands.getValue()[blockArgument.getArgNumber()];
+
+        for (int idx = ownerBlockOperands->getBeginOperandIndex(),
+                 eidx = idx + ownerBlockOperands->size();
+             idx < eidx; ++idx)
+          if (terminator->getOperand(idx) == operand)
+            detensorableBranchOps[terminator].insert(idx);
+
+        if (!discoverDetensorableComponent(
+                operand, typeConverter, workList, detensorableLinalgOps,
+                detensorableBranchOps, blockArgumentDetensoring)) {
+
+          return false;
+        }
+      }
+
+      return true;
+    }
+
+    bool comesFromElements(Operation *op) {
+      while (op && !dyn_cast<tensor::FromElementsOp>(op)) {
+        if (op->getNumOperands() > 1)
+          return false;
+
+        op = op->getOperand(0).getDefiningOp();
+      }
+
+      return op;
+    }
+  };
+
+  /// Detensorize everything that can detensored.
+  class AggressiveDetensoringModel : public CostModel {
+    void compute(
+        FuncOp func, DetensorizeTypeConverter typeConverter,
+        DenseSet<Operation *> &detensorableLinalgOps,
+        DenseMap<Operation *, DenseSet<int>> &detensorableBranchOps,
+        DenseMap<Block *, DenseSet<int>> &blockArgumentDetensoring) override {
+      func.walk([&](GenericOp genericOp) {
+        if (shouldBeDetensored(genericOp, typeConverter))
+          detensorableLinalgOps.insert(genericOp);
+      });
+
+      func.walk([&](BranchOpInterface brOp) {
+        DenseSet<int> brOpOperandDetensoring;
+
+        for (int p = 0, e = brOp->getBlock()->getNumSuccessors(); p < e; ++p) {
+          auto successorOperands = brOp.getSuccessorOperands(p);
+          Block *successor = brOp->getSuccessor(p);
+
+          if (!successorOperands.hasValue())
+            break;
+
+          for (int idx = successorOperands->getBeginOperandIndex(),
+                   eidx = idx + successorOperands->size();
+               idx < eidx; ++idx) {
+            brOpOperandDetensoring.insert(idx);
+            blockArgumentDetensoring[successor].insert(
+                idx - successorOperands->getBeginOperandIndex());
+          }
+        }
+
+        detensorableBranchOps.try_emplace(brOp,
+                                          std::move(brOpOperandDetensoring));
+      });
+    }
+  };
+
   void runOnFunction() override {
-    auto *context = &getContext();
+    MLIRContext *context = &getContext();
     DetensorizeTypeConverter typeConverter;
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
+    DenseSet<Operation *> detensorableLinalgOps;
+    DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
+    DenseMap<Block *, DenseSet<int>> blockArgumentDetensoring;
 
-    target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
-      // If any of the operands or results cannot be detensored (i.e. they are
-      // all legal according the DetensorizeTypeConverter), the op is considered
-      // legal and won't be detensored.
-      return llvm::any_of(op.getShapedOperandTypes(),
-                          [&](ShapedType shapedType) {
-                            return typeConverter.isLegal(shapedType);
-                          });
-    });
+    std::unique_ptr<CostModel> costModel;
+
+    if (aggressiveMode.getValue())
+      costModel = std::make_unique<AggressiveDetensoringModel>();
+    else
+      costModel = std::make_unique<PureControlFlowDetectionModel>();
+
+    costModel->compute(getFunction(), typeConverter, detensorableLinalgOps,
+                       detensorableBranchOps, blockArgumentDetensoring);
+
+    target.addDynamicallyLegalOp<GenericOp>(
+        [&](GenericOp op) { return !detensorableLinalgOps.count(op); });
 
     target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
-      // A function is legal if all of its non-entry blocks are legal. We don't
-      // legalize the entry block (i.e. the function's signature) since
-      // detensoring can't happen along external calling convention boundaries,
-      // which we conservatively approximate as all function signatures.
+      // A function is legal if all of its non-entry blocks are legal. We
+      // don't legalize the entry block (i.e. the function's signature) since
+      // detensoring can't happen along external calling convention
+      // boundaries, which we conservatively approximate as all function
+      // signatures.
       return llvm::all_of(llvm::drop_begin(op.getBody(), 1), [&](Block &block) {
-        return typeConverter.isLegal(block.getArgumentTypes());
+        if (blockArgumentDetensoring.count(&block) &&
+            llvm::any_of(blockArgumentDetensoring[&block], [&](int idx) {
+              return !typeConverter.isLegal(block.getArgumentTypes()[idx]);
+            })) {
+          return false;
+        }
+        return true;
       });
     });
 
     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
-      return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
-             isLegalForBranchOpInterfaceTypeConversionPattern(op,
-                                                              typeConverter) ||
-             isLegalForReturnOpTypeConversionPattern(
-                 op, typeConverter, /*returnOpAlwaysLegal*/ true);
+      if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
+          isLegalForReturnOpTypeConversionPattern(op, typeConverter,
+                                                  /*returnOpAlwaysLegal*/ true))
+        return true;
+
+      if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
+        if (!detensorableBranchOps.count(branchOp))
+          return true;
+
+        for (auto operandIdx : detensorableBranchOps[branchOp])
+          if (!typeConverter.isLegal(
+                  branchOp->getOperand(operandIdx).getType()))
+            return false;
+
+        return true;
+      }
+
+      return false;
     });
 
-    patterns.add<DetensorizeGenericOp>(typeConverter, context);
-    patterns.add<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
-                                                  context, typeConverter);
-    // Since non-entry block arguments get detensorized, we also need to update
-    // the control flow inside the function to reflect the correct types.
-    populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
+    patterns.insert<DetensorizeGenericOp>(typeConverter, context);
+    patterns.insert<FunctionNonEntryBlockConversion>(FuncOp::getOperationName(),
+                                                     context, typeConverter,
+                                                     blockArgumentDetensoring);
+    // Since non-entry block arguments get detensorized, we also need to
+    // update the control flow inside the function to reflect the correct
+    // types.
+    populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
+                                                   &detensorableBranchOps);
 
     if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
       signalPassFailure();
@@ -210,6 +487,11 @@
                                             std::move(canonPatterns))))
       signalPassFailure();
   }
+
+  Option<bool> aggressiveMode{
+      *this, "aggressive-mode",
+      llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
+                     "with branch operands and basic-block arguments.")};
 };
 } // namespace
 
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -52,27 +52,44 @@
   using OpInterfaceConversionPattern<
       BranchOpInterface>::OpInterfaceConversionPattern;
 
+  BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
+                                  MLIRContext *ctx,
+                                  const DenseMap<Operation *, DenseSet<int>>
+                                      *branchOpsOperandConversionFilter)
+      : OpInterfaceConversionPattern(typeConverter, ctx, /*benefit=*/1),
+        branchOpsOperandConversionFilter(branchOpsOperandConversionFilter) {}
+
   LogicalResult
   matchAndRewrite(BranchOpInterface op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
+    DenseSet<int> opOperandFilter;
+
+    if (branchOpsOperandConversionFilter)
+      opOperandFilter = branchOpsOperandConversionFilter->lookup(op);
+
     // For a branch operation, only some operands go to the target blocks, so
     // only rewrite those.
     SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
     for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
          succIdx < succEnd; ++succIdx) {
       auto successorOperands = op.getSuccessorOperands(succIdx);
-      if (!successorOperands)
+      if (!successorOperands || successorOperands->empty())
         continue;
+
       for (int idx = successorOperands->getBeginOperandIndex(),
                eidx = idx + successorOperands->size();
            idx < eidx; ++idx) {
-        newOperands[idx] = operands[idx];
+        if (!branchOpsOperandConversionFilter || opOperandFilter.count(idx))
+          newOperands[idx] = operands[idx];
       }
     }
     rewriter.updateRootInPlace(
         op, [newOperands, op]() { op->setOperands(newOperands); });
     return success();
   }
+
+private:
+  const DenseMap<Operation *, DenseSet<int>> *branchOpsOperandConversionFilter;
 };
 } // end anonymous namespace
 
@@ -98,9 +115,11 @@
 } // end anonymous namespace
 
 void mlir::populateBranchOpInterfaceTypeConversionPattern(
-    RewritePatternSet &patterns, TypeConverter &typeConverter) {
-  patterns.add<BranchOpInterfaceTypeConversion>(typeConverter,
-                                                patterns.getContext());
+    RewritePatternSet &patterns, TypeConverter &typeConverter,
+    const DenseMap<Operation *, DenseSet<int>>
+        *branchOpsOperandConversionFilter) {
+  patterns.insert<BranchOpInterfaceTypeConversion>(
+      typeConverter, patterns.getContext(), branchOpsOperandConversionFilter);
 }
 
 bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -495,8 +495,17 @@
     // to pack the new values. For 1->1 mappings, if there is no materialization
     // provided, use the argument directly instead.
     auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
-    Value newArg = converter.materializeArgumentConversion(
-        rewriter, origArg.getLoc(), origArg.getType(), replArgs);
+    Value newArg;
+
+    // If this is a 1->1 mapping and the types of new and replacement arguments
+    // match (i.e. it's an idnetity map), then the argument is mapped to its
+    // original type.
+    if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType())
+      newArg = replArgs[0];
+    else
+      newArg = converter.materializeArgumentConversion(
+          rewriter, origArg.getLoc(), origArg.getType(), replArgs);
+
     if (!newArg) {
       assert(replArgs.size() == 1 &&
              "couldn't materialize the result of 1->N conversion");
@@ -754,8 +763,9 @@
                      TypeConverter::SignatureConversion *entryConversion);
 
   /// Convert the types of non-entry block arguments within the given region.
-  LogicalResult convertNonEntryRegionTypes(Region *region,
-                                           TypeConverter &converter);
+  LogicalResult convertNonEntryRegionTypes(
+      Region *region, TypeConverter &converter,
+      SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions);
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -1164,7 +1174,7 @@
   if (region->empty())
     return nullptr;
 
-  if (failed(convertNonEntryRegionTypes(region, converter)))
+  if (failed(convertNonEntryRegionTypes(region, converter, nullptr)))
     return failure();
 
   FailureOr<Block *> newEntry =
@@ -1173,14 +1183,18 @@
 }
 
 LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
-    Region *region, TypeConverter &converter) {
+    Region *region, TypeConverter &converter,
+    SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions) {
   argConverter.setConverter(region, &converter);
   if (region->empty())
     return success();
 
   // Convert the arguments of each block within the region.
+  int blockIdx = 0;
   for (Block &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1)))
-    if (failed(convertBlockSignature(&block, converter)))
+    if (failed(convertBlockSignature(
+            &block, converter,
+            blockConversions ? &(*blockConversions)[blockIdx++] : nullptr)))
       return failure();
   return success();
 }
@@ -1351,8 +1365,9 @@
 }
 
 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
-    Region *region, TypeConverter &converter) {
-  return impl->convertNonEntryRegionTypes(region, converter);
+    Region *region, TypeConverter &converter,
+    SmallVectorImpl<TypeConverter::SignatureConversion> *blockConversions) {
+  return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
+  br ^bb1(%farg0 : tensor<i32>)
+
+^bb1(%0: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %1 = linalg.init_tensor [] : tensor<i1>
+  %2 = linalg.generic #attrs
+    ins(%0, %farg1 : tensor<i32>, tensor<i32>)
+    outs(%1 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %3 = tensor.extract %2[] : tensor<i1>
+  cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
+
+^bb2(%4: tensor<i32>):  // pred: ^bb1
+  %5 = linalg.init_tensor [] : tensor<i32>
+  %6 = linalg.generic #attrs
+    ins(%4, %4 : tensor<i32>, tensor<i32>)
+    outs(%5 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %8 = addi %arg0, %arg1 : i32
+      linalg.yield %8 : i32
+  } -> tensor<i32>
+  br ^bb1(%6 : tensor<i32>)
+
+^bb3(%7: tensor<i32>):  // pred: ^bb1
+  return %7 : tensor<i32>
+}
+
+// Test aggresively detensoring all detensorable ops.
+//
+// DET-ALL-LABEL: func @main
+// DET-ALL-SAME:    (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
+// DET-ALL:         tensor.extract {{.*}}
+// DET-ALL:         br ^[[bb1:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb1]](%{{.*}}: i32)
+// DET-ALL:         cmpi slt, {{.*}}
+// DET-ALL:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         addi {{.*}}
+// DET-ALL:         br ^[[bb1]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:         tensor.from_elements {{.*}}
+// DET-ALL:         linalg.tensor_reshape {{.*}}
+// DET-ALL:         return %{{.*}} : tensor<i32>
+
+// Test detensoring only ops involed in control-flow.
+//
+// DET-CF-LABEL: func @main
+// DET-CF-SAME:    (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
+// DET-CF:         tensor.extract {{.*}}
+// DET-CF:         br ^[[bb1:.*]](%{{.*}} : i32)
+// DET-CF:       ^[[bb1]](%{{.*}}: i32)
+// DET-CF-DAG      tensor.from_elements {{.*}}
+// DET-CF-DAG:     linalg.tensor_reshape {{.*}}
+// DET-CF-DAG:     cmpi slt, {{.*}}
+// DET-CF:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : tensor<i32>)
+// DET-CF:       ^[[bb2]](%{{.*}}: i32)
+// DET-CF:         addi {{.*}}
+// DET-CF:         br ^[[bb1]](%{{.*}} : i32)
+// DET-CF:       ^[[bb3]](%{{.*}}: tensor<i32>)
+// DET-CF:         return %{{.*}} : tensor<i32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while_failure.mlir
@@ -0,0 +1,111 @@
+// RUN: mlir-opt %s -linalg-detensorize=aggressive-mode | FileCheck %s -check-prefix=DET-ALL
+// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s -check-prefix=DET-CF
+
+#map0 = affine_map<() -> ()>
+#map1 = affine_map<(i) -> ()>
+#map2 = affine_map<(i) -> (i)>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+#sum_reduction_attrs = {
+  indexing_maps = [#map2, #map1],
+  iterator_types = ["reduction"]
+}
+
+
+#broadcast_attrs = {
+  indexing_maps = [#map1, #map2],
+  iterator_types = ["parallel"]
+}
+
+func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
+  br ^bb1(%farg0 : tensor<10xi32>)
+
+^bb1(%0: tensor<10xi32>):  // 2 preds: ^bb0, ^bb2
+  %1 = linalg.init_tensor [] : tensor<i32>
+  %2 = linalg.generic #sum_reduction_attrs
+    ins(%0: tensor<10xi32>)
+    outs(%1: tensor<i32>) {
+      ^bb(%a: i32, %x: i32):
+        %b = addi %x, %a : i32
+        linalg.yield %b : i32
+  } -> tensor<i32>
+
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%2, %farg1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %5 = tensor.extract %4[] : tensor<i1>
+  cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3(%2 : tensor<i32>)
+
+^bb2(%6: tensor<i32>):  // pred: ^bb1
+  %7 = linalg.init_tensor [10] : tensor<10xi32>
+  %9 = linalg.generic #broadcast_attrs
+       ins(%6: tensor<i32>)
+      outs(%7: tensor<10xi32>) {
+    ^bb(%a: i32, %b: i32) :
+      linalg.yield %a : i32
+  } -> tensor<10xi32>
+
+  br ^bb1(%9 : tensor<10xi32>)
+
+^bb3(%10: tensor<i32>):  // pred: ^bb1
+  return %10 : tensor<i32>
+}
+
+// Test aggresively detensoring all detensorable ops.
+//
+// DET-ALL-LABEL: func @main
+// DET-ALL-SAME:    (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
+// DET-ALL:         br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
+// DET-ALL:       ^[[bb1]](%{{.*}}: tensor<10xi32>)
+// DET-ALL:         linalg.init_tensor [] : tensor<i32>
+// DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
+// DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):  // no predecessors
+// DET-ALL:           %{{.*}} = addi %{{.*}}, %{{.*}}
+// DET-ALL:           linalg.yield %{{.*}} : i32
+// DET-ALL:         } -> tensor<i32>
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         tensor.extract %{{.*}}[] : tensor<i32>
+// DET-ALL:         cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
+// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-ALL:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         linalg.init_tensor [10] : tensor<10xi32>
+// DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
+// DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):
+// DET-ALL:           linalg.yield %{{.*}} : i32
+// DET-ALL:         } -> tensor<10xi32>
+// DET-ALL:         br ^[[bb1]](%{{.*}} : tensor<10xi32>)
+// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
+// DET-ALL:         linalg.tensor_reshape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         return %{{.*}} : tensor<i32>
+// DET-ALL:       }
+
+// Try to detensor pure control-flow. However, that fails since the potential 
+// detensorable component contains some ops that cannot be detensored.
+//
+// DET-CF-LABEL: func @main
+// DET-CF-SAME:    (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
+// DET-CF:         br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
+// DET-CF:       ^bb1(%{{.*}}: tensor<10xi32>)
+// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
+// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
+// DET-CF:         cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
+// DET-CF:       ^bb2(%{{.*}}: tensor<i32>)
+// DET-CF:         %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
+// DET-CF:         br ^bb1(%{{.*}} : tensor<10xi32>)
+// DET-CF:       ^bb3(%{{.*}}: tensor<i32>)
+// DET-CF:         return %{{.*}} : tensor<i32>
+// DET-CF:       }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+
+#map0 = affine_map<() -> ()>
+
+#attrs = {
+  indexing_maps = [#map0, #map0, #map0],
+  iterator_types = []
+}
+
+func @main() -> () attributes {} {
+  %c0 = constant 0 : i32
+  %0 = tensor.from_elements %c0 : tensor<1xi32>
+  %reshaped0 = linalg.tensor_reshape %0 [] : tensor<1xi32> into tensor<i32>
+  %c10 = constant 10 : i32
+  %1 = tensor.from_elements %c10 : tensor<1xi32>
+  %reshaped1 = linalg.tensor_reshape %1 [] : tensor<1xi32> into tensor<i32>
+  br ^bb1(%reshaped0 : tensor<i32>)
+
+^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
+  %3 = linalg.init_tensor [] : tensor<i1>
+  %4 = linalg.generic #attrs
+    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    outs(%3 : tensor<i1>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
+      %8 = cmpi slt, %arg0, %arg1 : i32
+      linalg.yield %8 : i1
+  } -> tensor<i1>
+  %5 = tensor.extract %4[] : tensor<i1>
+  cond_br %5, ^bb2(%2 : tensor<i32>), ^bb3
+
+^bb2(%6: tensor<i32>):  // pred: ^bb1
+  %7 = linalg.init_tensor [] : tensor<i32>
+  %8 = linalg.generic #attrs
+    ins(%6, %6 : tensor<i32>, tensor<i32>)
+    outs(%7 : tensor<i32>) {
+    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
+      %9 = addi %arg0, %arg1 : i32
+      linalg.yield %9 : i32
+  } -> tensor<i32>
+  br ^bb1(%8 : tensor<i32>)
+
+^bb3:  // pred: ^bb1
+  return
+}
+
+// CHECK-LABEL: func @main
+//     %c0_i32 = constant 0 : i32
+//     %c10_i32 = constant 10 : i32
+//     br ^bb1(%c0_i32 : i32)
+//   ^bb1(%0: i32):  // 2 preds: ^bb0, ^bb2
+//     %1 = cmpi slt, %0, %c10_i32 : i32
+//     cond_br %1, ^bb2(%0 : i32), ^bb3
+//   ^bb2(%2: i32):  // pred: ^bb1
+//     %3 = addi %2, %2 : i32
+//     br ^bb1(%3 : i32)
+//   ^bb3:  // pred: ^bb1
+//     return
+//   }
diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
--- a/mlir/test/Dialect/Linalg/detensorized_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-detensorize=aggressive-mode | FileCheck %s
 
 #map = affine_map<() -> ()>
 
diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir
deleted file mode 100644
--- a/mlir/test/Dialect/Linalg/detensorized_while.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-// RUN: mlir-opt %s -linalg-detensorize | FileCheck %s
-
-#map0 = affine_map<() -> ()>
-
-#attrs = {
-  indexing_maps = [#map0, #map0, #map0],
-  iterator_types = []
-}
-
-func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {} {
-  br ^bb1(%farg0 : tensor<i32>)
-
-^bb1(%0: tensor<i32>):  // 2 preds: ^bb0, ^bb2
-  %1 = linalg.init_tensor [] : tensor<i1>
-  %2 = linalg.generic #attrs
-    ins(%0, %farg1 : tensor<i32>, tensor<i32>)
-    outs(%1 : tensor<i1>) {
-    ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
-      %8 = cmpi slt, %arg0, %arg1 : i32
-      linalg.yield %8 : i1
-  } -> tensor<i1>
-  %3 = tensor.extract %2[] : tensor<i1>
-  cond_br %3, ^bb2(%0 : tensor<i32>), ^bb3(%0 : tensor<i32>)
-
-^bb2(%4: tensor<i32>):  // pred: ^bb1
-  %5 = linalg.init_tensor [] : tensor<i32>
-  %6 = linalg.generic #attrs
-    ins(%4, %4 : tensor<i32>, tensor<i32>)
-    outs(%5 : tensor<i32>) {
-    ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
-      %8 = addi %arg0, %arg1 : i32
-      linalg.yield %8 : i32
-  } -> tensor<i32>
-  br ^bb1(%6 : tensor<i32>)
-
-^bb3(%7: tensor<i32>):  // pred: ^bb1
-  return %7 : tensor<i32>
-}
-
-// CHECK-LABEL: func @main
-// CHECK-SAME:    (%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>)
-// CHECK:         tensor.extract {{.*}}
-// CHECK:         br ^[[bb1:.*]](%{{.*}} : i32)
-// CHECK:       ^[[bb1]](%{{.*}}: i32)
-// CHECK:         cmpi slt, {{.*}}
-// CHECK:         cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK:       ^[[bb2]](%{{.*}}: i32)
-// CHECK:         addi {{.*}}
-// CHECK:         br ^[[bb1]](%{{.*}} : i32)
-// CHECK:       ^[[bb3]](%{{.*}}: i32)
-// CHECK:         tensor.from_elements {{.*}}
-// CHECK:         linalg.tensor_reshape {{.*}}
-// CHECK:         return %{{.*}} : tensor<i32>