Skip to content

Commit 70b1845

Browse files
committed
Add testing and linting workflows
1 parent 5e34e4d commit 70b1845

16 files changed

Lines changed: 486 additions & 165 deletions
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: Pytest and Pylint
2+
3+
on:
4+
push:
5+
branches: [master]
6+
pull_request:
7+
branches: [master]
8+
9+
jobs:
10+
test-and-lint:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.12"]
15+
steps:
16+
- name: Checkout repository
17+
uses: actions/checkout@v6
18+
19+
- name: Set up Python ${{ matrix.python-version }}
20+
uses: actions/setup-python@v6
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
24+
- name: Cache pip
25+
uses: actions/cache@v5
26+
with:
27+
path: ~/.cache/pip
28+
key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }}
29+
restore-keys: ${{ runner.os }}-pip-
30+
31+
- name: Install dependencies
32+
run: |
33+
python -m pip install --upgrade pip
34+
pip install .
35+
pip install pytest pylint
36+
37+
- name: Run pytest
38+
run: python -m pytest -q
39+
40+
- name: Run pylint
41+
run: pylint $(git ls-files '*.py')

main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
def main() -> None:
19+
"""Demo script: create plant, design controller, simulate and save plot."""
1920
plant = InvertedPendulum()
2021
plant = add_delay(plant, T_input=0.5, T_output=0.5)
2122
plant = discretize(plant, dt=0.01)

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ dependencies = ["control>=0.9", "numpy>=1.24", "matplotlib>=3.7"]
1414

1515
[tool.setuptools.packages.find]
1616
where = ["src"]
17+
18+
[tool.pytest.ini_options]
19+
testpaths = ["tests"]
20+

src/control_toolbox/controller.py

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Nbar — precompensator (n_inputs x n_tracked)
1616
None if closed-loop DC gain is singular (regulates to zero only)
1717
"""
18+
# pylint: disable=invalid-name
1819

1920
import warnings
2021
from dataclasses import dataclass
@@ -40,40 +41,18 @@ class ControllerResult:
4041
tracked: dict[str, float]
4142

4243

43-
def build_controller(
44-
plant: StateSpace,
45-
track: dict[str | int, float | None] | None = None,
46-
p: float = 1.0,
47-
Q: np.ndarray | None = None,
48-
R: np.ndarray | None = None,
49-
) -> ControllerResult:
50-
"""LQR controller for a StateSpace plant. Supports both continuous-
51-
and discrete-time systems (uses ``lqr`` or ``dlqr`` depending on
52-
``plant.dt``). The returned closed-loop model preserves the sampling
53-
period so that downstream simulation routines behave correctly.
44+
def _parse_tracking(
45+
plant: StateSpace, track: dict | None
46+
) -> tuple[list[str], dict[str, float], np.ndarray]:
47+
"""Interpret the ``track`` argument.
5448
55-
Args:
56-
plant: Any StateSpace instance from systems.py.
57-
track: Dict of {output_name_or_index: target_value | None}.
58-
None value = don't care. Omitted outputs = don't care.
59-
track=None regulates all outputs to zero.
60-
p: Output-error weight (Q = p * C_track' @ C_track). Ignored if Q given.
61-
Q: State-cost matrix (n_states x n_states). Overrides p.
62-
R: Input-cost matrix (n_inputs x n_inputs). Defaults to identity.
63-
64-
Returns:
65-
ControllerResult containing ``K``, ``Nbar`` and ``sys_cl`` (plus
66-
the ``tracked`` dictionary used to compute the controller). If
67-
``plant`` was discrete-time, the returned ``sys_cl`` has its
68-
``dt`` field set accordingly and ``K`` is computed via ``dlqr``. The
69-
precompensator ``Nbar`` is computed with the appropriate steady-
70-
state formula for continuous or discrete dynamics.
49+
Returns ``(labels, tracked, C_track)`` where ``labels`` is the list of
50+
plant output names, ``tracked`` maps those names to non-None reference
51+
values, and ``C_track`` is the corresponding rows of the plant output
52+
matrix. This isolates a bulky dictionary comprehension from
53+
``build_controller``.
7154
"""
72-
A = np.array(plant.A, dtype=float)
73-
B = np.array(plant.B, dtype=float)
7455
C = np.array(plant.C, dtype=float)
75-
n_in = B.shape[1]
76-
7756
labels = (
7857
list(plant.output_labels)
7958
if plant.output_labels
@@ -87,47 +66,67 @@ def build_controller(
8766
}
8867
idx = [labels.index(lbl) for lbl in tracked]
8968
C_track = C[idx, :]
69+
return labels, tracked, C_track
9070

91-
# LQR / DLQR depending on plant type
92-
Q = p * C_track.T @ C_track if Q is None else np.array(Q, dtype=float)
93-
R = np.eye(n_in) if R is None else np.array(R, dtype=float)
94-
is_discrete = plant.dt not in (0, None)
95-
if is_discrete:
96-
K, _, _ = dlqr(A, B, Q, R)
97-
else:
98-
K, _, _ = lqr(A, B, Q, R)
9971

100-
# Nbar: solves C_track @ DC-gain(A_cl,B) @ Nbar = -I
101-
# For continuous DC gain = -C_track @ inv(A_cl) @ B
102-
# For discrete DC gain = C_track @ inv(I - A_cl) @ B
72+
def _compute_Nbar(A, B, C_track, K, is_discrete):
73+
"""Calculate the precompensator Nbar from closed-loop data."""
10374
A_cl = A - B @ K
10475
if is_discrete:
105-
# discrete-time steady-state: x = inv(I - A_cl) B Nbar r
106-
# require C_track @ inv(I - A_cl) B Nbar = I
107-
# use pseudo-inverse in case (I-A_cl) is singular
10876
M = C_track @ np.linalg.pinv(np.eye(A_cl.shape[0]) - A_cl) @ B
10977
else:
110-
# continuous-time steady-state: x = -inv(A_cl) B Nbar r
111-
# require -C_track @ inv(A_cl) B Nbar = I
112-
# use pseudo-inverse in case A_cl is singular
11378
M = -C_track @ np.linalg.pinv(A_cl) @ B
114-
Nbar = np.linalg.pinv(M) if np.all(np.isfinite(M)) else None
115-
if Nbar is not None and len(idx) != n_in:
79+
return np.linalg.pinv(M) if np.all(np.isfinite(M)) else None
80+
81+
82+
def build_controller(
83+
plant: StateSpace,
84+
track: dict[str | int, float | None] | None = None,
85+
p: float = 1.0,
86+
Q: np.ndarray | None = None,
87+
R: np.ndarray | None = None,
88+
) -> ControllerResult:
89+
"""LQR controller for a StateSpace plant. Supports both continuous-
90+
and discrete-time systems (uses ``lqr`` or ``dlqr`` depending on
91+
``plant.dt``). The returned closed-loop model preserves the sampling
92+
period so that downstream simulation routines behave correctly.
93+
"""
94+
A = np.array(plant.A, dtype=float)
95+
B = np.array(plant.B, dtype=float)
96+
97+
labels, tracked, C_track = _parse_tracking(plant, track)
98+
99+
# LQR / DLQR depending on plant type; inline Q and R defaults
100+
if plant.dt not in (0, None):
101+
K, _, _ = dlqr(
102+
A,
103+
B,
104+
p * C_track.T @ C_track if Q is None else np.array(Q, dtype=float),
105+
np.eye(B.shape[1]) if R is None else np.array(R, dtype=float),
106+
)
107+
discrete = True
108+
else:
109+
K, _, _ = lqr(
110+
A,
111+
B,
112+
p * C_track.T @ C_track if Q is None else np.array(Q, dtype=float),
113+
np.eye(B.shape[1]) if R is None else np.array(R, dtype=float),
114+
)
115+
discrete = False
116+
117+
Nbar = _compute_Nbar(A, B, C_track, K, discrete)
118+
if Nbar is not None and len(tracked) != B.shape[1]:
116119
warnings.warn(
117120
"n_tracked != n_inputs: Nbar is a pseudoinverse; "
118121
"exact tracking not guaranteed.",
119122
stacklevel=2,
120123
)
121124

122-
n_tracked = len(idx)
123-
B_cl = B @ (Nbar if Nbar is not None else np.zeros((n_in, n_tracked)))
124-
D_cl = np.zeros((C.shape[0], n_tracked))
125-
# preserve sampling period if discrete
126125
sys_cl = ss(
127-
A_cl,
128-
B_cl,
129-
C,
130-
D_cl,
126+
A - B @ K,
127+
B @ (Nbar if Nbar is not None else np.zeros((B.shape[1], len(tracked)))),
128+
np.array(plant.C, dtype=float),
129+
np.zeros((np.array(plant.C, dtype=float).shape[0], len(tracked))),
131130
inputs=[f"r_{lbl}" for lbl in tracked],
132131
outputs=labels,
133132
dt=plant.dt,

src/control_toolbox/metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def compute_metrics(resp: SimulationResult, settling_band: float = 0.02) -> Metr
6363
- If the reference is zero the percentage metrics are computed
6464
against the absolute deviation instead of a percentage.
6565
"""
66+
# pylint: disable=too-many-locals
6667
overshoot_d = {}
6768
settling_time_d = {}
6869
steady_state_error_d = {}

src/control_toolbox/plot.py

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,65 @@
1919
_INPUT_COLOURS = ["#16a34a", "#65a30d", "#0d9488", "#ca8a04"]
2020

2121

22+
def _annotate_metrics(ax, t, y, ref, metrics, colour): # pylint: disable=too-many-arguments,too-many-positional-arguments
23+
"""Add settling band, time and overshoot annotations to a single axis."""
24+
band = 0.02 * (abs(ref) if abs(ref) > 1e-12 else 1.0)
25+
ax.axhspan(ref - band, ref + band, alpha=0.08, color=colour, label="±2% band")
26+
27+
st = metrics.settling_time.get(ref)
28+
if st is not None:
29+
ax.axvline(st, color=colour, lw=1, ls=":", alpha=0.7)
30+
ax.annotate(
31+
f" t_s={st:.2f}s",
32+
xy=(st, ref),
33+
fontsize=8,
34+
color=colour,
35+
va="center",
36+
)
37+
38+
os_pct = metrics.overshoot.get(ref, 0.0)
39+
if os_pct > 0.1:
40+
peak_idx = np.argmax(y) if ref >= 0 else np.argmin(y)
41+
ax.annotate(
42+
f"OS={os_pct:.1f}%",
43+
xy=(t[peak_idx], y[peak_idx]),
44+
xytext=(0, 8),
45+
textcoords="offset points",
46+
fontsize=8,
47+
color=colour,
48+
ha="center",
49+
)
50+
51+
52+
def _plot_output(ax, resp, i, lbl, metrics):
53+
"""Draw a single output subplot including reference/metrics."""
54+
colour = _OUTPUT_COLOURS[i % len(_OUTPUT_COLOURS)]
55+
y = resp.y[i]
56+
ax.plot(resp.t, y, color=colour, lw=2, label=lbl)
57+
if lbl in resp.tracked:
58+
ref = resp.tracked[lbl]
59+
ax.axhline(
60+
ref, color=colour, lw=1, ls="--", alpha=0.5, label=f"reference ({ref})"
61+
)
62+
if metrics is not None:
63+
_annotate_metrics(ax, resp.t, y, ref, metrics, colour)
64+
ax.set_ylabel(lbl)
65+
ax.legend(loc="upper right", fontsize=8)
66+
ax.grid(True, alpha=0.3)
67+
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.3g"))
68+
69+
70+
def _plot_input(ax, resp, j, lbl):
71+
"""Draw a single input subplot."""
72+
colour = _INPUT_COLOURS[j % len(_INPUT_COLOURS)]
73+
ax.plot(resp.t, resp.u[j], color=colour, lw=2, label=f"u: {lbl}")
74+
ax.axhline(0, color="k", lw=0.5, alpha=0.3)
75+
ax.set_ylabel(f"u ({lbl})")
76+
ax.legend(loc="upper right", fontsize=8)
77+
ax.grid(True, alpha=0.3)
78+
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.3g"))
79+
80+
2281
def plot_response(
2382
resp: SimulationResult,
2483
metrics: Metrics | None = None,
@@ -58,65 +117,11 @@ def plot_response(
58117
title = "Closed-loop response: " + ", ".join(parts)
59118
fig.suptitle(title, fontsize=12)
60119

61-
# --- Output subplots -----------------------------------------------------
62120
for i, lbl in enumerate(resp.output_labels):
63-
ax = axes[i]
64-
colour = _OUTPUT_COLOURS[i % len(_OUTPUT_COLOURS)]
65-
y = resp.y[i]
66-
67-
ax.plot(resp.t, y, color=colour, lw=2, label=lbl)
68-
69-
if lbl in resp.tracked:
70-
ref = resp.tracked[lbl]
71-
ax.axhline(
72-
ref, color=colour, lw=1, ls="--", alpha=0.5, label=f"reference ({ref})"
73-
)
74-
75-
if metrics is not None:
76-
band = 0.02 * (abs(ref) if abs(ref) > 1e-12 else 1.0)
77-
ax.axhspan(
78-
ref - band, ref + band, alpha=0.08, color=colour, label="±2% band"
79-
)
80-
81-
st = metrics.settling_time.get(lbl)
82-
if st is not None:
83-
ax.axvline(st, color=colour, lw=1, ls=":", alpha=0.7)
84-
ax.annotate(
85-
f" t_s={st:.2f}s",
86-
xy=(st, ref),
87-
fontsize=8,
88-
color=colour,
89-
va="center",
90-
)
91-
92-
os_pct = metrics.overshoot.get(lbl, 0.0)
93-
if os_pct > 0.1:
94-
peak_idx = np.argmax(y) if ref >= 0 else np.argmin(y)
95-
ax.annotate(
96-
f"OS={os_pct:.1f}%",
97-
xy=(resp.t[peak_idx], y[peak_idx]),
98-
xytext=(0, 8),
99-
textcoords="offset points",
100-
fontsize=8,
101-
color=colour,
102-
ha="center",
103-
)
104-
105-
ax.set_ylabel(lbl)
106-
ax.legend(loc="upper right", fontsize=8)
107-
ax.grid(True, alpha=0.3)
108-
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.3g"))
109-
110-
# --- Input subplots ------------------------------------------------------
121+
_plot_output(axes[i], resp, i, lbl, metrics)
122+
111123
for j, lbl in enumerate(resp.input_labels):
112-
ax = axes[n_out + j]
113-
colour = _INPUT_COLOURS[j % len(_INPUT_COLOURS)]
114-
ax.plot(resp.t, resp.u[j], color=colour, lw=2, label=f"u: {lbl}")
115-
ax.axhline(0, color="k", lw=0.5, alpha=0.3)
116-
ax.set_ylabel(f"u ({lbl})")
117-
ax.legend(loc="upper right", fontsize=8)
118-
ax.grid(True, alpha=0.3)
119-
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.3g"))
124+
_plot_input(axes[n_out + j], resp, j, lbl)
120125

121126
axes[-1].set_xlabel("Time (s)")
122127
plt.tight_layout()

src/control_toolbox/preprocess.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ctrl = build_controller(plant, ...)
1111
K, Nbar, sys_cl = ctrl.K, ctrl.Nbar, ctrl.sys_cl
1212
"""
13+
# pylint: disable=invalid-name
1314

1415
import numpy as np
1516
from control import StateSpace, c2d, pade, series, ss, tf

0 commit comments

Comments
 (0)