Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -192,62 +192,15 @@
 /// assumes that `reductionOp` has tow operands and one of them is the reduction
 /// initial value.
 static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
-                                 Value outputArg,
-                                 const SmallVector<bool> &reductionMask,
-                                 const BlockAndValueMapping &bvm) {
+                                 Value valueToReduce,
+                                 const SmallVector<bool> &reductionMask) {
   auto maybeKind = getKindForOp(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
-  Value operandToReduce = reduceOp->getOperand(0) == outputArg
-                              ? reduceOp->getOperand(1)
-                              : reduceOp->getOperand(0);
-  Value vec = bvm.lookup(operandToReduce);
-  return b.create<vector::MultiDimReductionOp>(reduceOp->getLoc(), vec,
-                                               reductionMask, *maybeKind);
+  return b.create<vector::MultiDimReductionOp>(
+      reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind);
 }
 
-/// Read the initial value associated to the given `outputOperand`.
-static Value readInitialValue(OpBuilder &b, LinalgOp linalgOp,
-                              OpOperand *outputOperand) {
-  AffineMap map = inversePermutation(
-      reindexIndexingMap(linalgOp.getTiedIndexingMap(outputOperand)));
-  Type readType;
-  if (linalgOp.getShape(outputOperand).empty()) {
-    readType = getElementTypeOrSelf(outputOperand->get());
-  } else {
-    readType = VectorType::get(map.compose(linalgOp.getShape(outputOperand)),
-                               getElementTypeOrSelf(outputOperand->get()));
-  }
-  Value vectorRead = buildVectorRead(b, outputOperand->get(), readType, map);
-  return vectorRead;
-}
-
-/// Assuming `outputOperand` is an output operand of a LinalgOp, determine
-/// whether a reduction is needed to produce a `targetType` and create that
-/// reduction if it is the case.
-static Value reduceIfNeeded(OpBuilder &b, Type targetType, Value value,
-                            OpOperand *outputOperand,
-                            const BlockAndValueMapping &bvm) {
-  LDBG("Reduce " << value << " to type " << targetType);
-  LDBG("In LinalgOp operand #" << outputOperand->getOperandNumber() << "\n"
-                               << *(outputOperand->getOwner()));
-  auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
-  auto vecType = value.getType().dyn_cast<VectorType>();
-  VectorType targetVectorType = targetType.dyn_cast<VectorType>();
-  if (!vecType)
-    return value;
-  if (targetVectorType && vecType.getShape() == targetVectorType.getShape())
-    return value;
-
-  // At this point, we know we need to reduce. Detect the reduction operator.
-  unsigned pos = 0;
-  MLIRContext *ctx = b.getContext();
-  SmallVector<AffineExpr> exprs;
-  for (auto s : linalgOp.iterator_types())
-    if (isParallelIterator(s))
-      exprs.push_back(getAffineDimExpr(pos++, ctx));
-
-  Operation *reduceOp = matchLinalgReduction(outputOperand);
-  assert(reduceOp && "Failed precondition: could not math a reduction");
+static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
   unsigned idx = 0;
   SmallVector<bool> reductionMask(linalgOp.iterator_types().size(), false);
   for (auto attr : linalgOp.iterator_types()) {
@@ -255,24 +208,7 @@
       reductionMask[idx] = true;
     ++idx;
   }
-  assert(reduceOp->getNumOperands() == 2 &&
-         "Only support binary reduce op right now");
-  unsigned outputPos =
-      outputOperand->getOperandNumber() - linalgOp.getNumInputs();
-  Value outputArg = linalgOp.getRegionOutputArgs()[outputPos];
-  // Reduce across the iteration space.
-  Value reduce =
-      buildMultiDimReduce(b, reduceOp, outputArg, reductionMask, bvm);
-
-  // Read the original output value.
-  Value initialValue = readInitialValue(b, linalgOp, outputOperand);
-
-  // Combine the output argument with the reduced value.
-  OperationState state(reduceOp->getLoc(), reduceOp->getName());
-  state.addAttributes(reduceOp->getAttrs());
-  state.addOperands({reduce, initialValue});
-  state.addTypes(initialValue.getType());
-  return b.createOperation(state)->getResult(0);
+  return reductionMask;
 }
 
 /// Build a vector.transfer_write of `value` into `outputOperand` at indices set
@@ -280,8 +216,7 @@
 /// currently being vectorized. If `dest` has null rank, build an memref.store.
 /// Return the produced value or null if no value is produced.
 static Value buildVectorWrite(OpBuilder &b, Value value,
-                              OpOperand *outputOperand,
-                              const BlockAndValueMapping &bvm) {
+                              OpOperand *outputOperand) {
   Operation *write;
   Location loc = value.getLoc();
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
@@ -296,12 +231,9 @@
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
                                b.create<arith::ConstantIndexOp>(loc, 0));
     value = broadcastIfNeeded(b, value, vectorType.getShape());
-    value = reduceIfNeeded(b, vectorType, value, outputOperand, bvm);
     write = b.create<vector::TransferWriteOp>(loc, value, outputOperand->get(),
                                               indices, map);
   } else {
-    value = reduceIfNeeded(b, getElementTypeOrSelf(value), value, outputOperand,
-                           bvm);
     write = vector::TransferWriteOp::createScalarOp(
         b, loc, value, outputOperand->get(), ValueRange{});
   }
@@ -336,7 +268,7 @@
     // TODO: use a map.
     Value vectorValue = bvm.lookup(outputs.value());
     Value newResult = buildVectorWrite(
-        b, vectorValue, linalgOp.getOutputOperand(outputs.index()), bvm);
+        b, vectorValue, linalgOp.getOutputOperand(outputs.index()));
     if (newResult)
       newResults.push_back(newResult);
   }
@@ -379,6 +311,17 @@
   return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
 }
 
+/// Create a new vectorized verstion of `op` with the given operands and types.
+static Operation *createVectorizedOp(OpBuilder &b, Operation *op,
+                                     ValueRange newOperands,
+                                     ArrayRef<Type> types) {
+  OperationState state(op->getLoc(), op->getName());
+  state.addAttributes(op->getAttrs());
+  state.addOperands(newOperands);
+  state.addTypes(types);
+  return b.createOperation(state);
+}
+
 /// Generic vectorization for a single operation `op`, given already vectorized
 /// operands carried by `bvm`. Vectorization occurs as follows:
 ///   1. Try to apply any of the `customVectorizationHooks` and return its
@@ -399,7 +342,8 @@
 /// This function does not update `bvm` but returns a VectorizationStatus that
 /// instructs the caller what `bvm` update needs to occur.
 static VectorizationResult
-vectorizeOneOp(OpBuilder &b, Operation *op, const BlockAndValueMapping &bvm,
+vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op,
+               const BlockAndValueMapping &bvm,
                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
   LDBG("vectorize op " << *op);
 
@@ -422,7 +366,36 @@
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
 
-  // 4. Generic vectorization path for ElementwiseMappable ops.
+  // 4 . Check if the operation is a reduction.
+  for (Value operand : op->getOperands()) {
+    auto arg = operand.dyn_cast<BlockArgument>();
+    if (!arg || arg.getArgNumber() < linalgOp.getNumInputs())
+      continue;
+    SmallVector<Operation *> reductionOps;
+    Value reduceValue = matchReduction(
+        linalgOp.getRegionOutputArgs(),
+        arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps);
+    if (!reduceValue)
+      continue;
+    Value reduceVec = bvm.lookup(reduceValue);
+    Value outputVec = bvm.lookup(operand);
+    auto reduceType = reduceVec.getType().dyn_cast<VectorType>();
+    auto outputType = outputVec.getType().dyn_cast<VectorType>();
+    // Reduce only if needed as the value may already have been reduce for
+    // contraction vectorization.
+    if (!reduceType ||
+        (outputType && reduceType.getShape() == outputType.getShape()))
+      continue;
+    SmallVector<bool> reductionMask = getReductionMask(linalgOp);
+    Value reduce =
+        buildMultiDimReduce(b, reductionOps[0], reduceVec, reductionMask);
+    // Combine the output argument with the reduced value.
+    return VectorizationResult{
+        VectorizationStatus::NewOp,
+        createVectorizedOp(b, op, {reduce, outputVec}, reduce.getType())};
+  }
+
+  // 5. Generic vectorization path for ElementwiseMappable ops.
   //   a. first get the first max ranked shape.
   SmallVector<int64_t, 4> firstMaxRankedShape;
   for (Value operand : op->getOperands()) {
@@ -444,12 +417,10 @@
   });
 
   // Build and return the new op.
-  OperationState state(op->getLoc(), op->getName());
-  state.addAttributes(op->getAttrs());
-  state.addOperands(llvm::to_vector<4>(vectorizedOperands));
-  state.addTypes(llvm::to_vector<4>(returnTypes));
-  return VectorizationResult{VectorizationStatus::NewOp,
-                             b.createOperation(state)};
+  return VectorizationResult{
+      VectorizationStatus::NewOp,
+      createVectorizedOp(b, op, llvm::to_vector<4>(vectorizedOperands),
+                         llvm::to_vector<4>(returnTypes))};
 }
 
 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
@@ -544,7 +515,8 @@
     if (linalgOp.getShape(opOperand).empty()) {
       readType = bbarg.getType();
     } else {
-      if (broadcastToMaximalCommonShape) {
+      if (broadcastToMaximalCommonShape &&
+          opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
         map = inverseAndBroadcastProjectedPermuation(
             linalgOp.getTiedIndexingMap(opOperand));
         readType = VectorType::get(commonVectorShape,
@@ -581,7 +553,7 @@
 
   // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
   for (Operation &op : block.getOperations()) {
-    VectorizationResult result = vectorizeOneOp(b, &op, bvm, hooks);
+    VectorizationResult result = vectorizeOneOp(b, linalgOp, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
       LDBG("failed to vectorize: " << op);
       return failure();
Index: mlir/test/Dialect/Linalg/vectorization.mlir
===================================================================
--- mlir/test/Dialect/Linalg/vectorization.mlir
+++ mlir/test/Dialect/Linalg/vectorization.mlir
@@ -749,9 +749,9 @@
   -> tensor<4x16xf32>
 {
   // CHECK: vector.transfer_read {{.*}} : tensor<4x16x8xf32>, vector<4x16x8xf32>
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
   // CHECK: math.exp {{.*}} : vector<4x16x8xf32>
   // CHECK: vector.multi_reduction #vector.kind<add>, %{{.*}} [2] : vector<4x16x8xf32> to vector<4x16xf32>
-  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true]} : tensor<4x16xf32>, vector<4x16xf32>
   // CHECK: addf {{.*}} : vector<4x16xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4x16xf32>, tensor<4x16xf32>
   // CHECK: return {{.*}} : tensor<4x16xf32>
@@ -782,11 +782,11 @@
 {
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M1]]} : tensor<3x2xf32>, vector<2x3x4x5xf32>
   // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true, true, true], permutation_map = #[[$M2]]} : tensor<5x4xf32>, vector<2x3x4x5xf32>
+  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: math.exp {{.*}} : vector<2x3x4x5xf32>
   // CHECK: addf {{.*}} : vector<2x3x4x5xf32>
   // CHECK: vector.multi_reduction #vector.kind<add>, {{.*}}  [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
-  // CHECK: vector.transfer_read {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : tensor<5x2xf32>, vector<2x5xf32>
   // CHECK: addf {{.*}} : vector<2x5xf32>
   // CHECK: vector.transfer_write {{.*}} {in_bounds = [true, true], permutation_map = #[[$M3]]} : vector<2x5xf32>, tensor<5x2xf32>
   // CHECK: return {{.*}} : tensor<5x2xf32>