202 lines
7.0 KiB
Python
202 lines
7.0 KiB
Python
"""
|
|
冒烟测试 _run_train_sync: 用合成数据走通真实训练管线。
|
|
不依赖 FastAPI / xarray / dask, 只验训练 + 持久化 + 回测。
|
|
"""
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
# 绕过 main.py 触发 app 包导入(只导入 modeling 模块)
|
|
# 当前文件位于 new/app/api/_smoke_test_train.py
|
|
# app 包在 new/app/__init__.py, 故 new/ 必须在 sys.path 上
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
|
|
from app.api.modeling import (
|
|
_get_model_pipeline,
|
|
_load_train_df,
|
|
_resolve_feature_start,
|
|
_run_train_sync,
|
|
_MODEL_CLASS_REGISTRY,
|
|
)
|
|
|
|
|
|
def make_synthetic_csv(n_samples: int = 200, n_features: int = 8, noise: float = 0.1, seed: int = 42) -> Path:
|
|
"""生成 [lat, lon, target, lat2, lon2, feat_0, feat_1, ...] 布局的 CSV"""
|
|
rng = np.random.default_rng(seed)
|
|
lat = rng.uniform(20, 25, n_samples)
|
|
lon = rng.uniform(110, 115, n_samples)
|
|
target = rng.uniform(0, 50, n_samples)
|
|
lat2 = rng.uniform(0, 1, n_samples) # 元数据
|
|
lon2 = rng.uniform(0, 1, n_samples) # 元数据
|
|
feats = rng.normal(0, 1, (n_samples, n_features))
|
|
# 让 y 真正依赖前 3 个特征, RF 至少应该能学到 R² > 0.5
|
|
feats[:, 0] += target / 10
|
|
feats[:, 1] += target / 20
|
|
feats[:, 2] -= target / 15
|
|
|
|
df = pd.DataFrame({
|
|
"lat": lat,
|
|
"lon": lon,
|
|
"Chl-a": target,
|
|
"lat2": lat2,
|
|
"lon2": lon2,
|
|
**{f"feat_{i}": feats[:, i] for i in range(n_features)},
|
|
})
|
|
tmp = Path(tempfile.mkdtemp()) / "train.csv"
|
|
df.to_csv(tmp, index=False)
|
|
return tmp
|
|
|
|
|
|
def test_load_train_df():
|
|
print("== test_load_train_df ==")
|
|
p = make_synthetic_csv(n_samples=50)
|
|
df = _load_train_df(str(p))
|
|
assert df.shape == (50, 5 + 8), f"shape={df.shape}"
|
|
print(f" shape={df.shape}, columns[:6]={list(df.columns[:6])}")
|
|
print(" PASS")
|
|
|
|
|
|
def test_resolve_feature_start_int_and_str():
|
|
print("== test_resolve_feature_start (int + str) ==")
|
|
p = make_synthetic_csv()
|
|
df = _load_train_df(str(p))
|
|
idx_int = _resolve_feature_start(df, 5)
|
|
idx_str = _resolve_feature_start(df, "feat_0")
|
|
assert idx_int == 5 == idx_str, f"int={idx_int}, str={idx_str}"
|
|
print(f" int(5) -> {idx_int}, str('feat_0') -> {idx_str}")
|
|
print(" PASS")
|
|
|
|
|
|
def test_resolve_feature_start_str_miss():
|
|
print("== test_resolve_feature_start (str 不存在 -> 抛错) ==")
|
|
p = make_synthetic_csv()
|
|
df = _load_train_df(str(p))
|
|
try:
|
|
_resolve_feature_start(df, "not_exist")
|
|
print(" FAIL: 应抛 ValueError")
|
|
except ValueError as e:
|
|
print(f" 正确抛 ValueError: {e}")
|
|
print(" PASS")
|
|
|
|
|
|
def test_get_model_pipeline_all_types():
|
|
print("== test_get_model_pipeline (5 种 model_type) ==")
|
|
for mt in ["RF", "SVR", "LinearRegression", "KNN", "PLS"]:
|
|
p = _get_model_pipeline(mt, {})
|
|
assert len(p.steps) == 2
|
|
assert p.steps[0][0] == "scaler"
|
|
assert p.steps[1][0] == "model"
|
|
print(f" 全部通过: {list(_MODEL_CLASS_REGISTRY)}")
|
|
print(" PASS")
|
|
|
|
|
|
def test_get_model_pipeline_bad_type():
|
|
print("== test_get_model_pipeline (坏 model_type) ==")
|
|
try:
|
|
_get_model_pipeline("XGBoost", {})
|
|
print(" FAIL: 应抛 ValueError")
|
|
except ValueError as e:
|
|
print(f" 正确抛 ValueError: {e}")
|
|
print(" PASS")
|
|
|
|
|
|
def test_run_train_sync_rf_end_to_end():
|
|
print("== test_run_train_sync (RF 端到端) ==")
|
|
p = make_synthetic_csv(n_samples=200)
|
|
out_dir = Path(tempfile.mkdtemp())
|
|
out_path = out_dir / "model.joblib"
|
|
|
|
import time
|
|
t0 = time.time()
|
|
metadata = _run_train_sync(
|
|
model_type="RF",
|
|
target="Chl-a",
|
|
train_data_path=str(p),
|
|
feature_start=5,
|
|
params={"n_estimators": 30, "max_depth": 6, "random_state": 42, "n_jobs": 1},
|
|
output_model_path=out_path,
|
|
)
|
|
dt = time.time() - t0
|
|
|
|
assert out_path.exists(), f"joblib 未落盘: {out_path}"
|
|
print(f" joblib 落盘: {out_path} ({out_path.stat().st_size} bytes)")
|
|
print(f" metadata.test_r2={metadata['test_r2']:.4f} test_rmse={metadata['test_rmse']:.4f} test_mae={metadata['test_mae']:.4f}")
|
|
print(f" metadata.n_features={metadata['n_features']} n_samples={metadata['n_samples']} train_size={metadata['train_size']} test_size={metadata['test_size']}")
|
|
print(f" 耗时 {dt:.2f}s")
|
|
|
|
# 回测: 加载 joblib 再 predict
|
|
import joblib
|
|
saved = joblib.load(out_path)
|
|
assert "model" in saved and "metadata" in saved, f"joblib 双 key 缺失: {saved.keys()}"
|
|
assert hasattr(saved["model"], "predict")
|
|
assert saved["metadata"]["test_r2"] == metadata["test_r2"]
|
|
print(f" joblib 加载 OK, 含 'model' 和 'metadata' 双 key")
|
|
print(" PASS")
|
|
|
|
|
|
def test_run_train_sync_linearregression_fast():
|
|
print("== test_run_train_sync (LinearRegression 快速路径) ==")
|
|
p = make_synthetic_csv(n_samples=150)
|
|
out_path = Path(tempfile.mkdtemp()) / "lr.joblib"
|
|
metadata = _run_train_sync(
|
|
model_type="LinearRegression",
|
|
target="Chl-a",
|
|
train_data_path=str(p),
|
|
feature_start=5,
|
|
params={},
|
|
output_model_path=out_path,
|
|
)
|
|
print(f" test_r2={metadata['test_r2']:.4f} (LR 学到线性, R² 应 >= 0.4)")
|
|
assert metadata["test_r2"] > 0.3, f"LR test_r2={metadata['test_r2']} 太低, 数据生成可能有问题"
|
|
print(" PASS")
|
|
|
|
|
|
def test_run_train_sync_bad_csv():
|
|
print("== test_run_train_sync (CSV 不存在) ==")
|
|
try:
|
|
_run_train_sync("RF", "Chl-a", "/no/such/path.csv", 5, {}, Path("/tmp/x.joblib"))
|
|
print(" FAIL: 应抛异常")
|
|
except (FileNotFoundError, ValueError) as e:
|
|
print(f" 正确抛 {type(e).__name__}: {e}")
|
|
print(" PASS")
|
|
|
|
|
|
def test_run_train_sync_bad_target():
|
|
print("== test_run_train_sync (target 列不存在) ==")
|
|
p = make_synthetic_csv()
|
|
try:
|
|
_run_train_sync("RF", "NopeTarget", str(p), 5, {}, Path("/tmp/x.joblib"))
|
|
print(" FAIL: 应抛 ValueError")
|
|
except ValueError as e:
|
|
print(f" 正确抛 ValueError: {e}")
|
|
print(" PASS")
|
|
|
|
|
|
def test_run_train_sync_str_feature_start():
|
|
print("== test_run_train_sync (feature_start 用列名) ==")
|
|
p = make_synthetic_csv()
|
|
out_path = Path(tempfile.mkdtemp()) / "str_fs.joblib"
|
|
metadata = _run_train_sync("RF", "Chl-a", str(p), "feat_0", {"n_estimators": 10}, out_path)
|
|
assert metadata["feature_start"] == "feat_0"
|
|
assert metadata["n_features"] == 8
|
|
assert metadata["feature_columns"][0] == "feat_0"
|
|
print(f" 列名 'feat_0' 解析正确, n_features={metadata['n_features']}")
|
|
print(" PASS")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_load_train_df()
|
|
test_resolve_feature_start_int_and_str()
|
|
test_resolve_feature_start_str_miss()
|
|
test_get_model_pipeline_all_types()
|
|
test_get_model_pipeline_bad_type()
|
|
test_run_train_sync_rf_end_to_end()
|
|
test_run_train_sync_linearregression_fast()
|
|
test_run_train_sync_bad_csv()
|
|
test_run_train_sync_bad_target()
|
|
test_run_train_sync_str_feature_start()
|
|
print("\n>>> ALL SMOKE TESTS PASSED")
|