diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -644,10 +644,18 @@ for (OpOperand *opOperand : getDpsInitOperands()) { SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { - if (checkDimExpr.visit(shapeExprs[pos])) - shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim)); - else - shapes.push_back(allResultDimValues[pos]); + auto shapedType = opOperand->get().getType().cast(); + if (!shapedType.isDynamicDim(dim)) { + // Static dim: Return IntegerAttr. + shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim))); + } else { + // Dynamic dim: Return Value. + OpFoldResult ofr = + checkDimExpr.visit(shapeExprs[pos]) + ? createOrFoldDimOp(b, loc, opOperand->get(), dim) + : allResultDimValues[pos]; + shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); + } pos++; } reifiedReturnShapes.emplace_back(std::move(shapes)); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2205,8 +2205,13 @@ OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); for (auto dim : llvm::seq(0, getType().getRank())) { - reifiedReturnShapes[0][dim] = - builder.createOrFold(getLoc(), getDest(), dim); + if (getType().isDynamicDim(dim)) { + reifiedReturnShapes[0][dim] = + builder.createOrFold(getLoc(), getDest(), dim); + } else { + reifiedReturnShapes[0][dim] = + builder.getIndexAttr(getType().getDimSize(dim)); + } } return success(); } @@ -3154,9 +3159,15 @@ "applies to only pack or unpack operations"); int64_t destRank = op.getDestRank(); reifiedReturnShapes.resize(1, SmallVector(destRank)); + ShapedType resultType = op.getResult().getType().template cast(); for (auto dim : llvm::seq(0, destRank)) { - reifiedReturnShapes[0][dim] = - builder.createOrFold(op.getLoc(), op.getDest(), dim); + if (resultType.isDynamicDim(dim)) { + reifiedReturnShapes[0][dim] = + builder.createOrFold(op.getLoc(), op.getDest(), dim); + } else { + reifiedReturnShapes[0][dim] = + builder.getIndexAttr(resultType.getDimSize(dim)); + } } return success(); }