Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
===================================================================
--- mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
+++ mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h
@@ -68,6 +68,17 @@
                                 PatternRewriter &rewrite) const override;
 };
 
+/// Transforms type-inconsistent stores, aka stores where the type hint of
+/// the address contradicts the value stored, by inserting a bitcast if
+/// possible.
+class BitcastStores : public OpRewritePattern<StoreOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(StoreOp store,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace LLVM
 } // namespace mlir
 
Index: mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
===================================================================
--- mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
+++ mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
@@ -43,7 +43,7 @@
 }
 
 /// Checks that two types are the same or can be bitcast into one another.
-static bool areCastCompatible(DataLayout &layout, Type lhs, Type rhs) {
+static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
   return lhs == rhs || (!isa<LLVMStructType, LLVMArrayType>(lhs) &&
                         !isa<LLVMStructType, LLVMArrayType>(rhs) &&
                         layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
@@ -104,7 +104,7 @@
   if (!firstType)
     return failure();
   DataLayout layout = DataLayout::closest(load);
-  if (!areCastCompatible(layout, firstType, load.getResult().getType()))
+  if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
     return failure();
 
   insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
@@ -144,20 +144,13 @@
   DataLayout layout = DataLayout::closest(store);
   // Check that the first field has the right type or can at least be bitcast
   // to the right type.
-  if (!areCastCompatible(layout, firstType, store.getValue().getType()))
+  if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
     return failure();
 
   insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
 
-  Value replaceValue = store.getValue();
-  if (firstType != store.getValue().getType()) {
-    rewriter.setInsertionPointAfterValue(store.getValue());
-    replaceValue = rewriter.create<BitcastOp>(store->getLoc(), firstType,
-                                              store.getValue());
-  }
-
   rewriter.updateRootInPlace(
-      store, [&]() { store.getValueMutable().assign(replaceValue); });
+      store, [&]() { store.getValueMutable().assign(store.getValue()); });
 
   return success();
 }
@@ -458,12 +451,6 @@
 
     IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
     Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
-    if (fieldIntType != type) {
-      // Bitcast to the right type. `fieldIntType` was explicitly created
-      // to be of the same size as `type` and must currently be a primitive as
-      // well.
-      valueToStore = rewriter.create<BitcastOp>(loc, type, valueToStore);
-    }
 
     // We create an `i8` indexed GEP here as that is the easiest (offset is
     // already known). Other patterns turn this into a type-consistent GEP.
@@ -558,6 +545,26 @@
   return success();
 }
 
+LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
+                                             PatternRewriter &rewriter) const {
+  Type sourceType = store.getValue().getType();
+  Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
+  if (!typeHint) {
+    // Nothing to do, since it is already consistent.
+    return failure();
+  }
+
+  auto dataLayout = DataLayout::closest(store);
+  if (!areBitcastCompatible(dataLayout, typeHint, sourceType))
+    return failure();
+
+  auto bitcastOp =
+      rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
+  rewriter.updateRootInPlace(
+      store, [&] { store.getValueMutable().assign(bitcastOp); });
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Type consistency pass
 //===----------------------------------------------------------------------===//
@@ -572,6 +579,7 @@
         &getContext());
     rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
     rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
+    rewritePatterns.add<BitcastStores>(&getContext());
     FrozenRewritePatternSet frozen(std::move(rewritePatterns));
 
     if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))
Index: mlir/test/Dialect/LLVMIR/type-consistency.mlir
===================================================================
--- mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -218,8 +218,8 @@
   // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
   // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64
   // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32
-  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr)  -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)>
+  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32
   // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
   llvm.store %arg, %1 : i64, !llvm.ptr
   // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
@@ -409,12 +409,27 @@
 // CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
 llvm.func @type_consistent_vector_store_other_type(%arg: vector<4xi32>) {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (vector<4xf32>)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (vector<4xf32>)> : (i32) -> !llvm.ptr
   // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (vector<4xf32>)>
+  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : vector<4xi32> to vector<4xf32>
   // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]]
   llvm.store %arg, %1 : vector<4xi32>, !llvm.ptr
   // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: llvm.func @bitcast_insertion
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @bitcast_insertion(%arg: i32) {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x f32
+  %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+  // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[ARG]] : i32 to f32
+  // CHECK: llvm.store %[[BIT_CAST]], %[[ALLOCA]]
+  llvm.store %arg, %1 : i32, !llvm.ptr
+  // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
+  llvm.return
+}