from __future__ import annotations
import uuid
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from typing import TYPE_CHECKING, Any
from distributed.compatibility import PeriodicCallback
if TYPE_CHECKING:
# Optional runtime dependencies
import pandas as pd
# Circular dependencies
from distributed.client import Client
from distributed.scheduler import Scheduler
[docs]class MemorySampler:
"""Sample cluster-wide memory usage every <interval> seconds.
**Usage**
.. code-block:: python
client = Client()
ms = MemorySampler()
with ms.sample("run 1"):
<run first workflow>
with ms.sample("run 2"):
<run second workflow>
...
ms.plot()
or with an asynchronous client:
.. code-block:: python
client = await Client(asynchronous=True)
ms = MemorySampler()
async with ms.sample("run 1"):
<run first workflow>
async with ms.sample("run 2"):
<run second workflow>
...
ms.plot()
"""
samples: dict[str, list[tuple[float, int]]]
def __init__(self):
self.samples = {}
[docs] def sample(
self,
label: str | None = None,
*,
client: Client | None = None,
measure: str = "process",
interval: float = 0.5,
) -> Any:
"""Context manager that records memory usage in the cluster.
This is synchronous if the client is synchronous and
asynchronous if the client is asynchronous.
The samples are recorded in ``self.samples[<label>]``.
Parameters
==========
label: str, optional
Tag to record the samples under in the self.samples dict.
Default: automatically generate a random label
client: Client, optional
client used to connect to the scheduler.
Default: use the global client
measure: str, optional
One of the measures from :class:`distributed.scheduler.MemoryState`.
Default: sample process memory
interval: float, optional
sampling interval, in seconds.
Default: 0.5
"""
if not client:
from distributed.client import get_client
client = get_client()
if client.asynchronous:
return self._sample_async(label, client, measure, interval)
else:
return self._sample_sync(label, client, measure, interval)
@contextmanager
def _sample_sync(
self, label: str | None, client: Client, measure: str, interval: float
) -> Iterator[None]:
key = client.sync(
client.scheduler.memory_sampler_start,
client=client.id,
measure=measure,
interval=interval,
)
try:
yield
finally:
samples = client.sync(client.scheduler.memory_sampler_stop, key=key)
self.samples[label or key] = samples
@asynccontextmanager
async def _sample_async(
self, label: str | None, client: Client, measure: str, interval: float
) -> AsyncIterator[None]:
key = await client.scheduler.memory_sampler_start(
client=client.id, measure=measure, interval=interval
)
try:
yield
finally:
samples = await client.scheduler.memory_sampler_stop(key=key)
self.samples[label or key] = samples
[docs] def to_pandas(self, *, align: bool = False) -> pd.DataFrame:
"""Return the data series as a pandas.Dataframe.
Parameters
==========
align : bool, optional
If True, change the absolute timestamps into time deltas from the first
sample of each series, so that different series can be visualized side by
side. If False (the default), use absolute timestamps.
"""
import pandas as pd
ss = {}
for label, s_list in self.samples.items():
assert s_list # There's always at least one sample
s = pd.DataFrame(s_list).set_index(0)[1]
s.index = pd.to_datetime(s.index, unit="s")
s.name = label
if align:
# convert datetime to timedelta from the first sample
s.index -= s.index[0]
ss[label] = s[~s.index.duplicated()] # type: ignore[attr-defined]
df = pd.DataFrame(ss)
if len(ss) > 1:
# Forward-fill NaNs in the middle of a series created either by overlapping
# sampling time range or by align=True. Do not ffill series beyond their
# last sample.
df = df.ffill().where(~pd.isna(df.bfill()))
return df
[docs] def plot(self, *, align: bool = False, **kwargs: Any) -> Any:
"""Plot data series collected so far
Parameters
==========
align : bool (optional)
See :meth:`~distributed.diagnostics.MemorySampler.to_pandas`
kwargs
Passed verbatim to :meth:`pandas.DataFrame.plot`
Returns
=======
Output of :meth:`pandas.DataFrame.plot`
"""
df = self.to_pandas(align=align)
resampled = df.resample("1s").nearest() / 2**30
# If resampling collapses data onto one point, we'll run into
# https://stackoverflow.com/questions/58322744/matplotlib-userwarning-attempting-to-set-identical-left-right-737342-0
# This should only happen in tests since users typically sample for more
# than a second
if len(resampled) == 1:
resampled = df.resample("1ms").nearest() / 2**30
return resampled.plot(
xlabel="time",
ylabel="Cluster memory (GiB)",
**kwargs,
)
class MemorySamplerExtension:
"""Scheduler extension - server side of MemorySampler"""
scheduler: Scheduler
samples: dict[str, list[tuple[float, int]]]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.scheduler.handlers["memory_sampler_start"] = self.start
self.scheduler.handlers["memory_sampler_stop"] = self.stop
self.samples = {}
def start(self, client: str, measure: str, interval: float) -> str:
"""Start periodically sampling memory"""
assert not measure.startswith("_")
assert isinstance(getattr(self.scheduler.memory, measure), int)
key = str(uuid.uuid4())
self.samples[key] = []
def sample():
if client in self.scheduler.clients:
ts = datetime.now().timestamp()
nbytes = getattr(self.scheduler.memory, measure)
self.samples[key].append((ts, nbytes))
else:
self.stop(key)
pc = PeriodicCallback(sample, interval * 1000)
self.scheduler.periodic_callbacks["MemorySampler-" + key] = pc
pc.start()
# Immediately collect the first sample; this also ensures there's always at
# least one sample
sample()
return key
def stop(self, key: str) -> list[tuple[float, int]]:
"""Stop sampling and return the samples"""
pc = self.scheduler.periodic_callbacks.pop("MemorySampler-" + key, None)
if pc is not None: # Race condition with scheduler shutdown
pc.stop()
return self.samples.pop(key)