Compare commits

...

11 Commits

15 changed files with 2256 additions and 1428 deletions

View File

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""业务步骤层模块"""
from src.core.steps.water_mask_step import WaterMaskStep
from src.core.steps.glint_detection_step import GlintDetectionStep
from src.core.steps.glint_removal_step import GlintRemovalStep
from src.core.steps.data_preparation_step import DataPreparationStep
from src.core.steps.modeling_step import ModelingStep
from src.core.steps.prediction_step import PredictionStep
from src.core.steps.mapping_step import MappingStep
__all__ = [
"WaterMaskStep",
"GlintDetectionStep",
"GlintRemovalStep",
"DataPreparationStep",
"ModelingStep",
"PredictionStep",
"MappingStep",
]

View File

@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
"""
数据准备步骤
包含 step4_process_csv, step5_extract_training_spectra, step5_5_calculate_water_quality_indices
"""
import time
from pathlib import Path
from typing import Optional, List, Union, Callable, Dict
import pandas as pd
import numpy as np
class DataPreparationStep:
"""数据准备步骤"""
# ---- Step 4: 处理CSV文件 ----
@staticmethod
def process_csv(
csv_path: str,
output_dir: Union[str, Path] = "./4_processed_data",
callback: Optional[Callable] = None,
) -> str:
"""处理CSV文件筛选剔除异常值"""
from src.preprocessing.process_water_quality_data import process_water_quality_data
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / "processed_data.csv")
def notify(status, msg=""):
if callback:
callback("步骤4", status, msg)
print("\n" + "=" * 80)
print("步骤4: 处理CSV文件筛选剔除异常值")
print("=" * 80)
step_start_time = time.time()
if Path(output_path).exists():
print(f"检测到已存在的处理后CSV文件直接使用: {output_path}")
notify("skipped", f"处理后的CSV文件已设置: {output_path}")
return output_path
process_water_quality_data(csv_path, output_path)
notify("completed", f"处理后的CSV文件已保存: {output_path}")
return output_path
# ---- Step 5: 提取训练样本点光谱 ----
@staticmethod
def extract_training_spectra(
deglint_img_path: Optional[str] = None,
radius: int = 5,
source_epsg: int = 4326,
csv_path: Optional[str] = None,
boundary_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
water_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./5_training_spectra",
callback: Optional[Callable] = None,
) -> str:
"""根据采样点坐标在去耀斑影像中提取平均光谱"""
from src.core.glint_removal.get_spectral import get_spectral_in_coor
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / "training_spectra.csv")
def notify(status, msg=""):
if callback:
callback("步骤5", status, msg)
print("\n" + "=" * 80)
print("步骤5: 提取训练样本点的平均光谱")
print("=" * 80)
step_start_time = time.time()
if deglint_img_path is None:
raise ValueError("必须提供 deglint_img_path 参数")
if csv_path is None:
raise ValueError("必须提供 csv_path 参数")
if Path(output_path).exists():
print(f"检测到已存在的训练光谱数据文件,直接使用: {output_path}")
notify("skipped", f"训练光谱数据已设置: {output_path}")
return output_path
# 确保水体掩膜存在
final_boundary_path = boundary_path
if final_boundary_path is None and water_mask_path is not None:
final_boundary_path = water_mask_path
# 【新增安全防护】智能拦截矢量 .shp强制替换为步骤 1 生成的栅格 .dat
if final_boundary_path and str(final_boundary_path).lower().endswith('.shp'):
# 向上追溯查找 1_water_mask 目录下的 dat 替身
possible_dat = Path(deglint_img_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
if not possible_dat.exists() and output_path:
possible_dat = Path(output_path).parent.parent / "1_water_mask" / "water_mask_from_shp.dat"
if possible_dat.exists():
print(f"💡 智能拦截:检测到输入掩膜为矢量 .shp自动切换为已生成的栅格掩膜: {possible_dat}")
final_boundary_path = str(possible_dat)
else:
print(f"⚠️ 警告:检测到输入掩膜为 .shp 且未找到对应 .dat 替身,可能导致底层读取失败。")
flare_path = glint_mask_path
if flare_path:
print(f"光谱提取使用耀斑掩膜: {flare_path}")
get_spectral_in_coor(
deglint_img_path, csv_path, output_path,
radius=radius, flare_path=flare_path,
boundary_path=final_boundary_path, source_epsg=source_epsg
)
notify("completed", f"训练光谱数据已保存: {output_path}")
return output_path
# ---- Step 5.5: 计算水质光谱指数 ----
@staticmethod
def calculate_water_quality_indices(
training_spectra_path: Optional[str] = None,
formula_csv_file: Optional[str] = None,
formula_names: Optional[List[str]] = None,
output_file: Optional[str] = None,
enabled: bool = True,
output_dir: Union[str, Path] = "./6_water_quality_indices",
callback: Optional[Callable] = None,
) -> Optional[str]:
"""根据训练光谱计算水质光谱指数(使用 band_math 方法)"""
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤5.5", status, msg)
print("\n" + "=" * 80)
print("步骤5.5: 计算水质光谱指数使用band_math方法")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过水质指数计算enabled=False")
notify("skipped", "跳过水质指数计算")
return None
if training_spectra_path is None:
raise ValueError("必须提供 training_spectra_path 参数")
if formula_csv_file is None:
raise ValueError("必须提供 formula_csv_file 参数")
if output_file:
output_path = str(Path(output_file))
else:
output_path = str(output_dir / "water_quality_indices.csv")
if Path(output_path).exists():
print(f"检测到已存在的水质指数文件,直接使用: {output_path}")
notify("skipped", f"水质指数数据已设置: {output_path}")
return output_path
from src.utils.band_math import BandMathCalculator
calculator = BandMathCalculator(training_spectra_path)
result_df = calculator.process_formulas_from_csv(
formula_csv_file=formula_csv_file,
formula_names=formula_names,
output_file=output_path
)
if result_df is None:
raise ValueError("计算水质指数失败请检查公式CSV文件格式")
notify("completed", f"水质指数已保存: {output_path}")
return output_path

View File

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
"""
步骤2: 耀斑区域检测
支持多种检测方法: otsu, zscore, percentile, iqr, adaptive, multi_band
"""
import time
from pathlib import Path
from typing import Optional, List, Union
class GlintDetectionStep:
"""耀斑区域检测步骤"""
@staticmethod
def run(
img_path: str,
glint_wave: float = 750.0,
method: str = "otsu",
z_threshold: float = 2.5,
percentile: float = 95.0,
iqr_multiplier: float = 1.5,
window_size: int = 15,
multi_band_waves: Optional[List[float]] = None,
sub_method: str = "zscore",
weights: Optional[List[float]] = None,
max_area: Optional[int] = None,
buffer_size: Optional[int] = None,
water_mask_path: Optional[str] = None,
glint_dir: Union[str, Path] = "./2_glint",
callback: Optional[callable] = None,
) -> str:
"""
执行耀斑区域检测
Args:
img_path: 输入影像文件路径
glint_wave: 用于耀斑检测的波段波长nm
method: 检测方法 ('otsu' | 'zscore' | 'percentile' | 'iqr' | 'adaptive' | 'multi_band')
z_threshold: Z-score 方法阈值(默认 2.5
percentile: 百分位数阈值(默认 95.0
iqr_multiplier: IQR 倍数(默认 1.5
window_size: 自适应阈值窗口大小(默认 15
multi_band_waves: 多波段方法的波长列表,如 [750, 800, 850]
sub_method: 多波段方法的子方法(默认 'zscore'
weights: 多波段方法的权重列表None 表示等权重)
max_area: 最大连通域面积阈值(像素),超过则过滤
buffer_size: 岸边缓冲区大小(像素),用于去除岸边附近错误掩膜
water_mask_path: 水域掩膜文件路径dat 格式优先)
glint_dir: 工作目录
callback: 回调函数
Returns:
耀斑掩膜文件路径 (.dat)
"""
from src.utils.find_severe_glint_area import find_severe_glint_area
glint_dir = Path(glint_dir)
glint_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤2", status, msg)
print("\n" + "=" * 80)
print("步骤2: 找到耀斑区域")
print("=" * 80)
step_start_time = time.time()
# 确定水体掩膜路径
if water_mask_path is not None and Path(water_mask_path).exists():
final_water_mask_path = water_mask_path
else:
final_water_mask_path = None
output_path = str(glint_dir / "severe_glint_area.dat")
# 跳过已存在的文件
if Path(output_path).exists():
print(f"检测到已存在的耀斑掩膜文件,直接使用: {output_path}")
notify("skipped", f"耀斑掩膜已设置: {output_path}")
return output_path
# 构建检测参数字典
kwargs = {
"method": method,
"z_threshold": z_threshold,
"percentile": percentile,
"iqr_multiplier": iqr_multiplier,
"window_size": window_size,
}
if method == "multi_band":
if multi_band_waves is not None:
kwargs["multi_band_waves"] = multi_band_waves
if sub_method is not None:
kwargs["sub_method"] = sub_method
if weights is not None:
kwargs["weights"] = weights
if max_area is not None:
kwargs["max_area"] = max_area
if buffer_size is not None:
kwargs["buffer_size"] = buffer_size
glint_mask_path = find_severe_glint_area(
img_path, final_water_mask_path, glint_wave, output_path, **kwargs
)
print(f"耀斑掩膜已生成: {glint_mask_path}")
print(f"使用检测方法: {method}")
notify("completed", f"耀斑掩膜已生成: {glint_mask_path}")
return glint_mask_path

View File

@ -0,0 +1,375 @@
# -*- coding: utf-8 -*-
"""
步骤3: 去除耀斑
支持多种方法: subtract_nir, regression_slope, oxygen_absorption, kutser, goodman, hedley, sugar
每种方法都会:
1. 准备水域掩膜(支持 shp 自动转 dat
2. 调用对应的算法类执行处理
3. 复制 hdr 文件到输出影像
"""
import os
import time
from pathlib import Path
from typing import Optional, List, Union, Callable
import numpy as np
def _safe_rename(src_bsq: str, src_hdr: str, dest_bsq: str, dest_hdr: str) -> str:
"""将底层硬编码生成的 .bsq + .hdr 文件对重命名到用户指定的 output_path
使用 os.remove + os.rename 确保原子覆盖(不等 os.replace 的跨设备行为),
resolve() 断路防止同路径 self-rename 报错。
Returns:
dest_bsq 路径
"""
src_bsq_p = Path(src_bsq)
src_hdr_p = Path(src_hdr)
dest_bsq_p = Path(dest_bsq)
dest_hdr_p = Path(dest_hdr)
if str(src_bsq_p.resolve()) == str(dest_bsq_p.resolve()):
return dest_bsq
if dest_bsq_p.exists():
os.remove(dest_bsq_p)
if dest_hdr_p.exists():
os.remove(dest_hdr_p)
if src_bsq_p.exists():
os.rename(src_bsq_p, dest_bsq_p)
if src_hdr_p.exists():
os.rename(src_hdr_p, dest_hdr_p)
return dest_bsq
class GlintRemovalStep:
"""去除耀斑步骤"""
@staticmethod
def run(
img_path: str,
method: str = "subtract_nir",
start_wave: Optional[float] = None,
end_wave: Optional[float] = None,
json_path: Optional[str] = None,
left_shoulder_wave: Optional[float] = None,
valley_wave: Optional[float] = None,
right_shoulder_wave: Optional[float] = None,
water_mask: Optional[Union[str, np.ndarray]] = None,
interpolated_img_path: Optional[str] = None,
interpolate_zeros: bool = False,
interpolation_method: str = "nearest",
enabled: bool = True,
# Kutser 参数
kutser_shp_path: Optional[str] = None,
oxy_band: int = 38,
lower_oxy: int = 36,
upper_oxy: int = 49,
nir_band: int = 47,
# Goodman 参数
nir_lower: int = 25,
nir_upper: int = 37,
goodman_A: float = 0.000019,
goodman_B: float = 0.1,
# Hedley 参数
hedley_shp_path: Optional[str] = None,
hedley_nir_band: int = 47,
# SUGAR 参数
sugar_bounds: Optional[List[tuple]] = None,
sugar_sigma: float = 1.0,
sugar_estimate_background: bool = True,
sugar_glint_mask_method: str = "cdf",
sugar_iter: Optional[int] = 3,
sugar_termination_thresh: float = 20.0,
# 内部工具函数
_get_image_geo_info=None,
_load_image_as_array=None,
_save_bands_as_image=None,
_copy_hdr_info=None,
_prepare_water_mask_for_algorithm=None,
_interpolate_zero_pixels_batch=None,
deglint_dir: Union[str, Path] = "./3_deglint",
water_mask_dir: Union[str, Path] = "./1_water_mask",
callback: Optional[Callable] = None,
output_path: Optional[str] = None,
) -> str:
"""
执行去除耀斑处理
Args:
img_path: 输入影像文件路径
method: 去耀斑方法
...(其余参数同主类 step3_remove_glint
Returns:
去除耀斑后的影像文件路径
"""
from src.core.glint_removal.Kutser import Kutser
from src.core.glint_removal.Goodman import Goodman
from src.core.glint_removal.Hedley import Hedley
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
from src.core.utils.gdal_helper import (
get_image_geo_info as _default_get_geo,
load_image_as_array as _default_load,
save_bands_as_image as _default_save_bands,
copy_hdr_info as _default_copy_hdr,
)
from src.core.utils.mask_converter import (
prepare_water_mask_for_algorithm as _default_prepare,
)
# 使用提供的函数或默认函数
if _get_image_geo_info is None:
_get_image_geo_info = _default_get_geo
if _load_image_as_array is None:
_load_image_as_array = _default_load
if _save_bands_as_image is None:
_save_bands_as_image = _default_save_bands
if _copy_hdr_info is None:
_copy_hdr_info = _default_copy_hdr
if _prepare_water_mask_for_algorithm is None:
_prepare_water_mask_for_algorithm = _default_prepare
deglint_dir = Path(deglint_dir)
deglint_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤3", status, msg)
print("\n" + "=" * 80)
print("步骤3: 去除耀斑")
print("=" * 80)
step_start_time = time.time()
# 方法名标准化
raw_method = str(method).lower()
if "kutser" in raw_method:
method = "kutser"
elif "goodman" in raw_method:
method = "goodman"
elif "hedley" in raw_method:
method = "hedley"
elif "sugar" in raw_method:
method = "sugar"
# 如果未启用,直接返回原始影像
if not enabled:
print("已设置跳过去除耀斑enabled=False将直接使用原始影像。")
notify("skipped", "跳过去耀斑,使用原始影像")
return img_path
# ---- 确定水域掩膜 ----
final_water_mask = water_mask
if final_water_mask is not None and str(final_water_mask).lower().endswith(".shp"):
# shp 自动替换为 dat
dat_mask = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
if Path(dat_mask).exists():
print(f"检测到输入掩膜为 .shp自动替换为栅格掩膜: {dat_mask}")
final_water_mask = dat_mask
if final_water_mask is None:
dat_mask_default = str(Path(water_mask_dir) / "water_mask_from_shp.dat")
if Path(dat_mask_default).exists():
final_water_mask = dat_mask_default
print(f"使用步骤1生成的水域掩膜: {final_water_mask}")
# ---- 步骤3.1: 0值像素插值 ----
if interpolate_zeros:
print("\n" + "-" * 80)
print("步骤3.1: 对0值像素进行插值")
print("-" * 80)
interp_start_time = time.time()
if _interpolate_zero_pixels_batch is None:
from src.core.algorithms.interpolation.interpolator import (
interpolate_zero_pixels_batch as _interp_batch,
)
_interpolate_zero_pixels_batch = _interp_batch
interp_result, _ = _interpolate_zero_pixels_batch(
img_path=img_path,
interpolation_method=interpolation_method,
output_path=None,
water_mask=final_water_mask,
deglint_dir=str(deglint_dir),
callback_progress=lambda msg: print(f" {msg}"),
)
img_path = interp_result
interp_end_time = time.time()
print(f"插值完成,使用插值后的影像: {img_path}")
# ---- 获取影像信息 ----
geotransform, projection, width, height, n_bands = _get_image_geo_info(img_path)
print(f"影像尺寸: {width} x {height} x {n_bands}")
mask_for_algorithm = _prepare_water_mask_for_algorithm(
final_water_mask, (height, width), geotransform, projection, img_path
)
# ==================== Kutser ====================
if method == "kutser":
print(f"使用方法: Kutser (氧吸收波段={oxy_band}, NIR波段={nir_band})")
hardcoded_bsq = str(deglint_dir / "deglint_kutser.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
# 将用户指定的 output_path 标准化为 .bsq 路径
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
kutser = Kutser(
img_path,
shp_path=None,
oxy_band=oxy_band,
lower_oxy=lower_oxy,
upper_oxy=upper_oxy,
NIR_band=nir_band,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
kutser.get_corrected_bands()
if Path(hardcoded_bsq).exists():
_copy_hdr_info(img_path, hardcoded_bsq)
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
raise RuntimeError(f"Kutser算法未生成输出文件: {hardcoded_bsq}")
# ==================== Goodman ====================
elif method == "goodman":
print(f"使用方法: Goodman (NIR波段范围: {nir_lower}-{nir_upper})")
hardcoded_bsq = str(deglint_dir / "deglint_goodman.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
goodman = Goodman(
img_path,
NIR_lower=nir_lower,
NIR_upper=nir_upper,
A=goodman_A,
B=goodman_B,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
corrected_bands = goodman.get_corrected_bands()
if not Path(hardcoded_bsq).exists():
_save_bands_as_image(corrected_bands, hardcoded_bsq, geotransform, projection)
_copy_hdr_info(img_path, hardcoded_bsq)
del corrected_bands
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
# ==================== Hedley ====================
elif method == "hedley":
print(f"使用方法: Hedley (NIR波段={hedley_nir_band})")
hardcoded_bsq = str(deglint_dir / "deglint_hedley.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
hedley = Hedley(
img_path,
shp_path=None,
NIR_band=hedley_nir_band,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
hedley.get_corrected_bands()
if Path(hardcoded_bsq).exists():
_copy_hdr_info(img_path, hardcoded_bsq)
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
raise RuntimeError(f"Hedley算法未生成输出文件: {hardcoded_bsq}")
# ==================== SUGAR ====================
elif method == "sugar":
glint_method_raw = str(sugar_glint_mask_method).lower()
if "cdf" in glint_method_raw or "累积" in glint_method_raw:
sugar_glint_mask_method_fixed = "cdf"
elif "otsu" in glint_method_raw or "大津" in glint_method_raw:
sugar_glint_mask_method_fixed = "otsu"
else:
sugar_glint_mask_method_fixed = "cdf"
print(
f"使用方法: SUGAR (迭代次数={sugar_iter}, 掩膜方法={sugar_glint_mask_method_fixed})"
)
hardcoded_bsq = str(deglint_dir / "deglint_sugar.bsq")
hardcoded_hdr = hardcoded_bsq.replace(".bsq", ".hdr")
if output_path:
final_bsq = output_path.replace('.dat', '.bsq').replace('.tif', '.bsq')
final_hdr = final_bsq.replace(".bsq", ".hdr")
else:
final_bsq = hardcoded_bsq
final_hdr = hardcoded_hdr
if Path(hardcoded_bsq).exists():
print(f"检测到已存在的去耀斑影像文件,直接使用: {hardcoded_bsq}")
notify("skipped", f"去耀斑影像已设置: {hardcoded_bsq}")
return _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
if sugar_bounds is None:
sugar_bounds = [(1, 2)]
correction_iterative(
img_path,
iter=sugar_iter,
bounds=sugar_bounds,
estimate_background=sugar_estimate_background,
glint_mask_method=sugar_glint_mask_method_fixed,
termination_thresh=sugar_termination_thresh,
water_mask=mask_for_algorithm,
output_path=hardcoded_bsq,
)
if Path(hardcoded_bsq).exists():
_copy_hdr_info(img_path, hardcoded_bsq)
final = _safe_rename(hardcoded_bsq, hardcoded_hdr, final_bsq, final_hdr)
notify("completed", f"去耀斑影像已生成: {final}")
return final
raise RuntimeError(f"SUGAR算法未生成输出文件: {hardcoded_bsq}")
else:
raise ValueError(
f"不支持的方法: {method}。支持的方法: kutser, goodman, hedley, sugar"
)

View File

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""
成图步骤
包含 step9_generate_distribution_map
"""
import time
from pathlib import Path
from typing import Optional, Union, Callable
class MappingStep:
"""成图步骤"""
@staticmethod
def generate_distribution_map(
prediction_csv_path: str,
boundary_shp_path: str,
output_image_path: Optional[str] = None,
resolution: float = 30,
input_crs: str = "EPSG:32651",
output_crs: str = "EPSG:4326",
show_sample_points: bool = False,
base_map_tif: Optional[str] = None,
use_distance_diffusion: bool = True,
max_diffusion_distance: Optional[float] = None,
diffusion_power: float = 2,
diffusion_n_neighbors: int = 15,
cmap: Optional[str] = None,
expand_ratio: float = 0.05,
output_dir: Union[str, Path] = "./14_visualization",
callback: Optional[Callable] = None,
) -> str:
"""
根据采样点的坐标和反演的实测参数,通过插值方法得到水质参数可视化分布图
Args:
prediction_csv_path: 预测结果CSV文件路径前两列为经纬度第三列为预测值
boundary_shp_path: 边界shapefile文件路径
output_image_path: 输出图片路径如果为None自动生成
resolution: 插值网格分辨率(米)
input_crs: 输入坐标系
output_crs: 输出坐标系
show_sample_points: 是否在图上显示采样点
base_map_tif: 底图TIF路径
use_distance_diffusion: 是否启用距离扩散补全边界
max_diffusion_distance: 距离扩散最大距离(米)
diffusion_power: 距离扩散幂参数
diffusion_n_neighbors: 距离扩散最近邻数量
cmap: 颜色映射名称None表示自动识别
expand_ratio: 边界外扩比例0-1之间
output_dir: 输出目录
callback: 回调函数
Returns:
可视化分布图文件路径
"""
from src.postprocessing.map import ContentMapper
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤9", status, msg)
print("\n" + "=" * 80)
print("步骤9: 生成水质参数可视化分布图")
print("=" * 80)
step_start_time = time.time()
if output_image_path is None:
csv_name = Path(prediction_csv_path).stem
output_image_path = str(output_dir / f"{csv_name}_distribution.png")
if Path(output_image_path).exists():
print(f"检测到已存在的分布图文件,直接使用: {output_image_path}")
notify("skipped", f"可视化分布图已设置: {output_image_path}")
return output_image_path
mapper = ContentMapper(input_crs=input_crs, output_crs=output_crs)
mapper_kwargs = {
"resolution": resolution,
"show_sample_points": show_sample_points,
"use_distance_diffusion": use_distance_diffusion,
"diffusion_power": diffusion_power,
"diffusion_n_neighbors": diffusion_n_neighbors,
"expand_ratio": expand_ratio,
}
optional_kwargs = {
"base_map_tif": base_map_tif,
"max_diffusion_distance": max_diffusion_distance,
"cmap": cmap,
}
mapper_kwargs.update({k: v for k, v in optional_kwargs.items() if v is not None})
mapper.process_data(
csv_file=prediction_csv_path,
shp_file=boundary_shp_path,
output_file=output_image_path,
**mapper_kwargs,
)
notify("completed", f"可视化分布图已保存: {output_image_path}")
return output_image_path

View File

@ -0,0 +1,497 @@
# -*- coding: utf-8 -*-
"""
建模步骤
包含 step6_train_models, step6_5_non_empirical_modeling, step6_75_custom_regression
"""
import time
import json
from pathlib import Path
from typing import Optional, List, Union, Callable, Dict
import pandas as pd
import numpy as np
# ============================================================
# 汉化 -> 英文 反向映射字典UI 复选框显示文本 -> 底层算法键名)
# ============================================================
# 模型名称:中文 (缩写) -> 英文键名
MODEL_NAME_MAP = {
"多元线性回归 (MLR)": "LinearRegression",
"岭回归 (Ridge)": "Ridge",
"套索回归 (Lasso)": "Lasso",
"弹性网络 (ElasticNet)": "ElasticNet",
"偏最小二乘 (PLSR)": "PLS",
"决策树 (CART)": "DecisionTree",
"随机森林 (RF)": "RF",
"极端随机树 (ET)": "ExtraTrees",
"极值梯度提升 (XGBoost)": "XGBoost",
"轻量梯度提升 (LightGBM)": "LightGBM",
"类别梯度提升 (CatBoost)": "CatBoost",
"梯度提升树 (GBDT)": "GradientBoosting",
"自适应提升 (AdaBoost)": "AdaBoost",
"支持向量回归 (SVR)": "SVR",
"K近邻回归 (KNN)": "KNN",
"多层感知机 (BP神经网络)": "MLP",
}
# 预处理方法:各种可能的中文变体 -> 标准键名
PREPROC_NAME_MAP = {
# 无处理
"无 (None)": "None",
"None": "None",
# MMS
"最小-最大归一化 (MMS)": "MMS",
"MMS": "MMS",
# SS
"标度化 (SS)": "SS",
"SS": "SS",
# SNV
"标准正态变换 (SNV)": "SNV",
"SNV": "SNV",
# MA
"移动平均 (MA)": "MA",
"MA": "MA",
# SG
"Savitzky-Golay (SG)": "SG",
"SG": "SG",
# MSC
"多元散射校正 (MSC)": "MSC",
"MSC": "MSC",
# D1
"一阶导数 (D1)": "D1",
"D1": "D1",
# D2
"二阶导数 (D2)": "D2",
"D2": "D2",
# DT
"去趋势 (DT)": "DT",
"DT": "DT",
# CT
"中心化 (CT)": "CT",
"CT": "CT",
}
# 数据划分方法:各种可能的中文变体 -> 标准键名
SPLIT_NAME_MAP = {
"SPXY 算法 (考量X-Y空间)": "spxy",
"spxy": "spxy",
"KS 算法 (考量X空间)": "ks",
"ks": "ks",
"随机划分 (Random)": "random",
"random": "random",
}
def _normalize_model_names(model_names: List[str]) -> List[str]:
"""清洗模型名称列表:将汉化显示文本还原为英文键名"""
result = []
for name in model_names:
if name in MODEL_NAME_MAP:
result.append(MODEL_NAME_MAP[name])
else:
# 已经是英文键名,直接保留
result.append(name)
return result
def _normalize_preprocessing_methods(methods: List[str]) -> List[str]:
"""清洗预处理方法列表:将汉化显示文本还原为标准键名"""
result = []
for method in methods:
if method in PREPROC_NAME_MAP:
result.append(PREPROC_NAME_MAP[method])
else:
# 已经是标准键名,直接保留
result.append(method)
return result
def _normalize_split_methods(methods: List[str]) -> List[str]:
"""清洗数据划分方法列表:将汉化显示文本还原为标准键名"""
result = []
for method in methods:
if method in SPLIT_NAME_MAP:
result.append(SPLIT_NAME_MAP[method])
else:
# 已经是标准键名,直接保留
result.append(method)
return result
class ModelingStep:
"""建模步骤"""
# ---- Step 6: 训练机器学习模型 ----
@staticmethod
def train_models(
feature_start_column: str = "374.285004",
preprocessing_methods: Optional[List[str]] = None,
model_names: Optional[List[str]] = None,
split_methods: Optional[List[str]] = None,
cv_folds: int = 5,
training_csv_path: Optional[str] = None,
output_dir: Union[str, Path] = "./7_Supervised_Model_Training",
callback: Optional[Callable] = None,
_report_generator=None,
) -> str:
"""使用采样点光谱和实测值建立机器学习模型"""
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤6", status, msg)
print("\n" + "=" * 80)
print("步骤6: 训练机器学习模型")
print("=" * 80)
step_start_time = time.time()
if training_csv_path is None:
raise ValueError("必须提供 training_csv_path 参数")
# 检查模型目录是否已有模型
if output_dir.exists() and any(output_dir.iterdir()):
has_models = False
for item in output_dir.iterdir():
if item.is_dir():
model_files = (
list(item.glob("*.pkl"))
+ list(item.glob("*.joblib"))
+ list(item.glob("*.h5"))
)
if model_files:
has_models = True
break
if has_models:
print(f"检测到已存在的模型文件,直接使用: {output_dir}")
notify("skipped", f"模型目录已设置: {output_dir}")
return str(output_dir)
if preprocessing_methods is None:
preprocessing_methods = ["None", "MMS", "SS", "SNV", "MA", "SG", "MSC", "D1", "D2", "DT", "CT"]
if model_names is None:
model_names = ["SVR", "RF", "Ridge", "Lasso"]
if split_methods is None:
split_methods = ["spxy", "ks", "random"]
# ---- 汉化清洗:将 UI 传来的中文/混合名称转换为底层英文键名 ----
preprocessing_methods = _normalize_preprocessing_methods(preprocessing_methods)
model_names = _normalize_model_names(model_names)
split_methods = _normalize_split_methods(split_methods)
print(f"[参数清洗] 预处理方法: {preprocessing_methods}")
print(f"[参数清洗] 模型名称: {model_names}")
print(f"[参数清洗] 划分方法: {split_methods}")
modeler = WaterQualityModelingBatch(str(output_dir))
modeler.train_models_batch(
csv_path=training_csv_path,
feature_start_column=feature_start_column,
preprocessing_methods=preprocessing_methods,
model_names=model_names,
split_methods=split_methods,
cv_folds=cv_folds,
)
print(f"模型训练完成,结果保存在: {output_dir}")
if _report_generator is not None:
try:
summary_path = _report_generator.generate_training_summary(str(output_dir))
print(f"训练摘要报告已生成: {summary_path}")
except Exception as e:
print(f"生成训练摘要报告时出错: {e}")
notify("completed", f"模型训练完成: {output_dir}")
return str(output_dir)
# ---- Step 6.5: 非经验统计回归模型训练 ----
@staticmethod
def train_non_empirical_models(
csv_path: Optional[str] = None,
preprocessing_methods: Optional[List[str]] = None,
algorithms: Optional[List[str]] = None,
value_cols: Union[int, Dict[str, int]] = 0,
spectral_start_col: int = 1,
spectral_end_col: Optional[int] = None,
window: int = 5,
output_dir: Optional[str] = None,
enabled: bool = True,
callback: Optional[Callable] = None,
) -> Dict[str, str]:
"""非经验统计回归模型训练"""
def notify(status, msg=""):
if callback:
callback("步骤6.5", status, msg)
print("\n" + "=" * 80)
print("步骤6.5: 非经验统计回归模型训练")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过非经验模型训练enabled=False")
notify("skipped", "跳过的经验模型训练")
return {}
if csv_path is None:
raise ValueError("必须提供 csv_path 参数")
if output_dir is not None:
non_empirical_dir = Path(output_dir)
else:
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
non_empirical_dir.mkdir(parents=True, exist_ok=True)
if preprocessing_methods is None:
preprocessing_methods = ["None"]
if algorithms is None:
algorithms = ["chl_a", "nh3", "mno4", "tn", "tp", "tss"]
if isinstance(value_cols, int):
value_cols_dict = {algorithm: value_cols for algorithm in algorithms}
elif isinstance(value_cols, dict):
value_cols_dict = value_cols
else:
raise ValueError("value_cols 参数必须是整数或字典")
if spectral_end_col is None:
df = pd.read_csv(csv_path)
spectral_end_col = len(df.columns) - 1
all_model_results = {}
for preprocess in preprocessing_methods:
preprocess_dir = non_empirical_dir / preprocess
preprocess_dir.mkdir(parents=True, exist_ok=True)
processed_csv_path = _apply_preprocessing_internal(
csv_path, preprocess, preprocess_dir, spectral_start_col
)
for algorithm in algorithms:
algorithm_value_col = value_cols_dict[algorithm]
print(f"\n训练 {preprocess} + {algorithm} 模型 (实测值列: {algorithm_value_col})...")
model_outpath = str(preprocess_dir / f"{preprocess}_{algorithm}.json")
if Path(model_outpath).exists():
print(f"检测到已存在的模型文件,直接使用: {model_outpath}")
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
continue
try:
from src.core.non_empirical_model_correction import run_model_correction
run_model_correction(
algorithm=algorithm,
csv_file=processed_csv_path if Path(processed_csv_path).exists() else csv_path,
value_col=algorithm_value_col,
spectral_start=spectral_start_col,
spectral_end=spectral_end_col,
model_info_outpath=model_outpath,
window=window,
)
all_model_results[f"{preprocess}_{algorithm}"] = model_outpath
print(f"模型训练完成: {model_outpath}")
except Exception as e:
print(f"训练 {preprocess}_{algorithm} 模型时出错: {e}")
continue
summary_path = _generate_non_empirical_summary(all_model_results, non_empirical_dir)
notify("completed", f"非经验模型训练完成: {non_empirical_dir}")
return all_model_results
# ---- Step 6.75: 自定义回归分析 ----
@staticmethod
def custom_regression(
csv_path: Optional[str] = None,
x_columns: Optional[Union[str, List[str]]] = None,
y_columns: Optional[Union[str, List[str]]] = None,
methods: Union[str, List[str]] = "all",
output_dir: Optional[str] = None,
enabled: bool = True,
callback: Optional[Callable] = None,
work_dir: Union[str, Path] = "./work_dir",
) -> Optional[str]:
"""使用自定义回归方法分析指标与目标参数之间的关系"""
def notify(status, msg=""):
if callback:
callback("步骤6.75", status, msg)
print("\n" + "=" * 80)
print("步骤6.75: 自定义回归分析")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过自定义回归分析enabled=False")
notify("skipped", "跳过自定义回归分析")
return None
if csv_path is None:
raise ValueError("必须提供 csv_path 参数")
if y_columns is None:
raise ValueError("必须指定 y_columns")
if x_columns is None:
raise ValueError("必须指定 x_columns")
if isinstance(x_columns, str):
x_columns = [x_columns]
if isinstance(y_columns, str):
y_columns = [y_columns]
df = pd.read_csv(csv_path)
missing_x = [col for col in x_columns if col not in df.columns]
missing_y = [col for col in y_columns if col not in df.columns]
if missing_x:
raise ValueError(f"自变量列不存在: {missing_x}")
if missing_y:
raise ValueError(f"因变量列不存在: {missing_y}")
if output_dir is None:
custom_regression_dir = Path(work_dir) / "9_Custom_Regression_Modeling"
else:
custom_regression_dir = Path(work_dir) / output_dir
custom_regression_dir.mkdir(parents=True, exist_ok=True)
from src.core.modeling.regression import SingleVariableRegressionAnalysis
analyzer = SingleVariableRegressionAnalysis()
analyzer.batch_single_variable_regression(
data=df,
x_columns=x_columns,
y_columns=y_columns,
methods=methods,
output_dir=str(custom_regression_dir),
)
notify("completed", f"自定义回归结果已保存到目录: {custom_regression_dir}")
return str(custom_regression_dir)
# ============================================================
# 内部辅助函数(供 ModelingStep 内部使用)
# ============================================================
def _apply_preprocessing_internal(
csv_path: str,
preprocess_method: str,
output_dir: Path,
spectral_start_col: int = 4,
) -> str:
"""应用预处理到CSV数据内部函数"""
raw_p = str(preprocess_method).lower()
if raw_p == "none" or "" in raw_p or "跳过" in raw_p:
preprocess_method = "None"
elif raw_p == "mms" or "minmax" in raw_p or "最大最小" in raw_p:
preprocess_method = "MMS"
elif raw_p == "ss" or "标准" in raw_p or "标准化" in raw_p:
preprocess_method = "SS"
elif raw_p == "snv" or "标准正态" in raw_p:
preprocess_method = "SNV"
elif raw_p == "ma" or "移动" in raw_p:
preprocess_method = "MA"
elif raw_p == "sg" or "savitzky" in raw_p or "平滑" in raw_p:
preprocess_method = "SG"
elif raw_p == "msc" or "多元散射" in raw_p:
preprocess_method = "MSC"
elif raw_p in ("d1", "d2", "dt"):
preprocess_method = {"d1": "D1", "d2": "D2", "dt": "DT"}.get(raw_p, raw_p.upper())
elif raw_p == "ct" or "去趋势" in raw_p:
preprocess_method = "CT"
if preprocess_method == "None":
return csv_path
output_filename = f"preprocessed_{preprocess_method}.csv"
output_path = str(output_dir / output_filename)
if Path(output_path).exists():
print(f"检测到已存在的预处理文件,直接使用: {output_path}")
return output_path
df = pd.read_csv(csv_path)
non_spectral_cols = df.iloc[:, :spectral_start_col]
spectral_data = df.iloc[:, spectral_start_col:]
from src.preprocessing.spectral_Preprocessing import Preprocessing
save_path = None
if preprocess_method == "SS":
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training"
models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}")
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
if isinstance(processed_spectral, pd.DataFrame):
processed_df = pd.concat([non_spectral_cols, processed_spectral], axis=1)
else:
processed_spectral_df = pd.DataFrame(
processed_spectral, columns=spectral_data.columns, index=spectral_data.index
)
processed_df = pd.concat([non_spectral_cols, processed_spectral_df], axis=1)
processed_df.to_csv(output_path, index=False)
print(f"预处理完成: {output_path}")
return output_path
def _generate_non_empirical_summary(model_results: Dict[str, str], output_dir: Path) -> str:
"""生成非经验模型训练结果汇总CSV"""
summary_path = str(output_dir / "non_empirical_models_summary.csv")
summary_data = []
for model_key, model_path in model_results.items():
try:
parts = model_key.split("_")
preprocess_method = parts[0]
algorithm_name = "_".join(parts[1:]) if len(parts) > 2 else parts[1]
with open(model_path, "r", encoding="utf-8") as f:
model_info = json.load(f)
accuracy_list = model_info.get("accuracy", [])
summary_row = {
"Preprocessing Method": preprocess_method,
"Algorithm Name": algorithm_name,
"Model Type": model_info.get("model_type", ""),
"Coefficient Count": len(model_info.get("model_info", [])),
"Average Accuracy(%)": np.mean(accuracy_list) if accuracy_list else 0,
"Min Accuracy(%)": np.min(accuracy_list) if accuracy_list else 0,
"Max Accuracy(%)": np.max(accuracy_list) if accuracy_list else 0,
"Sample Count": len(model_info.get("long", [])),
"Model File": model_path,
}
coefficients = model_info.get("model_info", [])
for i, coeff in enumerate(coefficients[:5]):
summary_row[f"系数_{i+1}"] = coeff
summary_data.append(summary_row)
except Exception as e:
print(f"读取模型文件 {model_path} 时出错: {e}")
continue
if summary_data:
df_summary = pd.DataFrame(summary_data)
df_summary.to_csv(summary_path, index=False, encoding="utf-8-sig")
print(f"汇总文件已生成: {summary_path}")
else:
print("警告: 没有有效的模型数据可汇总")
summary_path = ""
return summary_path

View File

@ -0,0 +1,350 @@
# -*- coding: utf-8 -*-
"""
预测步骤
包含 step7_generate_sampling_points, step8_predict_water_quality,
step8_5_predict_with_non_empirical_models, step8_75_predict_with_custom_regression
"""
import time
from pathlib import Path
from typing import Optional, List, Union, Callable, Dict
class PredictionStep:
"""预测步骤"""
# ---- Step 7: 生成采样点并提取光谱 ----
@staticmethod
def generate_sampling_points(
deglint_img_path: Optional[str] = None,
interval: int = 50,
sample_radius: int = 5,
chunk_size: int = 1000,
water_mask_path: Optional[str] = None,
glint_mask_path: Optional[str] = None,
output_dir: Union[str, Path] = "./10_sampling",
callback: Optional[Callable] = None,
) -> str:
"""生成水域掩膜内且耀斑掩膜外的采样点,统计平均光谱"""
from pathlib import Path
from src.utils.sampling import get_spectral_sampling_points_chunked
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = str(output_dir / "sampling_spectra.csv")
def notify(status, msg=""):
if callback:
callback("步骤7", status, msg)
print("\n" + "=" * 80)
print("步骤7: 生成预测采样点并提取光谱")
print("=" * 80)
step_start_time = time.time()
if deglint_img_path is None:
raise ValueError("必须提供 deglint_img_path 参数")
# 1. 初始归一化与安全转换
original_path = Path(deglint_img_path)
final_deglint_path = original_path
# 2. 智能回溯探测:如果当前路径不存在,或者后缀是前端死板的 .dat
if not final_deglint_path.exists() or final_deglint_path.suffix.lower() == '.dat':
print(f"🔍 智能探测:输入去耀斑路径不存在或为 .dat 占位符 ({final_deglint_path}),正在向上搜索真实产物...")
# 定位到预期的 3_deglint 根目录
possible_dir = original_path.parent
if possible_dir.name != '3_deglint' and Path(output_path).parent.parent.exists():
possible_dir = Path(output_path).parent.parent / "3_deglint"
if possible_dir.exists():
# 搜寻该目录下所有真实存在的 .bsq 文件(接管 goodman/sugar/kutser/hedley 的硬编码产物)
existing_bsqs = list(possible_dir.glob("*.bsq"))
if existing_bsqs:
final_deglint_path = existing_bsqs[0]
print(f"💡 智能拦截成功:自动寻回底层真实去耀斑影像: {final_deglint_path}")
else:
final_deglint_path = original_path.with_suffix('.bsq')
else:
final_deglint_path = original_path.with_suffix('.bsq')
deglint_img_str = str(final_deglint_path)
if Path(output_path).exists():
print(f"检测到已存在的采样点光谱数据文件,直接使用: {output_path}")
notify("skipped", f"采样点光谱数据已设置: {output_path}")
return output_path
glint_mask_to_use = glint_mask_path
if glint_mask_to_use is None:
print("未检测到耀斑掩膜,将在采样点生成时不做耀斑区域剔除。")
# 传递极度安全的 deglint_img_str 进底层
get_spectral_sampling_points_chunked(
deglint_img_str, water_mask_path, glint_mask_to_use,
output_path, interval, sample_radius, chunk_size
)
notify("completed", f"采样点光谱数据已保存: {output_path}")
return output_path
# ---- Step 8: 机器学习模型预测水质参数 ----
@staticmethod
def predict_water_quality(
sampling_csv_path: str,
models_dir: Optional[str] = None,
metric: str = "test_r2",
prediction_column: str = "prediction",
output_dir: Union[str, Path] = "./11_12_13_predictions/Machine_Learning_Prediction",
callback: Optional[Callable] = None,
_report_generator=None,
) -> Dict[str, str]:
"""将训练好的最佳机器学习模型应用到采样点光谱上,预测水质参数"""
from src.core.prediction.inference_batch import WaterQualityInference
def notify(status, msg=""):
if callback:
callback("步骤8", status, msg)
print("\n" + "=" * 80)
print("步骤8: 预测水质参数")
print("=" * 80)
step_start_time = time.time()
if models_dir is None:
raise ValueError("必须提供 models_dir 参数")
ml_prediction_dir = Path(output_dir)
ml_prediction_dir.mkdir(parents=True, exist_ok=True)
prediction_files = {}
if ml_prediction_dir.exists():
csv_files = list(ml_prediction_dir.glob("*.csv"))
for csv_file in csv_files:
file_stem = csv_file.stem
if "_prediction" in file_stem:
target_name = file_stem.replace("_prediction", "")
elif "_pred" in file_stem:
target_name = file_stem.replace("_pred", "")
else:
target_name = file_stem
prediction_files[target_name] = str(csv_file)
# 检查是否所有目标参数都有预测文件
if prediction_files:
models_path_obj = Path(models_dir)
if models_path_obj.exists():
target_folders = [d.name for d in models_path_obj.iterdir() if d.is_dir()]
missing_targets = [t for t in target_folders if t not in prediction_files]
if not missing_targets:
print(f"检测到已存在的预测结果文件,直接使用: {ml_prediction_dir}")
notify("skipped", f"预测结果已设置: {ml_prediction_dir}")
return prediction_files
else:
print(f"检测到部分预测结果文件,缺少: {missing_targets},将继续生成...")
inferencer = WaterQualityInference(models_dir)
all_results = inferencer.batch_inference_multi_models(
models_root_dir=models_dir,
sampling_csv_path=sampling_csv_path,
output_dir=str(ml_prediction_dir),
metric=metric,
prediction_column=prediction_column,
output_format="csv",
)
for target_name, result in all_results.items():
if result.get("status") == "success":
prediction_files[target_name] = result["output_file"]
print(f"预测完成,结果保存在: {ml_prediction_dir}")
if _report_generator is not None:
try:
report_path = _report_generator.generate_prediction_report(prediction_files)
print(f"预测结果报告已生成: {report_path}")
except Exception as e:
print(f"生成预测结果报告时出错: {e}")
notify("completed", f"预测完成: {ml_prediction_dir}")
return prediction_files
# ---- Step 8.5: 非经验模型预测 ----
@staticmethod
def predict_with_non_empirical_models(
sampling_csv_path: str,
non_empirical_models_dir: Optional[str] = None,
output_dir: Optional[str] = None,
metric: str = "Average Accuracy(%)",
prediction_column: str = "prediction",
enabled: bool = True,
callback: Optional[Callable] = None,
work_dir: Union[str, Path] = "./work_dir",
) -> Dict[str, str]:
"""使用非经验统计回归模型进行参数预测"""
def notify(status, msg=""):
if callback:
callback("步骤8.5", status, msg)
print("\n" + "=" * 80)
print("步骤8.5: 使用非经验模型进行参数预测")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过非经验模型预测enabled=False")
notify("skipped", "跳过非经验模型预测")
return {}
if non_empirical_models_dir is not None:
final_models_dir = non_empirical_models_dir
else:
default_models_dir = str(Path(work_dir) / "8_Regression_Modeling")
if Path(default_models_dir).exists():
final_models_dir = default_models_dir
else:
raise ValueError("请先执行步骤6.5: 非经验模型训练,或提供 non_empirical_models_dir 参数")
if output_dir is not None:
non_empirical_prediction_dir = Path(output_dir)
else:
non_empirical_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Non_Empirical_Prediction"
non_empirical_prediction_dir.mkdir(parents=True, exist_ok=True)
prediction_files = {}
summary_path = Path(final_models_dir) / "non_empirical_models_summary.csv"
if not summary_path.exists():
raise ValueError(f"未找到非经验模型汇总文件: {summary_path}")
import pandas as pd
df_summary = pd.read_csv(summary_path)
best_models = {}
for algorithm in df_summary["Algorithm Name"].unique():
algorithm_df = df_summary[df_summary["Algorithm Name"] == algorithm]
if metric in algorithm_df.columns:
best_model_row = algorithm_df.nlargest(1, metric)
else:
best_model_row = algorithm_df.iloc[[0]]
best_model_path = best_model_row["Model File"].values[0]
best_preprocess = best_model_row["Preprocessing Method"].values[0]
best_accuracy = best_model_row[metric].values[0] if metric in best_model_row.columns else "N/A"
best_models[algorithm] = {
"model_path": best_model_path,
"preprocess_method": best_preprocess,
"accuracy": best_accuracy,
}
print(f"算法 {algorithm}: 选择 {best_preprocess} (准确率: {best_accuracy})")
pd.read_csv(sampling_csv_path) # just to validate
for algorithm, model_info in best_models.items():
print(f"\n使用 {algorithm} 算法进行预测...")
output_path = str(non_empirical_prediction_dir / f"non_empirical_{algorithm}_{prediction_column}.csv")
if Path(output_path).exists():
print(f"检测到已存在的预测结果文件,直接使用: {output_path}")
prediction_files[algorithm] = output_path
continue
try:
from src.core.non_empirical_retrieval import non_empirical_retrieval
non_empirical_retrieval(
algorithm=algorithm,
model_info_path=model_info["model_path"],
coor_spectral_path=sampling_csv_path,
output_path=output_path,
wave_radius=5,
)
prediction_files[algorithm] = output_path
print(f"预测完成: {output_path}")
except Exception as e:
print(f"使用 {algorithm} 算法预测时出错: {e}")
continue
notify("completed", f"非经验模型预测完成: {non_empirical_prediction_dir}")
return prediction_files
# ---- Step 8.75: 自定义回归模型预测 ----
@staticmethod
def predict_with_custom_regression(
sampling_csv_path: str,
custom_regression_dir: Optional[str] = None,
formula_csv_path: Optional[str] = None,
coordinate_columns: Optional[List[str]] = None,
output_dir: Optional[str] = None,
filename_prefix: str = "custom_regression_prediction",
enabled: bool = True,
callback: Optional[Callable] = None,
work_dir: Union[str, Path] = "./work_dir",
) -> Dict[str, str]:
"""使用自定义回归模型进行参数预测"""
def notify(status, msg=""):
if callback:
callback("步骤8.75", status, msg)
print("\n" + "=" * 80)
print("步骤8.75: 使用自定义回归模型进行参数预测")
print("=" * 80)
step_start_time = time.time()
if not enabled:
print("已设置跳过自定义回归模型预测enabled=False")
notify("skipped", "跳过自定义回归预测")
return {}
if not Path(sampling_csv_path).exists():
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
if custom_regression_dir is not None:
final_regression_dir = custom_regression_dir
else:
final_regression_dir = str(Path(work_dir) / "9_Custom_Regression_Modeling")
if not Path(final_regression_dir).exists():
raise ValueError(
"请先执行步骤6.75: 自定义回归分析,或提供 custom_regression_dir 参数"
)
if output_dir is None:
custom_regression_prediction_dir = Path(work_dir) / "11_12_13_predictions" / "Custom_Regression_Prediction"
custom_regression_prediction_dir.mkdir(parents=True, exist_ok=True)
prediction_output_dir = str(custom_regression_prediction_dir)
else:
prediction_output_dir = output_dir
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
predictor = CustomRegressionPredictor(
regression_csv_dir=final_regression_dir,
formula_csv_path=formula_csv_path,
)
print(f"开始使用自定义回归模块进行批量预测...")
print(f" 采样点数据: {sampling_csv_path}")
print(f" 回归模型目录: {final_regression_dir}")
print(f" 输出目录: {prediction_output_dir}")
saved_files = predictor.run_batch_prediction(
input_csv_path=sampling_csv_path,
coordinate_columns=coordinate_columns,
filename_prefix=filename_prefix,
)
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
for param_name, filepath in saved_files.items():
print(f" {param_name}: {filepath}")
notify("completed", f"自定义回归预测完成: {len(saved_files)} 个文件")
return saved_files

View File

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
"""
步骤1: 水域掩膜生成
支持三种方式:
1. 基于 shp 文件栅格化
2. 使用现有栅格格式掩膜文件 (.dat/.tif)
3. 基于 NDWI 从影像自动生成水体掩膜
"""
import os
import time
from pathlib import Path
from typing import Optional, List, Callable, Union
import numpy as np
class WaterMaskStep:
"""水域掩膜生成步骤"""
@staticmethod
def run(
mask_path: Optional[str] = None,
img_path: Optional[str] = None,
ndwi_threshold: float = 0.4,
use_ndwi: bool = False,
generate_png: bool = True,
output_path: Optional[str] = None,
water_mask_dir: Union[str, Path] = "./1_water_mask",
callback: Optional[Callable] = None,
) -> str:
"""
执行水域掩膜生成
Args:
mask_path: 水体掩膜文件路径,支持 .shp需 img_path或 .dat/.tif直接使用
img_path: 输入影像文件路径(当 mask_path 为 shp 或 use_ndwi=True 时必须提供)
ndwi_threshold: NDWI 阈值use_ndwi=True 时使用)
use_ndwi: 是否使用 NDWI 方法从影像生成水体掩膜
generate_png: 是否生成 PNG 预览图(默认 True
output_path: 指定输出掩膜文件的保存路径(可选)
water_mask_dir: 工作目录
callback: 回调函数,签名为 callback(step, status, message)
Returns:
dat 格式的水域掩膜文件路径
"""
from src.utils.extract_water_area import rasterize_shp, ndwi
from src.core.utils.preview_generator import (
generate_image_preview,
generate_water_mask_overlay,
)
water_mask_dir = Path(water_mask_dir)
water_mask_dir.mkdir(parents=True, exist_ok=True)
def notify(status, msg=""):
if callback:
callback("步骤1", status, msg)
print("\n" + "=" * 80)
print("步骤1: 生成或设置水域mask")
print("=" * 80)
step_start_time = time.time()
# 生成影像预览图
if generate_png and img_path is not None and Path(img_path).exists():
preview_path = str(water_mask_dir / "hsi_preview.png")
generate_image_preview(
img_path=img_path,
output_path=preview_path,
title="影像预览: RGB波段(基于波长)"
)
# ---- NDWI 方法 ----
if use_ndwi:
if img_path is None:
raise ValueError("当 use_ndwi=True 时,必须提供 img_path 参数")
if not Path(img_path).exists():
raise ValueError(f"影像文件不存在: {img_path}")
print(f"使用NDWI方法从影像生成水体掩膜阈值={ndwi_threshold}...")
ndwi_output_path = output_path or str(water_mask_dir / "water_mask_from_ndwi.dat")
os.makedirs(Path(ndwi_output_path).parent, exist_ok=True)
if Path(ndwi_output_path).exists():
print(f"检测到已存在的NDWI掩膜文件直接使用: {ndwi_output_path}")
notify("skipped", f"水域掩膜已设置: {ndwi_output_path}")
return ndwi_output_path
ndwi(img_path, ndwi_threshold, ndwi_output_path)
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(
img_path=img_path, mask_path=ndwi_output_path, output_path=str(overlay_path)
)
notify("completed", f"NDWI水体掩膜已生成: {ndwi_output_path}")
return ndwi_output_path
# ---- 必须提供 mask_path ----
if mask_path is None:
raise ValueError("必须提供 mask_path 参数或设置 use_ndwi=True")
if not Path(mask_path).exists():
raise ValueError(f"文件不存在: {mask_path}")
file_ext = Path(mask_path).suffix.lower()
# ---- SHP 栅格化 ----
if file_ext == ".shp":
if img_path is None:
raise ValueError("当 mask_path 为 shp 格式时,必须提供 img_path 参数")
print(f"检测到shp格式的水体掩膜正在转换为dat格式...")
shp_output_path = output_path or str(water_mask_dir / "water_mask_from_shp.dat")
os.makedirs(Path(shp_output_path).parent, exist_ok=True)
if Path(shp_output_path).exists():
print(f"检测到已存在的栅格化掩膜文件,直接使用: {shp_output_path}")
notify("skipped", f"水域掩膜已设置: {shp_output_path}")
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
if not overlay_path.exists():
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
return shp_output_path
safe_mask_path = os.path.abspath(mask_path).replace("\\", "/")
rasterize_shp(safe_mask_path, shp_output_path, img_path)
if generate_png:
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(img_path, shp_output_path, str(overlay_path))
notify("completed", f"dat格式水域掩膜已生成: {shp_output_path}")
return shp_output_path
# ---- 栅格格式直接使用 ----
print(f"检测到栅格格式的水体掩膜,直接使用: {mask_path}")
if generate_png and img_path is not None and Path(img_path).exists():
overlay_path = water_mask_dir / "water_mask_overlay.png"
generate_water_mask_overlay(img_path, mask_path, str(overlay_path))
notify("completed", f"水域掩膜已设置: {mask_path}")
return mask_path

File diff suppressed because it is too large Load Diff

View File

@ -209,7 +209,7 @@ class Step3Panel(QWidget):
"输出影像:",
"Image Files (*.bsq *.dat *.tif);;All Files (*.*)"
)
self.output_file.line_edit.setPlaceholderText("deglint_image.dat")
self.output_file.line_edit.setPlaceholderText("deglint_image.bsq")
layout.addWidget(self.output_file)
# 启用步骤
@ -301,7 +301,7 @@ class Step3Panel(QWidget):
if self.work_dir:
output_dir = os.path.join(self.work_dir, "3_deglint")
os.makedirs(output_dir, exist_ok=True)
default_output_path = os.path.join(output_dir, "deglint_image.dat").replace('\\', '/')
default_output_path = os.path.join(output_dir, "deglint_image.bsq").replace('\\', '/')
self.output_file.set_path(default_output_path)
else:
self.output_file.set_path("")

View File

@ -17,6 +17,57 @@ from src.gui.components.custom_widgets import FileSelectWidget
from src.gui.styles import ModernStylesheet
# ============================================================
# 中文映射表(内部键名 -> 显示文本)
# ============================================================
# 预处理方法:内部键 -> 显示文本
PREPROC_CHINESE = {
'None': '无 (None)',
'MMS': '最小-最大归一化 (MMS)',
'SS': '标度化 (SS)',
'SNV': '标准正态变换 (SNV)',
'MA': '移动平均 (MA)',
'SG': 'Savitzky-Golay (SG)',
'MSC': '多元散射校正 (MSC)',
'D1': '一阶导数 (D1)',
'D2': '二阶导数 (D2)',
'DT': '去趋势 (DT)',
'CT': '中心化 (CT)',
}
# 模型类型:内部键 -> 显示文本
MODEL_CHINESE = {
# 线性模型
'LinearRegression': '多元线性回归 (MLR)',
'Ridge': '岭回归 (Ridge)',
'Lasso': '套索回归 (Lasso)',
'ElasticNet': '弹性网络 (ElasticNet)',
'PLS': '偏最小二乘 (PLSR)',
# 树模型
'DecisionTree': '决策树 (CART)',
'RF': '随机森林 (RF)',
'ExtraTrees': '极端随机树 (ET)',
'XGBoost': '极值梯度提升 (XGBoost)',
'LightGBM': '轻量梯度提升 (LightGBM)',
'CatBoost': '类别梯度提升 (CatBoost)',
# 集成学习
'GradientBoosting': '梯度提升树 (GBDT)',
'AdaBoost': '自适应提升 (AdaBoost)',
# 其他模型
'SVR': '支持向量回归 (SVR)',
'KNN': 'K近邻回归 (KNN)',
'MLP': '多层感知机 (BP神经网络)',
}
# 数据划分方法:内部键 -> 显示文本
SPLIT_CHINESE = {
'spxy': 'SPXY 算法 (考量X-Y空间)',
'ks': 'KS 算法 (考量X空间)',
'random': '随机划分 (Random)',
}
class Step6Panel(QWidget):
"""步骤6机器学习建模"""
def __init__(self, parent=None):
@ -54,7 +105,7 @@ class Step6Panel(QWidget):
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
self.enable_checkbox.setChecked(True)
self.enable_checkbox.setChecked(False)
layout.addWidget(self.enable_checkbox)
# 独立运行按钮
@ -95,8 +146,8 @@ class Step6Panel(QWidget):
preproc_methods = ['None', 'MMS', 'SS', 'SNV', 'MA', 'SG', 'MSC', 'D1', 'D2', 'DT', 'CT']
for i, method in enumerate(preproc_methods):
checkbox = QCheckBox(method)
checkbox.setChecked(True)
checkbox = QCheckBox(PREPROC_CHINESE.get(method, method))
checkbox.setChecked(False)
self.preproc_checkboxes[method] = checkbox
preproc_grid.addWidget(checkbox, i // 4, i % 4)
@ -122,10 +173,10 @@ class Step6Panel(QWidget):
self.model_checkboxes = {}
model_groups = [
("线性模型", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
("树模型", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
("集成学习", ['GradientBoosting', 'AdaBoost']),
("其他模型", ['SVR', 'KNN', 'MLP'])
("线性模型", ['LinearRegression', 'Ridge', 'Lasso', 'ElasticNet', 'PLS']),
("树模型", ['DecisionTree', 'RF', 'ExtraTrees', 'XGBoost', 'LightGBM', 'CatBoost']),
("集成学习", ['GradientBoosting', 'AdaBoost']),
("其他模型", ['SVR', 'KNN', 'MLP'])
]
row = 0
@ -140,8 +191,8 @@ class Step6Panel(QWidget):
row += 1
for i, model in enumerate(models):
checkbox = QCheckBox(model)
checkbox.setChecked(model in ['SVR', 'RF', 'Ridge', 'Lasso'])
checkbox = QCheckBox(MODEL_CHINESE.get(model, model))
checkbox.setChecked(False)
self.model_checkboxes[model] = checkbox
model_grid.addWidget(checkbox, row, i % 4)
if (i + 1) % 4 == 0:
@ -172,8 +223,8 @@ class Step6Panel(QWidget):
split_methods = ['spxy', 'ks', 'random']
for i, method in enumerate(split_methods):
checkbox = QCheckBox(method)
checkbox.setChecked(True)
checkbox = QCheckBox(SPLIT_CHINESE.get(method, method))
checkbox.setChecked(False)
self.split_checkboxes[method] = checkbox
split_grid.addWidget(checkbox, 0, i)

View File

@ -109,7 +109,7 @@ class Step9Panel(QWidget):
mode_row.addStretch()
layout.addLayout(mode_row)
# ---------- RadioButton 美化样式(选中状态更醒目 ----------
# ---------- RadioButton 美化样式(选中状态为方形实心块,贴合主界面风格 ----------
radio_style = """
QRadioButton {
font-size: 14px;
@ -117,21 +117,16 @@ class Step9Panel(QWidget):
color: #333333;
}
QRadioButton::indicator {
width: 18px;
height: 18px;
width: 16px;
height: 16px;
border: 2px solid #999999;
border-radius: 9px;
border-radius: 3px;
background-color: white;
}
QRadioButton::indicator:checked {
border: 2px solid #0078d4;
background-color: qradialgradient(
cx:0.5, cy:0.5, radius:0.5,
fx:0.5, fy:0.5,
stop:0 #0078d4,
stop:0.6 white,
stop:1.0 white
);
background-color: #0078d4;
image: none;
}
QRadioButton::indicator:hover {
border: 2px solid #005a9e;
@ -353,7 +348,7 @@ class Step9Panel(QWidget):
if not main_window:
return
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(优先)
# 1. 尝试从 Step8 界面读取机器学习预测输出目录(优先)
pred_dir = None
if hasattr(main_window, 'step8_panel'):
step8_widget = getattr(main_window.step8_panel, 'output_file', None)
@ -367,7 +362,10 @@ class Step9Panel(QWidget):
# 若为相对路径,使用 work_dir 合成为绝对路径
if not os.path.isabs(step8_output):
step8_output = os.path.join(self.work_dir or '', step8_output).replace('\\', '/')
pred_dir = str(Path(step8_output).parent)
# 提取父目录后追加 Machine_Learning_Prediction最底层真实子目录
base_pred_dir = str(Path(step8_output).parent)
ml_pred_dir = Path(base_pred_dir) / "Machine_Learning_Prediction"
pred_dir = str(ml_pred_dir) if ml_pred_dir.exists() else base_pred_dir
# 2. 备选:从 Step8.5 界面读取非经验预测输出目录
if not pred_dir and hasattr(main_window, 'step8_5_panel'):
@ -411,6 +409,14 @@ class Step9Panel(QWidget):
existing_out = self.output_dir.get_path()
if not existing_out or not existing_out.strip():
self.output_dir.set_path(output_dir)
# 5. 自动继承步骤1的水域掩膜作为边界文件
if self.work_dir:
default_mask = Path(self.work_dir) / "1_water_mask" / "water_mask_from_shp.dat"
if default_mask.exists():
existing_boundary = (self.boundary_file.get_path() or "").strip()
if not existing_boundary:
self.boundary_file.set_path(str(default_mask))
except Exception as e:
import traceback
print(f"{self.__class__.__name__}】自动填充失败,跳过: {e}")

View File

@ -1825,6 +1825,11 @@ class WaterQualityGUI(QMainWindow):
for step_id, step_display in steps:
item = QListWidgetItem(f" └─ {step_display}")
item.setData(Qt.UserRole, step_id)
# 隐藏4个冗余回归步骤树节点
if step_id in ("step6_5", "step6_75", "step8_5", "step8_75"):
item.setHidden(True)
self.step_name_map[step_display] = step_id
# 设置步骤项的样式
@ -1905,9 +1910,11 @@ class WaterQualityGUI(QMainWindow):
self.step6_5_panel = Step6_5Panel()
self.step_stack.addTab(self.create_scroll_area(self.step6_5_panel), QIcon(self.get_icon_path("6.png")), "回归建模")
self.step_stack.tabBar().setTabVisible(7, False) # 隐藏回归建模 Tab
self.step6_75_panel = Step6_75Panel()
self.step_stack.addTab(self.create_scroll_area(self.step6_75_panel), QIcon(self.get_icon_path("6.png")), "自定义回归建模")
self.step_stack.tabBar().setTabVisible(8, False) # 隐藏自定义回归建模 Tab
self.step7_panel = Step7Panel()
self.step_stack.addTab(self.create_scroll_area(self.step7_panel), QIcon(self.get_icon_path("7.png")), "采样点布设")
@ -1917,9 +1924,11 @@ class WaterQualityGUI(QMainWindow):
self.step8_5_panel = Step8_5Panel()
self.step_stack.addTab(self.create_scroll_area(self.step8_5_panel), QIcon(self.get_icon_path("8.png")), "回归预测")
self.step_stack.tabBar().setTabVisible(11, False) # 隐藏回归预测 Tab
self.step8_75_panel = Step8_75Panel()
self.step_stack.addTab(self.create_scroll_area(self.step8_75_panel), QIcon(self.get_icon_path("8.png")), "自定义回归预测")
self.step_stack.tabBar().setTabVisible(12, False) # 隐藏自定义回归预测 Tab
self.step9_panel = Step9Panel()
self.step_stack.addTab(self.create_scroll_area(self.step9_panel), QIcon(self.get_icon_path("10.png")), "专题图生成")

View File

@ -1003,55 +1003,66 @@ class ReportGenerator:
Returns:
保存的文件路径
"""
from modeling_batch import WaterQualityModelingBatch
from src.core.modeling.modeling_batch import WaterQualityModelingBatch
import joblib
modeler = WaterQualityModelingBatch(models_dir)
# 需要先加载训练结果
# 这里假设results已经存储在modeler中或者需要从保存的文件中读取
# 由于modeling_batch.py的结构我们需要另一种方式来获取所有结果
# 尝试遍历模型目录,查找所有保存的结果
models_path = Path(models_dir)
all_results = []
# 遍历所有目标参数文件夹
for target_folder in models_path.iterdir():
if not target_folder.is_dir():
continue
# 递归扫描 *.joblib 和 *.pkl兼容 artifacts_dir/target_name/ 的所有子目录层级
model_files = list(models_path.rglob("*.joblib")) + list(models_path.rglob("*.pkl"))
target_name = target_folder.name
for model_file in model_files:
# 目标参数取直系父目录名(符合 artifacts_dir/target_name/ 结构)
target_name = model_file.parent.name
stem = model_file.stem
# 查找所有模型文件
for model_file in target_folder.rglob("*.pkl"):
# 从文件名提取信息(假设格式为:{preprocess}_{model}_{split}.pkl
model_info = {
# 文件名格式:{safe_target}_{preprocess}_{model_name}.joblib
# 使用 split('_', 2) 最多切 3 段,目标 1 段、预处理 1 段、模型 1 段
parts = stem.split('_', 2)
preprocess = parts[1] if len(parts) > 1 else 'Unknown'
model_name_str = parts[2] if len(parts) > 2 else stem
# 尝试从 joblib/pkl 读取元数据,提取性能指标
metrics = {}
try:
data = joblib.load(model_file)
metadata = data.get('metadata', {})
metrics = {
'train_r2': metadata.get('train_r2', 'N/A'),
'test_r2': metadata.get('test_r2', 'N/A'),
'test_rmse': metadata.get('test_rmse', 'N/A'),
'train_rmse': metadata.get('train_rmse', 'N/A'),
'train_mae': metadata.get('train_mae', 'N/A'),
'test_mae': metadata.get('test_mae', 'N/A'),
'cv_mean': metadata.get('cv_mean', 'N/A'),
}
except Exception:
pass # 加载失败时 metrics 保持为空字典,摘要中该列为 N/A
all_results.append({
'target': target_name,
'model_file': str(model_file),
'preprocess': 'Unknown',
'model': 'Unknown',
'split_method': 'Unknown'
}
'preprocess': preprocess,
'model': model_name_str,
**metrics,
})
# 尝试从文件名解析
parts = model_file.stem.split('_')
if len(parts) >= 3:
model_info['preprocess'] = parts[0]
model_info['model'] = parts[1]
model_info['split_method'] = parts[2]
all_results.append(model_info)
# 如果有训练结果数据,使用实际指标
# 否则创建一个基本的摘要
summary_data = []
for result in all_results:
summary_data.append({
'目标参数': result['target'],
'预处理方法': result['preprocess'],
'模型名称': result['model'],
'划分方法': result['split_method'],
'模型文件': result['model_file']
'模型文件': result['model_file'],
'训练集R²': result.get('train_r2', 'N/A'),
'测试集R²': result.get('test_r2', 'N/A'),
'测试集RMSE': result.get('test_rmse', 'N/A'),
'训练集RMSE': result.get('train_rmse', 'N/A'),
'训练集MAE': result.get('train_mae', 'N/A'),
'测试集MAE': result.get('test_mae', 'N/A'),
'CV均值': result.get('cv_mean', 'N/A'),
})
if not summary_data:
@ -1060,8 +1071,14 @@ class ReportGenerator:
'目标参数': 'No Data',
'预处理方法': 'N/A',
'模型名称': 'N/A',
'划分方法': 'N/A',
'模型文件': 'N/A'
'模型文件': 'N/A',
'训练集R²': 'N/A',
'测试集R²': 'N/A',
'测试集RMSE': 'N/A',
'训练集RMSE': 'N/A',
'训练集MAE': 'N/A',
'测试集MAE': 'N/A',
'CV均值': 'N/A',
}]
df_summary = pd.DataFrame(summary_data)

View File

@ -96,8 +96,14 @@ class BandMathCalculator:
print(f"计算表达式: {calc_expression}")
# 安全地计算表达式
result = eval(calc_expression)
# 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf
import numpy as np
try:
# 即使 calc_expression 含有纯字符 nan也能被 np.nan 安全接管
result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np})
except Exception as e:
print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan")
result = np.nan
# 返回结果
if var_name: