118 lines
3.8 KiB
Python
118 lines
3.8 KiB
Python
# python
|
|
"""
|
|
plot_metrics.py
|
|
|
|
Usage examples:
|
|
python plot_metrics.py --csv metrics_combined.csv
|
|
python plot_metrics.py --csv metrics_combined.csv --time-col tick --out myplot.png
|
|
python plot_metrics.py --csv metrics_combined.csv --cols entity_counts_cells,entity_counts_food
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
import sys
|
|
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
COMMON_TIME_COLS = ["tick", "time", "step", "tick_number", "t"]
|
|
DEFAULT_PLOT_COLS = ["entity_counts_cells", "entity_counts_food"]
|
|
|
|
|
|
def find_column(df: pd.DataFrame, candidates):
|
|
# return the first matching column name from candidates (case-insensitive, substring match)
|
|
cols = {c.lower(): c for c in df.columns}
|
|
for cand in candidates:
|
|
cand_l = cand.lower()
|
|
# exact match
|
|
if cand_l in cols:
|
|
return cols[cand_l]
|
|
# substring match
|
|
for k, orig in cols.items():
|
|
if cand_l in k:
|
|
return orig
|
|
return None
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser(description="Plot entity counts over time from a metrics CSV")
|
|
p.add_argument("--csv", "-c", type=str, default="metrics_combined.csv", help="Path to CSV file")
|
|
p.add_argument("--time-col", "-t", type=str, default=None, help="Name of the time column (optional)")
|
|
p.add_argument("--cols", type=str, default=None, help="Comma-separated column names to plot (default: entity_counts_cells,entity_counts_food)")
|
|
p.add_argument("--out", "-o", type=str, default="metrics_counts_plot.png", help="Output image path")
|
|
args = p.parse_args()
|
|
|
|
csv_path = Path(args.csv)
|
|
if not csv_path.exists():
|
|
print(f"CSV not found: {csv_path}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
df = pd.read_csv(csv_path)
|
|
|
|
# detect time column
|
|
time_col = None
|
|
if args.time_col:
|
|
if args.time_col in df.columns:
|
|
time_col = args.time_col
|
|
else:
|
|
print(f"Specified time column `{args.time_col}` not found in CSV columns.", file=sys.stderr)
|
|
sys.exit(1)
|
|
else:
|
|
time_col = find_column(df, COMMON_TIME_COLS)
|
|
if time_col is None:
|
|
print("Could not auto-detect a time column. Provide one with `--time-col`.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# determine plot columns
|
|
if args.cols:
|
|
cols = [c.strip() for c in args.cols.split(",") if c.strip()]
|
|
missing = [c for c in cols if c not in df.columns]
|
|
if missing:
|
|
print(f"Columns not found in CSV: {missing}", file=sys.stderr)
|
|
sys.exit(1)
|
|
else:
|
|
cols = []
|
|
for want in DEFAULT_PLOT_COLS:
|
|
found = find_column(df, [want])
|
|
if found:
|
|
cols.append(found)
|
|
if not cols:
|
|
print(f"Could not find default columns `{DEFAULT_PLOT_COLS}`. Provide `--cols` explicitly.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# prepare data
|
|
df = df[[time_col] + cols].copy()
|
|
df[time_col] = pd.to_numeric(df[time_col], errors="coerce")
|
|
for c in cols:
|
|
df[c] = pd.to_numeric(df[c], errors="coerce")
|
|
df = df.dropna(subset=[time_col])
|
|
if df.empty:
|
|
print("No numeric time values found after cleaning.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
df = df.sort_values(by=time_col)
|
|
|
|
# plot
|
|
plt.figure(figsize=(10, 5))
|
|
for c in cols:
|
|
plt.plot(df[time_col], df[c], label=c, linewidth=2)
|
|
plt.xlabel(time_col)
|
|
plt.ylabel("Count")
|
|
plt.title("Entity counts over time")
|
|
plt.grid(True, linestyle="--", alpha=0.4)
|
|
plt.legend()
|
|
plt.tight_layout()
|
|
|
|
out_path = Path(args.out)
|
|
plt.savefig(out_path, dpi=150)
|
|
print(f"Wrote plot to `{out_path}`")
|
|
# also show interactively if running in an environment with a display
|
|
try:
|
|
plt.show()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |