diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -947,9 +947,15 @@ } void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - Type type = arg.getType().isa() - ? (Type)getExtentTensorType(builder.getContext()) - : (Type)builder.getType(); + Type type; + Type indexTy = builder.getIndexType(); + if (auto shapedTy = arg.getType().dyn_cast()) { + type = RankedTensorType::get( + {shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize}, + indexTy); + } else { + type = builder.getType(); + } return ShapeOfOp::build(builder, result, type, arg); }