refactor(step4): 剥离 Steps 层 - step4~step9 业务逻辑下沉到独立模块
This commit is contained in:
@ -4,9 +4,17 @@
|
|||||||
from src.core.steps.water_mask_step import WaterMaskStep
|
from src.core.steps.water_mask_step import WaterMaskStep
|
||||||
from src.core.steps.glint_detection_step import GlintDetectionStep
|
from src.core.steps.glint_detection_step import GlintDetectionStep
|
||||||
from src.core.steps.glint_removal_step import GlintRemovalStep
|
from src.core.steps.glint_removal_step import GlintRemovalStep
|
||||||
|
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||||
|
from src.core.steps.modeling_step import ModelingStep
|
||||||
|
from src.core.steps.prediction_step import PredictionStep
|
||||||
|
from src.core.steps.mapping_step import MappingStep
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WaterMaskStep",
|
"WaterMaskStep",
|
||||||
"GlintDetectionStep",
|
"GlintDetectionStep",
|
||||||
"GlintRemovalStep",
|
"GlintRemovalStep",
|
||||||
|
"DataPreparationStep",
|
||||||
|
"ModelingStep",
|
||||||
|
"PredictionStep",
|
||||||
|
"MappingStep",
|
||||||
]
|
]
|
||||||
|
|||||||
171
src/core/steps/data_preparation_step.py
Normal file
171
src/core/steps/data_preparation_step.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
数据准备步骤
|
||||||
|
|
||||||
|
包含 step4_process_csv, step5_extract_training_spectra, step5_5_calculate_water_quality_indices
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union, Callable, Dict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DataPreparationStep:
|
||||||
|
"""数据准备步骤"""
|
||||||
|
|
||||||
|
# ---- Step 4: 处理CSV文件 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_csv(
|
||||||
|
csv_path: str,
|
||||||
|
output_dir: Union[str, Path] = "./4_processed_data",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""处理CSV文件(筛选剔除异常值)"""
|
||||||
|
from src.preprocessing.process_water_quality_data import process_water_quality_data
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = str(output_dir / "processed_data.csv")
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤4", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤4: 处理CSV文件,筛选剔除异常值")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的处理后CSV文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"处理后的CSV文件已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
process_water_quality_data(csv_path, output_path)
|
||||||
|
notify("completed", f"处理后的CSV文件已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# ---- Step 5: 提取训练样本点光谱 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_training_spectra(
|
||||||
|
deglint_img_path: Optional[str] = None,
|
||||||
|
radius: int = 5,
|
||||||
|
source_epsg: int = 4326,
|
||||||
|
csv_path: Optional[str] = None,
|
||||||
|
boundary_path: Optional[str] = None,
|
||||||
|
glint_mask_path: Optional[str] = None,
|
||||||
|
water_mask_path: Optional[str] = None,
|
||||||
|
output_dir: Union[str, Path] = "./5_training_spectra",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||||
|
from src.core.glint_removal.get_spectral import get_spectral_in_coor
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = str(output_dir / "training_spectra.csv")
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤5: 提取训练样本点的平均光谱")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if deglint_img_path is None:
|
||||||
|
raise ValueError("必须提供 deglint_img_path 参数")
|
||||||
|
if csv_path is None:
|
||||||
|
raise ValueError("必须提供 csv_path 参数")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的训练光谱数据文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"训练光谱数据已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# 确保水体掩膜存在
|
||||||
|
final_boundary_path = boundary_path
|
||||||
|
if final_boundary_path is None and water_mask_path is not None:
|
||||||
|
final_boundary_path = water_mask_path
|
||||||
|
|
||||||
|
flare_path = glint_mask_path
|
||||||
|
if flare_path:
|
||||||
|
print(f"光谱提取使用耀斑掩膜: {flare_path}")
|
||||||
|
|
||||||
|
get_spectral_in_coor(
|
||||||
|
deglint_img_path, csv_path, output_path,
|
||||||
|
radius=radius, flare_path=flare_path,
|
||||||
|
boundary_path=final_boundary_path, source_epsg=source_epsg
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"训练光谱数据已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# ---- Step 5.5: 计算水质光谱指数 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def calculate_water_quality_indices(
|
||||||
|
training_spectra_path: Optional[str] = None,
|
||||||
|
formula_csv_file: Optional[str] = None,
|
||||||
|
formula_names: Optional[List[str]] = None,
|
||||||
|
output_file: Optional[str] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
output_dir: Union[str, Path] = "./6_water_quality_indices",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤5.5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤5.5: 计算水质光谱指数(使用band_math方法)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过水质指数计算(enabled=False)。")
|
||||||
|
notify("skipped", "跳过水质指数计算")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if training_spectra_path is None:
|
||||||
|
raise ValueError("必须提供 training_spectra_path 参数")
|
||||||
|
if formula_csv_file is None:
|
||||||
|
raise ValueError("必须提供 formula_csv_file 参数")
|
||||||
|
|
||||||
|
if output_file:
|
||||||
|
output_path = str(Path(output_file))
|
||||||
|
else:
|
||||||
|
output_path = str(output_dir / "water_quality_indices.csv")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的水质指数文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"水质指数数据已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
from src.utils.band_math import BandMathCalculator
|
||||||
|
|
||||||
|
calculator = BandMathCalculator(training_spectra_path)
|
||||||
|
result_df = calculator.process_formulas_from_csv(
|
||||||
|
formula_csv_file=formula_csv_file,
|
||||||
|
formula_names=formula_names,
|
||||||
|
output_file=output_path
|
||||||
|
)
|
||||||
|
|
||||||
|
if result_df is None:
|
||||||
|
raise ValueError("计算水质指数失败,请检查公式CSV文件格式")
|
||||||
|
|
||||||
|
notify("completed", f"水质指数已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
109
src/core/steps/mapping_step.py
Normal file
109
src/core/steps/mapping_step.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
成图步骤
|
||||||
|
|
||||||
|
包含 step9_generate_distribution_map
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union, Callable
|
||||||
|
|
||||||
|
|
||||||
|
class MappingStep:
|
||||||
|
"""成图步骤"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_distribution_map(
|
||||||
|
prediction_csv_path: str,
|
||||||
|
boundary_shp_path: str,
|
||||||
|
output_image_path: Optional[str] = None,
|
||||||
|
resolution: float = 30,
|
||||||
|
input_crs: str = "EPSG:32651",
|
||||||
|
output_crs: str = "EPSG:4326",
|
||||||
|
show_sample_points: bool = False,
|
||||||
|
base_map_tif: Optional[str] = None,
|
||||||
|
use_distance_diffusion: bool = True,
|
||||||
|
max_diffusion_distance: Optional[float] = None,
|
||||||
|
diffusion_power: float = 2,
|
||||||
|
diffusion_n_neighbors: int = 15,
|
||||||
|
cmap: Optional[str] = None,
|
||||||
|
expand_ratio: float = 0.05,
|
||||||
|
output_dir: Union[str, Path] = "./14_visualization",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
根据采样点的坐标和反演的实测参数,通过插值方法得到水质参数可视化分布图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prediction_csv_path: 预测结果CSV文件路径(前两列为经纬度,第三列为预测值)
|
||||||
|
boundary_shp_path: 边界shapefile文件路径
|
||||||
|
output_image_path: 输出图片路径(如果为None,自动生成)
|
||||||
|
resolution: 插值网格分辨率(米)
|
||||||
|
input_crs: 输入坐标系
|
||||||
|
output_crs: 输出坐标系
|
||||||
|
show_sample_points: 是否在图上显示采样点
|
||||||
|
base_map_tif: 底图TIF路径
|
||||||
|
use_distance_diffusion: 是否启用距离扩散补全边界
|
||||||
|
max_diffusion_distance: 距离扩散最大距离(米)
|
||||||
|
diffusion_power: 距离扩散幂参数
|
||||||
|
diffusion_n_neighbors: 距离扩散最近邻数量
|
||||||
|
cmap: 颜色映射名称(None表示自动识别)
|
||||||
|
expand_ratio: 边界外扩比例(0-1之间)
|
||||||
|
output_dir: 输出目录
|
||||||
|
callback: 回调函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
可视化分布图文件路径
|
||||||
|
"""
|
||||||
|
from src.postprocessing.map import ContentMapper
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤9", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤9: 生成水质参数可视化分布图")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if output_image_path is None:
|
||||||
|
csv_name = Path(prediction_csv_path).stem
|
||||||
|
output_image_path = str(output_dir / f"{csv_name}_distribution.png")
|
||||||
|
|
||||||
|
if Path(output_image_path).exists():
|
||||||
|
print(f"检测到已存在的分布图文件,直接使用: {output_image_path}")
|
||||||
|
notify("skipped", f"可视化分布图已设置: {output_image_path}")
|
||||||
|
return output_image_path
|
||||||
|
|
||||||
|
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
|
||||||
|
|
||||||
|
mapper_kwargs = {
|
||||||
|
"resolution": resolution,
|
||||||
|
"show_sample_points": show_sample_points,
|
||||||
|
"use_distance_diffusion": use_distance_diffusion,
|
||||||
|
"diffusion_power": diffusion_power,
|
||||||
|
"diffusion_n_neighbors": diffusion_n_neighbors,
|
||||||
|
"expand_ratio": expand_ratio,
|
||||||
|
}
|
||||||
|
|
||||||
|
optional_kwargs = {
|
||||||
|
"base_map_tif": base_map_tif,
|
||||||
|
"max_diffusion_distance": max_diffusion_distance,
|
||||||
|
"cmap": cmap,
|
||||||
|
}
|
||||||
|
mapper_kwargs.update({k: v for k, v in optional_kwargs.items() if v is not None})
|
||||||
|
|
||||||
|
mapper.process_data(
|
||||||
|
csv_file=prediction_csv_path,
|
||||||
|
shp_file=boundary_shp_path,
|
||||||
|
output_file=output_image_path,
|
||||||
|
**mapper_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"可视化分布图已保存: {output_image_path}")
|
||||||
|
return output_image_path
|
||||||
380
src/core/steps/modeling_step.py
Normal file
380
src/core/steps/modeling_step.py
Normal file
@ -0,0 +1,380 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
建模步骤
|
||||||
|
|
||||||
|
包含 step6_train_models, step6_5_non_empirical_modeling, step6_75_custom_regression
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union, Callable, Dict
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ModelingStep:
|
||||||
|
"""建模步骤"""
|
||||||
|
|
||||||
|
# ---- Step 6: 训练机器学习模型 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def train_models(
|
||||||
|
feature_start_column: str = "374.285004",
|
||||||
|
preprocessing_methods: Optional[List[str]] = None,
|
||||||
|
model_names: Optional[List[str]] = None,
|
||||||
|
split_methods: Optional[List[str]] = None,
|
||||||
|
cv_folds: int = 5,
|
||||||
|
training_csv_path: Optional[str] = None,
|
||||||
|
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
_report_generator=None,
|
||||||
|
) -> str:
|
||||||
|
"""使用采样点光谱和实测值建立机器学习模型"""
|
||||||
|
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤6", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤6: 训练机器学习模型")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if training_csv_path is None:
|
||||||
|
raise ValueError("必须提供 training_csv_path 参数")
|
||||||
|
|
||||||
|
# 检查模型目录是否已有模型
|
||||||
|
if output_dir.exists() and any(output_dir.iterdir()):
|
||||||
|
has_models = False
|
||||||
|
for item in output_dir.iterdir():
|
||||||
|
if item.is_dir():
|
||||||
|
model_files = (
|
||||||
|
list(item.glob("*.pkl"))
|
||||||
|
+ list(item.glob("*.joblib"))
|
||||||
|
+ list(item.glob("*.h5"))
|
||||||
|
)
|
||||||
|
if model_files:
|
||||||
|
has_models = True
|
||||||
|
break
|
||||||
|
if has_models:
|
||||||
|
print(f"检测到已存在的模型文件,直接使用: {output_dir}")
|
||||||
|
notify("skipped", f"模型目录已设置: {output_dir}")
|
||||||
|
return str(output_dir)
|
||||||
|
|
||||||
|
if preprocessing_methods is None:
|
||||||
|
preprocessing_methods = ["None", "MMS", "SS", "SNV", "MA", "SG", "MSC", "D1", "D2", "DT", "CT"]
|
||||||
|
if model_names is None:
|
||||||
|
model_names = ["SVR", "RF", "Ridge", "Lasso"]
|
||||||
|
if split_methods is None:
|
||||||
|
split_methods = ["spxy", "ks", "random"]
|
||||||
|
|
||||||
|
modeler = WaterQualityModelingBatch(str(output_dir))
|
||||||
|
modeler.train_models_batch(
|
||||||
|
csv_path=training_csv_path,
|
||||||
|
feature_start_column=feature_start_column,
|
||||||
|
preprocessing_methods=preprocessing_methods,
|
||||||
|
model_names=model_names,
|
||||||
|
split_methods=split_methods,
|
||||||
|
cv_folds=cv_folds,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"模型训练完成,结果保存在: {output_dir}")
|
||||||
|
|
||||||
|
if _report_generator is not None:
|
||||||
|
try:
|
||||||
|
summary_path = _report_generator.generate_training_summary(str(output_dir))
|
||||||
|
print(f"训练摘要报告已生成: {summary_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"生成训练摘要报告时出错: {e}")
|
||||||
|
|
||||||
|
notify("completed", f"模型训练完成: {output_dir}")
|
||||||
|
return str(output_dir)
|
||||||
|
|
||||||
|
# ---- Step 6.5: 非经验统计回归模型训练 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def train_non_empirical_models(
|
||||||
|
csv_path: Optional[str] = None,
|
||||||
|
preprocessing_methods: Optional[List[str]] = None,
|
||||||
|
algorithms: Optional[List[str]] = None,
|
||||||
|
value_cols: Union[int, Dict[str, int]] = 0,
|
||||||
|
spectral_start_col: int = 1,
|
||||||
|
spectral_end_col: Optional[int] = None,
|
||||||
|
window: int = 5,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""非经验统计回归模型训练"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤6.5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤6.5: 非经验统计回归模型训练")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过非经验模型训练(enabled=False)。")
|
||||||
|
notify("skipped", "跳过的经验模型训练")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if csv_path is None:
|
||||||
|
raise ValueError("必须提供 csv_path 参数")
|
||||||
|
|
||||||
|
if output_dir is not None:
|
||||||
|
non_empirical_dir = Path(output_dir)
|
||||||
|
else:
|
||||||
|
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
||||||
|
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if preprocessing_methods is None:
|
||||||
|
preprocessing_methods = ["None"]
|
||||||
|
if algorithms is None:
|
||||||
|
algorithms = ["chl_a", "nh3", "mno4", "tn", "tp", "tss"]
|
||||||
|
|
||||||
|
if isinstance(value_cols, int):
|
||||||
|
value_cols_dict = {algorithm: value_cols for algorithm in algorithms}
|
||||||
|
elif isinstance(value_cols, dict):
|
||||||
|
value_cols_dict = value_cols
|
||||||
|
else:
|
||||||
|
raise ValueError("value_cols 参数必须是整数或字典")
|
||||||
|
|
||||||
|
if spectral_end_col is None:
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
spectral_end_col = len(df.columns) - 1
|
||||||
|
|
||||||
|
all_model_results = {}
|
||||||
|
|
||||||
|
for preprocess in preprocessing_methods:
|
||||||
|
preprocess_dir = non_empirical_dir / preprocess
|
||||||
|
preprocess_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
processed_csv_path = _apply_preprocessing_internal(
|
||||||
|
csv_path, preprocess, preprocess_dir, spectral_start_col
|
||||||
|
)
|
||||||
|
|
||||||
|
for algorithm in algorithms:
|
||||||
|
algorithm_value_col = value_cols_dict[algorithm]
|
||||||
|
print(f"\n训练 {preprocess} + {algorithm} 模型 (实测值列: {algorithm_value_col})...")
|
||||||
|
|
||||||
|
model_outpath = str(preprocess_dir / f"{preprocess}_{algorithm}.json")
|
||||||
|
|
||||||
|
if Path(model_outpath).exists():
|
||||||
|
print(f"检测到已存在的模型文件,直接使用: {model_outpath}")
|
||||||
|
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.core.non_empirical_model_correction import run_model_correction
|
||||||
|
run_model_correction(
|
||||||
|
algorithm=algorithm,
|
||||||
|
csv_file=processed_csv_path if Path(processed_csv_path).exists() else csv_path,
|
||||||
|
value_col=algorithm_value_col,
|
||||||
|
spectral_start=spectral_start_col,
|
||||||
|
spectral_end=spectral_end_col,
|
||||||
|
model_info_outpath=model_outpath,
|
||||||
|
window=window,
|
||||||
|
)
|
||||||
|
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||||
|
print(f"模型训练完成: {model_outpath}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"训练 {preprocess}_{algorithm} 模型时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
summary_path = _generate_non_empirical_summary(all_model_results, non_empirical_dir)
|
||||||
|
notify("completed", f"非经验模型训练完成: {non_empirical_dir}")
|
||||||
|
return all_model_results
|
||||||
|
|
||||||
|
# ---- Step 6.75: 自定义回归分析 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def custom_regression(
|
||||||
|
csv_path: Optional[str] = None,
|
||||||
|
x_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
y_columns: Optional[Union[str, List[str]]] = None,
|
||||||
|
methods: Union[str, List[str]] = "all",
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
work_dir: Union[str, Path] = "./work_dir",
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""使用自定义回归方法分析指标与目标参数之间的关系"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤6.75", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤6.75: 自定义回归分析")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过自定义回归分析(enabled=False)。")
|
||||||
|
notify("skipped", "跳过自定义回归分析")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if csv_path is None:
|
||||||
|
raise ValueError("必须提供 csv_path 参数")
|
||||||
|
if y_columns is None:
|
||||||
|
raise ValueError("必须指定 y_columns")
|
||||||
|
if x_columns is None:
|
||||||
|
raise ValueError("必须指定 x_columns")
|
||||||
|
|
||||||
|
if isinstance(x_columns, str):
|
||||||
|
x_columns = [x_columns]
|
||||||
|
if isinstance(y_columns, str):
|
||||||
|
y_columns = [y_columns]
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
missing_x = [col for col in x_columns if col not in df.columns]
|
||||||
|
missing_y = [col for col in y_columns if col not in df.columns]
|
||||||
|
if missing_x:
|
||||||
|
raise ValueError(f"自变量列不存在: {missing_x}")
|
||||||
|
if missing_y:
|
||||||
|
raise ValueError(f"因变量列不存在: {missing_y}")
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
custom_regression_dir = Path(work_dir) / "9_Custom_Regression_Modeling"
|
||||||
|
else:
|
||||||
|
custom_regression_dir = Path(work_dir) / output_dir
|
||||||
|
custom_regression_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
from src.core.modeling.regression import SingleVariableRegressionAnalysis
|
||||||
|
analyzer = SingleVariableRegressionAnalysis()
|
||||||
|
analyzer.batch_single_variable_regression(
|
||||||
|
data=df,
|
||||||
|
x_columns=x_columns,
|
||||||
|
y_columns=y_columns,
|
||||||
|
methods=methods,
|
||||||
|
output_dir=str(custom_regression_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"自定义回归结果已保存到目录: {custom_regression_dir}")
|
||||||
|
return str(custom_regression_dir)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 内部辅助函数(供 ModelingStep 内部使用)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def _apply_preprocessing_internal(
|
||||||
|
csv_path: str,
|
||||||
|
preprocess_method: str,
|
||||||
|
output_dir: Path,
|
||||||
|
spectral_start_col: int = 4,
|
||||||
|
) -> str:
|
||||||
|
"""应用预处理到CSV数据(内部函数)"""
|
||||||
|
raw_p = str(preprocess_method).lower()
|
||||||
|
if raw_p == "none" or "无" in raw_p or "跳过" in raw_p:
|
||||||
|
preprocess_method = "None"
|
||||||
|
elif raw_p == "mms" or "minmax" in raw_p or "最大最小" in raw_p:
|
||||||
|
preprocess_method = "MMS"
|
||||||
|
elif raw_p == "ss" or "标准" in raw_p or "标准化" in raw_p:
|
||||||
|
preprocess_method = "SS"
|
||||||
|
elif raw_p == "snv" or "标准正态" in raw_p:
|
||||||
|
preprocess_method = "SNV"
|
||||||
|
elif raw_p == "ma" or "移动" in raw_p:
|
||||||
|
preprocess_method = "MA"
|
||||||
|
elif raw_p == "sg" or "savitzky" in raw_p or "平滑" in raw_p:
|
||||||
|
preprocess_method = "SG"
|
||||||
|
elif raw_p == "msc" or "多元散射" in raw_p:
|
||||||
|
preprocess_method = "MSC"
|
||||||
|
elif raw_p in ("d1", "d2", "dt"):
|
||||||
|
preprocess_method = {"d1": "D1", "d2": "D2", "dt": "DT"}.get(raw_p, raw_p.upper())
|
||||||
|
elif raw_p == "ct" or "去趋势" in raw_p:
|
||||||
|
preprocess_method = "CT"
|
||||||
|
|
||||||
|
if preprocess_method == "None":
|
||||||
|
return csv_path
|
||||||
|
|
||||||
|
output_filename = f"preprocessed_{preprocess_method}.csv"
|
||||||
|
output_path = str(output_dir / output_filename)
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的预处理文件,直接使用: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
non_spectral_cols = df.iloc[:, :spectral_start_col]
|
||||||
|
spectral_data = df.iloc[:, spectral_start_col:]
|
||||||
|
|
||||||
|
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||||
|
|
||||||
|
save_path = None
|
||||||
|
if preprocess_method == "SS":
|
||||||
|
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
|
||||||
|
models_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_path = str(models_dir / "scaler_params.pkl")
|
||||||
|
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||||
|
|
||||||
|
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
|
||||||
|
|
||||||
|
if isinstance(processed_spectral, pd.DataFrame):
|
||||||
|
processed_df = pd.concat([non_spectral_cols, processed_spectral], axis=1)
|
||||||
|
else:
|
||||||
|
processed_spectral_df = pd.DataFrame(
|
||||||
|
processed_spectral, columns=spectral_data.columns, index=spectral_data.index
|
||||||
|
)
|
||||||
|
processed_df = pd.concat([non_spectral_cols, processed_spectral_df], axis=1)
|
||||||
|
|
||||||
|
processed_df.to_csv(output_path, index=False)
|
||||||
|
print(f"预处理完成: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_non_empirical_summary(model_results: Dict[str, str], output_dir: Path) -> str:
|
||||||
|
"""生成非经验模型训练结果汇总CSV"""
|
||||||
|
summary_path = str(output_dir / "non_empirical_models_summary.csv")
|
||||||
|
summary_data = []
|
||||||
|
|
||||||
|
for model_key, model_path in model_results.items():
|
||||||
|
try:
|
||||||
|
parts = model_key.split("_")
|
||||||
|
preprocess_method = parts[0]
|
||||||
|
algorithm_name = "_".join(parts[1:]) if len(parts) > 2 else parts[1]
|
||||||
|
|
||||||
|
with open(model_path, "r", encoding="utf-8") as f:
|
||||||
|
model_info = json.load(f)
|
||||||
|
|
||||||
|
accuracy_list = model_info.get("accuracy", [])
|
||||||
|
summary_row = {
|
||||||
|
"Preprocessing Method": preprocess_method,
|
||||||
|
"Algorithm Name": algorithm_name,
|
||||||
|
"Model Type": model_info.get("model_type", ""),
|
||||||
|
"Coefficient Count": len(model_info.get("model_info", [])),
|
||||||
|
"Average Accuracy(%)": np.mean(accuracy_list) if accuracy_list else 0,
|
||||||
|
"Min Accuracy(%)": np.min(accuracy_list) if accuracy_list else 0,
|
||||||
|
"Max Accuracy(%)": np.max(accuracy_list) if accuracy_list else 0,
|
||||||
|
"Sample Count": len(model_info.get("long", [])),
|
||||||
|
"Model File": model_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
coefficients = model_info.get("model_info", [])
|
||||||
|
for i, coeff in enumerate(coefficients[:5]):
|
||||||
|
summary_row[f"系数_{i+1}"] = coeff
|
||||||
|
|
||||||
|
summary_data.append(summary_row)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"读取模型文件 {model_path} 时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if summary_data:
|
||||||
|
df_summary = pd.DataFrame(summary_data)
|
||||||
|
df_summary.to_csv(summary_path, index=False, encoding="utf-8-sig")
|
||||||
|
print(f"汇总文件已生成: {summary_path}")
|
||||||
|
else:
|
||||||
|
print("警告: 没有有效的模型数据可汇总")
|
||||||
|
summary_path = ""
|
||||||
|
|
||||||
|
return summary_path
|
||||||
323
src/core/steps/prediction_step.py
Normal file
323
src/core/steps/prediction_step.py
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
预测步骤
|
||||||
|
|
||||||
|
包含 step7_generate_sampling_points, step8_predict_water_quality,
|
||||||
|
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Union, Callable, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionStep:
|
||||||
|
"""预测步骤"""
|
||||||
|
|
||||||
|
# ---- Step 7: 生成采样点并提取光谱 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_sampling_points(
|
||||||
|
deglint_img_path: Optional[str] = None,
|
||||||
|
interval: int = 50,
|
||||||
|
sample_radius: int = 5,
|
||||||
|
chunk_size: int = 1000,
|
||||||
|
water_mask_path: Optional[str] = None,
|
||||||
|
glint_mask_path: Optional[str] = None,
|
||||||
|
output_dir: Union[str, Path] = "./10_sampling",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
) -> str:
|
||||||
|
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
|
||||||
|
from src.utils.sampling import get_spectral_sampling_points_chunked
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = str(output_dir / "sampling_spectra.csv")
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤7", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤7: 生成预测采样点并提取光谱")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if deglint_img_path is None:
|
||||||
|
raise ValueError("必须提供 deglint_img_path 参数")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
|
||||||
|
notify("skipped", f"采样点光谱数据已设置: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# 允许外部显式传入 glint_mask_path 覆盖内部默认值
|
||||||
|
glint_mask_to_use = glint_mask_path
|
||||||
|
if glint_mask_to_use is None:
|
||||||
|
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
|
||||||
|
|
||||||
|
get_spectral_sampling_points_chunked(
|
||||||
|
deglint_img_path, water_mask_path, glint_mask_to_use,
|
||||||
|
output_path, interval, sample_radius, chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
notify("completed", f"采样点光谱数据已保存: {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# ---- Step 8: 机器学习模型预测水质参数 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_water_quality(
|
||||||
|
sampling_csv_path: str,
|
||||||
|
models_dir: Optional[str] = None,
|
||||||
|
metric: str = "test_r2",
|
||||||
|
prediction_column: str = "prediction",
|
||||||
|
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
_report_generator=None,
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||||
|
from src.core.prediction.inference_batch import WaterQualityInference
|
||||||
|
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤8", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤8: 预测水质参数")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if models_dir is None:
|
||||||
|
raise ValueError("必须提供 models_dir 参数")
|
||||||
|
|
||||||
|
ml_prediction_dir = Path(output_dir)
|
||||||
|
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
prediction_files = {}
|
||||||
|
if ml_prediction_dir.exists():
|
||||||
|
csv_files = list(ml_prediction_dir.glob("*.csv"))
|
||||||
|
for csv_file in csv_files:
|
||||||
|
file_stem = csv_file.stem
|
||||||
|
if "_prediction" in file_stem:
|
||||||
|
target_name = file_stem.replace("_prediction", "")
|
||||||
|
elif "_pred" in file_stem:
|
||||||
|
target_name = file_stem.replace("_pred", "")
|
||||||
|
else:
|
||||||
|
target_name = file_stem
|
||||||
|
prediction_files[target_name] = str(csv_file)
|
||||||
|
|
||||||
|
# 检查是否所有目标参数都有预测文件
|
||||||
|
if prediction_files:
|
||||||
|
models_path_obj = Path(models_dir)
|
||||||
|
if models_path_obj.exists():
|
||||||
|
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
|
||||||
|
missing_targets = [t for t in target_folders if t not in prediction_files]
|
||||||
|
if not missing_targets:
|
||||||
|
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
|
||||||
|
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
|
||||||
|
return prediction_files
|
||||||
|
else:
|
||||||
|
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||||
|
|
||||||
|
inferencer = WaterQualityInference(models_dir)
|
||||||
|
all_results = inferencer.batch_inference_multi_models(
|
||||||
|
models_root_dir=models_dir,
|
||||||
|
sampling_csv_path=sampling_csv_path,
|
||||||
|
output_dir=str(ml_prediction_dir),
|
||||||
|
metric=metric,
|
||||||
|
prediction_column=prediction_column,
|
||||||
|
output_format="csv",
|
||||||
|
)
|
||||||
|
|
||||||
|
for target_name, result in all_results.items():
|
||||||
|
if result.get("status") == "success":
|
||||||
|
prediction_files[target_name] = result["output_file"]
|
||||||
|
|
||||||
|
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||||
|
|
||||||
|
if _report_generator is not None:
|
||||||
|
try:
|
||||||
|
report_path = _report_generator.generate_prediction_report(prediction_files)
|
||||||
|
print(f"预测结果报告已生成: {report_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"生成预测结果报告时出错: {e}")
|
||||||
|
|
||||||
|
notify("completed", f"预测完成: {ml_prediction_dir}")
|
||||||
|
return prediction_files
|
||||||
|
|
||||||
|
# ---- Step 8.5: 非经验模型预测 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_with_non_empirical_models(
|
||||||
|
sampling_csv_path: str,
|
||||||
|
non_empirical_models_dir: Optional[str] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
metric: str = "Average Accuracy(%)",
|
||||||
|
prediction_column: str = "prediction",
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
work_dir: Union[str, Path] = "./work_dir",
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""使用非经验统计回归模型进行参数预测"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤8.5", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤8.5: 使用非经验模型进行参数预测")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过非经验模型预测(enabled=False)。")
|
||||||
|
notify("skipped", "跳过非经验模型预测")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if non_empirical_models_dir is not None:
|
||||||
|
final_models_dir = non_empirical_models_dir
|
||||||
|
else:
|
||||||
|
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
||||||
|
if Path(default_models_dir).exists():
|
||||||
|
final_models_dir = default_models_dir
|
||||||
|
else:
|
||||||
|
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
|
||||||
|
|
||||||
|
if output_dir is not None:
|
||||||
|
non_empirical_prediction_dir = Path(output_dir)
|
||||||
|
else:
|
||||||
|
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
|
||||||
|
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
prediction_files = {}
|
||||||
|
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
|
||||||
|
if not summary_path.exists():
|
||||||
|
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
df_summary = pd.read_csv(summary_path)
|
||||||
|
|
||||||
|
best_models = {}
|
||||||
|
for algorithm in df_summary["Algorithm Name"].unique():
|
||||||
|
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
|
||||||
|
if metric in algorithm_df.columns:
|
||||||
|
best_model_row = algorithm_df.nlargest(1, metric)
|
||||||
|
else:
|
||||||
|
best_model_row = algorithm_df.iloc[[0]]
|
||||||
|
|
||||||
|
best_model_path = best_model_row["Model File"].values[0]
|
||||||
|
best_preprocess = best_model_row["Preprocessing Method"].values[0]
|
||||||
|
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
|
||||||
|
|
||||||
|
best_models[algorithm] = {
|
||||||
|
"model_path": best_model_path,
|
||||||
|
"preprocess_method": best_preprocess,
|
||||||
|
"accuracy": best_accuracy,
|
||||||
|
}
|
||||||
|
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
|
||||||
|
|
||||||
|
pd.read_csv(sampling_csv_path) # just to validate
|
||||||
|
|
||||||
|
for algorithm, model_info in best_models.items():
|
||||||
|
print(f"\n使用 {algorithm} 算法进行预测...")
|
||||||
|
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
|
||||||
|
|
||||||
|
if Path(output_path).exists():
|
||||||
|
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
|
||||||
|
prediction_files[algorithm] = output_path
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.core.non_empirical_retrieval import non_empirical_retrieval
|
||||||
|
non_empirical_retrieval(
|
||||||
|
algorithm=algorithm,
|
||||||
|
model_info_path=model_info["model_path"],
|
||||||
|
coor_spectral_path=sampling_csv_path,
|
||||||
|
output_path=output_path,
|
||||||
|
wave_radius=5,
|
||||||
|
)
|
||||||
|
prediction_files[algorithm] = output_path
|
||||||
|
print(f"预测完成: {output_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"使用 {algorithm} 算法预测时出错: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
|
||||||
|
return prediction_files
|
||||||
|
|
||||||
|
# ---- Step 8.75: 自定义回归模型预测 ----
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def predict_with_custom_regression(
|
||||||
|
sampling_csv_path: str,
|
||||||
|
custom_regression_dir: Optional[str] = None,
|
||||||
|
formula_csv_path: Optional[str] = None,
|
||||||
|
coordinate_columns: Optional[List[str]] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
filename_prefix: str = "custom_regression_prediction",
|
||||||
|
enabled: bool = True,
|
||||||
|
callback: Optional[Callable] = None,
|
||||||
|
work_dir: Union[str, Path] = "./work_dir",
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""使用自定义回归模型进行参数预测"""
|
||||||
|
def notify(status, msg=""):
|
||||||
|
if callback:
|
||||||
|
callback("步骤8.75", status, msg)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
if not enabled:
|
||||||
|
print("已设置跳过自定义回归模型预测(enabled=False)。")
|
||||||
|
notify("skipped", "跳过自定义回归预测")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if not Path(sampling_csv_path).exists():
|
||||||
|
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||||||
|
|
||||||
|
if custom_regression_dir is not None:
|
||||||
|
final_regression_dir = custom_regression_dir
|
||||||
|
else:
|
||||||
|
final_regression_dir = str(Path(work_dir) / "9_Custom_Regression_Modeling")
|
||||||
|
if not Path(final_regression_dir).exists():
|
||||||
|
raise ValueError(
|
||||||
|
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
custom_regression_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Custom_Regression_Prediction"
|
||||||
|
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
prediction_output_dir = str(custom_regression_prediction_dir)
|
||||||
|
else:
|
||||||
|
prediction_output_dir = output_dir
|
||||||
|
|
||||||
|
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||||||
|
|
||||||
|
predictor = CustomRegressionPredictor(
|
||||||
|
regression_csv_dir=final_regression_dir,
|
||||||
|
formula_csv_path=formula_csv_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"开始使用自定义回归模块进行批量预测...")
|
||||||
|
print(f" 采样点数据: {sampling_csv_path}")
|
||||||
|
print(f" 回归模型目录: {final_regression_dir}")
|
||||||
|
print(f" 输出目录: {prediction_output_dir}")
|
||||||
|
|
||||||
|
saved_files = predictor.run_batch_prediction(
|
||||||
|
input_csv_path=sampling_csv_path,
|
||||||
|
coordinate_columns=coordinate_columns,
|
||||||
|
filename_prefix=filename_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||||||
|
for param_name, filepath in saved_files.items():
|
||||||
|
print(f" {param_name}: {filepath}")
|
||||||
|
|
||||||
|
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
|
||||||
|
return saved_files
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user