Compare commits
11 Commits
605ec86108
...
5a55be286f
| Author | SHA1 | Date | |
|---|---|---|---|
| 5a55be286f | |||
| 9ba39a7bff | |||
| d15a7a1e2b | |||
| 6d4d802ffe | |||
| abac272b31 | |||
| 95d30d8d81 | |||
| 375fea77b9 | |||
| 8c7c995985 | |||
| f96c55f361 | |||
| 14278739bf | |||
| d0eb458392 |
20
src/core/steps/__init__.py
Normal file
20
src/core/steps/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""业务步骤层模块"""
|
||||
|
||||
from src.core.steps.water_mask_step import WaterMaskStep
|
||||
from src.core.steps.glint_detection_step import GlintDetectionStep
|
||||
from src.core.steps.glint_removal_step import GlintRemovalStep
|
||||
from src.core.steps.data_preparation_step import DataPreparationStep
|
||||
from src.core.steps.modeling_step import ModelingStep
|
||||
from src.core.steps.prediction_step import PredictionStep
|
||||
from src.core.steps.mapping_step import MappingStep
|
||||
|
||||
__all__ = [
|
||||
"WaterMaskStep",
|
||||
"GlintDetectionStep",
|
||||
"GlintRemovalStep",
|
||||
"DataPreparationStep",
|
||||
"ModelingStep",
|
||||
"PredictionStep",
|
||||
"MappingStep",
|
||||
]
|
||||
184
src/core/steps/data_preparation_step.py
Normal file
184
src/core/steps/data_preparation_step.py
Normal file
@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据准备步骤
|
||||
|
||||
包含 step4_process_csv, step5_extract_training_spectra, step5_5_calculate_water_quality_indices
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataPreparationStep:
|
||||
"""数据准备步骤"""
|
||||
|
||||
# ---- Step 4: 处理CSV文件 ----
|
||||
|
||||
@staticmethod
|
||||
def process_csv(
|
||||
csv_path: str,
|
||||
output_dir: Union[str, Path] = "./4_processed_data",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""处理CSV文件(筛选剔除异常值)"""
|
||||
from src.preprocessing.process_water_quality_data import process_water_quality_data
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "processed_data.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤4", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤4: 处理CSV文件,筛选剔除异常值")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的处理后CSV文件,直接使用: {output_path}")
|
||||
notify("skipped", f"处理后的CSV文件已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
process_water_quality_data(csv_path, output_path)
|
||||
notify("completed", f"处理后的CSV文件已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 5: 提取训练样本点光谱 ----
|
||||
|
||||
@staticmethod
|
||||
def extract_training_spectra(
|
||||
deglint_img_path: Optional[str] = None,
|
||||
radius: int = 5,
|
||||
source_epsg: int = 4326,
|
||||
csv_path: Optional[str] = None,
|
||||
boundary_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./5_training_spectra",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
|
||||
from src.core.glint_removal.get_spectral import get_spectral_in_coor
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "training_spectra.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤5: 提取训练样本点的平均光谱")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if deglint_img_path is None:
|
||||
raise ValueError("必须提供 deglint_img_path 参数")
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的训练光谱数据文件,直接使用: {output_path}")
|
||||
notify("skipped", f"训练光谱数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
# 确保水体掩膜存在
|
||||
final_boundary_path = boundary_path
|
||||
if final_boundary_path is None and water_mask_path is not None:
|
||||
final_boundary_path = water_mask_path
|
||||
|
||||
# 【新增安全防护】智能拦截矢量 .shp,强制替换为步骤 1 生成的栅格 .dat
|
||||
if final_boundary_path and str(final_boundary_path).lower().endswith('.shp'):
|
||||
# 向上追溯查找 1_water_mask 目录下的 dat 替身
|
||||
possible_dat = Path(deglint_img_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||
if not possible_dat.exists() and output_path:
|
||||
possible_dat = Path(output_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
|
||||
|
||||
if possible_dat.exists():
|
||||
print(f"💡 智能拦截:检测到输入掩膜为矢量 .shp,自动切换为已生成的栅格掩膜: {possible_dat}")
|
||||
final_boundary_path = str(possible_dat)
|
||||
else:
|
||||
print(f"⚠️ 警告:检测到输入掩膜为 .shp 且未找到对应 .dat 替身,可能导致底层读取失败。")
|
||||
|
||||
flare_path = glint_mask_path
|
||||
if flare_path:
|
||||
print(f"光谱提取使用耀斑掩膜: {flare_path}")
|
||||
|
||||
get_spectral_in_coor(
|
||||
deglint_img_path, csv_path, output_path,
|
||||
radius=radius, flare_path=flare_path,
|
||||
boundary_path=final_boundary_path, source_epsg=source_epsg
|
||||
)
|
||||
|
||||
notify("completed", f"训练光谱数据已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 5.5: 计算水质光谱指数 ----
|
||||
|
||||
@staticmethod
|
||||
def calculate_water_quality_indices(
|
||||
training_spectra_path: Optional[str] = None,
|
||||
formula_csv_file: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
output_file: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
output_dir: Union[str, Path] = "./6_water_quality_indices",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤5.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤5.5: 计算水质光谱指数(使用band_math方法)")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过水质指数计算(enabled=False)。")
|
||||
notify("skipped", "跳过水质指数计算")
|
||||
return None
|
||||
|
||||
if training_spectra_path is None:
|
||||
raise ValueError("必须提供 training_spectra_path 参数")
|
||||
if formula_csv_file is None:
|
||||
raise ValueError("必须提供 formula_csv_file 参数")
|
||||
|
||||
if output_file:
|
||||
output_path = str(Path(output_file))
|
||||
else:
|
||||
output_path = str(output_dir / "water_quality_indices.csv")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的水质指数文件,直接使用: {output_path}")
|
||||
notify("skipped", f"水质指数数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
from src.utils.band_math import BandMathCalculator
|
||||
|
||||
calculator = BandMathCalculator(training_spectra_path)
|
||||
result_df = calculator.process_formulas_from_csv(
|
||||
formula_csv_file=formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
output_file=output_path
|
||||
)
|
||||
|
||||
if result_df is None:
|
||||
raise ValueError("计算水质指数失败,请检查公式CSV文件格式")
|
||||
|
||||
notify("completed", f"水质指数已保存: {output_path}")
|
||||
return output_path
|
||||
113
src/core/steps/glint_detection_step.py
Normal file
113
src/core/steps/glint_detection_step.py
Normal file
@ -0,0 +1,113 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤2: 耀斑区域检测
|
||||
|
||||
支持多种检测方法: otsu, zscore, percentile, iqr, adaptive, multi_band
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
class GlintDetectionStep:
|
||||
"""耀斑区域检测步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
img_path: str,
|
||||
glint_wave: float = 750.0,
|
||||
method: str = "otsu",
|
||||
z_threshold: float = 2.5,
|
||||
percentile: float = 95.0,
|
||||
iqr_multiplier: float = 1.5,
|
||||
window_size: int = 15,
|
||||
multi_band_waves: Optional[List[float]] = None,
|
||||
sub_method: str = "zscore",
|
||||
weights: Optional[List[float]] = None,
|
||||
max_area: Optional[int] = None,
|
||||
buffer_size: Optional[int] = None,
|
||||
water_mask_path: Optional[str] = None,
|
||||
glint_dir: Union[str, Path] = "./2_glint",
|
||||
callback: Optional[callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行耀斑区域检测
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
glint_wave: 用于耀斑检测的波段波长(nm)
|
||||
method: 检测方法 ('otsu' | 'zscore' | 'percentile' | 'iqr' | 'adaptive' | 'multi_band')
|
||||
z_threshold: Z-score 方法阈值(默认 2.5)
|
||||
percentile: 百分位数阈值(默认 95.0)
|
||||
iqr_multiplier: IQR 倍数(默认 1.5)
|
||||
window_size: 自适应阈值窗口大小(默认 15)
|
||||
multi_band_waves: 多波段方法的波长列表,如 [750, 800, 850]
|
||||
sub_method: 多波段方法的子方法(默认 'zscore')
|
||||
weights: 多波段方法的权重列表(None 表示等权重)
|
||||
max_area: 最大连通域面积阈值(像素),超过则过滤
|
||||
buffer_size: 岸边缓冲区大小(像素),用于去除岸边附近错误掩膜
|
||||
water_mask_path: 水域掩膜文件路径(dat 格式优先)
|
||||
glint_dir: 工作目录
|
||||
callback: 回调函数
|
||||
|
||||
Returns:
|
||||
耀斑掩膜文件路径 (.dat)
|
||||
"""
|
||||
from src.utils.find_severe_glint_area import find_severe_glint_area
|
||||
|
||||
glint_dir = Path(glint_dir)
|
||||
glint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤2", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤2: 找到耀斑区域")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 确定水体掩膜路径
|
||||
if water_mask_path is not None and Path(water_mask_path).exists():
|
||||
final_water_mask_path = water_mask_path
|
||||
else:
|
||||
final_water_mask_path = None
|
||||
|
||||
output_path = str(glint_dir / "severe_glint_area.dat")
|
||||
|
||||
# 跳过已存在的文件
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的耀斑掩膜文件,直接使用: {output_path}")
|
||||
notify("skipped", f"耀斑掩膜已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
# 构建检测参数字典
|
||||
kwargs = {
|
||||
"method": method,
|
||||
"z_threshold": z_threshold,
|
||||
"percentile": percentile,
|
||||
"iqr_multiplier": iqr_multiplier,
|
||||
"window_size": window_size,
|
||||
}
|
||||
if method == "multi_band":
|
||||
if multi_band_waves is not None:
|
||||
kwargs["multi_band_waves"] = multi_band_waves
|
||||
if sub_method is not None:
|
||||
kwargs["sub_method"] = sub_method
|
||||
if weights is not None:
|
||||
kwargs["weights"] = weights
|
||||
if max_area is not None:
|
||||
kwargs["max_area"] = max_area
|
||||
if buffer_size is not None:
|
||||
kwargs["buffer_size"] = buffer_size
|
||||
|
||||
glint_mask_path = find_severe_glint_area(
|
||||
img_path, final_water_mask_path, glint_wave, output_path, **kwargs
|
||||
)
|
||||
|
||||
print(f"耀斑掩膜已生成: {glint_mask_path}")
|
||||
print(f"使用检测方法: {method}")
|
||||
notify("completed", f"耀斑掩膜已生成: {glint_mask_path}")
|
||||
return glint_mask_path
|
||||
375
src/core/steps/glint_removal_step.py
Normal file
375
src/core/steps/glint_removal_step.py
Normal file
@ -0,0 +1,375 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤3: 去除耀斑
|
||||
|
||||
支持多种方法: subtract_nir, regression_slope, oxygen_absorption, kutser, goodman, hedley, sugar
|
||||
|
||||
每种方法都会:
|
||||
1. 准备水域掩膜(支持 shp 自动转 dat)
|
||||
2. 调用对应的算法类执行处理
|
||||
3. 复制 hdr 文件到输出影像
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _safe_rename(src_bsq: str, src_hdr: str, dest_bsq: str, dest_hdr: str) -> str:
|
||||
"""将底层硬编码生成的 .bsq + .hdr 文件对重命名到用户指定的 output_path
|
||||
|
||||
使用 os.remove + os.rename 确保原子覆盖(不等 os.replace 的跨设备行为),
|
||||
resolve() 断路防止同路径 self-rename 报错。
|
||||
|
||||
Returns:
|
||||
dest_bsq 路径
|
||||
"""
|
||||
src_bsq_p = Path(src_bsq)
|
||||
src_hdr_p = Path(src_hdr)
|
||||
dest_bsq_p = Path(dest_bsq)
|
||||
dest_hdr_p = Path(dest_hdr)
|
||||
|
||||
if str(src_bsq_p.resolve()) == str(dest_bsq_p.resolve()):
|
||||
return dest_bsq
|
||||
|
||||
if dest_bsq_p.exists():
|
||||
os.remove(dest_bsq_p)
|
||||
if dest_hdr_p.exists():
|
||||
os.remove(dest_hdr_p)
|
||||
|
||||
if src_bsq_p.exists():
|
||||
os.rename(src_bsq_p, dest_bsq_p)
|
||||
if src_hdr_p.exists():
|
||||
os.rename(src_hdr_p, dest_hdr_p)
|
||||
|
||||
return dest_bsq
|
||||
|
||||
|
||||
class GlintRemovalStep:
|
||||
"""去除耀斑步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
img_path: str,
|
||||
method: str = "subtract_nir",
|
||||
start_wave: Optional[float] = None,
|
||||
end_wave: Optional[float] = None,
|
||||
json_path: Optional[str] = None,
|
||||
left_shoulder_wave: Optional[float] = None,
|
||||
valley_wave: Optional[float] = None,
|
||||
right_shoulder_wave: Optional[float] = None,
|
||||
water_mask: Optional[Union[str, np.ndarray]] = None,
|
||||
interpolated_img_path: Optional[str] = None,
|
||||
interpolate_zeros: bool = False,
|
||||
interpolation_method: str = "nearest",
|
||||
enabled: bool = True,
|
||||
# Kutser 参数
|
||||
kutser_shp_path: Optional[str] = None,
|
||||
oxy_band: int = 38,
|
||||
lower_oxy: int = 36,
|
||||
upper_oxy: int = 49,
|
||||
nir_band: int = 47,
|
||||
# Goodman 参数
|
||||
nir_lower: int = 25,
|
||||
nir_upper: int = 37,
|
||||
goodman_A: float = 0.000019,
|
||||
goodman_B: float = 0.1,
|
||||
# Hedley 参数
|
||||
hedley_shp_path: Optional[str] = None,
|
||||
hedley_nir_band: int = 47,
|
||||
# SUGAR 参数
|
||||
sugar_bounds: Optional[List[tuple]] = None,
|
||||
sugar_sigma: float = 1.0,
|
||||
sugar_estimate_background: bool = True,
|
||||
sugar_glint_mask_method: str = "cdf",
|
||||
sugar_iter: Optional[int] = 3,
|
||||
sugar_termination_thresh: float = 20.0,
|
||||
# 内部工具函数
|
||||
_get_image_geo_info=None,
|
||||
_load_image_as_array=None,
|
||||
_save_bands_as_image=None,
|
||||
_copy_hdr_info=None,
|
||||
_prepare_water_mask_for_algorithm=None,
|
||||
_interpolate_zero_pixels_batch=None,
|
||||
deglint_dir: Union[str, Path] = "./3_deglint",
|
||||
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||
callback: Optional[Callable] = None,
|
||||
output_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行去除耀斑处理
|
||||
|
||||
Args:
|
||||
img_path: 输入影像文件路径
|
||||
method: 去耀斑方法
|
||||
...(其余参数同主类 step3_remove_glint)
|
||||
|
||||
Returns:
|
||||
去除耀斑后的影像文件路径
|
||||
"""
|
||||
from src.core.glint_removal.Kutser import Kutser
|
||||
from src.core.glint_removal.Goodman import Goodman
|
||||
from src.core.glint_removal.Hedley import Hedley
|
||||
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
|
||||
from src.core.utils.gdal_helper import (
|
||||
get_image_geo_info as _default_get_geo,
|
||||
load_image_as_array as _default_load,
|
||||
save_bands_as_image as _default_save_bands,
|
||||
copy_hdr_info as _default_copy_hdr,
|
||||
)
|
||||
from src.core.utils.mask_converter import (
|
||||
prepare_water_mask_for_algorithm as _default_prepare,
|
||||
)
|
||||
|
||||
# 使用提供的函数或默认函数
|
||||
if _get_image_geo_info is None:
|
||||
_get_image_geo_info = _default_get_geo
|
||||
if _load_image_as_array is None:
|
||||
_load_image_as_array = _default_load
|
||||
if _save_bands_as_image is None:
|
||||
_save_bands_as_image = _default_save_bands
|
||||
if _copy_hdr_info is None:
|
||||
_copy_hdr_info = _default_copy_hdr
|
||||
if _prepare_water_mask_for_algorithm is None:
|
||||
_prepare_water_mask_for_algorithm = _default_prepare
|
||||
|
||||
deglint_dir = Path(deglint_dir)
|
||||
deglint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤3", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤3: 去除耀斑")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 方法名标准化
|
||||
raw_method = str(method).lower()
|
||||
if "kutser" in raw_method:
|
||||
method = "kutser"
|
||||
elif "goodman" in raw_method:
|
||||
method = "goodman"
|
||||
elif "hedley" in raw_method:
|
||||
method = "hedley"
|
||||
elif "sugar" in raw_method:
|
||||
method = "sugar"
|
||||
|
||||
# 如果未启用,直接返回原始影像
|
||||
if not enabled:
|
||||
print("已设置跳过去除耀斑(enabled=False),将直接使用原始影像。")
|
||||
notify("skipped", "跳过去耀斑,使用原始影像")
|
||||
return img_path
|
||||
|
||||
# ---- 确定水域掩膜 ----
|
||||
final_water_mask = water_mask
|
||||
if final_water_mask is not None and str(final_water_mask).lower().endswith(".shp"):
|
||||
# shp 自动替换为 dat
|
||||
dat_mask = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||
if Path(dat_mask).exists():
|
||||
print(f"检测到输入掩膜为 .shp,自动替换为栅格掩膜: {dat_mask}")
|
||||
final_water_mask = dat_mask
|
||||
|
||||
if final_water_mask is None:
|
||||
dat_mask_default = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
|
||||
if Path(dat_mask_default).exists():
|
||||
final_water_mask = dat_mask_default
|
||||
print(f"使用步骤1生成的水域掩膜: {final_water_mask}")
|
||||
|
||||
# ---- 步骤3.1: 0值像素插值 ----
|
||||
if interpolate_zeros:
|
||||
print("\n" + "-" * 80)
|
||||
print("步骤3.1: 对0值像素进行插值")
|
||||
print("-" * 80)
|
||||
interp_start_time = time.time()
|
||||
|
||||
if _interpolate_zero_pixels_batch is None:
|
||||
from src.core.algorithms.interpolation.interpolator import (
|
||||
interpolate_zero_pixels_batch as _interp_batch,
|
||||
)
|
||||
_interpolate_zero_pixels_batch = _interp_batch
|
||||
|
||||
interp_result, _ = _interpolate_zero_pixels_batch(
|
||||
img_path=img_path,
|
||||
interpolation_method=interpolation_method,
|
||||
output_path=None,
|
||||
water_mask=final_water_mask,
|
||||
deglint_dir=str(deglint_dir),
|
||||
callback_progress=lambda msg: print(f" {msg}"),
|
||||
)
|
||||
img_path = interp_result
|
||||
interp_end_time = time.time()
|
||||
print(f"插值完成,使用插值后的影像: {img_path}")
|
||||
|
||||
# ---- 获取影像信息 ----
|
||||
geotransform, projection, width, height, n_bands = _get_image_geo_info(img_path)
|
||||
print(f"影像尺寸: {width} x {height} x {n_bands}")
|
||||
|
||||
mask_for_algorithm = _prepare_water_mask_for_algorithm(
|
||||
final_water_mask, (height, width), geotransform, projection, img_path
|
||||
)
|
||||
|
||||
# ==================== Kutser ====================
|
||||
if method == "kutser":
|
||||
print(f"使用方法: Kutser (氧吸收波段={oxy_band}, NIR波段={nir_band})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_kutser.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
# 将用户指定的 output_path 标准化为 .bsq 路径
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
kutser = Kutser(
|
||||
img_path,
|
||||
shp_path=None,
|
||||
oxy_band=oxy_band,
|
||||
lower_oxy=lower_oxy,
|
||||
upper_oxy=upper_oxy,
|
||||
NIR_band=nir_band,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
kutser.get_corrected_bands()
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"Kutser算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
# ==================== Goodman ====================
|
||||
elif method == "goodman":
|
||||
print(f"使用方法: Goodman (NIR波段范围: {nir_lower}-{nir_upper})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_goodman.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
goodman = Goodman(
|
||||
img_path,
|
||||
NIR_lower=nir_lower,
|
||||
NIR_upper=nir_upper,
|
||||
A=goodman_A,
|
||||
B=goodman_B,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
corrected_bands = goodman.get_corrected_bands()
|
||||
|
||||
if not Path(hardcoded_bsq).exists():
|
||||
_save_bands_as_image(corrected_bands, hardcoded_bsq, geotransform, projection)
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
del corrected_bands
|
||||
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
|
||||
# ==================== Hedley ====================
|
||||
elif method == "hedley":
|
||||
print(f"使用方法: Hedley (NIR波段={hedley_nir_band})")
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_hedley.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
hedley = Hedley(
|
||||
img_path,
|
||||
shp_path=None,
|
||||
NIR_band=hedley_nir_band,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
hedley.get_corrected_bands()
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"Hedley算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
# ==================== SUGAR ====================
|
||||
elif method == "sugar":
|
||||
glint_method_raw = str(sugar_glint_mask_method).lower()
|
||||
if "cdf" in glint_method_raw or "累积" in glint_method_raw:
|
||||
sugar_glint_mask_method_fixed = "cdf"
|
||||
elif "otsu" in glint_method_raw or "大津" in glint_method_raw:
|
||||
sugar_glint_mask_method_fixed = "otsu"
|
||||
else:
|
||||
sugar_glint_mask_method_fixed = "cdf"
|
||||
|
||||
print(
|
||||
f"使用方法: SUGAR (迭代次数={sugar_iter}, 掩膜方法={sugar_glint_mask_method_fixed})"
|
||||
)
|
||||
hardcoded_bsq = str(deglint_dir / "deglint_sugar.bsq")
|
||||
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
|
||||
if output_path:
|
||||
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
|
||||
final_hdr = final_bsq.replace(".bsq", ".hdr")
|
||||
else:
|
||||
final_bsq = hardcoded_bsq
|
||||
final_hdr = hardcoded_hdr
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
|
||||
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
|
||||
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
|
||||
if sugar_bounds is None:
|
||||
sugar_bounds = [(1, 2)]
|
||||
|
||||
correction_iterative(
|
||||
img_path,
|
||||
iter=sugar_iter,
|
||||
bounds=sugar_bounds,
|
||||
estimate_background=sugar_estimate_background,
|
||||
glint_mask_method=sugar_glint_mask_method_fixed,
|
||||
termination_thresh=sugar_termination_thresh,
|
||||
water_mask=mask_for_algorithm,
|
||||
output_path=hardcoded_bsq,
|
||||
)
|
||||
|
||||
if Path(hardcoded_bsq).exists():
|
||||
_copy_hdr_info(img_path, hardcoded_bsq)
|
||||
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
|
||||
notify("completed", f"去耀斑影像已生成: {final}")
|
||||
return final
|
||||
raise RuntimeError(f"SUGAR算法未生成输出文件: {hardcoded_bsq}")
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"不支持的方法: {method}。支持的方法: kutser, goodman, hedley, sugar"
|
||||
)
|
||||
109
src/core/steps/mapping_step.py
Normal file
109
src/core/steps/mapping_step.py
Normal file
@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
成图步骤
|
||||
|
||||
包含 step9_generate_distribution_map
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Callable
|
||||
|
||||
|
||||
class MappingStep:
|
||||
"""成图步骤"""
|
||||
|
||||
@staticmethod
|
||||
def generate_distribution_map(
|
||||
prediction_csv_path: str,
|
||||
boundary_shp_path: str,
|
||||
output_image_path: Optional[str] = None,
|
||||
resolution: float = 30,
|
||||
input_crs: str = "EPSG:32651",
|
||||
output_crs: str = "EPSG:4326",
|
||||
show_sample_points: bool = False,
|
||||
base_map_tif: Optional[str] = None,
|
||||
use_distance_diffusion: bool = True,
|
||||
max_diffusion_distance: Optional[float] = None,
|
||||
diffusion_power: float = 2,
|
||||
diffusion_n_neighbors: int = 15,
|
||||
cmap: Optional[str] = None,
|
||||
expand_ratio: float = 0.05,
|
||||
output_dir: Union[str, Path] = "./14_visualization",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
根据采样点的坐标和反演的实测参数,通过插值方法得到水质参数可视化分布图
|
||||
|
||||
Args:
|
||||
prediction_csv_path: 预测结果CSV文件路径(前两列为经纬度,第三列为预测值)
|
||||
boundary_shp_path: 边界shapefile文件路径
|
||||
output_image_path: 输出图片路径(如果为None,自动生成)
|
||||
resolution: 插值网格分辨率(米)
|
||||
input_crs: 输入坐标系
|
||||
output_crs: 输出坐标系
|
||||
show_sample_points: 是否在图上显示采样点
|
||||
base_map_tif: 底图TIF路径
|
||||
use_distance_diffusion: 是否启用距离扩散补全边界
|
||||
max_diffusion_distance: 距离扩散最大距离(米)
|
||||
diffusion_power: 距离扩散幂参数
|
||||
diffusion_n_neighbors: 距离扩散最近邻数量
|
||||
cmap: 颜色映射名称(None表示自动识别)
|
||||
expand_ratio: 边界外扩比例(0-1之间)
|
||||
output_dir: 输出目录
|
||||
callback: 回调函数
|
||||
|
||||
Returns:
|
||||
可视化分布图文件路径
|
||||
"""
|
||||
from src.postprocessing.map import ContentMapper
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤9", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤9: 生成水质参数可视化分布图")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if output_image_path is None:
|
||||
csv_name = Path(prediction_csv_path).stem
|
||||
output_image_path = str(output_dir / f"{csv_name}_distribution.png")
|
||||
|
||||
if Path(output_image_path).exists():
|
||||
print(f"检测到已存在的分布图文件,直接使用: {output_image_path}")
|
||||
notify("skipped", f"可视化分布图已设置: {output_image_path}")
|
||||
return output_image_path
|
||||
|
||||
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
|
||||
|
||||
mapper_kwargs = {
|
||||
"resolution": resolution,
|
||||
"show_sample_points": show_sample_points,
|
||||
"use_distance_diffusion": use_distance_diffusion,
|
||||
"diffusion_power": diffusion_power,
|
||||
"diffusion_n_neighbors": diffusion_n_neighbors,
|
||||
"expand_ratio": expand_ratio,
|
||||
}
|
||||
|
||||
optional_kwargs = {
|
||||
"base_map_tif": base_map_tif,
|
||||
"max_diffusion_distance": max_diffusion_distance,
|
||||
"cmap": cmap,
|
||||
}
|
||||
mapper_kwargs.update({k: v for k, v in optional_kwargs.items() if v is not None})
|
||||
|
||||
mapper.process_data(
|
||||
csv_file=prediction_csv_path,
|
||||
shp_file=boundary_shp_path,
|
||||
output_file=output_image_path,
|
||||
**mapper_kwargs,
|
||||
)
|
||||
|
||||
notify("completed", f"可视化分布图已保存: {output_image_path}")
|
||||
return output_image_path
|
||||
497
src/core/steps/modeling_step.py
Normal file
497
src/core/steps/modeling_step.py
Normal file
@ -0,0 +1,497 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
建模步骤
|
||||
|
||||
包含 step6_train_models, step6_5_non_empirical_modeling, step6_75_custom_regression
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 汉化 -> 英文 反向映射字典(UI 复选框显示文本 -> 底层算法键名)
|
||||
# ============================================================
|
||||
|
||||
# 模型名称:中文 (缩写) -> 英文键名
|
||||
MODEL_NAME_MAP = {
|
||||
"多元线性回归 (MLR)": "LinearRegression",
|
||||
"岭回归 (Ridge)": "Ridge",
|
||||
"套索回归 (Lasso)": "Lasso",
|
||||
"弹性网络 (ElasticNet)": "ElasticNet",
|
||||
"偏最小二乘 (PLSR)": "PLS",
|
||||
"决策树 (CART)": "DecisionTree",
|
||||
"随机森林 (RF)": "RF",
|
||||
"极端随机树 (ET)": "ExtraTrees",
|
||||
"极值梯度提升 (XGBoost)": "XGBoost",
|
||||
"轻量梯度提升 (LightGBM)": "LightGBM",
|
||||
"类别梯度提升 (CatBoost)": "CatBoost",
|
||||
"梯度提升树 (GBDT)": "GradientBoosting",
|
||||
"自适应提升 (AdaBoost)": "AdaBoost",
|
||||
"支持向量回归 (SVR)": "SVR",
|
||||
"K近邻回归 (KNN)": "KNN",
|
||||
"多层感知机 (BP神经网络)": "MLP",
|
||||
}
|
||||
|
||||
# 预处理方法:各种可能的中文变体 -> 标准键名
|
||||
PREPROC_NAME_MAP = {
|
||||
# 无处理
|
||||
"无 (None)": "None",
|
||||
"None": "None",
|
||||
# MMS
|
||||
"最小-最大归一化 (MMS)": "MMS",
|
||||
"MMS": "MMS",
|
||||
# SS
|
||||
"标度化 (SS)": "SS",
|
||||
"SS": "SS",
|
||||
# SNV
|
||||
"标准正态变换 (SNV)": "SNV",
|
||||
"SNV": "SNV",
|
||||
# MA
|
||||
"移动平均 (MA)": "MA",
|
||||
"MA": "MA",
|
||||
# SG
|
||||
"Savitzky-Golay (SG)": "SG",
|
||||
"SG": "SG",
|
||||
# MSC
|
||||
"多元散射校正 (MSC)": "MSC",
|
||||
"MSC": "MSC",
|
||||
# D1
|
||||
"一阶导数 (D1)": "D1",
|
||||
"D1": "D1",
|
||||
# D2
|
||||
"二阶导数 (D2)": "D2",
|
||||
"D2": "D2",
|
||||
# DT
|
||||
"去趋势 (DT)": "DT",
|
||||
"DT": "DT",
|
||||
# CT
|
||||
"中心化 (CT)": "CT",
|
||||
"CT": "CT",
|
||||
}
|
||||
|
||||
# 数据划分方法:各种可能的中文变体 -> 标准键名
|
||||
SPLIT_NAME_MAP = {
|
||||
"SPXY 算法 (考量X-Y空间)": "spxy",
|
||||
"spxy": "spxy",
|
||||
"KS 算法 (考量X空间)": "ks",
|
||||
"ks": "ks",
|
||||
"随机划分 (Random)": "random",
|
||||
"random": "random",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_model_names(model_names: List[str]) -> List[str]:
|
||||
"""清洗模型名称列表:将汉化显示文本还原为英文键名"""
|
||||
result = []
|
||||
for name in model_names:
|
||||
if name in MODEL_NAME_MAP:
|
||||
result.append(MODEL_NAME_MAP[name])
|
||||
else:
|
||||
# 已经是英文键名,直接保留
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_preprocessing_methods(methods: List[str]) -> List[str]:
|
||||
"""清洗预处理方法列表:将汉化显示文本还原为标准键名"""
|
||||
result = []
|
||||
for method in methods:
|
||||
if method in PREPROC_NAME_MAP:
|
||||
result.append(PREPROC_NAME_MAP[method])
|
||||
else:
|
||||
# 已经是标准键名,直接保留
|
||||
result.append(method)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_split_methods(methods: List[str]) -> List[str]:
|
||||
"""清洗数据划分方法列表:将汉化显示文本还原为标准键名"""
|
||||
result = []
|
||||
for method in methods:
|
||||
if method in SPLIT_NAME_MAP:
|
||||
result.append(SPLIT_NAME_MAP[method])
|
||||
else:
|
||||
# 已经是标准键名,直接保留
|
||||
result.append(method)
|
||||
return result
|
||||
|
||||
|
||||
class ModelingStep:
|
||||
"""建模步骤"""
|
||||
|
||||
# ---- Step 6: 训练机器学习模型 ----
|
||||
|
||||
@staticmethod
|
||||
def train_models(
|
||||
feature_start_column: str = "374.285004",
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
model_names: Optional[List[str]] = None,
|
||||
split_methods: Optional[List[str]] = None,
|
||||
cv_folds: int = 5,
|
||||
training_csv_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> str:
|
||||
"""使用采样点光谱和实测值建立机器学习模型"""
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6: 训练机器学习模型")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if training_csv_path is None:
|
||||
raise ValueError("必须提供 training_csv_path 参数")
|
||||
|
||||
# 检查模型目录是否已有模型
|
||||
if output_dir.exists() and any(output_dir.iterdir()):
|
||||
has_models = False
|
||||
for item in output_dir.iterdir():
|
||||
if item.is_dir():
|
||||
model_files = (
|
||||
list(item.glob("*.pkl"))
|
||||
+ list(item.glob("*.joblib"))
|
||||
+ list(item.glob("*.h5"))
|
||||
)
|
||||
if model_files:
|
||||
has_models = True
|
||||
break
|
||||
if has_models:
|
||||
print(f"检测到已存在的模型文件,直接使用: {output_dir}")
|
||||
notify("skipped", f"模型目录已设置: {output_dir}")
|
||||
return str(output_dir)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["None", "MMS", "SS", "SNV", "MA", "SG", "MSC", "D1", "D2", "DT", "CT"]
|
||||
if model_names is None:
|
||||
model_names = ["SVR", "RF", "Ridge", "Lasso"]
|
||||
if split_methods is None:
|
||||
split_methods = ["spxy", "ks", "random"]
|
||||
|
||||
# ---- 汉化清洗:将 UI 传来的中文/混合名称转换为底层英文键名 ----
|
||||
preprocessing_methods = _normalize_preprocessing_methods(preprocessing_methods)
|
||||
model_names = _normalize_model_names(model_names)
|
||||
split_methods = _normalize_split_methods(split_methods)
|
||||
|
||||
print(f"[参数清洗] 预处理方法: {preprocessing_methods}")
|
||||
print(f"[参数清洗] 模型名称: {model_names}")
|
||||
print(f"[参数清洗] 划分方法: {split_methods}")
|
||||
|
||||
modeler = WaterQualityModelingBatch(str(output_dir))
|
||||
modeler.train_models_batch(
|
||||
csv_path=training_csv_path,
|
||||
feature_start_column=feature_start_column,
|
||||
preprocessing_methods=preprocessing_methods,
|
||||
model_names=model_names,
|
||||
split_methods=split_methods,
|
||||
cv_folds=cv_folds,
|
||||
)
|
||||
|
||||
print(f"模型训练完成,结果保存在: {output_dir}")
|
||||
|
||||
if _report_generator is not None:
|
||||
try:
|
||||
summary_path = _report_generator.generate_training_summary(str(output_dir))
|
||||
print(f"训练摘要报告已生成: {summary_path}")
|
||||
except Exception as e:
|
||||
print(f"生成训练摘要报告时出错: {e}")
|
||||
|
||||
notify("completed", f"模型训练完成: {output_dir}")
|
||||
return str(output_dir)
|
||||
|
||||
# ---- Step 6.5: 非经验统计回归模型训练 ----
|
||||
|
||||
@staticmethod
|
||||
def train_non_empirical_models(
|
||||
csv_path: Optional[str] = None,
|
||||
preprocessing_methods: Optional[List[str]] = None,
|
||||
algorithms: Optional[List[str]] = None,
|
||||
value_cols: Union[int, Dict[str, int]] = 0,
|
||||
spectral_start_col: int = 1,
|
||||
spectral_end_col: Optional[int] = None,
|
||||
window: int = 5,
|
||||
output_dir: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""非经验统计回归模型训练"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6.5: 非经验统计回归模型训练")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过非经验模型训练(enabled=False)。")
|
||||
notify("skipped", "跳过的经验模型训练")
|
||||
return {}
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
|
||||
if output_dir is not None:
|
||||
non_empirical_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
||||
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if preprocessing_methods is None:
|
||||
preprocessing_methods = ["None"]
|
||||
if algorithms is None:
|
||||
algorithms = ["chl_a", "nh3", "mno4", "tn", "tp", "tss"]
|
||||
|
||||
if isinstance(value_cols, int):
|
||||
value_cols_dict = {algorithm: value_cols for algorithm in algorithms}
|
||||
elif isinstance(value_cols, dict):
|
||||
value_cols_dict = value_cols
|
||||
else:
|
||||
raise ValueError("value_cols 参数必须是整数或字典")
|
||||
|
||||
if spectral_end_col is None:
|
||||
df = pd.read_csv(csv_path)
|
||||
spectral_end_col = len(df.columns) - 1
|
||||
|
||||
all_model_results = {}
|
||||
|
||||
for preprocess in preprocessing_methods:
|
||||
preprocess_dir = non_empirical_dir / preprocess
|
||||
preprocess_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
processed_csv_path = _apply_preprocessing_internal(
|
||||
csv_path, preprocess, preprocess_dir, spectral_start_col
|
||||
)
|
||||
|
||||
for algorithm in algorithms:
|
||||
algorithm_value_col = value_cols_dict[algorithm]
|
||||
print(f"\n训练 {preprocess} + {algorithm} 模型 (实测值列: {algorithm_value_col})...")
|
||||
|
||||
model_outpath = str(preprocess_dir / f"{preprocess}_{algorithm}.json")
|
||||
|
||||
if Path(model_outpath).exists():
|
||||
print(f"检测到已存在的模型文件,直接使用: {model_outpath}")
|
||||
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.core.non_empirical_model_correction import run_model_correction
|
||||
run_model_correction(
|
||||
algorithm=algorithm,
|
||||
csv_file=processed_csv_path if Path(processed_csv_path).exists() else csv_path,
|
||||
value_col=algorithm_value_col,
|
||||
spectral_start=spectral_start_col,
|
||||
spectral_end=spectral_end_col,
|
||||
model_info_outpath=model_outpath,
|
||||
window=window,
|
||||
)
|
||||
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
|
||||
print(f"模型训练完成: {model_outpath}")
|
||||
except Exception as e:
|
||||
print(f"训练 {preprocess}_{algorithm} 模型时出错: {e}")
|
||||
continue
|
||||
|
||||
summary_path = _generate_non_empirical_summary(all_model_results, non_empirical_dir)
|
||||
notify("completed", f"非经验模型训练完成: {non_empirical_dir}")
|
||||
return all_model_results
|
||||
|
||||
# ---- Step 6.75: 自定义回归分析 ----
|
||||
|
||||
@staticmethod
|
||||
def custom_regression(
|
||||
csv_path: Optional[str] = None,
|
||||
x_columns: Optional[Union[str, List[str]]] = None,
|
||||
y_columns: Optional[Union[str, List[str]]] = None,
|
||||
methods: Union[str, List[str]] = "all",
|
||||
output_dir: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Optional[str]:
|
||||
"""使用自定义回归方法分析指标与目标参数之间的关系"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤6.75", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤6.75: 自定义回归分析")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过自定义回归分析(enabled=False)。")
|
||||
notify("skipped", "跳过自定义回归分析")
|
||||
return None
|
||||
|
||||
if csv_path is None:
|
||||
raise ValueError("必须提供 csv_path 参数")
|
||||
if y_columns is None:
|
||||
raise ValueError("必须指定 y_columns")
|
||||
if x_columns is None:
|
||||
raise ValueError("必须指定 x_columns")
|
||||
|
||||
if isinstance(x_columns, str):
|
||||
x_columns = [x_columns]
|
||||
if isinstance(y_columns, str):
|
||||
y_columns = [y_columns]
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
missing_x = [col for col in x_columns if col not in df.columns]
|
||||
missing_y = [col for col in y_columns if col not in df.columns]
|
||||
if missing_x:
|
||||
raise ValueError(f"自变量列不存在: {missing_x}")
|
||||
if missing_y:
|
||||
raise ValueError(f"因变量列不存在: {missing_y}")
|
||||
|
||||
if output_dir is None:
|
||||
custom_regression_dir = Path(work_dir) / "9_Custom_Regression_Modeling"
|
||||
else:
|
||||
custom_regression_dir = Path(work_dir) / output_dir
|
||||
custom_regression_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from src.core.modeling.regression import SingleVariableRegressionAnalysis
|
||||
analyzer = SingleVariableRegressionAnalysis()
|
||||
analyzer.batch_single_variable_regression(
|
||||
data=df,
|
||||
x_columns=x_columns,
|
||||
y_columns=y_columns,
|
||||
methods=methods,
|
||||
output_dir=str(custom_regression_dir),
|
||||
)
|
||||
|
||||
notify("completed", f"自定义回归结果已保存到目录: {custom_regression_dir}")
|
||||
return str(custom_regression_dir)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 内部辅助函数(供 ModelingStep 内部使用)
|
||||
# ============================================================
|
||||
|
||||
def _apply_preprocessing_internal(
|
||||
csv_path: str,
|
||||
preprocess_method: str,
|
||||
output_dir: Path,
|
||||
spectral_start_col: int = 4,
|
||||
) -> str:
|
||||
"""应用预处理到CSV数据(内部函数)"""
|
||||
raw_p = str(preprocess_method).lower()
|
||||
if raw_p == "none" or "无" in raw_p or "跳过" in raw_p:
|
||||
preprocess_method = "None"
|
||||
elif raw_p == "mms" or "minmax" in raw_p or "最大最小" in raw_p:
|
||||
preprocess_method = "MMS"
|
||||
elif raw_p == "ss" or "标准" in raw_p or "标准化" in raw_p:
|
||||
preprocess_method = "SS"
|
||||
elif raw_p == "snv" or "标准正态" in raw_p:
|
||||
preprocess_method = "SNV"
|
||||
elif raw_p == "ma" or "移动" in raw_p:
|
||||
preprocess_method = "MA"
|
||||
elif raw_p == "sg" or "savitzky" in raw_p or "平滑" in raw_p:
|
||||
preprocess_method = "SG"
|
||||
elif raw_p == "msc" or "多元散射" in raw_p:
|
||||
preprocess_method = "MSC"
|
||||
elif raw_p in ("d1", "d2", "dt"):
|
||||
preprocess_method = {"d1": "D1", "d2": "D2", "dt": "DT"}.get(raw_p, raw_p.upper())
|
||||
elif raw_p == "ct" or "去趋势" in raw_p:
|
||||
preprocess_method = "CT"
|
||||
|
||||
if preprocess_method == "None":
|
||||
return csv_path
|
||||
|
||||
output_filename = f"preprocessed_{preprocess_method}.csv"
|
||||
output_path = str(output_dir / output_filename)
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预处理文件,直接使用: {output_path}")
|
||||
return output_path
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
non_spectral_cols = df.iloc[:, :spectral_start_col]
|
||||
spectral_data = df.iloc[:, spectral_start_col:]
|
||||
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
save_path = None
|
||||
if preprocess_method == "SS":
|
||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
|
||||
|
||||
if isinstance(processed_spectral, pd.DataFrame):
|
||||
processed_df = pd.concat([non_spectral_cols, processed_spectral], axis=1)
|
||||
else:
|
||||
processed_spectral_df = pd.DataFrame(
|
||||
processed_spectral, columns=spectral_data.columns, index=spectral_data.index
|
||||
)
|
||||
processed_df = pd.concat([non_spectral_cols, processed_spectral_df], axis=1)
|
||||
|
||||
processed_df.to_csv(output_path, index=False)
|
||||
print(f"预处理完成: {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def _generate_non_empirical_summary(model_results: Dict[str, str], output_dir: Path) -> str:
|
||||
"""生成非经验模型训练结果汇总CSV"""
|
||||
summary_path = str(output_dir / "non_empirical_models_summary.csv")
|
||||
summary_data = []
|
||||
|
||||
for model_key, model_path in model_results.items():
|
||||
try:
|
||||
parts = model_key.split("_")
|
||||
preprocess_method = parts[0]
|
||||
algorithm_name = "_".join(parts[1:]) if len(parts) > 2 else parts[1]
|
||||
|
||||
with open(model_path, "r", encoding="utf-8") as f:
|
||||
model_info = json.load(f)
|
||||
|
||||
accuracy_list = model_info.get("accuracy", [])
|
||||
summary_row = {
|
||||
"Preprocessing Method": preprocess_method,
|
||||
"Algorithm Name": algorithm_name,
|
||||
"Model Type": model_info.get("model_type", ""),
|
||||
"Coefficient Count": len(model_info.get("model_info", [])),
|
||||
"Average Accuracy(%)": np.mean(accuracy_list) if accuracy_list else 0,
|
||||
"Min Accuracy(%)": np.min(accuracy_list) if accuracy_list else 0,
|
||||
"Max Accuracy(%)": np.max(accuracy_list) if accuracy_list else 0,
|
||||
"Sample Count": len(model_info.get("long", [])),
|
||||
"Model File": model_path,
|
||||
}
|
||||
|
||||
coefficients = model_info.get("model_info", [])
|
||||
for i, coeff in enumerate(coefficients[:5]):
|
||||
summary_row[f"系数_{i+1}"] = coeff
|
||||
|
||||
summary_data.append(summary_row)
|
||||
except Exception as e:
|
||||
print(f"读取模型文件 {model_path} 时出错: {e}")
|
||||
continue
|
||||
|
||||
if summary_data:
|
||||
df_summary = pd.DataFrame(summary_data)
|
||||
df_summary.to_csv(summary_path, index=False, encoding="utf-8-sig")
|
||||
print(f"汇总文件已生成: {summary_path}")
|
||||
else:
|
||||
print("警告: 没有有效的模型数据可汇总")
|
||||
summary_path = ""
|
||||
|
||||
return summary_path
|
||||
350
src/core/steps/prediction_step.py
Normal file
350
src/core/steps/prediction_step.py
Normal file
@ -0,0 +1,350 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
预测步骤
|
||||
|
||||
包含 step7_generate_sampling_points, step8_predict_water_quality,
|
||||
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
|
||||
"""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Callable, Dict
|
||||
|
||||
|
||||
class PredictionStep:
|
||||
"""预测步骤"""
|
||||
|
||||
# ---- Step 7: 生成采样点并提取光谱 ----
|
||||
|
||||
@staticmethod
|
||||
def generate_sampling_points(
|
||||
deglint_img_path: Optional[str] = None,
|
||||
interval: int = 50,
|
||||
sample_radius: int = 5,
|
||||
chunk_size: int = 1000,
|
||||
water_mask_path: Optional[str] = None,
|
||||
glint_mask_path: Optional[str] = None,
|
||||
output_dir: Union[str, Path] = "./10_sampling",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
|
||||
from pathlib import Path
|
||||
from src.utils.sampling import get_spectral_sampling_points_chunked
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = str(output_dir / "sampling_spectra.csv")
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤7", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤7: 生成预测采样点并提取光谱")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if deglint_img_path is None:
|
||||
raise ValueError("必须提供 deglint_img_path 参数")
|
||||
|
||||
# 1. 初始归一化与安全转换
|
||||
original_path = Path(deglint_img_path)
|
||||
final_deglint_path = original_path
|
||||
|
||||
# 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat
|
||||
if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat':
|
||||
print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...")
|
||||
|
||||
# 定位到预期的 3_deglint 根目录
|
||||
possible_dir = original_path.parent
|
||||
if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists():
|
||||
possible_dir = Path(output_path).parent.parent / "3_deglint"
|
||||
|
||||
if possible_dir.exists():
|
||||
# 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物)
|
||||
existing_bsqs = list(possible_dir.glob("*.bsq"))
|
||||
if existing_bsqs:
|
||||
final_deglint_path = existing_bsqs[0]
|
||||
print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}")
|
||||
else:
|
||||
final_deglint_path = original_path.with_suffix('.bsq')
|
||||
else:
|
||||
final_deglint_path = original_path.with_suffix('.bsq')
|
||||
|
||||
deglint_img_str = str(final_deglint_path)
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
|
||||
notify("skipped", f"采样点光谱数据已设置: {output_path}")
|
||||
return output_path
|
||||
|
||||
glint_mask_to_use = glint_mask_path
|
||||
if glint_mask_to_use is None:
|
||||
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
|
||||
|
||||
# 传递极度安全的 deglint_img_str 进底层
|
||||
get_spectral_sampling_points_chunked(
|
||||
deglint_img_str, water_mask_path, glint_mask_to_use,
|
||||
output_path, interval, sample_radius, chunk_size
|
||||
)
|
||||
|
||||
notify("completed", f"采样点光谱数据已保存: {output_path}")
|
||||
return output_path
|
||||
|
||||
# ---- Step 8: 机器学习模型预测水质参数 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_water_quality(
|
||||
sampling_csv_path: str,
|
||||
models_dir: Optional[str] = None,
|
||||
metric: str = "test_r2",
|
||||
prediction_column: str = "prediction",
|
||||
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
|
||||
callback: Optional[Callable] = None,
|
||||
_report_generator=None,
|
||||
) -> Dict[str, str]:
|
||||
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
|
||||
from src.core.prediction.inference_batch import WaterQualityInference
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8: 预测水质参数")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if models_dir is None:
|
||||
raise ValueError("必须提供 models_dir 参数")
|
||||
|
||||
ml_prediction_dir = Path(output_dir)
|
||||
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
if ml_prediction_dir.exists():
|
||||
csv_files = list(ml_prediction_dir.glob("*.csv"))
|
||||
for csv_file in csv_files:
|
||||
file_stem = csv_file.stem
|
||||
if "_prediction" in file_stem:
|
||||
target_name = file_stem.replace("_prediction", "")
|
||||
elif "_pred" in file_stem:
|
||||
target_name = file_stem.replace("_pred", "")
|
||||
else:
|
||||
target_name = file_stem
|
||||
prediction_files[target_name] = str(csv_file)
|
||||
|
||||
# 检查是否所有目标参数都有预测文件
|
||||
if prediction_files:
|
||||
models_path_obj = Path(models_dir)
|
||||
if models_path_obj.exists():
|
||||
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
|
||||
missing_targets = [t for t in target_folders if t not in prediction_files]
|
||||
if not missing_targets:
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
|
||||
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
else:
|
||||
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
|
||||
|
||||
inferencer = WaterQualityInference(models_dir)
|
||||
all_results = inferencer.batch_inference_multi_models(
|
||||
models_root_dir=models_dir,
|
||||
sampling_csv_path=sampling_csv_path,
|
||||
output_dir=str(ml_prediction_dir),
|
||||
metric=metric,
|
||||
prediction_column=prediction_column,
|
||||
output_format="csv",
|
||||
)
|
||||
|
||||
for target_name, result in all_results.items():
|
||||
if result.get("status") == "success":
|
||||
prediction_files[target_name] = result["output_file"]
|
||||
|
||||
print(f"预测完成,结果保存在: {ml_prediction_dir}")
|
||||
|
||||
if _report_generator is not None:
|
||||
try:
|
||||
report_path = _report_generator.generate_prediction_report(prediction_files)
|
||||
print(f"预测结果报告已生成: {report_path}")
|
||||
except Exception as e:
|
||||
print(f"生成预测结果报告时出错: {e}")
|
||||
|
||||
notify("completed", f"预测完成: {ml_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.5: 非经验模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_non_empirical_models(
|
||||
sampling_csv_path: str,
|
||||
non_empirical_models_dir: Optional[str] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
metric: str = "Average Accuracy(%)",
|
||||
prediction_column: str = "prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用非经验统计回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.5", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.5: 使用非经验模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过非经验模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过非经验模型预测")
|
||||
return {}
|
||||
|
||||
if non_empirical_models_dir is not None:
|
||||
final_models_dir = non_empirical_models_dir
|
||||
else:
|
||||
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
|
||||
if Path(default_models_dir).exists():
|
||||
final_models_dir = default_models_dir
|
||||
else:
|
||||
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
|
||||
|
||||
if output_dir is not None:
|
||||
non_empirical_prediction_dir = Path(output_dir)
|
||||
else:
|
||||
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
|
||||
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prediction_files = {}
|
||||
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
|
||||
if not summary_path.exists():
|
||||
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
|
||||
|
||||
import pandas as pd
|
||||
df_summary = pd.read_csv(summary_path)
|
||||
|
||||
best_models = {}
|
||||
for algorithm in df_summary["Algorithm Name"].unique():
|
||||
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
|
||||
if metric in algorithm_df.columns:
|
||||
best_model_row = algorithm_df.nlargest(1, metric)
|
||||
else:
|
||||
best_model_row = algorithm_df.iloc[[0]]
|
||||
|
||||
best_model_path = best_model_row["Model File"].values[0]
|
||||
best_preprocess = best_model_row["Preprocessing Method"].values[0]
|
||||
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
|
||||
|
||||
best_models[algorithm] = {
|
||||
"model_path": best_model_path,
|
||||
"preprocess_method": best_preprocess,
|
||||
"accuracy": best_accuracy,
|
||||
}
|
||||
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
|
||||
|
||||
pd.read_csv(sampling_csv_path) # just to validate
|
||||
|
||||
for algorithm, model_info in best_models.items():
|
||||
print(f"\n使用 {algorithm} 算法进行预测...")
|
||||
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
|
||||
|
||||
if Path(output_path).exists():
|
||||
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
|
||||
prediction_files[algorithm] = output_path
|
||||
continue
|
||||
|
||||
try:
|
||||
from src.core.non_empirical_retrieval import non_empirical_retrieval
|
||||
non_empirical_retrieval(
|
||||
algorithm=algorithm,
|
||||
model_info_path=model_info["model_path"],
|
||||
coor_spectral_path=sampling_csv_path,
|
||||
output_path=output_path,
|
||||
wave_radius=5,
|
||||
)
|
||||
prediction_files[algorithm] = output_path
|
||||
print(f"预测完成: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"使用 {algorithm} 算法预测时出错: {e}")
|
||||
continue
|
||||
|
||||
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
|
||||
return prediction_files
|
||||
|
||||
# ---- Step 8.75: 自定义回归模型预测 ----
|
||||
|
||||
@staticmethod
|
||||
def predict_with_custom_regression(
|
||||
sampling_csv_path: str,
|
||||
custom_regression_dir: Optional[str] = None,
|
||||
formula_csv_path: Optional[str] = None,
|
||||
coordinate_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
filename_prefix: str = "custom_regression_prediction",
|
||||
enabled: bool = True,
|
||||
callback: Optional[Callable] = None,
|
||||
work_dir: Union[str, Path] = "./work_dir",
|
||||
) -> Dict[str, str]:
|
||||
"""使用自定义回归模型进行参数预测"""
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤8.75", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
if not enabled:
|
||||
print("已设置跳过自定义回归模型预测(enabled=False)。")
|
||||
notify("skipped", "跳过自定义回归预测")
|
||||
return {}
|
||||
|
||||
if not Path(sampling_csv_path).exists():
|
||||
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||||
|
||||
if custom_regression_dir is not None:
|
||||
final_regression_dir = custom_regression_dir
|
||||
else:
|
||||
final_regression_dir = str(Path(work_dir) / "9_Custom_Regression_Modeling")
|
||||
if not Path(final_regression_dir).exists():
|
||||
raise ValueError(
|
||||
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
|
||||
)
|
||||
|
||||
if output_dir is None:
|
||||
custom_regression_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Custom_Regression_Prediction"
|
||||
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
prediction_output_dir = str(custom_regression_prediction_dir)
|
||||
else:
|
||||
prediction_output_dir = output_dir
|
||||
|
||||
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||||
|
||||
predictor = CustomRegressionPredictor(
|
||||
regression_csv_dir=final_regression_dir,
|
||||
formula_csv_path=formula_csv_path,
|
||||
)
|
||||
|
||||
print(f"开始使用自定义回归模块进行批量预测...")
|
||||
print(f" 采样点数据: {sampling_csv_path}")
|
||||
print(f" 回归模型目录: {final_regression_dir}")
|
||||
print(f" 输出目录: {prediction_output_dir}")
|
||||
|
||||
saved_files = predictor.run_batch_prediction(
|
||||
input_csv_path=sampling_csv_path,
|
||||
coordinate_columns=coordinate_columns,
|
||||
filename_prefix=filename_prefix,
|
||||
)
|
||||
|
||||
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||||
for param_name, filepath in saved_files.items():
|
||||
print(f" {param_name}: {filepath}")
|
||||
|
||||
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
|
||||
return saved_files
|
||||
148
src/core/steps/water_mask_step.py
Normal file
148
src/core/steps/water_mask_step.py
Normal file
@ -0,0 +1,148 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
步骤1: 水域掩膜生成
|
||||
|
||||
支持三种方式:
|
||||
1. 基于 shp 文件栅格化
|
||||
2. 使用现有栅格格式掩膜文件 (.dat/.tif)
|
||||
3. 基于 NDWI 从影像自动生成水体掩膜
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Callable, Union
|
||||
import numpy as np
|
||||
|
||||
|
||||
class WaterMaskStep:
|
||||
"""水域掩膜生成步骤"""
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
mask_path: Optional[str] = None,
|
||||
img_path: Optional[str] = None,
|
||||
ndwi_threshold: float = 0.4,
|
||||
use_ndwi: bool = False,
|
||||
generate_png: bool = True,
|
||||
output_path: Optional[str] = None,
|
||||
water_mask_dir: Union[str, Path] = "./1_water_mask",
|
||||
callback: Optional[Callable] = None,
|
||||
) -> str:
|
||||
"""
|
||||
执行水域掩膜生成
|
||||
|
||||
Args:
|
||||
mask_path: 水体掩膜文件路径,支持 .shp(需 img_path)或 .dat/.tif(直接使用)
|
||||
img_path: 输入影像文件路径(当 mask_path 为 shp 或 use_ndwi=True 时必须提供)
|
||||
ndwi_threshold: NDWI 阈值(use_ndwi=True 时使用)
|
||||
use_ndwi: 是否使用 NDWI 方法从影像生成水体掩膜
|
||||
generate_png: 是否生成 PNG 预览图(默认 True)
|
||||
output_path: 指定输出掩膜文件的保存路径(可选)
|
||||
water_mask_dir: 工作目录
|
||||
callback: 回调函数,签名为 callback(step, status, message)
|
||||
|
||||
Returns:
|
||||
dat 格式的水域掩膜文件路径
|
||||
"""
|
||||
from src.utils.extract_water_area import rasterize_shp, ndwi
|
||||
from src.core.utils.preview_generator import (
|
||||
generate_image_preview,
|
||||
generate_water_mask_overlay,
|
||||
)
|
||||
|
||||
water_mask_dir = Path(water_mask_dir)
|
||||
water_mask_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def notify(status, msg=""):
|
||||
if callback:
|
||||
callback("步骤1", status, msg)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤1: 生成或设置水域mask")
|
||||
print("=" * 80)
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
# 生成影像预览图
|
||||
if generate_png and img_path is not None and Path(img_path).exists():
|
||||
preview_path = str(water_mask_dir / "hsi_preview.png")
|
||||
generate_image_preview(
|
||||
img_path=img_path,
|
||||
output_path=preview_path,
|
||||
title="影像预览: RGB波段(基于波长)"
|
||||
)
|
||||
|
||||
# ---- NDWI 方法 ----
|
||||
if use_ndwi:
|
||||
if img_path is None:
|
||||
raise ValueError("当 use_ndwi=True 时,必须提供 img_path 参数")
|
||||
if not Path(img_path).exists():
|
||||
raise ValueError(f"影像文件不存在: {img_path}")
|
||||
|
||||
print(f"使用NDWI方法从影像生成水体掩膜,阈值={ndwi_threshold}...")
|
||||
|
||||
ndwi_output_path = output_path or str(water_mask_dir / "water_mask_from_ndwi.dat")
|
||||
os.makedirs(Path(ndwi_output_path).parent, exist_ok=True)
|
||||
|
||||
if Path(ndwi_output_path).exists():
|
||||
print(f"检测到已存在的NDWI掩膜文件,直接使用: {ndwi_output_path}")
|
||||
notify("skipped", f"水域掩膜已设置: {ndwi_output_path}")
|
||||
return ndwi_output_path
|
||||
|
||||
ndwi(img_path, ndwi_threshold, ndwi_output_path)
|
||||
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(
|
||||
img_path=img_path, mask_path=ndwi_output_path, output_path=str(overlay_path)
|
||||
)
|
||||
|
||||
notify("completed", f"NDWI水体掩膜已生成: {ndwi_output_path}")
|
||||
return ndwi_output_path
|
||||
|
||||
# ---- 必须提供 mask_path ----
|
||||
if mask_path is None:
|
||||
raise ValueError("必须提供 mask_path 参数或设置 use_ndwi=True")
|
||||
if not Path(mask_path).exists():
|
||||
raise ValueError(f"文件不存在: {mask_path}")
|
||||
|
||||
file_ext = Path(mask_path).suffix.lower()
|
||||
|
||||
# ---- SHP 栅格化 ----
|
||||
if file_ext == ".shp":
|
||||
if img_path is None:
|
||||
raise ValueError("当 mask_path 为 shp 格式时,必须提供 img_path 参数")
|
||||
|
||||
print(f"检测到shp格式的水体掩膜,正在转换为dat格式...")
|
||||
|
||||
shp_output_path = output_path or str(water_mask_dir / "water_mask_from_shp.dat")
|
||||
os.makedirs(Path(shp_output_path).parent, exist_ok=True)
|
||||
|
||||
if Path(shp_output_path).exists():
|
||||
print(f"检测到已存在的栅格化掩膜文件,直接使用: {shp_output_path}")
|
||||
notify("skipped", f"水域掩膜已设置: {shp_output_path}")
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
if not overlay_path.exists():
|
||||
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||
return shp_output_path
|
||||
|
||||
safe_mask_path = os.path.abspath(mask_path).replace("\\", "/")
|
||||
rasterize_shp(safe_mask_path, shp_output_path, img_path)
|
||||
|
||||
if generate_png:
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
|
||||
|
||||
notify("completed", f"dat格式水域掩膜已生成: {shp_output_path}")
|
||||
return shp_output_path
|
||||
|
||||
# ---- 栅格格式直接使用 ----
|
||||
print(f"检测到栅格格式的水体掩膜,直接使用: {mask_path}")
|
||||
if generate_png and img_path is not None and Path(img_path).exists():
|
||||
overlay_path = water_mask_dir / "water_mask_overlay.png"
|
||||
generate_water_mask_overlay(img_path, mask_path, str(overlay_path))
|
||||
|
||||
notify("completed", f"水域掩膜已设置: {mask_path}")
|
||||
return mask_path
|
||||
File diff suppressed because it is too large
Load Diff
@ -209,7 +209,7 @@ class Step3Panel(QWidget):
|
||||
"输出影像:",
|
||||
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
|
||||
)
|
||||
self.output_file.line_edit.setPlaceholderText("deglint_image.dat")
|
||||
self.output_file.line_edit.setPlaceholderText("deglint_image.bsq")
|
||||
layout.addWidget(self.output_file)
|
||||
|
||||
# 启用步骤
|
||||
@ -301,7 +301,7 @@ class Step3Panel(QWidget):
|
||||
if self.work_dir:
|
||||
output_dir = os.path.join(self.work_dir, "3_deglint")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
default_output_path = os.path.join(output_dir, "deglint_image.dat").replace('\\', '/')
|
||||
default_output_path = os.path.join(output_dir, "deglint_image.bsq").replace('\\', '/')
|
||||
self.output_file.set_path(default_output_path)
|
||||
else:
|
||||
self.output_file.set_path("")
|
||||
|
||||
@ -17,6 +17,57 @@ from src.gui.components.custom_widgets import FileSelectWidget
|
||||
from src.gui.styles import ModernStylesheet
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 中文映射表(内部键名 -> 显示文本)
|
||||
# ============================================================
|
||||
|
||||
# 预处理方法:内部键 -> 显示文本
|
||||
PREPROC_CHINESE = {
|
||||
'None': '无 (None)',
|
||||
'MMS': '最小-最大归一化 (MMS)',
|
||||
'SS': '标度化 (SS)',
|
||||
'SNV': '标准正态变换 (SNV)',
|
||||
'MA': '移动平均 (MA)',
|
||||
'SG': 'Savitzky-Golay (SG)',
|
||||
'MSC': '多元散射校正 (MSC)',
|
||||
'D1': '一阶导数 (D1)',
|
||||
'D2': '二阶导数 (D2)',
|
||||
'DT': '去趋势 (DT)',
|
||||
'CT': '中心化 (CT)',
|
||||
}
|
||||
|
||||
# 模型类型:内部键 -> 显示文本
|
||||
MODEL_CHINESE = {
|
||||
# 线性模型
|
||||
'LinearRegression': '多元线性回归 (MLR)',
|
||||
'Ridge': '岭回归 (Ridge)',
|
||||
'Lasso': '套索回归 (Lasso)',
|
||||
'ElasticNet': '弹性网络 (ElasticNet)',
|
||||
'PLS': '偏最小二乘 (PLSR)',
|
||||
# 树模型
|
||||
'DecisionTree': '决策树 (CART)',
|
||||
'RF': '随机森林 (RF)',
|
||||
'ExtraTrees': '极端随机树 (ET)',
|
||||
'XGBoost': '极值梯度提升 (XGBoost)',
|
||||
'LightGBM': '轻量梯度提升 (LightGBM)',
|
||||
'CatBoost': '类别梯度提升 (CatBoost)',
|
||||
# 集成学习
|
||||
'GradientBoosting': '梯度提升树 (GBDT)',
|
||||
'AdaBoost': '自适应提升 (AdaBoost)',
|
||||
# 其他模型
|
||||
'SVR': '支持向量回归 (SVR)',
|
||||
'KNN': 'K近邻回归 (KNN)',
|
||||
'MLP': '多层感知机 (BP神经网络)',
|
||||
}
|
||||
|
||||
# 数据划分方法:内部键 -> 显示文本
|
||||
SPLIT_CHINESE = {
|
||||
'spxy': 'SPXY 算法 (考量X-Y空间)',
|
||||
'ks': 'KS 算法 (考量X空间)',
|
||||
'random': '随机划分 (Random)',
|
||||
}
|
||||
|
||||
|
||||
class Step6Panel(QWidget):
|
||||
"""步骤6:机器学习建模"""
|
||||
def __init__(self, parent=None):
|
||||
@ -54,7 +105,7 @@ class Step6Panel(QWidget):
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
self.enable_checkbox.setChecked(True)
|
||||
self.enable_checkbox.setChecked(False)
|
||||
layout.addWidget(self.enable_checkbox)
|
||||
|
||||
# 独立运行按钮
|
||||
@ -95,8 +146,8 @@ class Step6Panel(QWidget):
|
||||
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
|
||||
|
||||
for i, method in enumerate(preproc_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True)
|
||||
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.preproc_checkboxes[method] = checkbox
|
||||
preproc_grid.addWidget(checkbox, i // 4, i % 4)
|
||||
|
||||
@ -122,10 +173,10 @@ class Step6Panel(QWidget):
|
||||
self.model_checkboxes = {}
|
||||
|
||||
model_groups = [
|
||||
("线性模型", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||||
("树模型", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||||
("集成学习", ['GradientBoosting', 'AdaBoost']),
|
||||
("其他模型", ['SVR', 'KNN', 'MLP'])
|
||||
("【线性模型】", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
|
||||
("【树模型】", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
|
||||
("【集成学习】", ['GradientBoosting', 'AdaBoost']),
|
||||
("【其他模型】", ['SVR', 'KNN', 'MLP'])
|
||||
]
|
||||
|
||||
row = 0
|
||||
@ -140,8 +191,8 @@ class Step6Panel(QWidget):
|
||||
row += 1
|
||||
|
||||
for i, model in enumerate(models):
|
||||
checkbox = QCheckBox(model)
|
||||
checkbox.setChecked(model in ['SVR', 'RF', 'Ridge', 'Lasso'])
|
||||
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
|
||||
checkbox.setChecked(False)
|
||||
self.model_checkboxes[model] = checkbox
|
||||
model_grid.addWidget(checkbox, row, i % 4)
|
||||
if (i + 1) % 4 == 0:
|
||||
@ -172,8 +223,8 @@ class Step6Panel(QWidget):
|
||||
split_methods = ['spxy', 'ks', 'random']
|
||||
|
||||
for i, method in enumerate(split_methods):
|
||||
checkbox = QCheckBox(method)
|
||||
checkbox.setChecked(True)
|
||||
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
|
||||
checkbox.setChecked(False)
|
||||
self.split_checkboxes[method] = checkbox
|
||||
split_grid.addWidget(checkbox, 0, i)
|
||||
|
||||
|
||||
@ -109,7 +109,7 @@ class Step9Panel(QWidget):
|
||||
mode_row.addStretch()
|
||||
layout.addLayout(mode_row)
|
||||
|
||||
# ---------- RadioButton 美化样式(选中状态更醒目) ----------
|
||||
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格) ----------
|
||||
radio_style = """
|
||||
QRadioButton {
|
||||
font-size: 14px;
|
||||
@ -117,21 +117,16 @@ class Step9Panel(QWidget):
|
||||
color: #333333;
|
||||
}
|
||||
QRadioButton::indicator {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border: 2px solid #999999;
|
||||
border-radius: 9px;
|
||||
border-radius: 3px;
|
||||
background-color: white;
|
||||
}
|
||||
QRadioButton::indicator:checked {
|
||||
border: 2px solid #0078d4;
|
||||
background-color: qradialgradient(
|
||||
cx:0.5, cy:0.5, radius:0.5,
|
||||
fx:0.5, fy:0.5,
|
||||
stop:0 #0078d4,
|
||||
stop:0.6 white,
|
||||
stop:1.0 white
|
||||
);
|
||||
background-color: #0078d4;
|
||||
image: none;
|
||||
}
|
||||
QRadioButton::indicator:hover {
|
||||
border: 2px solid #005a9e;
|
||||
@ -353,7 +348,7 @@ class Step9Panel(QWidget):
|
||||
if not main_window:
|
||||
return
|
||||
|
||||
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(优先)
|
||||
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(最优先)
|
||||
pred_dir = None
|
||||
if hasattr(main_window, 'step8_panel'):
|
||||
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
|
||||
@ -367,7 +362,10 @@ class Step9Panel(QWidget):
|
||||
# 若为相对路径,使用 work_dir 合成为绝对路径
|
||||
if not os.path.isabs(step8_output):
|
||||
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
|
||||
pred_dir = str(Path(step8_output).parent)
|
||||
# 提取父目录后追加 Machine_Learning_Prediction(最底层真实子目录)
|
||||
base_pred_dir = str(Path(step8_output).parent)
|
||||
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
|
||||
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
|
||||
|
||||
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
|
||||
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
|
||||
@ -411,6 +409,14 @@ class Step9Panel(QWidget):
|
||||
existing_out = self.output_dir.get_path()
|
||||
if not existing_out or not existing_out.strip():
|
||||
self.output_dir.set_path(output_dir)
|
||||
|
||||
# 5. 自动继承步骤1的水域掩膜作为边界文件
|
||||
if self.work_dir:
|
||||
default_mask = Path(self.work_dir) / "1_water_mask" / "water_mask_from_shp.dat"
|
||||
if default_mask.exists():
|
||||
existing_boundary = (self.boundary_file.get_path() or "").strip()
|
||||
if not existing_boundary:
|
||||
self.boundary_file.set_path(str(default_mask))
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"【{self.__class__.__name__}】自动填充失败,跳过: {e}")
|
||||
|
||||
@ -1825,6 +1825,11 @@ class WaterQualityGUI(QMainWindow):
|
||||
for step_id, step_display in steps:
|
||||
item = QListWidgetItem(f" └─ {step_display}")
|
||||
item.setData(Qt.UserRole, step_id)
|
||||
|
||||
# 隐藏4个冗余回归步骤(树节点)
|
||||
if step_id in ("step6_5", "step6_75", "step8_5", "step8_75"):
|
||||
item.setHidden(True)
|
||||
|
||||
self.step_name_map[step_display] = step_id
|
||||
|
||||
# 设置步骤项的样式
|
||||
@ -1905,9 +1910,11 @@ class WaterQualityGUI(QMainWindow):
|
||||
|
||||
self.step6_5_panel = Step6_5Panel()
|
||||
self.step_stack.addTab(self.create_scroll_area(self.step6_5_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
|
||||
self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab
|
||||
|
||||
self.step6_75_panel = Step6_75Panel()
|
||||
self.step_stack.addTab(self.create_scroll_area(self.step6_75_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
|
||||
self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab
|
||||
|
||||
self.step7_panel = Step7Panel()
|
||||
self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
|
||||
@ -1917,9 +1924,11 @@ class WaterQualityGUI(QMainWindow):
|
||||
|
||||
self.step8_5_panel = Step8_5Panel()
|
||||
self.step_stack.addTab(self.create_scroll_area(self.step8_5_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
|
||||
self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab
|
||||
|
||||
self.step8_75_panel = Step8_75Panel()
|
||||
self.step_stack.addTab(self.create_scroll_area(self.step8_75_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
|
||||
self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab
|
||||
|
||||
self.step9_panel = Step9Panel()
|
||||
self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")
|
||||
|
||||
@ -1003,67 +1003,84 @@ class ReportGenerator:
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
from modeling_batch import WaterQualityModelingBatch
|
||||
|
||||
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
|
||||
import joblib
|
||||
|
||||
modeler = WaterQualityModelingBatch(models_dir)
|
||||
|
||||
# 需要先加载训练结果
|
||||
# 这里假设results已经存储在modeler中,或者需要从保存的文件中读取
|
||||
# 由于modeling_batch.py的结构,我们需要另一种方式来获取所有结果
|
||||
|
||||
# 尝试遍历模型目录,查找所有保存的结果
|
||||
models_path = Path(models_dir)
|
||||
all_results = []
|
||||
|
||||
# 遍历所有目标参数文件夹
|
||||
for target_folder in models_path.iterdir():
|
||||
if not target_folder.is_dir():
|
||||
continue
|
||||
|
||||
target_name = target_folder.name
|
||||
|
||||
# 查找所有模型文件
|
||||
for model_file in target_folder.rglob("*.pkl"):
|
||||
# 从文件名提取信息(假设格式为:{preprocess}_{model}_{split}.pkl)
|
||||
model_info = {
|
||||
'target': target_name,
|
||||
'model_file': str(model_file),
|
||||
'preprocess': 'Unknown',
|
||||
'model': 'Unknown',
|
||||
'split_method': 'Unknown'
|
||||
|
||||
# 递归扫描 *.joblib 和 *.pkl,兼容 artifacts_dir/target_name/ 的所有子目录层级
|
||||
model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl"))
|
||||
|
||||
for model_file in model_files:
|
||||
# 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构)
|
||||
target_name = model_file.parent.name
|
||||
stem = model_file.stem
|
||||
|
||||
# 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib
|
||||
# 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段
|
||||
parts = stem.split('_', 2)
|
||||
preprocess = parts[1] if len(parts) > 1 else 'Unknown'
|
||||
model_name_str = parts[2] if len(parts) > 2 else stem
|
||||
|
||||
# 尝试从 joblib/pkl 读取元数据,提取性能指标
|
||||
metrics = {}
|
||||
try:
|
||||
data = joblib.load(model_file)
|
||||
metadata = data.get('metadata', {})
|
||||
metrics = {
|
||||
'train_r2': metadata.get('train_r2', 'N/A'),
|
||||
'test_r2': metadata.get('test_r2', 'N/A'),
|
||||
'test_rmse': metadata.get('test_rmse', 'N/A'),
|
||||
'train_rmse': metadata.get('train_rmse', 'N/A'),
|
||||
'train_mae': metadata.get('train_mae', 'N/A'),
|
||||
'test_mae': metadata.get('test_mae', 'N/A'),
|
||||
'cv_mean': metadata.get('cv_mean', 'N/A'),
|
||||
}
|
||||
|
||||
# 尝试从文件名解析
|
||||
parts = model_file.stem.split('_')
|
||||
if len(parts) >= 3:
|
||||
model_info['preprocess'] = parts[0]
|
||||
model_info['model'] = parts[1]
|
||||
model_info['split_method'] = parts[2]
|
||||
|
||||
all_results.append(model_info)
|
||||
|
||||
# 如果有训练结果数据,使用实际指标
|
||||
# 否则创建一个基本的摘要
|
||||
except Exception:
|
||||
pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A
|
||||
|
||||
all_results.append({
|
||||
'target': target_name,
|
||||
'model_file': str(model_file),
|
||||
'preprocess': preprocess,
|
||||
'model': model_name_str,
|
||||
**metrics,
|
||||
})
|
||||
|
||||
summary_data = []
|
||||
for result in all_results:
|
||||
summary_data.append({
|
||||
'目标参数': result['target'],
|
||||
'预处理方法': result['preprocess'],
|
||||
'模型名称': result['model'],
|
||||
'划分方法': result['split_method'],
|
||||
'模型文件': result['model_file']
|
||||
'模型文件': result['model_file'],
|
||||
'训练集R²': result.get('train_r2', 'N/A'),
|
||||
'测试集R²': result.get('test_r2', 'N/A'),
|
||||
'测试集RMSE': result.get('test_rmse', 'N/A'),
|
||||
'训练集RMSE': result.get('train_rmse', 'N/A'),
|
||||
'训练集MAE': result.get('train_mae', 'N/A'),
|
||||
'测试集MAE': result.get('test_mae', 'N/A'),
|
||||
'CV均值': result.get('cv_mean', 'N/A'),
|
||||
})
|
||||
|
||||
|
||||
if not summary_data:
|
||||
print("警告:未找到模型文件,生成空摘要")
|
||||
summary_data = [{
|
||||
'目标参数': 'No Data',
|
||||
'预处理方法': 'N/A',
|
||||
'模型名称': 'N/A',
|
||||
'划分方法': 'N/A',
|
||||
'模型文件': 'N/A'
|
||||
'模型文件': 'N/A',
|
||||
'训练集R²': 'N/A',
|
||||
'测试集R²': 'N/A',
|
||||
'测试集RMSE': 'N/A',
|
||||
'训练集RMSE': 'N/A',
|
||||
'训练集MAE': 'N/A',
|
||||
'测试集MAE': 'N/A',
|
||||
'CV均值': 'N/A',
|
||||
}]
|
||||
|
||||
|
||||
df_summary = pd.DataFrame(summary_data)
|
||||
|
||||
if output_path is None:
|
||||
|
||||
@ -96,8 +96,14 @@ class BandMathCalculator:
|
||||
|
||||
print(f"计算表达式: {calc_expression}")
|
||||
|
||||
# 安全地计算表达式
|
||||
result = eval(calc_expression)
|
||||
# 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf
|
||||
import numpy as np
|
||||
try:
|
||||
# 即使 calc_expression 含有纯字符 nan,也能被 np.nan 安全接管
|
||||
result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np})
|
||||
except Exception as e:
|
||||
print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan")
|
||||
result = np.nan
|
||||
|
||||
# 返回结果
|
||||
if var_name:
|
||||
|
||||
Reference in New Issue
Block a user