diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -1087,18 +1087,20 @@
     >,
     InterfaceMethod<
       /*desc=*/[{
-        Return the position in the results of the affine map computed
-        by getLoopsToShapesMap() that represents the shape of the
-        result value at a dimension.
+        Return the range of position in the result of the affine map
+        computed by getLoopsToShapesMap() which correspond to the
+        AffineExprs used to access the outputs of the operation.
       }],
-      /*retTy=*/"Optional<unsigned>",
-      /*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
-      /*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
+      /*retTy=*/"std::pair<unsigned, unsigned>",
+      /*methodName=*/"getResultsPositionInLoopsToShapeMap",
+      /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        if (resultIdx >= getNumOutputs()) return {};
-        return getOperandDimPositionInLoopsToShapeMap(
-            getNumInputs() + resultIdx, dim);
+        return
+          {*getOperandDimPositionInLoopsToShapeMap(getNumInputs(), 0),
+           (*getOperandDimPositionInLoopsToShapeMap
+                 (getNumInputs() + getNumOutputs() - 1,
+                  getOutputShapedType(getNumOutputs()-1).getRank() - 1)) + 1};
       }]
     >,
     InterfaceMethod<
@@ -1226,8 +1228,8 @@
 
     /// Returns the value that expresses the shape of the output in terms of
     /// shape of the input operands where possible
-    Optional<Value> inferResultDimFromInputShapes
-      (OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
+    LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
+        SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes);
 
     //========================================================================//
     // Helper functions to mutate the `operand_segment_sizes` attribute.
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Support/LLVM.h"
@@ -107,13 +108,6 @@
 void getDimsOfType(Operation *op, StringRef iteratorTypeName,
                    SmallVectorImpl<AffineExpr> &res);
 
-/// For reshape operation, compute the shape of the output based on the result
-/// type and shape of the input.
-SmallVector<Value, 4>
-getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src,
-                                    ArrayRef<int64_t> dstStaticShape,
-                                    ArrayRef<AffineMap> reassociation);
-
 namespace detail {
 LogicalResult verifyStructuredOpInterface(Operation *op);
 } // namespace detail
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
@@ -33,7 +34,10 @@
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> {
+def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
+    [NoSideEffect,
+     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+         ["reifyReturnTypeShapesPerResultDim"]>]> {
   let summary = "operation to define a tensor of particular value";
 
   let description = [{
@@ -126,7 +130,10 @@
 }
 
 def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
-    [AttrSizedOperandSegments, NoSideEffect]> {
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+         ["reifyReturnTypeShapesPerResultDim"]>,
+     NoSideEffect]> {
   let summary = "tensor pad operation";
   let description = [{
     `linalg.pad_tensor` is an operation that pads the `source` tensor
@@ -348,11 +355,6 @@
               a.cast<AffineMapAttr>().getValue().getResults());
           }));
     }
-    SmallVector<Value, 4> getOutputShape(OpBuilder &b, Location loc) {
-      return getReshapeOutputShapeFromInputShape(
-          b, loc, src(), getResultType().getShape(),
-          getReassociationMaps());
-    }
   }];
   let assemblyFormat = [{
     $src $reassociation attr-dict `:` type($src) `into` type(results)
@@ -417,7 +419,10 @@
   let hasCanonicalizer = 1;
 }
 
-def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
+def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
+    "tensor_reshape",
+    [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+      ["reifyReturnTypeShapesPerResultDim"]>]>,
     Arguments<(ins AnyTensor:$src,
                    AffineMapArrayAttr:$reassociation)>,
     Results<(outs AnyTensor:$result)> {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -17,6 +17,7 @@
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // Base Tablegen class for Linalg ops.
@@ -25,7 +26,7 @@
 // depending on the specific Linalg op.
 class LinalgStructuredBase_Op<string mnemonic, list<OpTrait> props>
   : Op<Linalg_Dialect, mnemonic, !listconcat(props, [
-       LinalgStructuredInterface])> {
+       LinalgStructuredInterface, InferShapedTypeOpInterface])> {
   code structuredOpsBaseDecls = [{
     // Return the number of induction variables in the basic block. This should
     // always be 0 for index-free linalg ops. For IndexedGeneric, this must be
@@ -33,6 +34,12 @@
     unsigned getNumPayloadInductionVariables() {
       return isa<IndexedGenericOp>(this->getOperation()) ? getNumLoops() : 0;
     }
+
+    LogicalResult reifyReturnTypeShapesPerResultDim(OpBuilder &b,
+        SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+      return cast<LinalgOp>(getOperation()).reifyReturnTypeShapesPerResultDim(b,
+          reifiedReturnShapes);
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -97,21 +97,53 @@
                     "::mlir::DictionaryAttr":$attributes,
                     "::mlir::RegionRange":$regions,
                     "::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&":
-                      $inferredReturnShapes)
+                      $inferredReturnShapes),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
     >,
     InterfaceMethod<
       /*desc=*/[{Reify the shape computation for the operation.
 
-      Insert operations using the given OpBuilder that computes the result
-      shape.
+      Insert operations using the given OpBuilder that computes the
+      result shape. Only one of this method or
+      `reifyReturnTypeShapesPerResultDim` needs to be overriden by the
+      operation.
       }],
       /*retTy=*/"::mlir::LogicalResult",
       /*methodName=*/"reifyReturnTypeShapes",
       /*args=*/(ins "::mlir::OpBuilder&":$builder,
-                    "::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes),
+          "::mlir::SmallVectorImpl<Value> &":$reifiedReturnShapes),
       /*methodBody=*/[{}],
       /*defaultImplementation=*/[{ return ::mlir::failure(); }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{Reify the shape computation for the operation.
+
+      Insert operations using the given OpBuilder that computes the
+      result shape. The `reifiedReturnShapes` is expected to be
+      populated with as many vectors as the number of results of the
+      op (empty if the shape of a result value cannot be computed). If
+      the returned shape for a result is not empty, its size must
+      match the rank of the shaped type returned. Consequently, this
+      interface can only be overridden if the return types are ranked.
+
+      If both this method and `reifyReturnTypeShapes` are overridden
+      by the operation, `reifyReturnTypeShapes` takes precedence. This
+      method is intended to be used when the shape of each result, dim
+      pair can be computed independently. Using this method avoids
+      adding additional instructions to aggregate individual dimension
+      of a result shape into an single `Value` (and consequently
+      avoids the need to extract the value from the shape on the
+      client side).
+      }],
+      /*retTy=*/"::mlir::LogicalResult",
+      /*methodName=*/"reifyReturnTypeShapesPerResultDim",
+      /*args=*/(ins "::mlir::OpBuilder&":$builder,
+          "::mlir::SmallVectorImpl<SmallVector<::mlir::Value>>&"
+          :$reifiedReturnShapes),
+      /*methodBody=*/[{}],
+      /*defaultImplementation=*/[{ return ::mlir::failure(); }]
+    >
   ];
 }
 
@@ -129,6 +161,7 @@
     NativeOpTrait<"InferTensorType">
   ];
 }
-defvar InferTensorTypeWithReify = InferTensorType<["reifyReturnTypeShapes"]>;
+defvar InferTensorTypeWithReify = InferTensorType<[
+    "inferReturnTypeComponents", "reifyReturnTypeShapes"]>;
 
 #endif // MLIR_INFERTYPEOPINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -14,6 +14,7 @@
   LINK_LIBS PUBLIC
   MLIRAffine
   MLIRDialectUtils
+  MLIRInferTypeOpInterface
   MLIRIR
   MLIRParser
   MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -188,7 +188,7 @@
   for (Value v : getShapedOperands()) {
     ShapedType t = v.getType().template cast<ShapedType>();
     for (unsigned i = 0, e = t.getRank(); i < e; ++i)
-      res.push_back(b.create<memref::DimOp>(loc, v, i));
+      res.push_back(b.createOrFold<memref::DimOp>(loc, v, i));
   }
   return res;
 }
@@ -234,57 +234,58 @@
   llvm::SmallSet<unsigned, 4> positions;
 };
 
-Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
-                                                        Location loc,
-                                                        unsigned resultIdx,
-                                                        unsigned dim) {
+LogicalResult LinalgOp::reifyReturnTypeShapesPerResultDim(
+    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
   // An example that helps understand the logic below.
   // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
   // We want to express the shape of dim 0 of O in terms of shape of the inputs.
   // This is achieved as follows.
   //   loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
-  //   subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
+  //   subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
   //   shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
-  //   resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
-  //     = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
+  //   resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
+  //     = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
   AffineMap loopsToShapesMap = getLoopsToShapesMap();
 
   // Find the position in the above map that represents the shape of the
   // result:dim being inferred.
-  Optional<unsigned> resultDimSubMapPos =
-      getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
-  if (!resultDimSubMapPos)
-    return {};
+  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap();
 
   /// From loopsToShapesMap extract the submap that represents the shape of the
-  /// (resultIdx, dim) needed
-  AffineMap loopToResultDimShapeMap =
-      loopsToShapesMap.getSubMap(*resultDimSubMapPos);
-  AffineMap operandShapesToResultDimMap =
-      loopToResultDimShapeMap.compose(getShapesToLoopsMap());
+  /// (resultIdx, dim) needed.
+  SmallVector<unsigned, 4> resultPosRange =
+      llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
+                                             resultShapesSubMapPos.second));
+  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
+  AffineMap resultShapesFromInputShapesMap =
+      loopToResultsShapeMap.compose(getShapesToLoopsMap());
 
   // Check that the result dim map does not contain the positions corresponding
   // to the outputs.
   llvm::SmallSet<unsigned, 4> outputDims;
-  unsigned outputDimPosStart =
-      getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
-  unsigned outputDimPosEnd =
-      getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
-                                                 getOutputOpOperands()
-                                                         .back()
-                                                         .get()
-                                                         .getType()
-                                                         .cast<ShapedType>()
-                                                         .getRank() -
-                                                     1)
-          .getValue();
-  llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
+  llvm::for_each(resultPosRange,
                  [&outputDims](unsigned dim) { outputDims.insert(dim); });
   HasAffineDimExprVisitor checkDimExpr(outputDims);
-  if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
-    return llvm::None;
-  return applyMapToValues(b, loc, operandShapesToResultDimMap,
-                          createFlatListOfOperandDims(b, loc))[0];
+  Location loc = getOperation()->getLoc();
+  auto allResultDimValues =
+      applyMapToValues(b, loc, resultShapesFromInputShapesMap,
+                       createFlatListOfOperandDims(b, loc));
+  unsigned pos = 0;
+  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
+  for (auto resultIdx : llvm::seq<unsigned>(0, getNumOutputs())) {
+    ShapedType resultType = getOutputShapedType(resultIdx);
+    SmallVector<Value> shapes;
+    for (unsigned dim : llvm::seq<unsigned>(0, resultType.getRank())) {
+      if (checkDimExpr.visit(shapeExprs[pos]))
+        shapes.push_back(
+            b.createOrFold<memref::DimOp>(loc, getOutput(resultIdx), dim));
+      else
+        shapes.push_back(allResultDimValues[pos]);
+      pos++;
+    }
+    reifiedReturnShapes.emplace_back(std::move(shapes));
+  }
+  return success();
 }
 
 LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser.h"
 
 #include "llvm/ADT/DenseMap.h"
@@ -88,6 +89,33 @@
 template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op);
 
+/// Helper function to convert a Value into an OpFoldResult, if the Value is
+/// known to be a constant index value.
+static SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
+  return llvm::to_vector<4>(
+      llvm::map_range(values, [](Value v) -> OpFoldResult {
+        APInt intValue;
+        if (v.getType().isa<IndexType>() &&
+            matchPattern(v, m_ConstantInt(&intValue))) {
+          return IntegerAttr::get(v.getType(), intValue.getSExtValue());
+        }
+        return v;
+      }));
+}
+
+/// Helper function to convert a vector of `OpFoldResult`s into a vector of
+/// `Value`s.
+static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
+                                      ArrayRef<OpFoldResult> valueOrAttrVec) {
+  return llvm::to_vector<4>(
+      llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
+        if (auto attr = value.dyn_cast<Attribute>())
+          return b.create<ConstantIndexOp>(loc,
+                                           attr.cast<IntegerAttr>().getInt());
+        return value.get<Value>();
+      }));
+}
+
 /// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
 /// it is a Value or into `staticVec` if it is an IntegerAttr.
 /// In the case of a Value, a copy of the `sentinel` value is also pushed to
@@ -679,10 +707,6 @@
   SmallVector<Value, 4> dynamicSizes;
   SmallVector<int64_t, 4> staticSizes;
   for (unsigned i = 0; i < rank; ++i) {
-    // staticLow and staticHigh have full information of the padding config.
-    // This will grow staticLow and staticHigh with 1 value. If the config is
-    // dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
-    // value as well.
     dispatchIndexOpFoldResult(sizes[i], dynamicSizes, staticSizes,
                               ShapedType::kDynamicSize);
   }
@@ -771,33 +795,6 @@
     return success();
   }
 };
-
-/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim`
-/// with
-/// - A constant value if the size is static along the dimension.
-/// - The dynamic value that defines the size of the result of
-///   `linalg.init_tensor` op.
-struct ReplaceDimOfInitTensorOp : public OpRewritePattern<memref::DimOp> {
-  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(memref::DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp<InitTensorOp>();
-    if (!initTensorOp)
-      return failure();
-    auto dimIndex = dimOp.index().getDefiningOp<ConstantIndexOp>();
-    if (!dimIndex)
-      return failure();
-    int64_t index = dimIndex.getValue();
-    if (!initTensorOp.isDynamicSize(index)) {
-      rewriter.replaceOpWithNewOp<ConstantIndexOp>(
-          dimOp, initTensorOp.getStaticSize(index));
-    } else {
-      rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index));
-    }
-    return success();
-  }
-};
 } // namespace
 
 namespace {
@@ -831,12 +828,20 @@
     if (!reshapeOp.src().getDefiningOp<InitTensorOp>())
       return failure();
     Location loc = reshapeOp.getLoc();
-    SmallVector<Value, 4> resultShapeValues =
-        reshapeOp.getOutputShape(rewriter, loc);
+    SmallVector<SmallVector<Value>, 4> resultShapes;
+    if (failed(reshapeOp.reifyReturnTypeShapesPerResultDim(rewriter,
+                                                           resultShapes)) ||
+        !llvm::hasSingleElement(resultShapes))
+      return failure();
     Value initTensor = rewriter.create<InitTensorOp>(
-        loc, resultShapeValues, reshapeOp.getResultType().getElementType());
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        reshapeOp, reshapeOp.getResultType(), initTensor);
+        loc, getAsOpFoldResult(resultShapes[0]),
+        reshapeOp.getResultType().getElementType());
+    if (initTensor.getType() != reshapeOp.getResultType()) {
+      rewriter.replaceOpWithNewOp<tensor::CastOp>(
+          reshapeOp, reshapeOp.getResultType(), initTensor);
+    } else {
+      rewriter.replaceOp(reshapeOp, initTensor);
+    }
     return success();
   }
 };
@@ -845,7 +850,20 @@
 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {
   results.add<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
-              ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
+              ReplaceStaticShapeDims>(context);
+}
+
+LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
+    OpBuilder &builder,
+    SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+  auto shapes = llvm::to_vector<4>(llvm::map_range(
+      llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
+        if (isDynamicSize(dim))
+          return getDynamicSize(dim);
+        return builder.create<ConstantIndexOp>(getLoc(), getStaticSize(dim));
+      }));
+  reifiedReturnShapes.emplace_back(std::move(shapes));
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -997,6 +1015,37 @@
                                         builder);
 }
 
+LogicalResult PadTensorOp::reifyReturnTypeShapesPerResultDim(
+    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+  Location loc = getLoc();
+  auto lowPad = getMixedLowPad();
+  auto highPad = getMixedHighPad();
+  SmallVector<Value> shapes;
+  for (auto dim : llvm::seq<int64_t>(0, getSourceType().getRank())) {
+    // Shape along each dimension is source dim + low pad + high pad.
+    SmallVector<Value> mapOperands;
+    mapOperands.push_back(b.createOrFold<memref::DimOp>(loc, source(), dim));
+    AffineExpr expr = b.getAffineDimExpr(0);
+    unsigned numSymbols = 0;
+    auto addOpFoldResult = [&](OpFoldResult valueOrAttr) {
+      if (Value v = valueOrAttr.dyn_cast<Value>()) {
+        expr = expr + b.getAffineSymbolExpr(numSymbols++);
+        mapOperands.push_back(v);
+        return;
+      }
+      int64_t staticValue =
+          valueOrAttr.get<Attribute>().cast<IntegerAttr>().getInt();
+      expr = expr + staticValue;
+    };
+    addOpFoldResult(lowPad[dim]);
+    addOpFoldResult(highPad[dim]);
+    shapes.push_back(applyMapToValues(
+        b, loc, AffineMap::get(1, numSymbols, expr), mapOperands)[0]);
+  }
+  reifiedReturnShapes.emplace_back(std::move(shapes));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ReshapeOp
 //===----------------------------------------------------------------------===//
@@ -1281,7 +1330,7 @@
 /// terms of shape of the `src`, when the reshape op is a collapsing
 /// operation. It is the product of the shape of the collapsed dimensions of the
 /// `src`.
-static Value
+static OpFoldResult
 getCollapsedOutputDimFromInputShape(OpBuilder &builder, Location loc,
                                     int64_t dimIndex, Value src,
                                     ArrayRef<AffineMap> reassociationMap) {
@@ -1292,7 +1341,7 @@
   AffineExpr expr;
   SmallVector<Value, 2> dynamicDims;
   for (auto dim : llvm::seq(startPos, endPos + 1)) {
-    dynamicDims.push_back(builder.create<memref::DimOp>(loc, src, dim));
+    dynamicDims.push_back(builder.createOrFold<memref::DimOp>(loc, src, dim));
     AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos);
     expr = (expr ? expr * currExpr : currExpr);
   }
@@ -1303,7 +1352,7 @@
 
 /// Given the `src` of a collapsing reshape op and its reassociation maps,
 /// compute the shape of the result of the reshape.
-static SmallVector<Value, 4> getCollapsedOutputShapeFromInputShape(
+static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
     OpBuilder &builder, Location loc, Value src,
     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
   return llvm::to_vector<4>(llvm::map_range(
@@ -1333,12 +1382,12 @@
 
 /// For an expanding reshape op, compute the value for a dimension of the output
 /// from the shape of the input.
-static Value getExpandedOutputDimFromInputShape(
+static OpFoldResult getExpandedOutputDimFromInputShape(
     OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
     llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
   if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
-    return builder.create<ConstantIndexOp>(loc, dstStaticShape[dimIndex]);
+    return builder.getI64IntegerAttr(dstStaticShape[dimIndex]);
   }
   unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
   unsigned startPos = reassociation[sourceDimPos]
@@ -1371,7 +1420,7 @@
 
 /// Given the `src` of an expanding reshape op, the reassociation maps and the
 /// result type, compute the shape of the result of the reshape.
-static SmallVector<Value, 4> getExpandedOutputShapeFromInputShape(
+static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
     OpBuilder &builder, Location loc, Value src,
     ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
   llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
@@ -1384,9 +1433,10 @@
       }));
 }
 
-SmallVector<Value, 4> mlir::linalg::getReshapeOutputShapeFromInputShape(
-    OpBuilder &builder, Location loc, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassocation) {
+static SmallVector<OpFoldResult, 4>
+getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
+                                    ArrayRef<int64_t> dstStaticShape,
+                                    ArrayRef<AffineMap> reassocation) {
   return dstStaticShape.size() >
                  static_cast<size_t>(src.getType().cast<ShapedType>().getRank())
              ? getExpandedOutputShapeFromInputShape(
@@ -1395,23 +1445,6 @@
                    builder, loc, src, dstStaticShape, reassocation);
 }
 
-/// For a reshape op, compute the value of a given dimension of the output
-/// (`dimIndex`) from the shape of the inputs and type of the result.
-static Value getReshapeOutputDimFromInputShape(
-    OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
-    ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
-  if (dstStaticShape.size() >
-      static_cast<size_t>(src.getType().cast<ShapedType>().getRank())) {
-    llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
-        getExpandedDimToCollapsedDimMap(reassociation);
-    return getExpandedOutputDimFromInputShape(builder, loc, dimIndex, src,
-                                              dstStaticShape, reassociation,
-                                              expandedDimToCollapsedDim);
-  }
-  return getCollapsedOutputDimFromInputShape(builder, loc, dimIndex, src,
-                                             reassociation);
-}
-
 void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result,
                                     Value src,
                                     ArrayRef<ReassociationExprs> reassociation,
@@ -1636,29 +1669,6 @@
   }
 };
 
-/// Canonicalize dim ops that use the output shape with dim of the input.
-struct ReplaceDimOfReshapeOpResult : OpRewritePattern<memref::DimOp> {
-  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
-  LogicalResult matchAndRewrite(memref::DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    Value dimValue = dimOp.memrefOrTensor();
-    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
-    if (!dimIndex)
-      return failure();
-
-    auto reshapeOp = dimValue.getDefiningOp<TensorReshapeOp>();
-    if (!reshapeOp)
-      return failure();
-
-    rewriter.replaceOp(dimOp,
-                       getReshapeOutputDimFromInputShape(
-                           rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(),
-                           reshapeOp.getResultType().getShape(),
-                           reshapeOp.getReassociationMaps()));
-    return success();
-  }
-};
-
 /// Fold linalg.fill -> linalg.tensor_reshape chain.
 ///
 /// For such op chains, we can create new linalg.fill ops with the result
@@ -1684,7 +1694,18 @@
 void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                   MLIRContext *context) {
   results.add<CollapseReshapeOps<TensorReshapeOp>, FoldFillWithTensorReshape,
-              FoldReshapeWithConstant, ReplaceDimOfReshapeOpResult>(context);
+              FoldReshapeWithConstant>(context);
+}
+
+LogicalResult TensorReshapeOp::reifyReturnTypeShapesPerResultDim(
+    OpBuilder &b, SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
+  auto resultShape =
+      getAsValues(b, getLoc(),
+                  getReshapeOutputShapeFromInputShape(
+                      b, getLoc(), src(), getResultType().getShape(),
+                      getReassociationMaps()));
+  reifiedReturnShapes.emplace_back(std::move(resultShape));
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -2544,50 +2565,6 @@
     return success();
   }
 };
-
-/// Replaces memref.dim operations that use the result of a LinalgOp (on
-/// tensors) with memref.dim operations that use one of the arguments. For
-/// example,
-///
-/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
-/// %1 = memref.dim %0, %c0
-///
-/// with
-///
-/// %1 = memref.dim %arg0, %c0
-///
-/// where possible. With this the result of the `linalg.matmul` is not used in
-/// dim operations. If the value produced is replaced with another value (say by
-/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
-/// used in a dim op that would prevent the DCE of this op.
-struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<memref::DimOp> {
-  using OpRewritePattern<memref::DimOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(memref::DimOp dimOp,
-                                PatternRewriter &rewriter) const override {
-    Value dimValue = dimOp.memrefOrTensor();
-    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
-    if (!dimIndex)
-      return failure();
-    auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
-    if (!linalgOp)
-      return failure();
-
-    unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
-    Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
-        rewriter, dimOp.getLoc(), resultIndex,
-        static_cast<unsigned>(*dimIndex));
-    if (!operandDimValue) {
-      // Its always possible to replace using the corresponding `outs`
-      // parameter.
-      operandDimValue = rewriter.create<memref::DimOp>(
-          dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
-    }
-    rewriter.replaceOp(dimOp, *operandDimValue);
-    return success();
-  }
-};
-
 } // namespace
 
 namespace {
@@ -2745,7 +2722,7 @@
   void XXX::getCanonicalizationPatterns(RewritePatternSet &results,            \
                                         MLIRContext *context) {                \
     results.add<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp,        \
-                RemoveIdentityLinalgOps, ReplaceDimOfLinalgOpResult>(context); \
+                RemoveIdentityLinalgOps>(context);                             \
   }                                                                            \
                                                                                \
   LogicalResult XXX::fold(ArrayRef<Attribute>,                                 \
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
@@ -673,12 +674,84 @@
     return success();
   }
 };
+
+/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
+/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
+/// TODO(ravishankarm): This is better put as a interface utility method
+/// somewhere, but that would imply the interface will depend on the `tensor`
+/// dialect. Ideally maybe a utility method in the `tensor` dialect.
+static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
+                                            int64_t dimIndex) {
+  unsigned resultNumber = result.getResultNumber();
+  auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
+  Location loc = result.getOwner()->getLoc();
+  if (!shapedTypeOp)
+    return nullptr;
+
+  // The interface exposes two methods, one that returns the shape of all the
+  // results as `Value` and other that returns the shape as a list of
+  // `SmallVector<Value>`. The former takes precedence over the latter. So first
+  // check if the op implements the first interface method or the second, and
+  // get the value to use appropriately.
+  SmallVector<Value> reifiedResultShapes;
+  if (succeeded(
+          shapedTypeOp.reifyReturnTypeShapes(builder, reifiedResultShapes))) {
+    if (reifiedResultShapes.size() <= resultNumber)
+      return nullptr;
+    Value resultShape = reifiedResultShapes[resultNumber];
+    auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
+    if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
+      return nullptr;
+    return builder.create<tensor::ExtractOp>(
+        loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
+  }
+
+  SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
+  if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
+          builder, reifiedResultShapesPerDim)))
+    return nullptr;
+  if (reifiedResultShapesPerDim.size() <= resultNumber ||
+      reifiedResultShapesPerDim[resultNumber].size() !=
+          static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
+    return nullptr;
+  OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
+  if (auto attr = valueOrAttr.dyn_cast<Attribute>())
+    return builder.createOrFold<ConstantIndexOp>(
+        loc, attr.cast<IntegerAttr>().getInt());
+  return valueOrAttr.get<Value>();
+}
+
+/// Fold dim of an operation that implements the InferShapedTypeOpInterface
+struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
+  using OpRewritePattern<DimOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(DimOp dimOp,
+                                PatternRewriter &rewriter) const override {
+    OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
+    if (!dimValue)
+      return failure();
+    auto shapedTypeOp =
+        dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
+    if (!shapedTypeOp)
+      return failure();
+
+    Optional<int64_t> dimIndex = dimOp.getConstantIndex();
+    if (!dimIndex)
+      return failure();
+    Value replacement =
+        getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
+    if (!replacement)
+      return failure();
+    rewriter.replaceOp(dimOp, replacement);
+    return success();
+  }
+};
 } // end anonymous namespace.
 
 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
   results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
-              DimOfCastOp<tensor::CastOp>>(context);
+              DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
 }
 
 // ---------------------------------------------------------------------------
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -12,7 +12,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Interfaces/InferTypeOpInterface.h"
-
 #include "mlir/IR/BuiltinTypes.h"
 
 using namespace mlir;
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -404,12 +404,13 @@
 
 func @remove_dim_result_uses
   (%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
-   %arg2 : tensor<?x?xf32>) -> (index) {
+   %arg2 : tensor<?x?xf32>) -> (index, index) {
   %c0 = constant 0 : index
+  %c1 = constant 1 : index
   %0 = linalg.generic
     {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
                       affine_map<(d0, d1, d2) -> (d2, d1)>,
-                      affine_map<(d0, d1, d2) -> (d0 + d1, d1)>],
+                      affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
      iterator_types = ["parallel", "parallel", "reduction"]}
     ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
     outs(%arg2 : tensor<?x?xf32>) {
@@ -419,9 +420,11 @@
       linalg.yield %2 : f32
     } -> tensor<?x?xf32>
   %3 = memref.dim %0, %c0 : tensor<?x?xf32>
-  return %3 : index
+  %4 = memref.dim %0, %c1 : tensor<?x?xf32>
+  return %3, %4 : index, index
 }
-//       CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+//       CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
 //       CHECK: func @remove_dim_result_uses
 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@@ -430,8 +433,11 @@
 //   CHECK-DAG:   %[[C1:.+]] = constant 1 : index
 //   CHECK-DAG:   %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
 //   CHECK-DAG:   %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
-//       CHECK:   %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]]
-//       CHECK:   return %[[T2]]
+//       CHECK:   %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
+//   CHECK-DAG:   %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
+//       CHECK:   %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
+//       CHECK:   return %[[T2]], %[[T5]]
 
 // -----
 
@@ -861,3 +867,38 @@
 // CHECK-SAME: outs (%[[C]]:memref<192x192xf32>) {
 // CHECK-NEXT:   call @foo(%[[A]], %[[B]], %[[C]])
 // CHECK-NEXT:   linalg.yield
+
+// -----
+
+func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
+    %arg3: f32) -> (index, index, index)
+{
+   %c0 = constant 0 : index
+   %c1 = constant 1 : index
+   %c2 = constant 2 : index
+   %c3 = constant 3 : index
+   %c4 = constant 4 : index
+   %c5 = constant 5 : index
+   %0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
+     ^bb0(%arg4: index, %arg5: index, %arg6: index):
+       linalg.yield %arg3 : f32
+   } : tensor<2x?x?xf32> to tensor<?x?x?xf32>
+   %1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
+   %2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
+   %3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
+   return %1, %2, %3 : index, index, index
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
+//      CHECK: func @dim_of_pad_op
+// CHECK-SAME:   %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
+// CHECK-SAME:   %[[ARG1:[A-Za-z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[A-Za-z0-9_]+]]: index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = constant 2 : index
+//  CHECK-DAG:   %[[C12:.+]] = constant 12 : index
+//      CHECK:   %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
+//      CHECK:   %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
+//      CHECK:   %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
+//      CHECK:   %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
+//      CHECK:   return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -81,3 +81,27 @@
   %0 = "test.passthrough_fold"(%c42) : (f32) -> (i32)
   return %0 : i32
 }
+
+// CHECK-LABEL: func @result_shape_per_dim
+// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
+func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
+    -> (index, index, index, index, index) {
+  // CHECK-DAG: %[[C0:.+]] = constant 0 : index
+  // CHECK-DAG: %[[C2:.+]] = constant 2 : index
+  // CHECK-DAG: %[[C3:.+]] = constant 3 : index
+  // CHECK-DAG: %[[C5:.+]] = constant 5 : index
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
+      : (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
+  %1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
+  %2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
+  %3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
+  %4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
+  %5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
+  // CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
+  // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
+  // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
+  return %1, %2, %3, %4, %5 : index, index, index, index, index
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -754,6 +754,25 @@
   return success();
 }
 
+LogicalResult
+OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
+    OpBuilder &builder,
+    llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
+  SmallVector<Value> operand1Shape, operand2Shape;
+  Location loc = getLoc();
+  for (auto i :
+       llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
+    operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
+  }
+  for (auto i :
+       llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
+    operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
+  }
+  shapes.emplace_back(std::move(operand2Shape));
+  shapes.emplace_back(std::move(operand1Shape));
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Test SideEffect interfaces
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -549,7 +549,8 @@
 }
 
 def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
-    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+    DeclareOpInterfaceMethods<InferTypeOpInterface,
+        ["inferReturnTypeComponents"]>]> {
   let arguments = (ins AnyTensor, AnyTensor);
   let results = (outs AnyTensor);
 }
@@ -560,6 +561,13 @@
   let results = (outs AnyTensor);
 }
 
+def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
+      [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+           ["reifyReturnTypeShapesPerResultDim"]>]> {
+  let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+  let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
 def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
 
 def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -128,11 +128,13 @@
   // Use permutations of 2 args as operands.
   auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
   SmallVector<Value, 2> shapes;
-  if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
+  if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)) ||
+      !llvm::hasSingleElement(shapes))
     return;
-  for (auto it : llvm::enumerate(shapes))
+  for (auto it : llvm::enumerate(shapes)) {
     op->emitRemark() << "value " << it.index() << ": "
                      << it.value().getDefiningOp();
+  }
 }
 
 struct TestReturnTypeDriver