403 lines
17 KiB
Python
403 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
预测步骤
|
||
|
||
包含 step7_generate_sampling_points, step8_predict_water_quality,
|
||
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
|
||
"""
|
||
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Optional, List, Union, Callable, Dict
|
||
|
||
|
||
class PredictionStep:
|
||
"""预测步骤"""
|
||
|
||
# ---- Step 7: 生成采样点并提取光谱 ----
|
||
|
||
@staticmethod
|
||
def generate_sampling_points(
|
||
deglint_img_path: Optional[str] = None,
|
||
interval: int = 50,
|
||
sample_radius: int = 5,
|
||
chunk_size: int = 1000,
|
||
water_mask_path: Optional[str] = None,
|
||
glint_mask_path: Optional[str] = None,
|
||
output_dir: Union[str, Path] = "./4_sampling",
|
||
callback: Optional[Callable] = None,
|
||
use_adaptive_sampling: bool = True,
|
||
) -> str:
|
||
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
|
||
from pathlib import Path
|
||
from src.utils.sampling import get_spectral_sampling_points_chunked
|
||
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
output_path = str(output_dir / "sampling_spectra.csv")
|
||
|
||
def notify(status, msg=""):
|
||
if callback:
|
||
callback("步骤7", status, msg)
|
||
|
||
print("\n" + "=" * 80)
|
||
print("步骤7: 生成预测采样点并提取光谱")
|
||
print("=" * 80)
|
||
|
||
step_start_time = time.time()
|
||
|
||
if deglint_img_path is None:
|
||
raise ValueError("必须提供 deglint_img_path 参数")
|
||
|
||
# 1. 初始归一化与安全转换
|
||
original_path = Path(deglint_img_path)
|
||
final_deglint_path = original_path
|
||
|
||
# 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat
|
||
if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat':
|
||
print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...")
|
||
|
||
# 定位到预期的 3_deglint 根目录
|
||
possible_dir = original_path.parent
|
||
if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists():
|
||
possible_dir = Path(output_path).parent.parent / "3_deglint"
|
||
|
||
if possible_dir.exists():
|
||
# 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物)
|
||
existing_bsqs = list(possible_dir.glob("*.bsq"))
|
||
if existing_bsqs:
|
||
final_deglint_path = existing_bsqs[0]
|
||
print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}")
|
||
else:
|
||
final_deglint_path = original_path.with_suffix('.bsq')
|
||
else:
|
||
final_deglint_path = original_path.with_suffix('.bsq')
|
||
|
||
deglint_img_str = str(final_deglint_path)
|
||
|
||
if Path(output_path).exists():
|
||
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
|
||
notify("skipped", f"采样点光谱数据已设置: {output_path}")
|
||
return output_path
|
||
|
||
glint_mask_to_use = glint_mask_path
|
||
if glint_mask_to_use is None:
|
||
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
|
||
|
||
# 传递极度安全的 deglint_img_str 进底层(关键字传参,避免 positional 参数顺序陷阱)
|
||
get_spectral_sampling_points_chunked(
|
||
deglint_img_str, water_mask_path, glint_mask_to_use,
|
||
output_path,
|
||
interval=interval,
|
||
sample_radius=sample_radius,
|
||
chunk_size=chunk_size,
|
||
use_adaptive_sampling=use_adaptive_sampling,
|
||
)
|
||
|
||
notify("completed", f"采样点光谱数据已保存: {output_path}")
|
||
return output_path
|
||
|
||
# ---- Step 8: 机器学习模型预测水质参数 ----
|
||
|
||
@staticmethod
|
||
def predict_water_quality(
|
||
sampling_csv_path: str,
|
||
models_dir: Optional[str] = None,
|
||
metric: str = "test_r2",
|
||
prediction_column: str = "prediction",
|
||
output_dir: Union[str, Path] = "./9_ML_Prediction",
|
||
callback: Optional[Callable] = None,
|
||
_report_generator=None,
|
||
_external_model=None,
|
||
_external_model_path=None,
|
||
_external_models_dict=None,
|
||
_external_model_dir=None,
|
||
) -> Dict[str, str]:
|
||
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||
from src.core.prediction.inference_batch import WaterQualityInference
|
||
|
||
def notify(status, msg=""):
|
||
if callback:
|
||
callback("步骤8", status, msg)
|
||
|
||
print("\n" + "=" * 80)
|
||
print("步骤8: 预测水质参数")
|
||
print("=" * 80)
|
||
print(f"[PredictionStep] 准备执行预测,字典状态: {'Yes' if _external_models_dict else 'No'}"
|
||
f", 单模型状态: {'Yes' if _external_model else 'No'}")
|
||
|
||
step_start_time = time.time()
|
||
|
||
if models_dir is None:
|
||
raise ValueError("必须提供 models_dir 参数")
|
||
|
||
ml_prediction_dir = Path(output_dir)
|
||
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
prediction_files = {}
|
||
if ml_prediction_dir.exists():
|
||
csv_files = list(ml_prediction_dir.glob("*.csv"))
|
||
for csv_file in csv_files:
|
||
file_stem = csv_file.stem
|
||
if "_prediction" in file_stem:
|
||
target_name = file_stem.replace("_prediction", "")
|
||
elif "_pred" in file_stem:
|
||
target_name = file_stem.replace("_pred", "")
|
||
else:
|
||
target_name = file_stem
|
||
prediction_files[target_name] = str(csv_file)
|
||
|
||
# 检查是否所有目标参数都有预测文件
|
||
if prediction_files:
|
||
models_path_obj = Path(models_dir)
|
||
if models_path_obj.exists():
|
||
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
|
||
missing_targets = [t for t in target_folders if t not in prediction_files]
|
||
if not missing_targets:
|
||
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
|
||
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
|
||
return prediction_files
|
||
else:
|
||
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||
|
||
all_results = {}
|
||
|
||
if _external_models_dict:
|
||
# 外部模型字典优先:直接用字典的 keys 作为 targets 列表,
|
||
# 手动为每个模型创建 inference 实例并调用 inference_pipeline。
|
||
print(f"\n使用外部导入模型字典({len(_external_models_dict)} 个模型)...")
|
||
for target_name, model_obj in _external_models_dict.items():
|
||
try:
|
||
output_file = ml_prediction_dir / f"{target_name}.csv"
|
||
model_inferencer = WaterQualityInference(
|
||
models_dir or "./",
|
||
external_model=model_obj,
|
||
external_model_path=_external_model_dir or "",
|
||
)
|
||
predictions, result_df = model_inferencer.inference_pipeline(
|
||
sampling_csv_path=sampling_csv_path,
|
||
output_csv_path=str(output_file),
|
||
metric=metric,
|
||
prediction_column=prediction_column,
|
||
)
|
||
prediction_files[target_name] = str(output_file)
|
||
all_results[target_name] = {
|
||
"status": "success",
|
||
"output_file": str(output_file),
|
||
"sample_count": len(predictions),
|
||
}
|
||
print(f" ✓ {target_name}: {len(predictions)} 个预测值")
|
||
except Exception as e:
|
||
print(f" ✗ {target_name}: 失败 — {type(e).__name__}: {e}")
|
||
prediction_files[target_name] = None
|
||
all_results[target_name] = {"status": "error", "error": str(e)}
|
||
else:
|
||
# 字典为空或不存在:回退到扫描 models_dir 子目录的传统逻辑
|
||
inferencer = WaterQualityInference(
|
||
models_dir,
|
||
external_model=_external_model,
|
||
external_model_path=_external_model_path,
|
||
)
|
||
all_results = inferencer.batch_inference_multi_models(
|
||
models_root_dir=models_dir,
|
||
sampling_csv_path=sampling_csv_path,
|
||
output_dir=str(ml_prediction_dir),
|
||
metric=metric,
|
||
prediction_column=prediction_column,
|
||
output_format="csv",
|
||
external_model=_external_model,
|
||
external_model_path=_external_model_path,
|
||
external_models_dict=_external_models_dict,
|
||
)
|
||
# batch_inference_multi_models 已确保返回字典,永不返回 None
|
||
if all_results:
|
||
for target_name, result in all_results.items():
|
||
if result.get("status") == "success":
|
||
prediction_files[target_name] = result["output_file"]
|
||
|
||
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||
|
||
if _report_generator is not None:
|
||
try:
|
||
report_path = _report_generator.generate_prediction_report(prediction_files)
|
||
print(f"预测结果报告已生成: {report_path}")
|
||
except Exception as e:
|
||
print(f"生成预测结果报告时出错: {e}")
|
||
|
||
notify("completed", f"预测完成: {ml_prediction_dir}")
|
||
return prediction_files
|
||
|
||
# ---- Step 8.5: 非经验模型预测 ----
|
||
|
||
@staticmethod
|
||
def predict_with_non_empirical_models(
|
||
sampling_csv_path: str,
|
||
non_empirical_models_dir: Optional[str] = None,
|
||
output_dir: Optional[str] = None,
|
||
metric: str = "Average Accuracy(%)",
|
||
prediction_column: str = "prediction",
|
||
enabled: bool = True,
|
||
callback: Optional[Callable] = None,
|
||
work_dir: Union[str, Path] = "./work_dir",
|
||
) -> Dict[str, str]:
|
||
"""使用非经验统计回归模型进行参数预测"""
|
||
def notify(status, msg=""):
|
||
if callback:
|
||
callback("步骤8.5", status, msg)
|
||
|
||
print("\n" + "=" * 80)
|
||
print("步骤8.5: 使用非经验模型进行参数预测")
|
||
print("=" * 80)
|
||
|
||
step_start_time = time.time()
|
||
|
||
if not enabled:
|
||
print("已设置跳过非经验模型预测(enabled=False)。")
|
||
notify("skipped", "跳过非经验模型预测")
|
||
return {}
|
||
|
||
if non_empirical_models_dir is not None:
|
||
final_models_dir = non_empirical_models_dir
|
||
else:
|
||
default_models_dir = str(Path(work_dir) / "8_Non_Empirical_Regression")
|
||
if Path(default_models_dir).exists():
|
||
final_models_dir = default_models_dir
|
||
else:
|
||
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
|
||
|
||
if output_dir is not None:
|
||
non_empirical_prediction_dir = Path(output_dir)
|
||
else:
|
||
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
|
||
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
prediction_files = {}
|
||
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
|
||
if not summary_path.exists():
|
||
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
|
||
|
||
import pandas as pd
|
||
df_summary = pd.read_csv(summary_path)
|
||
|
||
best_models = {}
|
||
for algorithm in df_summary["Algorithm Name"].unique():
|
||
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
|
||
if metric in algorithm_df.columns:
|
||
best_model_row = algorithm_df.nlargest(1, metric)
|
||
else:
|
||
best_model_row = algorithm_df.iloc[[0]]
|
||
|
||
best_model_path = best_model_row["Model File"].values[0]
|
||
best_preprocess = best_model_row["Preprocessing Method"].values[0]
|
||
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
|
||
|
||
best_models[algorithm] = {
|
||
"model_path": best_model_path,
|
||
"preprocess_method": best_preprocess,
|
||
"accuracy": best_accuracy,
|
||
}
|
||
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
|
||
|
||
pd.read_csv(sampling_csv_path) # just to validate
|
||
|
||
for algorithm, model_info in best_models.items():
|
||
print(f"\n使用 {algorithm} 算法进行预测...")
|
||
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
|
||
|
||
if Path(output_path).exists():
|
||
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
|
||
prediction_files[algorithm] = output_path
|
||
continue
|
||
|
||
try:
|
||
from src.core.non_empirical_retrieval import non_empirical_retrieval
|
||
non_empirical_retrieval(
|
||
algorithm=algorithm,
|
||
model_info_path=model_info["model_path"],
|
||
coor_spectral_path=sampling_csv_path,
|
||
output_path=output_path,
|
||
wave_radius=5,
|
||
)
|
||
prediction_files[algorithm] = output_path
|
||
print(f"预测完成: {output_path}")
|
||
except Exception as e:
|
||
print(f"使用 {algorithm} 算法预测时出错: {e}")
|
||
continue
|
||
|
||
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
|
||
return prediction_files
|
||
|
||
# ---- Step 8.75: 自定义回归模型预测 ----
|
||
|
||
@staticmethod
|
||
def predict_with_custom_regression(
|
||
sampling_csv_path: str,
|
||
custom_regression_dir: Optional[str] = None,
|
||
formula_csv_path: Optional[str] = None,
|
||
coordinate_columns: Optional[List[str]] = None,
|
||
output_dir: Optional[str] = None,
|
||
filename_prefix: str = "custom_regression_prediction",
|
||
enabled: bool = True,
|
||
callback: Optional[Callable] = None,
|
||
work_dir: Union[str, Path] = "./work_dir",
|
||
) -> Dict[str, str]:
|
||
"""使用自定义回归模型进行参数预测"""
|
||
def notify(status, msg=""):
|
||
if callback:
|
||
callback("步骤8.75", status, msg)
|
||
|
||
print("\n" + "=" * 80)
|
||
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||
print("=" * 80)
|
||
|
||
step_start_time = time.time()
|
||
|
||
if not enabled:
|
||
print("已设置跳过自定义回归模型预测(enabled=False)。")
|
||
notify("skipped", "跳过自定义回归预测")
|
||
return {}
|
||
|
||
if not Path(sampling_csv_path).exists():
|
||
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||
|
||
if custom_regression_dir is not None:
|
||
final_regression_dir = custom_regression_dir
|
||
else:
|
||
final_regression_dir = str(Path(work_dir) / "13_Custom_Regression")
|
||
if not Path(final_regression_dir).exists():
|
||
raise ValueError(
|
||
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
|
||
)
|
||
|
||
if output_dir is None:
|
||
custom_regression_prediction_dir = Path(work_dir) / "13_Custom_Regression" / "Custom_Regression_Prediction"
|
||
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||
prediction_output_dir = str(custom_regression_prediction_dir)
|
||
else:
|
||
prediction_output_dir = output_dir
|
||
|
||
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||
|
||
predictor = CustomRegressionPredictor(
|
||
regression_csv_dir=final_regression_dir,
|
||
formula_csv_path=formula_csv_path,
|
||
)
|
||
|
||
print(f"开始使用自定义回归模块进行批量预测...")
|
||
print(f" 采样点数据: {sampling_csv_path}")
|
||
print(f" 回归模型目录: {final_regression_dir}")
|
||
print(f" 输出目录: {prediction_output_dir}")
|
||
|
||
saved_files = predictor.run_batch_prediction(
|
||
input_csv_path=sampling_csv_path,
|
||
coordinate_columns=coordinate_columns,
|
||
filename_prefix=filename_prefix,
|
||
)
|
||
|
||
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||
for param_name, filepath in saved_files.items():
|
||
print(f" {param_name}: {filepath}")
|
||
|
||
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
|
||
return saved_files
|