diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -42,6 +42,7 @@ OpBuilder &builder, Location loc, int64_t dimIndex, Value src, ArrayRef dstStaticShape, ArrayRef reassociationMap) { if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + // Static dimension: return Attribute. return builder.getIndexAttr(dstStaticShape[dimIndex]); } AffineMap map = reassociationMap[dimIndex]; @@ -55,9 +56,12 @@ AffineExpr currExpr = builder.getAffineSymbolExpr(dim - startPos); expr = (expr ? expr * currExpr : currExpr); } - return affine::makeComposedFoldedAffineApply( - builder, loc, AffineMap::get(0, endPos - startPos + 1, expr), - dynamicDims); + + // Dynamic dimension: return Value. + return affine::makeComposedAffineApply( + builder, loc, AffineMap::get(0, endPos - startPos + 1, expr), + dynamicDims) + ->getResult(0); } /// Given the `src` of a collapsing reshape op and its reassociation maps, @@ -79,6 +83,7 @@ ArrayRef dstStaticShape, ArrayRef reassociation, llvm::DenseMap &expandedDimToCollapsedDim) { if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { + // Static dimension: return Attribute. return builder.getIndexAttr(dstStaticShape[dimIndex]); } unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; @@ -104,11 +109,15 @@ } OpFoldResult sourceDim = builder.create(loc, src, sourceDimPos).getResult(); - return affine::makeComposedFoldedAffineApply( - builder, loc, - AffineMap::get( - 0, 1, builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), - sourceDim); + + // Dynamic dimension: return Value. + return affine::makeComposedAffineApply( + builder, loc, + AffineMap::get( + 0, 1, + builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), + sourceDim) + ->getResult(0); } /// Given the `src` of an expanding reshape op, the reassociation maps and the diff --git a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir --- a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir @@ -142,3 +142,21 @@ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK: return %[[ARG1]], %[[ARG2]] + +// ----- + +func.func @collapse_shape() -> index { + %c0 = arith.constant 0 : index + %c7 = arith.constant 7 : index + %c1_i16 = arith.constant 1 : i16 + %generated = tensor.generate %c7 { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %c1_i16 : i16 + } : tensor + %collapsed = tensor.collapse_shape %generated [[0, 1]] : tensor into tensor + %d0 = tensor.dim %collapsed, %c0 : tensor + return %d0 : index +} +// CHECK-LABEL: func @collapse_shape( +// CHECK: %[[c154:.*]] = arith.constant 154 : index +// CHECK: return %[[c154]]