Files
WQ_GUI/new/app/api/_smoke_test_train.py

202 lines
7.0 KiB
Python

"""
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
"""
import sys
import tempfile
from pathlib import Path
import numpy as np
import pandas as pd
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
# 当前文件位于 new/app/api/_smoke_test_train.py
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from app.api.modeling import (
_get_model_pipeline,
_load_train_df,
_resolve_feature_start,
_run_train_sync,
_MODEL_CLASS_REGISTRY,
)
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
rng = np.random.default_rng(seed)
lat = rng.uniform(20, 25, n_samples)
lon = rng.uniform(110, 115, n_samples)
target = rng.uniform(0, 50, n_samples)
lat2 = rng.uniform(0, 1, n_samples) # 元数据
lon2 = rng.uniform(0, 1, n_samples) # 元数据
feats = rng.normal(0, 1, (n_samples, n_features))
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
feats[:, 0] += target / 10
feats[:, 1] += target / 20
feats[:, 2] -= target / 15
df = pd.DataFrame({
"lat": lat,
"lon": lon,
"Chl-a": target,
"lat2": lat2,
"lon2": lon2,
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
})
tmp = Path(tempfile.mkdtemp()) / "train.csv"
df.to_csv(tmp, index=False)
return tmp
def test_load_train_df():
print("== test_load_train_df ==")
p = make_synthetic_csv(n_samples=50)
df = _load_train_df(str(p))
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
print(" PASS")
def test_resolve_feature_start_int_and_str():
print("== test_resolve_feature_start (int + str) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
idx_int = _resolve_feature_start(df, 5)
idx_str = _resolve_feature_start(df, "feat_0")
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
print(" PASS")
def test_resolve_feature_start_str_miss():
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
p = make_synthetic_csv()
df = _load_train_df(str(p))
try:
_resolve_feature_start(df, "not_exist")
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_get_model_pipeline_all_types():
print("== test_get_model_pipeline (5 种 model_type) ==")
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
p = _get_model_pipeline(mt, {})
assert len(p.steps) == 2
assert p.steps[0][0] == "scaler"
assert p.steps[1][0] == "model"
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
print(" PASS")
def test_get_model_pipeline_bad_type():
print("== test_get_model_pipeline (坏 model_type) ==")
try:
_get_model_pipeline("XGBoost", {})
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_rf_end_to_end():
print("== test_run_train_sync (RF 端到端) ==")
p = make_synthetic_csv(n_samples=200)
out_dir = Path(tempfile.mkdtemp())
out_path = out_dir / "model.joblib"
import time
t0 = time.time()
metadata = _run_train_sync(
model_type="RF",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
output_model_path=out_path,
)
dt = time.time() - t0
assert out_path.exists(), f"joblib 未落盘: {out_path}"
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
print(f" 耗时 {dt:.2f}s")
# 回测: 加载 joblib 再 predict
import joblib
saved = joblib.load(out_path)
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
assert hasattr(saved["model"], "predict")
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
print(f" joblib 加载 OK, 含 'model''metadata' 双 key")
print(" PASS")
def test_run_train_sync_linearregression_fast():
print("== test_run_train_sync (LinearRegression 快速路径) ==")
p = make_synthetic_csv(n_samples=150)
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
metadata = _run_train_sync(
model_type="LinearRegression",
target="Chl-a",
train_data_path=str(p),
feature_start=5,
params={},
output_model_path=out_path,
)
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
print(" PASS")
def test_run_train_sync_bad_csv():
print("== test_run_train_sync (CSV 不存在) ==")
try:
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛异常")
except (FileNotFoundError, ValueError) as e:
print(f" 正确抛 {type(e).__name__}: {e}")
print(" PASS")
def test_run_train_sync_bad_target():
print("== test_run_train_sync (target 列不存在) ==")
p = make_synthetic_csv()
try:
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
print(" FAIL: 应抛 ValueError")
except ValueError as e:
print(f" 正确抛 ValueError: {e}")
print(" PASS")
def test_run_train_sync_str_feature_start():
print("== test_run_train_sync (feature_start 用列名) ==")
p = make_synthetic_csv()
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
assert metadata["feature_start"] == "feat_0"
assert metadata["n_features"] == 8
assert metadata["feature_columns"][0] == "feat_0"
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
print(" PASS")
if __name__ == "__main__":
test_load_train_df()
test_resolve_feature_start_int_and_str()
test_resolve_feature_start_str_miss()
test_get_model_pipeline_all_types()
test_get_model_pipeline_bad_type()
test_run_train_sync_rf_end_to_end()
test_run_train_sync_linearregression_fast()
test_run_train_sync_bad_csv()
test_run_train_sync_bad_target()
test_run_train_sync_str_feature_start()
print("\n>>> ALL SMOKE TESTS PASSED")