#!/usr/bin/env python3
"""
ceph_perf_histogram -- render `ceph daemon osd.X perf histogram dump` JSON
as a human-readable size (and optionally latency) distribution.

Examples:
    ceph daemon osd.0 perf histogram dump | ceph_perf_histogram
    ceph_perf_histogram --latency < dump.json
    ceph_perf_histogram before.json after.json          # rate over interval

The default output is a per-histogram size distribution with a marker on the
bucket boundary at 32 KiB (the default bluestore_prefer_deferred_size_hdd --
writes at or below it take the deferred/SSD-WAL path; writes above it hit
the slow device before ack).
"""

import argparse
import json
import sys


HISTS = [
    ("op_w_latency_in_bytes_histogram",  "Client writes"),
    ("op_rw_latency_in_bytes_histogram", "Client read-modify-write"),
    ("op_r_latency_out_bytes_histogram", "Client reads"),
]

DEFERRED_HDD = 32 * 1024


def _fmt_scaled(v, suf):
    if v == int(v):
        return f"{int(v)}{suf}"
    return f"{v:.1f}{suf}"


def fmt_bytes(n):
    if n is None:
        return "inf"
    if n < 1024:
        return f"{n}B"
    for div, suf in ((1024 ** 4, "T"), (1024 ** 3, "G"), (1024 ** 2, "M"), (1024, "K")):
        if n >= div:
            return _fmt_scaled(n / div, suf)
    return f"{n}B"


def make_fmt_time(ns_per_unit):
    """Build a time formatter for axis values in `ns_per_unit`-nanosecond units.

    The axis JSON often labels the latency axis 'Latency (usec)', but recent
    Ceph stores the values in nanoseconds (quant_size 100000 ns = 100 us), so
    we detect from quant_size and convert.
    """
    def fmt(n):
        if n is None:
            return "inf"
        ns = n * ns_per_unit
        if ns < 1000:
            return f"{int(ns)}ns"
        if ns < 1_000_000:
            return _fmt_scaled(ns / 1000, "us")
        if ns < 1_000_000_000:
            return _fmt_scaled(ns / 1_000_000, "ms")
        return _fmt_scaled(ns / 1_000_000_000, "s")
    return fmt


def detect_time_scale(axis):
    """Return ns-per-unit for the latency axis. 1 = values are ns, 1000 = values are us."""
    qs = axis.get("quant_size", 1)
    # Heuristic: Ceph's internal clock granularity is 100us. If quant_size
    # is on that scale (100), values are us. If it's 1000x bigger (100000),
    # values are ns despite any '(usec)' label in the axis name.
    return 1 if qs >= 1000 else 1000


def range_label(r, fmt):
    lo, hi = r.get("min"), r.get("max")
    # Buckets are integer-inclusive [lo, hi]; format as half-open [lo, hi+1)
    # so the upper bound lands on the next clean power-of-2 boundary.
    if lo is None and hi is not None:
        return f"<{fmt(hi + 1)}"
    if hi is None and lo is not None:
        return f">={fmt(lo)}"
    if lo == hi:
        return fmt(lo)
    return f"[{fmt(lo)},{fmt(hi + 1)})"


def find_axis(axes, keyword):
    for i, a in enumerate(axes):
        if keyword in a["name"].lower():
            return i, a
    return None, None


def collapse(values, keep_axis):
    if keep_axis == 0:
        return [sum(row) for row in values]
    n_cols = len(values[0]) if values else 0
    return [sum(row[i] for row in values) for i in range(n_cols)]


def diff_values(after, before):
    return [
        [a - b for a, b in zip(arow, brow)]
        for arow, brow in zip(after, before)
    ]


def print_dist(title, ranges, counts, fmt, threshold=None):
    total = sum(counts)
    print(f"\n{title}  (total: {total:,})")
    if total <= 0:
        print("  (no samples)")
        return
    width = max(len(range_label(r, fmt)) for r in ranges)
    bar_max = max(counts) or 1
    for r, c in zip(ranges, counts):
        if c == 0:
            continue
        pct = 100.0 * c / total
        bar_len = max(0, int(40 * c / bar_max))
        marker = ""
        if threshold is not None:
            hi = r.get("max")
            if hi is not None and hi + 1 == threshold:
                marker = f"  <-- {fmt(threshold)} boundary"
        print(f"  {range_label(r, fmt):>{width}}  {c:>10,}  {pct:5.1f}%  {'#' * bar_len}{marker}")


def render(data, show_latency):
    osd = data.get("osd", data)
    for key, label in HISTS:
        h = osd.get(key)
        if not h:
            continue
        axes, values = h["axes"], h["values"]
        size_idx, size_axis = find_axis(axes, "size")
        if size_idx is None:
            continue
        size_counts = collapse(values, keep_axis=size_idx)
        write_like = "_w_" in key or "_rw_" in key
        print_dist(
            f"{label} -- size distribution  [{key}]",
            size_axis["ranges"],
            size_counts,
            fmt_bytes,
            threshold=DEFERRED_HDD if write_like else None,
        )
        if show_latency:
            lat_idx, lat_axis = find_axis(axes, "latency")
            if lat_idx is None:
                continue
            lat_counts = collapse(values, keep_axis=lat_idx)
            fmt_time = make_fmt_time(detect_time_scale(lat_axis))
            print_dist(
                f"{label} -- latency distribution",
                lat_axis["ranges"],
                lat_counts,
                fmt_time,
            )


def load(path_or_stdin):
    if path_or_stdin == "-":
        return json.load(sys.stdin)
    with open(path_or_stdin) as f:
        return json.load(f)


def main():
    p = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    p.add_argument("before", nargs="?", help="earlier snapshot (for diff mode)")
    p.add_argument("after", nargs="?", help="later snapshot (for diff mode)")
    p.add_argument("--latency", action="store_true", help="also show latency distribution")
    args = p.parse_args()

    if args.before and args.after:
        before = load(args.before)
        after = load(args.after)
        b_osd = before.get("osd", before)
        a_osd = after.get("osd", after)
        diffed = {"osd": {}}
        for key, _ in HISTS:
            if key in a_osd and key in b_osd:
                a_h, b_h = a_osd[key], b_osd[key]
                diffed["osd"][key] = {
                    "axes": a_h["axes"],
                    "values": diff_values(a_h["values"], b_h["values"]),
                }
        print(f"# diff: {args.before} -> {args.after}")
        render(diffed, args.latency)
    elif args.before and not args.after:
        render(load(args.before), args.latency)
    else:
        render(load("-"), args.latency)


if __name__ == "__main__":
    main()
