fix: 修复工作目录与步骤名不对应、回归预测虚数报错、模型加载及预处理名称转换问题,重构可视化并修正勾选联动

This commit is contained in:
2026-04-14 17:41:38 +08:00
parent b0a94ba1e7
commit 9b7bcfadd1
17 changed files with 12470 additions and 3113 deletions

View File

@ -0,0 +1,654 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
自定义回归预测模块
该模块根据9_Custom_Regression_Modeling文件夹中的CSV信息批量预测水质指数。
处理流程:
1. 读取9_Custom_Regression_Modeling文件夹中的CSV文件
2. 根据r_squared选择最佳模型指数公式+反演公式)
3. 使用指数公式计算光谱指数值
4. 使用反演公式计算水质参数值
5. 输出包含投影经纬度和预测值的CSV文件
作者: Assistant
创建时间: 2026-04-14
"""
import numpy as np
import pandas as pd
import os
import re
from pathlib import Path
from typing import List, Dict, Union, Tuple, Optional
import warnings
import logging
from datetime import datetime
warnings.filterwarnings('ignore')
# 导入水质指数计算器
try:
from src.utils.water_index import WaterQualityIndexCalculator
except ImportError:
from ..utils.water_index import WaterQualityIndexCalculator
class CustomRegressionPredictor:
"""
自定义回归预测器
基于9_Custom_Regression_Modeling文件夹中的回归模型CSV文件
进行水质参数的批量预测。
"""
def __init__(self,
regression_models_dir: str = "9_Custom_Regression_Modeling",
formula_csv_path: Optional[str] = None,
output_dir: str = "prediction_results",
log_level: int = logging.INFO):
"""
初始化预测器
Args:
regression_models_dir: 回归模型CSV文件所在目录
formula_csv_path: 公式CSV文件路径用于查找index_formula
output_dir: 预测结果输出目录
log_level: 日志级别
"""
self.regression_models_dir = Path(regression_models_dir)
self.formula_csv_path = formula_csv_path
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# 初始化日志
self._setup_logging(log_level)
# 初始化水质指数计算器
self.index_calculator = WaterQualityIndexCalculator()
# 存储加载的回归模型信息
self.regression_models = {}
self.best_models = {}
# 加载公式CSV如果提供
self.formula_df = None
if self.formula_csv_path and Path(self.formula_csv_path).exists():
self.formula_df = pd.read_csv(self.formula_csv_path)
self.logger.info(f"加载公式CSV: {self.formula_csv_path}, 共 {len(self.formula_df)} 个公式")
self.logger.info(f"CustomRegressionPredictor初始化完成")
self.logger.info(f"回归模型目录: {self.regression_models_dir}")
self.logger.info(f"输出目录: {self.output_dir}")
def _setup_logging(self, log_level: int):
"""设置日志配置"""
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(log_level)
# 避免重复添加处理器
if not self.logger.handlers:
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)
# 创建格式器
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
def load_regression_models(self) -> Dict[str, pd.DataFrame]:
"""
加载9_Custom_Regression_Modeling文件夹中的所有CSV文件
支持的CSV格式
- 回归结果CSV包含列y_variable, x_variable, equation, r_squared
- 其中 x_variable 是指数名称,需要从 formula_csv 中查找对应的 Formula
- equation 是反演公式
Returns:
Dict[str, pd.DataFrame]: 参数名 -> 回归模型数据框
"""
if not self.regression_models_dir.exists():
raise FileNotFoundError(f"回归模型目录不存在: {self.regression_models_dir}")
csv_files = list(self.regression_models_dir.glob("*.csv"))
if not csv_files:
raise FileNotFoundError(f"在目录 {self.regression_models_dir} 中未找到CSV文件")
self.logger.info(f"找到 {len(csv_files)} 个CSV文件")
for csv_file in csv_files:
try:
# 跳过all_regression_results.csv处理单个参数的结果文件
if csv_file.name.lower() == "all_regression_results.csv":
self.logger.info(f"跳过汇总文件: {csv_file.name}")
continue
# 从文件名推断参数名移除_regression_results后缀
param_name = csv_file.stem.replace('_regression_results', '').replace('_results', '')
self.logger.info(f"加载回归模型文件: {csv_file.name} -> 参数: {param_name}")
# 读取CSV文件
df = pd.read_csv(csv_file)
# 检查必需的列并进行列名映射
# 实际CSV列名: y_variable, x_variable, equation, r_squared
# 需要的列名: index_formula, inversion_formula, r_squared
# 列名映射
column_mapping = {}
# 检查r_squared可能为R2, r2, R_squared等
r_squared_cols = ['r_squared', 'R_squared', 'R2', 'r2', 'R_squared', '']
found_r2 = None
for col in r_squared_cols:
if col in df.columns:
found_r2 = col
break
if found_r2 and found_r2 != 'r_squared':
column_mapping[found_r2] = 'r_squared'
# 检查equation反演公式
if 'equation' in df.columns:
column_mapping['equation'] = 'inversion_formula'
elif 'inversion_formula' not in df.columns:
self.logger.warning(f"文件 {csv_file.name} 缺少反演公式列(equation)")
continue
# 检查x_variable指数名称需要转换为index_formula
if 'x_variable' in df.columns:
# 需要从formula_csv中查找对应的公式
if self.formula_df is not None:
df['index_formula'] = df['x_variable'].apply(
lambda x: self._lookup_formula_from_name(x)
)
else:
# 如果没有formula_csv直接使用x_variable作为index_formula
column_mapping['x_variable'] = 'index_formula'
elif 'index_formula' not in df.columns:
self.logger.warning(f"文件 {csv_file.name} 缺少指数名称列(x_variable)")
continue
# 应用列名映射
if column_mapping:
df = df.rename(columns=column_mapping)
# 验证必需的列
required_columns = ['index_formula', 'inversion_formula', 'r_squared']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
self.logger.warning(f"文件 {csv_file.name} 缺少必需列: {missing_columns}")
continue
self.regression_models[param_name] = df
self.logger.info(f"成功加载参数 {param_name}{len(df)} 个回归模型")
except Exception as e:
self.logger.error(f"加载文件 {csv_file.name} 失败: {e}")
continue
if not self.regression_models:
raise ValueError("未成功加载任何回归模型")
return self.regression_models
def _lookup_formula_from_name(self, formula_name: str) -> str:
"""
从公式CSV中根据公式名称查找对应的公式
Args:
formula_name: 公式名称(如 BGA_Am09KBBI
Returns:
公式字符串
"""
if self.formula_df is None:
return formula_name
# 查找匹配的公式名称
# 支持多种列名Formula_Name, formula_name, name等
name_cols = ['Formula_Name', 'formula_name', 'Name', 'name']
name_col = None
for col in name_cols:
if col in self.formula_df.columns:
name_col = col
break
if name_col is None:
self.logger.warning(f"公式CSV中未找到公式名称列使用默认列名")
name_col = self.formula_df.columns[0] if len(self.formula_df.columns) > 0 else None
# 查找公式列
formula_cols = ['Formula', 'formula', 'Expression', 'expression']
formula_col = None
for col in formula_cols:
if col in self.formula_df.columns:
formula_col = col
break
if formula_col is None and len(self.formula_df.columns) > 2:
formula_col = self.formula_df.columns[2] # 通常第3列是公式
if name_col and formula_col:
match = self.formula_df[self.formula_df[name_col] == formula_name]
if not match.empty:
return match[formula_col].iloc[0]
# 如果未找到,返回原始名称
return formula_name
def select_best_models(self) -> Dict[str, pd.Series]:
"""
为每个水质参数选择r_squared最高的回归模型
Returns:
Dict[str, pd.Series]: 参数名 -> 最佳模型信息
"""
if not self.regression_models:
self.load_regression_models()
for param_name, models_df in self.regression_models.items():
# 根据r_squared排序选择最高的
best_model = models_df.loc[models_df['r_squared'].idxmax()]
self.best_models[param_name] = best_model
self.logger.info(f"参数 {param_name} 最佳模型:")
self.logger.info(f" 指数公式: {best_model['index_formula']}")
self.logger.info(f" 反演公式: {best_model['inversion_formula']}")
self.logger.info(f" R²: {best_model['r_squared']:.4f}")
return self.best_models
def calculate_spectral_index(self,
spectral_data: pd.DataFrame,
index_formula: str) -> pd.Series:
"""
根据指数公式计算光谱指数值
Args:
spectral_data: 光谱数据,列名为波长
index_formula: 指数公式字符串
Returns:
pd.Series: 计算得到的指数值
"""
try:
# 检查公式是否为水质指数计算器中的标准函数
if hasattr(self.index_calculator, index_formula):
# 使用标准水质指数计算器
calculator_func = getattr(self.index_calculator, index_formula)
index_values = calculator_func(spectral_data)
self.logger.debug(f"使用标准指数函数 {index_formula} 计算完成")
return index_values
# 否则尝试解析自定义公式
# 支持常见的指数公式格式,如 (R865-R644)/(R458+R529) 或 w681 - w665
formula = index_formula.strip()
# 替换波长变量为对应的列值
# 查找所有R或w开头的波长变量支持 R681, w681, R_681 等格式)
wavelength_pattern = r'[Rw](\d+)'
wavelengths = re.findall(wavelength_pattern, formula)
# 构建可执行的表达式
expression = formula
for wl in set(wavelengths):
wl_int = int(wl)
# 找到最接近的波长列
closest_col = self.index_calculator.find_closest_wavelength(
spectral_data.columns.tolist(), wl_int
)
# 替换公式中的变量(支持 R数字 和 w数字 格式)
expression = expression.replace(f'R{wl}', f'spectral_data["{closest_col}"]')
expression = expression.replace(f'w{wl}', f'spectral_data["{closest_col}"]')
# 评估表达式
index_values = eval(expression)
self.logger.debug(f"使用自定义公式 {index_formula} 计算完成")
return pd.Series(index_values, index=spectral_data.index)
except Exception as e:
self.logger.error(f"计算光谱指数失败,公式: {index_formula}, 错误: {e}")
# 返回NaN值
return pd.Series(np.nan, index=spectral_data.index)
def apply_inversion_formula(self,
index_values: pd.Series,
inversion_formula: str) -> pd.Series:
"""
根据反演公式将指数值转换为水质参数值
Args:
index_values: 光谱指数值
inversion_formula: 反演公式字符串,如 "y = 11.27 + 0.82*x""11.27 + 0.82*x"
支持的函数ln(x)自然对数, log10(x)常用对数, exp(x)指数, sqrt(x)平方根
Returns:
pd.Series: 水质参数预测值
"""
try:
# 处理公式中的常见数学符号
formula = inversion_formula.strip()
# 去除 "y = " 或 "y=" 前缀
if formula.lower().startswith('y='):
formula = formula[2:].strip()
elif formula.lower().startswith('y = '):
formula = formula[4:].strip()
# 替换数学符号
formula = formula.replace('^', '**') # 幂运算
# 定义数学函数环境
import math
math_env = {
'ln': math.log, # 自然对数
'log': math.log, # 自然对数(别名)
'log10': math.log10, # 常用对数
'exp': math.exp, # 指数函数
'sqrt': math.sqrt, # 平方根
'abs': abs, # 绝对值
'pow': pow, # 幂函数
}
self.logger.debug(f"处理后的反演公式: {formula}")
# 评估每个指数值
predicted_values = []
for x_val in index_values:
try:
if pd.isna(x_val):
predicted_values.append(np.nan)
else:
# 评估公式,将 x 作为局部变量传递,同时提供数学函数
local_vars = {'x': x_val}
local_vars.update(math_env)
result = eval(formula, {"__builtins__": {}}, local_vars)
# 如果是复数,取实部
if isinstance(result, complex):
result = result.real
predicted_values.append(result)
except Exception as eval_err:
self.logger.debug(f"评估公式失败x={x_val}: {eval_err}")
predicted_values.append(np.nan)
self.logger.debug(f"使用反演公式 {inversion_formula} 计算完成")
return pd.Series(predicted_values, index=index_values.index)
except Exception as e:
self.logger.error(f"应用反演公式失败,公式: {inversion_formula}, 错误: {e}")
return pd.Series(np.nan, index=index_values.index)
def predict_single_parameter(self,
spectral_data: pd.DataFrame,
coordinate_data: pd.DataFrame,
param_name: str) -> pd.DataFrame:
"""
预测单个水质参数
Args:
spectral_data: 光谱数据(不包含坐标列)
coordinate_data: 坐标数据(包含经纬度或投影坐标)
param_name: 水质参数名
Returns:
pd.DataFrame: 包含坐标和预测值的结果
"""
if param_name not in self.best_models:
raise ValueError(f"未找到参数 {param_name} 的最佳模型")
best_model = self.best_models[param_name]
index_formula = best_model['index_formula']
inversion_formula = best_model['inversion_formula']
self.logger.info(f"开始预测参数: {param_name}")
# 步骤1: 计算光谱指数
self.logger.info("步骤1: 计算光谱指数")
index_values = self.calculate_spectral_index(spectral_data, index_formula)
# 步骤2: 应用反演公式
self.logger.info("步骤2: 应用反演公式")
predicted_values = self.apply_inversion_formula(index_values, inversion_formula)
# 步骤3: 组合结果
result_df = coordinate_data.copy()
result_df[f'{param_name}_predicted'] = predicted_values
result_df[f'{param_name}_index'] = index_values
# 添加元数据
result_df.attrs = {
'parameter': param_name,
'index_formula': index_formula,
'inversion_formula': inversion_formula,
'r_squared': best_model['r_squared'],
'prediction_time': datetime.now().isoformat()
}
valid_predictions = result_df[f'{param_name}_predicted'].notna().sum()
total_points = len(result_df)
self.logger.info(f"参数 {param_name} 预测完成: {valid_predictions}/{total_points} 个有效预测")
return result_df
def predict_all_parameters(self,
spectral_data: pd.DataFrame,
coordinate_data: pd.DataFrame) -> Dict[str, pd.DataFrame]:
"""
批量预测所有水质参数
Args:
spectral_data: 光谱数据(不包含坐标列)
coordinate_data: 坐标数据(包含经纬度或投影坐标)
Returns:
Dict[str, pd.DataFrame]: 参数名 -> 预测结果DataFrame
"""
if not self.best_models:
self.select_best_models()
prediction_results = {}
for param_name in self.best_models.keys():
try:
self.logger.info(f"正在预测参数: {param_name}")
result_df = self.predict_single_parameter(
spectral_data, coordinate_data, param_name
)
prediction_results[param_name] = result_df
except Exception as e:
self.logger.error(f"预测参数 {param_name} 失败: {e}")
continue
self.logger.info(f"批量预测完成,成功预测 {len(prediction_results)} 个参数")
return prediction_results
def _convert_complex_to_real(self, df: pd.DataFrame) -> pd.DataFrame:
"""
将DataFrame中的复数转换为实数取实部
Args:
df: 输入DataFrame
Returns:
处理后的DataFrame
"""
df_clean = df.copy()
for col in df_clean.columns:
if df_clean[col].dtype == 'object' or 'complex' in str(df_clean[col].dtype):
# 检查是否包含复数
try:
df_clean[col] = df_clean[col].apply(
lambda x: x.real if isinstance(x, complex) else x
)
except:
pass
return df_clean
def save_prediction_results(self,
prediction_results: Dict[str, pd.DataFrame],
filename_prefix: str = "custom_regression_prediction") -> Dict[str, str]:
"""
保存预测结果为CSV文件
Args:
prediction_results: 预测结果字典
filename_prefix: 文件名前缀
Returns:
Dict[str, str]: 参数名 -> 保存的文件路径
"""
saved_files = {}
for param_name, result_df in prediction_results.items():
try:
# 构建文件名
filename = f"{filename_prefix}_{param_name}.csv"
filepath = self.output_dir / filename
# 处理复数类型,转换为实数
result_df_clean = self._convert_complex_to_real(result_df)
# 保存CSV文件
result_df_clean.to_csv(filepath, index=False, encoding='utf-8-sig')
saved_files[param_name] = str(filepath)
self.logger.info(f"参数 {param_name} 预测结果已保存: {filepath}")
# 打印统计信息
predicted_col = f'{param_name}_predicted'
if predicted_col in result_df.columns:
valid_count = result_df[predicted_col].notna().sum()
mean_value = result_df[predicted_col].mean()
std_value = result_df[predicted_col].std()
self.logger.info(f" 有效预测: {valid_count}")
self.logger.info(f" 平均值: {mean_value:.4f}")
self.logger.info(f" 标准差: {std_value:.4f}")
except Exception as e:
self.logger.error(f"保存参数 {param_name} 结果失败: {e}")
continue
return saved_files
def run_batch_prediction(self,
input_csv_path: str,
coordinate_columns: List[str] = None,
filename_prefix: str = "custom_regression_prediction") -> Dict[str, str]:
"""
运行完整的批量预测流程
Args:
input_csv_path: 输入的光谱采样CSV文件路径
coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']
filename_prefix: 输出文件名前缀
Returns:
Dict[str, str]: 参数名 -> 保存的文件路径
"""
if coordinate_columns is None:
coordinate_columns = ['longitude', 'latitude']
self.logger.info("="*50)
self.logger.info("开始自定义回归批量预测")
self.logger.info("="*50)
# 1. 加载回归模型
self.logger.info("步骤1: 加载回归模型")
self.load_regression_models()
# 2. 选择最佳模型
self.logger.info("步骤2: 选择最佳模型")
self.select_best_models()
# 3. 读取输入数据
self.logger.info("步骤3: 读取输入数据")
input_df = pd.read_csv(input_csv_path)
self.logger.info(f"读取输入数据: {len(input_df)} 行, {len(input_df.columns)}")
# 4. 分离坐标和光谱数据
self.logger.info("步骤4: 分离坐标和光谱数据")
# 检查坐标列是否存在
missing_coords = [col for col in coordinate_columns if col not in input_df.columns]
if missing_coords:
# 尝试自动识别坐标列
potential_coord_cols = []
for col in input_df.columns:
if any(coord_name in col.lower() for coord_name in ['lon', 'lat', 'x', 'y']):
potential_coord_cols.append(col)
if len(potential_coord_cols) >= 2:
coordinate_columns = potential_coord_cols[:2]
self.logger.info(f"自动识别坐标列: {coordinate_columns}")
else:
raise ValueError(f"未找到坐标列: {missing_coords}")
coordinate_data = input_df[coordinate_columns].copy()
spectral_data = input_df.drop(columns=coordinate_columns)
self.logger.info(f"坐标数据: {len(coordinate_data)} 行, {len(coordinate_data.columns)}")
self.logger.info(f"光谱数据: {len(spectral_data)} 行, {len(spectral_data.columns)}")
# 5. 批量预测
self.logger.info("步骤5: 执行批量预测")
prediction_results = self.predict_all_parameters(spectral_data, coordinate_data)
# 6. 保存结果
self.logger.info("步骤6: 保存预测结果")
saved_files = self.save_prediction_results(prediction_results, filename_prefix)
self.logger.info("="*50)
self.logger.info(f"批量预测完成!共生成 {len(saved_files)} 个结果文件")
self.logger.info("="*50)
return saved_files
def main():
"""主函数,用于命令行调用"""
import argparse
parser = argparse.ArgumentParser(description='自定义回归预测模块')
parser.add_argument('--input_csv', required=True, help='输入的光谱采样CSV文件路径')
parser.add_argument('--models_dir', default='9_Custom_Regression_Modeling',
help='回归模型CSV文件目录')
parser.add_argument('--output_dir', default='prediction_results',
help='预测结果输出目录')
parser.add_argument('--coord_cols', nargs='+', default=['longitude', 'latitude'],
help='坐标列名')
parser.add_argument('--prefix', default='custom_regression_prediction',
help='输出文件名前缀')
args = parser.parse_args()
# 创建预测器
predictor = CustomRegressionPredictor(
regression_models_dir=args.models_dir,
output_dir=args.output_dir
)
# 运行预测
saved_files = predictor.run_batch_prediction(
input_csv_path=args.input_csv,
coordinate_columns=args.coord_cols,
filename_prefix=args.prefix
)
print("\n预测结果文件:")
for param_name, filepath in saved_files.items():
print(f" {param_name}: {filepath}")
if __name__ == "__main__":
main()

View File

@ -609,6 +609,11 @@ class WaterQualityScatterBatch:
split_method = best_row['划分方法']
preprocess_method = best_row['预处理方法']
model_name = best_row['建模方法']
# 处理 nan/NaN/None 值,转换为 "None" 字符串
if pd.isna(preprocess_method) or str(preprocess_method).lower() in ['nan', 'none', '']:
preprocess_method = "None"
best_combination = f"{split_method}_{preprocess_method}_{model_name}"
else:
# 简化结果文件格式(英文列名)
@ -617,11 +622,15 @@ class WaterQualityScatterBatch:
parts = best_combination.split('_')
if len(parts) < 3:
raise ValueError(f"无效的模型组合名称格式: {best_combination}")
split_method = parts[0]
preprocess_method = parts[1]
model_name = '_'.join(parts[2:])
# 处理 nan/NaN/None 值,转换为 "None" 字符串
if pd.isna(preprocess_method) or str(preprocess_method).lower() in ['nan', 'none', '']:
preprocess_method = "None"
print(f"最佳模型组合: {best_combination}")
print(f" 划分方法: {split_method}")
print(f" 预处理方法: {preprocess_method}")

File diff suppressed because it is too large Load Diff

View File

@ -58,6 +58,8 @@ from src.core.glint_removal.Hedley import Hedley
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
from src.utils.water_index import WaterQualityIndexCalculator
from src.core.modeling.regression import SingleVariableRegressionAnalysis
# 导入新的自定义回归预测模块
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
# 导入hdr文件处理函数
try:
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path, find_band_number
@ -106,9 +108,9 @@ class WaterQualityInversionPipeline:
self.processed_data_dir = self.work_dir / "4_processed_data"
self.training_spectra_dir = self.work_dir / "5_training_spectra"
self.indices_dir = self.work_dir / "6_water_quality_indices"
self.models_dir = self.work_dir / "7_models"
self.non_empirical_models_dir = self.work_dir / "8_non_empirical_models"
self.custom_regression_dir = self.work_dir / "9_custom_regression"
self.models_dir = self.work_dir / "7_Supervised_Model_Training"
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling"
self.custom_regression_dir = self.work_dir / "9_Custom_Regression_Modeling"
self.sampling_dir = self.work_dir / "10_sampling"
self.prediction_dir = self.work_dir / "11_12_13_predictions"
self.visualization_dir = self.work_dir / "14_visualization"
@ -3379,10 +3381,10 @@ class WaterQualityInversionPipeline:
else:
# 如果output_dir为空使用工作目录
if hasattr(self, 'work_dir') and self.work_dir is not None:
non_empirical_dir = Path(self.work_dir) / "8_non_empirical_models"
non_empirical_dir = Path(self.work_dir) / "8_Regression_Modeling"
else:
# 如果没有工作目录,使用当前目录
non_empirical_dir = Path.cwd() / "8_non_empirical_models"
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
non_empirical_dir.mkdir(parents=True, exist_ok=True)
# 设置默认参数
@ -3592,9 +3594,17 @@ class WaterQualityInversionPipeline:
# 应用预处理 - 使用spectral_Preprocessing模块
from src.preprocessing.spectral_Preprocessing import Preprocessing
# 调用预处理函数
processed_spectral = Preprocessing(preprocess_method, spectral_data)
# 为SS预处理提供scaler保存路径保存在工作目录的7_Supervised_Model_Training中
save_path = None
if preprocess_method == 'SS':
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录
models_dir.mkdir(parents=True, exist_ok=True)
save_path = str(models_dir / "scaler_params.pkl")
print(f"SS预处理: scaler模型将保存到 {save_path}")
# 调用预处理函数为SS方法传递save_path
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
# 重新组合数据
if isinstance(processed_spectral, pd.DataFrame):
@ -3712,7 +3722,7 @@ class WaterQualityInversionPipeline:
if non_empirical_models_dir is not None:
final_models_dir = non_empirical_models_dir
else:
default_models_dir = str(self.work_dir / "8_non_empirical_models")
default_models_dir = str(self.work_dir / "8_Regression_Modeling")
if Path(default_models_dir).exists():
final_models_dir = default_models_dir
else:
@ -3812,28 +3822,31 @@ class WaterQualityInversionPipeline:
raise
def step8_75_predict_with_custom_regression(self, sampling_csv_path: str,
formula_csv_file: str,
custom_regression_dir: Optional[str] = None,
formula_names: Optional[List[str]] = None,
prediction_column: str = 'prediction',
formula_csv_path: Optional[str] = None,
coordinate_columns: Optional[List[str]] = None,
output_dir: Optional[str] = None,
filename_prefix: str = "custom_regression_prediction",
enabled: bool = True,
skip_dependency_check: bool = False) -> Dict[str, str]:
"""
步骤8.75: 使用自定义回归模型进行参数预测
使用步骤6.75的自定义回归结果中的all_regression_results.csv文件
选择每个y_variable中r_squared最高的equation
使用采样点光谱数据计算水质指数,然后进行预测
使用新的CustomRegressionPredictor模块基于9_Custom_Regression_Modeling文件夹中的CSV
根据r_squared选择最佳模型批量预测水质参数
Args:
sampling_csv_path: 采样点光谱数据CSV路径来自步骤7
formula_csv_file: 公式CSV文件路径包含水质指数计算公式
custom_regression_dir: 自定义回归结果目录如果为None使用步骤6.75的结果)
formula_names: 要计算的公式名称列表如果为None则计算所有公式
prediction_column: 预测结果列名
custom_regression_dir: 自定义回归模型目录9_Custom_Regression_Modeling
formula_csv_path: 公式CSV文件路径用于查找index_formula
coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']或自动识别
output_dir: 输出目录默认为prediction_dir
filename_prefix: 输出文件名前缀
enabled: 是否启用
skip_dependency_check: 是否跳过依赖检查
Returns:
预测结果文件路径字典(键为y_variable名)
预测结果文件路径字典(键为参数名)
"""
print("\n" + "="*80)
print("步骤8.75: 使用自定义回归模型进行参数预测")
@ -3848,135 +3861,54 @@ class WaterQualityInversionPipeline:
self._record_step_time("步骤8.75: 自定义回归模型预测", step_start_time, step_end_time, status="skipped")
return {}
# 检查公式CSV文件是否存在
if not Path(formula_csv_file).exists():
raise FileNotFoundError(f"公式CSV文件不存在: {formula_csv_file}")
# 检查采样点CSV文件是否存在
if not Path(sampling_csv_path).exists():
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
# 确定自定义回归结果目录
# 确定自定义回归模型目录
if custom_regression_dir is not None:
final_regression_dir = custom_regression_dir
else:
default_regression_dir = str(self.custom_regression_dir)
if Path(default_regression_dir).exists():
final_regression_dir = default_regression_dir
else:
final_regression_dir = str(self.custom_regression_dir)
if not Path(final_regression_dir).exists():
if skip_dependency_check:
raise ValueError("必须提供custom_regression_dir参数才能独立运行步骤8.75")
else:
raise ValueError("请先执行步骤6.75: 自定义回归分析或提供custom_regression_dir参数")
# 读取all_regression_results.csv文件
regression_results_path = Path(final_regression_dir) / "all_regression_results.csv"
if not regression_results_path.exists():
raise FileNotFoundError(f"未找到自定义回归结果文件: {regression_results_path}")
# 读取回归结果
regression_df = pd.read_csv(regression_results_path)
# 确定输出目录
if output_dir is None:
prediction_output_dir = str(self.prediction_dir)
else:
prediction_output_dir = output_dir
# 使用band_math.py计算水质指数
print("正在使用采样点光谱数据计算水质指数...")
from src.utils.band_math import BandMathCalculator
# 创建计算器实例
calculator = BandMathCalculator(sampling_csv_path)
# 计算所有公式
indices_df = calculator.process_formulas_from_csv(
formula_csv_file,
formula_names=formula_names,
output_file=str(self.prediction_dir / "water_quality_indices.csv")
# 创建CustomRegressionPredictor实例
predictor = CustomRegressionPredictor(
regression_models_dir=final_regression_dir,
formula_csv_path=formula_csv_path,
output_dir=prediction_output_dir
)
if indices_df is None:
raise ValueError("水质指数计算失败")
# 运行批量预测
print(f"开始使用自定义回归模块进行批量预测...")
print(f" 采样点数据: {sampling_csv_path}")
print(f" 回归模型目录: {final_regression_dir}")
print(f" 输出目录: {prediction_output_dir}")
# 读取采样点数据(包含坐标信息)
sampling_df = pd.read_csv(sampling_csv_path)
# 获取所有唯一的y_variable
y_variables = regression_df['y_variable'].unique()
prediction_files = {}
for y_var in y_variables:
try:
# 筛选当前y_variable的所有回归结果
y_var_results = regression_df[regression_df['y_variable'] == y_var]
# 找到r_squared最高的记录
best_result = y_var_results.loc[y_var_results['r_squared'].idxmax()]
# 解析equation
equation = best_result['equation']
x_variable = best_result['x_variable']
print(f"{y_var} 选择最佳方程: {equation} (R² = {best_result['r_squared']:.4f})")
# 检查x_variable是否在水质指数数据中存在
if x_variable not in indices_df.columns:
print(f"警告: x_variable '{x_variable}' 不在水质指数数据中,跳过 {y_var}")
continue
# 合并采样点坐标和水质指数数据
# 假设采样点数据和水质指数数据有相同的行数和顺序
if len(sampling_df) != len(indices_df):
print(f"警告: 采样点数据({len(sampling_df)}行)和水质指数数据({len(indices_df)}行)行数不一致")
# 只取前min(len(sampling_df), len(indices_df))行
min_rows = min(len(sampling_df), len(indices_df))
merged_df = pd.concat([
sampling_df.iloc[:min_rows].reset_index(drop=True),
indices_df.iloc[:min_rows].reset_index(drop=True)
], axis=1)
else:
merged_df = pd.concat([sampling_df, indices_df], axis=1)
# 应用回归方程进行预测
# 使用eval函数安全地计算表达式
try:
# 创建局部命名空间,包含需要的变量和数学函数
local_vars = {x_variable: merged_df[x_variable].values}
# 添加数学函数到命名空间
import math
local_vars.update({
'exp': math.exp,
'log': math.log,
'log10': math.log10,
'sqrt': math.sqrt,
'sin': math.sin,
'cos': math.cos,
'tan': math.tan,
'pi': math.pi,
'e': math.e
})
# 替换方程中的变量名为实际值
# 这里需要确保方程格式正确,例如: "a*x + b"
prediction_values = eval(equation, {"__builtins__": {}}, local_vars)
# 创建预测结果DataFrame
result_df = merged_df[['UTM_X', 'UTM_Y']].copy()
result_df[prediction_column] = prediction_values
# 保存预测结果
output_filename = f"custom_regression_{y_var}.csv"
output_path = str(self.prediction_dir / output_filename)
result_df.to_csv(output_path, index=False)
prediction_files[y_var] = output_path
print(f"成功为 {y_var} 生成预测结果: {output_path}")
except Exception as e:
print(f"应用方程 {equation} 进行预测时出错: {e}")
continue
except Exception as e:
print(f"处理 y_variable {y_var} 时出错: {e}")
continue
saved_files = predictor.run_batch_prediction(
input_csv_path=sampling_csv_path,
coordinate_columns=coordinate_columns,
filename_prefix=filename_prefix
)
step_end_time = time.time()
self._record_step_time("步骤8.75: 自定义回归模型预测", step_start_time, step_end_time)
return prediction_files
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
for param_name, filepath in saved_files.items():
print(f" {param_name}: {filepath}")
return saved_files
except Exception as e:
step_end_time = time.time()
@ -4113,6 +4045,14 @@ def main():
'prediction_column': 'prediction',
'enabled': True # 是否启用非经验模型预测
},
'step8_75': {
'custom_regression_dir': None, # 自定义回归模型目录None表示使用9_Custom_Regression_Modeling
'formula_csv_path': None, # 公式CSV文件路径用于查找index_formula如water_quality_formulas.csv
'coordinate_columns': None, # 坐标列名None表示自动识别
'output_dir': None, # 输出目录None表示使用prediction_dir
'filename_prefix': 'custom_regression_prediction', # 输出文件名前缀
'enabled': True # 是否启用自定义回归预测
},
'step9': {
'boundary_shp_path': r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp" ,
'resolution': 30,

View File

@ -660,7 +660,7 @@ class VisualizationWorkerThread(QThread):
self.failed.emit("训练光谱 CSV 无效或不存在请确认已选择步骤5输出的文件。")
return
if not models_dir or not Path(models_dir).is_dir():
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_models 下的参数子文件夹。")
self.failed.emit("模型目录无效或不存在请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
return
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
@ -672,12 +672,111 @@ class VisualizationWorkerThread(QThread):
from src.postprocessing.visualization_reports import WaterQualityVisualization
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
parts = []
# 获取训练数据CSV路径多个图表类型共用
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
# 生成散点图
if self.extra.get("gen_scatter"):
if training_csv.is_file():
models_dir = wp / "7_Supervised_Model_Training"
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
scatter_paths = pipeline.generate_model_scatter_plots(
training_csv_path=str(training_csv),
models_dir=str(models_dir),
)
count = len(scatter_paths) if scatter_paths else 0
parts.append(f"散点图: {count}")
else:
parts.append("散点图: 跳过(无模型目录)")
else:
parts.append("散点图: 跳过(无训练数据)")
# 生成光谱图
if self.extra.get("gen_spectrum"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
# 推断水质参数列(光谱波段列之前的数值型列)
wl_col = _viz_infer_wavelength_start_column(df)
if isinstance(wl_col, str):
idx = int(df.columns.get_loc(wl_col)) + 1
else:
idx = int(wl_col)
param_cols = []
if idx > 0 and idx < len(df.columns):
param_cols = [
c for c in df.columns[:idx]
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
]
if param_cols:
# plot_spectrum_by_parameter 接受单个参数列,逐个调用
spectrum_paths = []
for param_col in param_cols:
try:
path = viz.plot_spectrum_by_parameter(
csv_path=str(training_csv),
parameter_column=param_col,
wavelength_start_column=wl_col,
n_groups=5,
)
if path:
spectrum_paths.append(path)
except Exception as e:
print(f"生成光谱图失败 ({param_col}): {e}")
count = len(spectrum_paths)
parts.append(f"光谱图: {count}")
else:
parts.append("光谱图: 跳过(无可用参数列)")
else:
parts.append("光谱图: 跳过(无训练数据)")
# 生成统计图
if self.extra.get("gen_boxplots"):
if training_csv.is_file():
import pandas as pd
df = pd.read_csv(training_csv)
# **只统计水质参数列(数值型),排除波长列**
# 获取水质参数列(数值型且不是波长、不是坐标列)
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
param_cols = [
c for c in df.select_dtypes(include=[np.number]).columns
if not any(exc in c.lower() for exc in exclude_cols)
]
# 排除光谱波长列:找到波长开始位置,只取之前的数值列
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
idx = int(df.columns.get_loc(wl)) + 1
else:
idx = int(wl)
if 0 < idx < len(df.columns):
meta_set = set(df.columns[:idx])
param_cols = [c for c in param_cols if c in meta_set]
if param_cols:
output_dict = viz.plot_statistical_charts(
csv_path=str(training_csv),
parameter_columns=param_cols,
)
# plot_statistical_charts 返回字典,统计值非空
count = len([v for v in output_dict.values() if v]) if output_dict else 0
parts.append(f"统计图: {count}")
else:
parts.append("统计图: 跳过(无可用水质参数列)")
else:
parts.append("统计图: 跳过(无训练数据)")
# 生成掩膜/耀斑预览图
if self.extra.get("gen_mask_glint"):
preview_paths = viz.generate_glint_deglint_previews(
work_dir=str(wp),
output_subdir="glint_deglint_previews",
)
parts.append(f"掩膜/耀斑预览 {len(preview_paths) if preview_paths else 0}")
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0}")
# 生成采样点地图
if self.extra.get("gen_sampling_map"):
hyperspectral_files = []
deglint_dir = wp / "3_deglint"
@ -3103,35 +3202,37 @@ class Step8_75Panel(QWidget):
)
layout.addWidget(self.sampling_csv_file)
# 公式CSV文件选择
# 自定义回归模型目录选择9_Custom_Regression_Modeling
self.regression_models_dir = FileSelectWidget(
"回归模型目录:",
"Directories;;All Files (*.*)"
)
self.regression_models_dir.label.setText("回归模型目录:")
# 修改浏览按钮为选择目录
self.regression_models_dir.browse_btn.clicked.disconnect()
self.regression_models_dir.browse_btn.clicked.connect(self.browse_regression_models_dir)
self.regression_models_dir.set_path("9_Custom_Regression_Modeling") # 设置默认值
layout.addWidget(self.regression_models_dir)
# 公式CSV文件选择用于查找index_formula
self.formula_csv_file = FileSelectWidget(
"公式CSV文件:",
"CSV Files (*.csv);;All Files (*.*)"
)
self.formula_csv_file.label.setText("公式CSV文件:")
layout.addWidget(self.formula_csv_file)
# 模型目录选择
self.models_dir_file = FileSelectWidget(
"模型目录:",
# 输出目录选择
self.output_dir_widget = FileSelectWidget(
"输出目录:",
"Directories;;All Files (*.*)"
)
self.models_dir_file.label.setText("模型目录:")
self.output_dir_widget.label.setText("输出目录:")
# 修改浏览按钮为选择目录
self.models_dir_file.browse_btn.clicked.disconnect()
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
layout.addWidget(self.models_dir_file)
# 参数设置
params_group = QGroupBox("预测参数")
params_layout = QFormLayout()
# 预测列名
self.prediction_column = QLineEdit()
self.prediction_column.setText("prediction")
params_layout.addRow("预测列名:", self.prediction_column)
params_group.setLayout(params_layout)
layout.addWidget(params_group)
self.output_dir_widget.browse_btn.clicked.disconnect()
self.output_dir_widget.browse_btn.clicked.connect(self.browse_output_dir)
self.output_dir_widget.line_edit.setPlaceholderText("留空使用默认prediction目录")
layout.addWidget(self.output_dir_widget)
# 启用步骤
self.enable_checkbox = QCheckBox("启用此步骤")
@ -3162,45 +3263,59 @@ class Step8_75Panel(QWidget):
layout.addStretch()
self.setLayout(layout)
def browse_models_dir(self):
"""浏览模型目录"""
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", "")
def browse_regression_models_dir(self):
"""浏览回归模型目录"""
dir_path = QFileDialog.getExistingDirectory(self, "选择回归模型目录", "")
if dir_path:
self.models_dir_file.set_path(dir_path)
self.regression_models_dir.set_path(dir_path)
def browse_output_dir(self):
"""浏览输出目录"""
dir_path = QFileDialog.getExistingDirectory(self, "选择输出目录", "")
if dir_path:
self.output_dir_widget.set_path(dir_path)
def get_config(self):
"""获取配置"""
config = {
'prediction_column': self.prediction_column.text(),
'enabled': self.enable_checkbox.isChecked()
}
# 添加采样光谱CSV路径
sampling_csv_path = self.sampling_csv_file.get_path()
if sampling_csv_path:
config['sampling_csv_path'] = sampling_csv_path
# 添加回归模型目录路径
regression_models_dir = self.regression_models_dir.get_path()
if regression_models_dir:
config['custom_regression_dir'] = regression_models_dir
# 添加公式CSV文件路径
formula_csv_path = self.formula_csv_file.get_path()
if formula_csv_path:
config['formula_csv_file'] = formula_csv_path
# 添加模型目录路径
models_dir = self.models_dir_file.get_path()
if models_dir:
config['custom_regression_dir'] = models_dir
config['formula_csv_path'] = formula_csv_path
# 添加输出目录路径
output_dir = self.output_dir_widget.get_path()
if output_dir:
config['output_dir'] = output_dir
return config
def set_config(self, config):
"""设置配置"""
if 'prediction_column' in config:
self.prediction_column.setText(config['prediction_column'])
if 'sampling_csv_path' in config:
self.sampling_csv_file.set_path(config['sampling_csv_path'])
if 'formula_csv_file' in config:
self.formula_csv_file.set_path(config['formula_csv_file'])
if 'custom_regression_dir' in config:
self.models_dir_file.set_path(config['custom_regression_dir'])
self.regression_models_dir.set_path(config['custom_regression_dir'])
if 'formula_csv_path' in config:
self.formula_csv_file.set_path(config['formula_csv_path'])
if 'output_dir' in config:
self.output_dir_widget.set_path(config['output_dir'])
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
@ -3213,9 +3328,9 @@ class Step8_75Panel(QWidget):
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件")
return
formula_csv_path = self.formula_csv_file.get_path()
if not formula_csv_path:
QMessageBox.warning(self, "输入错误", "请选择公式CSV文件")
regression_models_dir = self.regression_models_dir.get_path()
if not regression_models_dir:
QMessageBox.warning(self, "输入错误", "请选择回归模型目录")
return
# 获取配置
@ -3323,8 +3438,7 @@ class ImageCategoryTree(QTreeWidget):
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
("采样分析", ["sampling", "flight_path", "point_map", "trajectory"], "📍"),
("其他图表", [], "📁"),
("含量分布图", [], "📁"),
]
def __init__(self, parent=None):
@ -3376,7 +3490,7 @@ class ImageCategoryTree(QTreeWidget):
# 根据文件名关键词确定类别
category = self._determine_category(file_path.name)
category_item = self.category_items.get(category, self.category_items["其他图表"])
category_item = self.category_items.get(category, self.category_items["含量分布图"])
# 创建图像项
image_item = QTreeWidgetItem(category_item)
@ -3394,7 +3508,7 @@ class ImageCategoryTree(QTreeWidget):
if any(keyword in filename_lower for keyword in keywords):
return category_name
return "其他图表"
return "含量分布图"
def scan_directory(self, work_dir: str):
"""扫描目录中的所有图像文件"""
@ -3682,11 +3796,7 @@ class VisualizationPanel(QWidget):
def _viz_set_busy(self, busy: bool):
for w in (
getattr(self, "gen_all_btn", None),
getattr(self, "gen_scatter_btn", None),
getattr(self, "gen_spectrum_btn", None),
getattr(self, "gen_stats_btn", None),
getattr(self, "gen_mask_glint_btn", None),
getattr(self, "gen_sampling_map_btn", None),
getattr(self, "scan_btn", None),
):
if w is not None:
w.setEnabled(not busy)
@ -3728,7 +3838,13 @@ class VisualizationPanel(QWidget):
return [c for c in meta if pd.api.types.is_numeric_dtype(df[c])]
def _statistics_param_columns(self, df: pd.DataFrame) -> List[str]:
"""统计图用的参数列;若存在光谱波段,则只统计波段前的字段。"""
"""统计图用的参数列**只统计水质参数列(数值型),排除波长列**。
- 包括:数值型的水质参数(浓度、含量等)
- 排除:光谱波长列(虽然也是数值型,但不是水质参数)
- 排除坐标列UTM_X, UTM_Y, lat, lon等
若存在光谱波段,则只统计波段前的数值字段。
"""
# 选择数值类型列
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
wl = _viz_infer_wavelength_start_column(df)
if isinstance(wl, str):
@ -3737,6 +3853,7 @@ class VisualizationPanel(QWidget):
idx = int(wl)
coord_kw = ("utm", "lat", "lon")
if 0 < idx < len(df.columns):
# 只取波长开始之前的列(水质参数区域)
meta_set = set(df.columns[:idx])
return [
col
@ -3744,6 +3861,7 @@ class VisualizationPanel(QWidget):
if col in meta_set and not any(x in str(col).lower() for x in coord_kw)
]
return [
# 如果没有找到波长列,排除坐标相关列
col
for col in numeric_cols
if not any(x in str(col).lower() for x in coord_kw + ("x", "y"))
@ -3825,7 +3943,7 @@ class VisualizationPanel(QWidget):
QMessageBox.warning(
self,
"提示",
"未生成任何散点图。请确认 7_models 下已有各参数子目录及模型文件,"
"未生成任何散点图。请确认 7_Supervised_Model_Training 下已有各参数子目录及模型文件,"
"且训练 CSV 与建模时一致。",
)
elif t == "generate_all_selected":
@ -3870,36 +3988,22 @@ class VisualizationPanel(QWidget):
self.image_tree = ImageCategoryTree()
self.image_tree.itemClicked.connect(self.on_tree_item_clicked)
tree_layout.addWidget(self.image_tree)
# 生成按钮组
gen_btn_layout = QHBoxLayout()
self.gen_all_btn = QPushButton("🚀 生成全部")
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
gen_btn_layout.addWidget(self.gen_all_btn)
self.scan_btn = QPushButton("📁 扫描")
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
self.scan_btn.clicked.connect(self.scan_work_directory)
gen_btn_layout.addWidget(self.scan_btn)
tree_layout.addLayout(gen_btn_layout)
tree_group.setLayout(tree_layout)
left_layout.addWidget(tree_group, 1)
# 可视化配置
config_group = QGroupBox("可视化配置")
config_layout = QVBoxLayout()
self.gen_scatter = QCheckBox("模型评估散点图")
self.gen_scatter.setChecked(True)
config_layout.addWidget(self.gen_scatter)
self.gen_spectrum = QCheckBox("光谱曲线图")
self.gen_spectrum.setChecked(True)
config_layout.addWidget(self.gen_spectrum)
self.gen_boxplots = QCheckBox("统计图表")
self.gen_boxplots.setChecked(True)
config_layout.addWidget(self.gen_boxplots)
@ -3912,6 +4016,27 @@ class VisualizationPanel(QWidget):
self.gen_sampling_map.setChecked(True)
config_layout.addWidget(self.gen_sampling_map)
# 添加分隔线
config_layout.addSpacing(10)
line = QFrame()
line.setFrameShape(QFrame.HLine)
line.setStyleSheet("color: #ddd;")
config_layout.addWidget(line)
config_layout.addSpacing(10)
# 生成全部按钮
self.gen_all_btn = QPushButton("🚀 生成全部")
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
config_layout.addWidget(self.gen_all_btn)
# 扫描按钮
self.scan_btn = QPushButton("📁 扫描目录")
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
self.scan_btn.clicked.connect(self.scan_work_directory)
config_layout.addWidget(self.scan_btn)
config_group.setLayout(config_layout)
left_layout.addWidget(config_group)
@ -3928,43 +4053,7 @@ class VisualizationPanel(QWidget):
self.image_viewer = ImageViewerWidget()
self.image_viewer.refresh_btn.clicked.connect(self.scan_work_directory)
right_layout.addWidget(self.image_viewer, 1)
# 生成特定图表按钮组
specific_group = QGroupBox("生成特定图表")
specific_layout = QHBoxLayout()
self.gen_scatter_btn = QPushButton("📊 散点图")
self.gen_scatter_btn.setToolTip(
"基于工作目录下 5_training_spectra/training_spectra.csv 与 7_models 生成模型评估散点图"
)
self.gen_scatter_btn.clicked.connect(lambda: self.generate_chart('scatter'))
specific_layout.addWidget(self.gen_scatter_btn)
self.gen_spectrum_btn = QPushButton("📈 光谱图")
self.gen_spectrum_btn.setToolTip(
"基于 5_training_spectra/training_spectra.csv为每个数值型水质参数各生成一张光谱对比图无需选择"
)
self.gen_spectrum_btn.clicked.connect(lambda: self.generate_chart('spectrum'))
specific_layout.addWidget(self.gen_spectrum_btn)
self.gen_stats_btn = QPushButton("📉 统计图")
self.gen_stats_btn.setToolTip(
"基于工作目录下 5_training_spectra/training_spectra.csv 生成箱线图、直方图与相关性热力图"
)
self.gen_stats_btn.clicked.connect(lambda: self.generate_chart('statistics'))
specific_layout.addWidget(self.gen_stats_btn)
self.gen_mask_glint_btn = QPushButton("🖼️ 掩膜图")
self.gen_mask_glint_btn.clicked.connect(lambda: self.generate_mask_glint_previews())
specific_layout.addWidget(self.gen_mask_glint_btn)
self.gen_sampling_map_btn = QPushButton("📍 采样点图")
self.gen_sampling_map_btn.clicked.connect(lambda: self.generate_sampling_point_map())
specific_layout.addWidget(self.gen_sampling_map_btn)
specific_group.setLayout(specific_layout)
right_layout.addWidget(specific_group)
right_panel.setLayout(right_layout)
main_layout.addWidget(right_panel, 1)
@ -3999,6 +4088,9 @@ class VisualizationPanel(QWidget):
print(f"扫描工作目录: {work_path}")
self.image_tree.scan_directory(str(work_path))
# 设置三个预测步骤的默认输出路径
self._setup_prediction_output_dirs(work_path)
# 如果有图像,自动选择第一个
viz_dir = work_path / "14_visualization"
if viz_dir.exists():
@ -4006,6 +4098,41 @@ class VisualizationPanel(QWidget):
if image_files:
self.image_viewer.load_image(str(image_files[0]))
def _setup_prediction_output_dirs(self, work_path: Path):
"""
设置三个预测步骤的默认输出目录
在11_12_13_predictions下创建三个子文件夹
"""
try:
# 基础预测目录
base_prediction_dir = work_path / "11_12_13_predictions"
# 三个子文件夹路径
ml_dir = base_prediction_dir / "Machine_Learning_Prediction"
reg_dir = base_prediction_dir / "Regression_Model_Prediction"
custom_dir = base_prediction_dir / "Custom_Regression_Prediction"
# 创建目录(如果不存在)
ml_dir.mkdir(parents=True, exist_ok=True)
reg_dir.mkdir(parents=True, exist_ok=True)
custom_dir.mkdir(parents=True, exist_ok=True)
# 设置Step8Panel机器学习预测的默认输出路径
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
self.step8_panel.output_file.set_path(str(ml_dir))
# 设置Step8_5Panel回归模型预测的默认输出路径
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
self.step8_5_panel.output_file.set_path(str(reg_dir))
# 设置Step8_75Panel自定义回归预测的默认输出路径
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
except Exception as e:
print(f"设置预测输出目录失败: {e}")
def on_tree_item_clicked(self, item, column):
"""目录树项点击事件"""
data = item.data(0, Qt.UserRole)
@ -4022,37 +4149,37 @@ class VisualizationPanel(QWidget):
if not self.work_dir:
QMessageBox.warning(self, "警告", "请先选择工作目录!")
return
work_path = Path(self.work_dir)
if not work_path.exists():
QMessageBox.warning(self, "警告", "工作目录不存在!")
return
reply = QMessageBox.question(
self, "确认生成",
"将按左侧勾选项在后台生成可视化(掩膜/耀斑预览、采样点图等),可能需要较长时间。\n是否继续?",
QMessageBox.Yes | QMessageBox.No
)
if reply != QMessageBox.Yes:
return
if self.gen_scatter.isChecked():
print("生成散点图...(占位,请用建模/可视化流程生成)")
if self.gen_spectrum.isChecked():
print("生成光谱图...(占位,请用下方「光谱图」按钮)")
if self.gen_boxplots.isChecked():
print("生成统计图...(占位,请用下方「统计图」按钮)")
if not self.gen_mask_glint.isChecked() and not self.gen_sampling_map.isChecked():
# 检查是否有任何选项被勾选
if not (self.gen_scatter.isChecked() or self.gen_spectrum.isChecked() or
self.gen_boxplots.isChecked() or self.gen_mask_glint.isChecked() or
self.gen_sampling_map.isChecked()):
QMessageBox.information(
self,
"提示",
"请至少勾选「掩膜和耀斑缩略图」或「采样点地图」以执行后台批量任务",
"请至少勾选一项可视化配置选项以生成图表",
)
return
reply = QMessageBox.question(
self, "确认生成",
"将根据左侧勾选项在后台生成可视化图表,可能需要较长时间。\n是否继续?",
QMessageBox.Yes | QMessageBox.No
)
if reply != QMessageBox.Yes:
return
# 收集所有选中的任务
extra = {
"gen_scatter": self.gen_scatter.isChecked(),
"gen_spectrum": self.gen_spectrum.isChecked(),
"gen_boxplots": self.gen_boxplots.isChecked(),
"gen_mask_glint": self.gen_mask_glint.isChecked(),
"gen_sampling_map": self.gen_sampling_map.isChecked(),
}
@ -4082,7 +4209,7 @@ class VisualizationPanel(QWidget):
)
return
training_csv = training_spectra_csv
models_dir = work_path / "7_models"
models_dir = work_path / "7_Supervised_Model_Training"
if not models_dir.is_dir() or not any(
d.is_dir() for d in models_dir.iterdir()
):
@ -4284,7 +4411,6 @@ class VisualizationPanel(QWidget):
'generate_scatter': self.gen_scatter.isChecked(),
'generate_boxplots': self.gen_boxplots.isChecked(),
'generate_spectrum': self.gen_spectrum.isChecked(),
'generate_statistics': self.gen_stats_btn.isChecked(),
'generate_glint_previews': self.gen_mask_glint.isChecked(),
'generate_sampling_maps': self.gen_sampling_map.isChecked(),
'scatter_config': {
@ -4314,8 +4440,6 @@ class VisualizationPanel(QWidget):
self.gen_boxplots.setChecked(config['generate_boxplots'])
if 'generate_spectrum' in config:
self.gen_spectrum.setChecked(config['generate_spectrum'])
if 'generate_statistics' in config:
self.gen_stats_btn.setChecked(config['generate_statistics'])
if 'generate_glint_previews' in config:
self.gen_mask_glint.setChecked(config['generate_glint_previews'])
if 'generate_sampling_maps' in config:
@ -4755,7 +4879,7 @@ class Step6_5Panel(QWidget):
"输出模型目录:",
"Directories;;All Files (*.*)"
)
self.output_dir.line_edit.setPlaceholderText("8_non_empirical_models")
self.output_dir.line_edit.setPlaceholderText("8_Regression_Modeling")
# 修改浏览按钮为选择目录
self.output_dir.browse_btn.clicked.disconnect()
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
@ -4825,9 +4949,9 @@ class Step6_5Panel(QWidget):
# 如果output_dir为空使用工作目录或当前目录
main_window = self.parent().window()
if hasattr(main_window, 'work_dir') and main_window.work_dir:
output_dir = str(Path(main_window.work_dir) / "8_non_empirical_models")
output_dir = str(Path(main_window.work_dir) / "8_Regression_Modeling")
else:
output_dir = str(Path.cwd() / "8_non_empirical_models")
output_dir = str(Path.cwd() / "8_Regression_Modeling")
config['output_dir'] = output_dir
# 添加训练数据路径(用于独立运行)
@ -4952,7 +5076,8 @@ class Step6_75Panel(QWidget):
# 创建滚动区域来容纳自变量选择
x_scroll = QScrollArea()
x_scroll.setWidgetResizable(True)
x_scroll.setMaximumHeight(200)
x_scroll.setMinimumHeight(250)
x_scroll.setMaximumHeight(350)
x_widget = QWidget()
self.x_columns_layout = QGridLayout()
@ -4982,7 +5107,8 @@ class Step6_75Panel(QWidget):
# 创建滚动区域来容纳因变量选择
y_scroll = QScrollArea()
y_scroll.setWidgetResizable(True)
y_scroll.setMaximumHeight(150)
y_scroll.setMinimumHeight(200)
y_scroll.setMaximumHeight(300)
y_widget = QWidget()
self.y_columns_layout = QGridLayout()
@ -5044,7 +5170,7 @@ class Step6_75Panel(QWidget):
output_layout = QFormLayout()
self.output_dir = QLineEdit()
self.output_dir.setText("9_custom_regression")
self.output_dir.setText("9_Custom_Regression_Modeling")
output_layout.addRow("输出目录名:", self.output_dir)
output_group.setLayout(output_layout)
@ -5202,7 +5328,7 @@ class Step6_75Panel(QWidget):
checkbox.setChecked(method in selected_methods)
if 'output_dir' in config:
self.output_dir.setText(config['output_dir'] or "9_custom_regression")
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
if 'enabled' in config:
self.enable_checkbox.setChecked(config['enabled'])
@ -5574,26 +5700,28 @@ class WaterQualityGUI(QMainWindow):
self.step_list = QListWidget()
self.step_list.setStyleSheet(ModernStylesheet.get_sidebar_stylesheet())
# 定义阶段结构
# 定义阶段结构
self.process_stages = {
"阶段一:数据预处理": [
"阶段一:影像预处理": [
("step1", "1. 水域掩膜生成"),
("step2", "2. 耀斑区域识别"),
("step3", "3. 耀斑去除与修复"),
("step4", "4. 数据标准化处理"),
],
"阶段二:特征提取与建模": [
"阶段二:样本数据准备 ": [
("step4", "4. 数据标准化处理"),
("step5", "5. 光谱特征提取"),
("step5_5", "6. 水质参数指数计算"),
("step6", "7. 监督学习模型训练"),
("step6_5", "8. 经验统计回归"),
("step6_75", "9. 自定义回归模型"),
],
"阶段三:应用与可视化": [
"阶段三:模型构建与训练": [
("step6", "7. 机器学习模型训练"),
("step6_5", "8. 回归模型训练"),
("step6_75", "9. 自定义回归模型训练"),
],
"阶段四:预测与成果输出 ": [
("step7", "10. 采样点布设"),
("step8", "11. 基于监督学习预测"),
("step8_5", "12. 基于统计回归预测"),
("step8_75", "13. 基于自定义回归预测"),
("step8", "11. 机器学习学习预测"),
("step8_5", "12. 回归预测"),
("step8_75", "13. 自定义回归预测"),
("step9", "14. 专题图生成"),
("step9_viz", "15. 可视化分析"),
("step_report", "16. 分析报告生成"),

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,15 @@
ENVI
description = {
work_dir\10_sampling\sampling_spectra_valid_area.bsq}
samples = 11363
lines = 10408
bands = 1
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bsq
byte order = 0
map info = {UTM, 1, 1, 600742.055, 4613386.65, 0.2, 0.2, 51, North,WGS-84}
coordinate system string = {PROJCS["unnamed",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",123.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]}
band names = {
Band 1}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

View File

@ -0,0 +1,15 @@
ENVI
description = {
work_dir\1_water_mask\water_mask_from_shp.dat}
samples = 11363
lines = 10408
bands = 1
header offset = 0
file type = ENVI Standard
data type = 4
interleave = bsq
byte order = 0
map info = {UTM, 1, 1, 600742.055, 4613386.65, 0.2, 0.2, 51, North,WGS-84}
coordinate system string = {PROJCS["unnamed",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",123.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]}
band names = {
Band 1}

View File

@ -0,0 +1,15 @@
ENVI
description = {
work_dir\1_water_mask\water_mask_from_shp__tmp_delete.dat}
samples = 7411
lines = 5368
bands = 1
header offset = 0
file type = ENVI Standard
data type = 1
interleave = bsq
byte order = 0
map info = {UTM, 1, 1, 600947.955, 4612951.75, 0.2, 0.2, 51, North,WGS-84}
coordinate system string = {PROJCS["WGS_1984_UTM_Zone_51N",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",123.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]}
band names = {
Band 1}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

View File

@ -281,11 +281,13 @@ class WaterQualityVisualization:
def plot_statistical_charts(self, csv_path: str, parameter_columns: List[str],
output_dir: Optional[str] = None) -> Dict[str, str]:
"""
绘制统计图表:箱线图、直方图、相关性热力图
绘制统计图表:**只针对水质参数列**(数值型,排除波长列)
- 水质参数列(如浓度、含量等数值型参数)使用箱线图/直方图/相关性热力图
- 排除光谱波长列(虽然也是数值型,但不是水质参数)
Args:
csv_path: CSV文件路径
parameter_columns: 参数列名列表
parameter_columns: **水质参数**列名列表(数值型,已排除波长列)
output_dir: 输出目录
Returns:
@ -301,12 +303,16 @@ class WaterQualityVisualization:
output_paths = {}
# 水质参数统计图表(针对数值型参数,排除波长列)
# 假设传入的 parameter_columns 已经是过滤后的水质参数列
numeric_cols = [col for col in parameter_columns if col in df.columns and pd.api.types.is_numeric_dtype(df[col])]
# 1. 箱线图
if len(parameter_columns) > 0:
if len(numeric_cols) > 0:
fig, ax = plt.subplots(figsize=(12, 6))
data_for_boxplot = [df[col].dropna() for col in parameter_columns if col in df.columns]
data_for_boxplot = [df[col].dropna() for col in numeric_cols]
if data_for_boxplot:
ax.boxplot(data_for_boxplot, labels=[col for col in parameter_columns if col in df.columns])
ax.boxplot(data_for_boxplot, labels=numeric_cols)
ax.set_ylabel('数值', fontsize=12, fontweight='bold')
ax.set_title('水质参数箱线图', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
@ -318,51 +324,51 @@ class WaterQualityVisualization:
plt.close()
output_paths['boxplot'] = str(boxplot_path)
# 2. 直方图
for col in parameter_columns:
if col not in df.columns:
continue
# 2. 直方图 (每个水质参数列)
for col in numeric_cols:
fig, ax = plt.subplots(figsize=(10, 6))
data = df[col].dropna()
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
# 添加统计信息
mean_val = data.mean()
std_val = data.std()
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
ax.legend()
plt.tight_layout()
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
hist_path = output_dir / f"{safe_name}_histogram.png"
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths[f'histogram_{col}'] = str(hist_path)
# 3. 相关性热力图
if len(parameter_columns) >= 2:
valid_cols = [col for col in parameter_columns if col in df.columns]
if len(valid_cols) >= 2:
corr_matrix = df[valid_cols].corr()
if len(data) > 1:
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')
# 添加统计信息
mean_val = data.mean()
std_val = data.std() if len(data) > 1 else 0
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
ax.legend()
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
ax=ax, vmin=-1, vmax=1)
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
plt.tight_layout()
heatmap_path = output_dir / "correlation_heatmap.png"
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
hist_path = output_dir / f"{safe_name}_histogram.png"
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths['heatmap'] = str(heatmap_path)
output_paths[f'histogram_{col}'] = str(hist_path)
print(f"统计图表已保存到: {output_dir}")
# 3. 相关性热力图
if len(numeric_cols) >= 2:
corr_matrix = df[numeric_cols].corr()
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
ax=ax, vmin=-1, vmax=1)
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
plt.tight_layout()
heatmap_path = output_dir / "correlation_heatmap.png"
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
plt.close()
output_paths['heatmap'] = str(heatmap_path)
if not output_paths:
print("警告: 没有生成任何统计图表(可能无合适的水质参数列)")
else:
print(f"统计图表已保存到: {output_dir},共 {len(output_paths)} 个文件")
return output_paths
def plot_distribution_map_enhanced(self, prediction_csv_path: str,

View File

@ -11,16 +11,24 @@ def MMS(input_spectrum):
output_spectrum = MinMaxScaler().fit_transform(input_spectrum)
return output_spectrum
# 标准化
# 标准化 (StandardScaler)
def SS(input_spectrum, save_path=None):
"""标准化预处理,使用 StandardScaler 拟合并可选保存模型参数。
Args:
input_spectrum: 输入光谱数据 (numpy array or DataFrame)
save_path: scaler模型保存路径。如果提供则保存到该路径推荐保存到7_Supervised_Model_Training目录
"""
# 初始化 StandardScaler 并拟合数据
scaler = StandardScaler()
output_spectrum = scaler.fit_transform(input_spectrum)
# 如果指定了保存路径,保存 scaler 对象
# 如果指定了保存路径,保存 scaler 对象(用于后续预测时加载)
if save_path:
import os
os.makedirs(os.path.dirname(save_path), exist_ok=True)
joblib.dump(scaler, save_path)
print(f"Scaler parameters saved to {save_path}")
print(f"SS Scaler parameters saved to: {save_path}")
return output_spectrum
@ -124,7 +132,14 @@ def wave(input_spectrum):
return output_spectrum
# 通用预处理函数
def Preprocessing(method, input_spectrum):
def Preprocessing(method, input_spectrum, save_path=None):
"""通用预处理函数
Args:
method: 预处理方法名称
input_spectrum: 输入光谱数据
save_path: 可选的模型保存路径仅SS方法使用保存到7_Supervised_Model_Training目录
"""
if isinstance(input_spectrum, np.ndarray):
input_spectrum = pd.DataFrame(input_spectrum)
if method == "None":
@ -132,7 +147,11 @@ def Preprocessing(method, input_spectrum):
elif method == 'MMS':
output_spectrum = MMS(input_spectrum.values)
elif method == 'SS':
output_spectrum = SS(input_spectrum.values, r'E:\code\WQ\models/scaler_params.pkl')
# SS预处理模型保存到工作目录的7_Supervised_Model_Training/scaler_params.pkl
# 如果调用者没有提供save_path则使用默认路径
if not save_path:
save_path = r'E:\code\WQ\models\scaler_params.pkl'
output_spectrum = SS(input_spectrum.values, save_path)
elif method == 'CT':
output_spectrum = CT(input_spectrum.values)
elif method == 'SNV':

View File

@ -16,78 +16,46 @@ except ImportError:
def get_wavelengths_from_bil_header(bil_file):
"""
从BIL文件的头文件中读取波长信息
从BIL文件的头文件中读取波长信息使用spectral库
参数:
bil_file: str - BIL文件路径
bil_file: str - BIL文件路径
返回:
list - 波长列表如果无法获取则返回None
list - 波长列表如果无法获取则返回None
"""
try:
# 获取头文件路径通常与BIL文件同目录后缀为.hdr
# 获取头文件路径
header_file = os.path.splitext(bil_file)[0] + ".hdr"
if not os.path.exists(header_file):
print(f"警告: 找不到头文件 {header_file}")
return None
wavelengths = []
# 使用spectral库读取头文件
import spectral.io.envi as envi
header = envi.read_envi_header(header_file)
with open(header_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
# 获取波长信息
wavelengths = header.get('wavelength', None)
if wavelengths is not None:
# 确保是列表形式
if isinstance(wavelengths, str):
# 如果是字符串,解析为列表
wavelengths = wavelengths.strip('{}').replace(',', ' ').split()
wavelengths = [float(w.strip()) for w in wavelengths if w.strip()]
# 查找包含波长信息的行
wavelength_lines = []
in_wavelength_block = False
# 过滤掉0值和无效值
wavelengths = [float(w) for w in wavelengths if float(w) > 0]
for line in lines:
stripped_line = line.strip()
# 检测波长块的开始(精确匹配 wavelength =
if stripped_line.startswith('wavelength ='):
in_wavelength_block = True
# 提取第一行的波长信息
wavelength_str = stripped_line.replace('wavelength =', '').strip()
if wavelength_str.startswith('{'):
wavelength_str = wavelength_str[1:].strip()
wavelength_lines.append(wavelength_str)
# 检测波长块的中间行
elif in_wavelength_block:
if '}' in stripped_line:
# 波长块结束
end_str = stripped_line.replace('}', '').strip()
if end_str:
wavelength_lines.append(end_str)
in_wavelength_block = False
else:
wavelength_lines.append(stripped_line)
print(f"从头文件读取到 {len(wavelengths)} 个波长值")
print(f"波长范围: {min(wavelengths):.2f} ~ {max(wavelengths):.2f}")
return wavelengths
else:
print("警告: 头文件中未找到波长信息")
return None
if wavelength_lines:
# 合并所有波长行
combined_wavelengths = ' '.join(wavelength_lines)
# 移除所有花括号和逗号
combined_wavelengths = combined_wavelengths.replace('{', '').replace('}', '').strip()
# 分割波长值(支持逗号和空格分隔)
wavelength_values = []
for part in combined_wavelengths.split(','):
part = part.strip()
if part:
# 处理可能的多值情况(空格分隔)
for value in part.split():
if value.strip():
try:
wavelength_values.append(float(value.strip()))
except ValueError:
continue
print(f"从头文件读取到 {len(wavelength_values)} 个波长值")
return wavelength_values
else:
print("警告: 头文件中未找到波长信息")
return None
except Exception as e:
print(f"读取头文件波长信息时发生错误: {str(e)}")
return None
@ -1021,10 +989,10 @@ def get_coor_base_interval(water_mask, severe_glint=None, output_csvpath=None, i
# 使用示例
if __name__ == "__main__":
# 新功能使用示例
bil_file = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"
water_mask_shp = r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp"
severe_glint = r"D:\BaiduNetdiskDownload\yaobao\find_glint\result3_glint_otsu"
output_csvpath = r"D:\BaiduNetdiskDownload\yaobao\csv\spectral_sampling_results.csv"
bil_file = r"E:\wq_gui_test\3_deglint\deglint_goodman.bsq"
water_mask_shp = r"E:\wq_gui_test\1_water_mask\water_mask_from_shp.dat"
severe_glint = r"E:\wq_gui_test\2_glint\severe_glint_area.dat"
output_csvpath = r"E:\wq_gui_test\10_sampling\sampling_spectra.csv"
# 设置参数
interval = 50 # 基础采样点间隔像元数当use_adaptive_sampling=False时使用