fix: 修复工作目录与步骤名不对应、回归预测虚数报错、模型加载及预处理名称转换问题,重构可视化并修正勾选联动
This commit is contained in:
654
src/core/prediction/custom_regression_prediction.py
Normal file
654
src/core/prediction/custom_regression_prediction.py
Normal file
@ -0,0 +1,654 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自定义回归预测模块
|
||||
|
||||
该模块根据9_Custom_Regression_Modeling文件夹中的CSV信息,批量预测水质指数。
|
||||
处理流程:
|
||||
1. 读取9_Custom_Regression_Modeling文件夹中的CSV文件
|
||||
2. 根据r_squared选择最佳模型(指数公式+反演公式)
|
||||
3. 使用指数公式计算光谱指数值
|
||||
4. 使用反演公式计算水质参数值
|
||||
5. 输出包含投影经纬度和预测值的CSV文件
|
||||
|
||||
作者: Assistant
|
||||
创建时间: 2026-04-14
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Union, Tuple, Optional
|
||||
import warnings
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 导入水质指数计算器
|
||||
try:
|
||||
from src.utils.water_index import WaterQualityIndexCalculator
|
||||
except ImportError:
|
||||
from ..utils.water_index import WaterQualityIndexCalculator
|
||||
|
||||
|
||||
class CustomRegressionPredictor:
|
||||
"""
|
||||
自定义回归预测器
|
||||
|
||||
基于9_Custom_Regression_Modeling文件夹中的回归模型CSV文件,
|
||||
进行水质参数的批量预测。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
regression_models_dir: str = "9_Custom_Regression_Modeling",
|
||||
formula_csv_path: Optional[str] = None,
|
||||
output_dir: str = "prediction_results",
|
||||
log_level: int = logging.INFO):
|
||||
"""
|
||||
初始化预测器
|
||||
|
||||
Args:
|
||||
regression_models_dir: 回归模型CSV文件所在目录
|
||||
formula_csv_path: 公式CSV文件路径,用于查找index_formula
|
||||
output_dir: 预测结果输出目录
|
||||
log_level: 日志级别
|
||||
"""
|
||||
self.regression_models_dir = Path(regression_models_dir)
|
||||
self.formula_csv_path = formula_csv_path
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 初始化日志
|
||||
self._setup_logging(log_level)
|
||||
|
||||
# 初始化水质指数计算器
|
||||
self.index_calculator = WaterQualityIndexCalculator()
|
||||
|
||||
# 存储加载的回归模型信息
|
||||
self.regression_models = {}
|
||||
self.best_models = {}
|
||||
|
||||
# 加载公式CSV(如果提供)
|
||||
self.formula_df = None
|
||||
if self.formula_csv_path and Path(self.formula_csv_path).exists():
|
||||
self.formula_df = pd.read_csv(self.formula_csv_path)
|
||||
self.logger.info(f"加载公式CSV: {self.formula_csv_path}, 共 {len(self.formula_df)} 个公式")
|
||||
|
||||
self.logger.info(f"CustomRegressionPredictor初始化完成")
|
||||
self.logger.info(f"回归模型目录: {self.regression_models_dir}")
|
||||
self.logger.info(f"输出目录: {self.output_dir}")
|
||||
|
||||
def _setup_logging(self, log_level: int):
|
||||
"""设置日志配置"""
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.logger.setLevel(log_level)
|
||||
|
||||
# 避免重复添加处理器
|
||||
if not self.logger.handlers:
|
||||
# 创建控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(log_level)
|
||||
|
||||
# 创建格式器
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def load_regression_models(self) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
加载9_Custom_Regression_Modeling文件夹中的所有CSV文件
|
||||
|
||||
支持的CSV格式:
|
||||
- 回归结果CSV包含列:y_variable, x_variable, equation, r_squared
|
||||
- 其中 x_variable 是指数名称,需要从 formula_csv 中查找对应的 Formula
|
||||
- equation 是反演公式
|
||||
|
||||
Returns:
|
||||
Dict[str, pd.DataFrame]: 参数名 -> 回归模型数据框
|
||||
"""
|
||||
if not self.regression_models_dir.exists():
|
||||
raise FileNotFoundError(f"回归模型目录不存在: {self.regression_models_dir}")
|
||||
|
||||
csv_files = list(self.regression_models_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
raise FileNotFoundError(f"在目录 {self.regression_models_dir} 中未找到CSV文件")
|
||||
|
||||
self.logger.info(f"找到 {len(csv_files)} 个CSV文件")
|
||||
|
||||
for csv_file in csv_files:
|
||||
try:
|
||||
# 跳过all_regression_results.csv,处理单个参数的结果文件
|
||||
if csv_file.name.lower() == "all_regression_results.csv":
|
||||
self.logger.info(f"跳过汇总文件: {csv_file.name}")
|
||||
continue
|
||||
|
||||
# 从文件名推断参数名(移除_regression_results后缀)
|
||||
param_name = csv_file.stem.replace('_regression_results', '').replace('_results', '')
|
||||
self.logger.info(f"加载回归模型文件: {csv_file.name} -> 参数: {param_name}")
|
||||
|
||||
# 读取CSV文件
|
||||
df = pd.read_csv(csv_file)
|
||||
|
||||
# 检查必需的列并进行列名映射
|
||||
# 实际CSV列名: y_variable, x_variable, equation, r_squared
|
||||
# 需要的列名: index_formula, inversion_formula, r_squared
|
||||
|
||||
# 列名映射
|
||||
column_mapping = {}
|
||||
|
||||
# 检查r_squared(可能为R2, r2, R_squared等)
|
||||
r_squared_cols = ['r_squared', 'R_squared', 'R2', 'r2', 'R_squared', 'R²']
|
||||
found_r2 = None
|
||||
for col in r_squared_cols:
|
||||
if col in df.columns:
|
||||
found_r2 = col
|
||||
break
|
||||
if found_r2 and found_r2 != 'r_squared':
|
||||
column_mapping[found_r2] = 'r_squared'
|
||||
|
||||
# 检查equation(反演公式)
|
||||
if 'equation' in df.columns:
|
||||
column_mapping['equation'] = 'inversion_formula'
|
||||
elif 'inversion_formula' not in df.columns:
|
||||
self.logger.warning(f"文件 {csv_file.name} 缺少反演公式列(equation)")
|
||||
continue
|
||||
|
||||
# 检查x_variable(指数名称,需要转换为index_formula)
|
||||
if 'x_variable' in df.columns:
|
||||
# 需要从formula_csv中查找对应的公式
|
||||
if self.formula_df is not None:
|
||||
df['index_formula'] = df['x_variable'].apply(
|
||||
lambda x: self._lookup_formula_from_name(x)
|
||||
)
|
||||
else:
|
||||
# 如果没有formula_csv,直接使用x_variable作为index_formula
|
||||
column_mapping['x_variable'] = 'index_formula'
|
||||
elif 'index_formula' not in df.columns:
|
||||
self.logger.warning(f"文件 {csv_file.name} 缺少指数名称列(x_variable)")
|
||||
continue
|
||||
|
||||
# 应用列名映射
|
||||
if column_mapping:
|
||||
df = df.rename(columns=column_mapping)
|
||||
|
||||
# 验证必需的列
|
||||
required_columns = ['index_formula', 'inversion_formula', 'r_squared']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
self.logger.warning(f"文件 {csv_file.name} 缺少必需列: {missing_columns}")
|
||||
continue
|
||||
|
||||
self.regression_models[param_name] = df
|
||||
self.logger.info(f"成功加载参数 {param_name} 的 {len(df)} 个回归模型")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载文件 {csv_file.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
if not self.regression_models:
|
||||
raise ValueError("未成功加载任何回归模型")
|
||||
|
||||
return self.regression_models
|
||||
|
||||
def _lookup_formula_from_name(self, formula_name: str) -> str:
|
||||
"""
|
||||
从公式CSV中根据公式名称查找对应的公式
|
||||
|
||||
Args:
|
||||
formula_name: 公式名称(如 BGA_Am09KBBI)
|
||||
|
||||
Returns:
|
||||
公式字符串
|
||||
"""
|
||||
if self.formula_df is None:
|
||||
return formula_name
|
||||
|
||||
# 查找匹配的公式名称
|
||||
# 支持多种列名:Formula_Name, formula_name, name等
|
||||
name_cols = ['Formula_Name', 'formula_name', 'Name', 'name']
|
||||
name_col = None
|
||||
for col in name_cols:
|
||||
if col in self.formula_df.columns:
|
||||
name_col = col
|
||||
break
|
||||
|
||||
if name_col is None:
|
||||
self.logger.warning(f"公式CSV中未找到公式名称列,使用默认列名")
|
||||
name_col = self.formula_df.columns[0] if len(self.formula_df.columns) > 0 else None
|
||||
|
||||
# 查找公式列
|
||||
formula_cols = ['Formula', 'formula', 'Expression', 'expression']
|
||||
formula_col = None
|
||||
for col in formula_cols:
|
||||
if col in self.formula_df.columns:
|
||||
formula_col = col
|
||||
break
|
||||
|
||||
if formula_col is None and len(self.formula_df.columns) > 2:
|
||||
formula_col = self.formula_df.columns[2] # 通常第3列是公式
|
||||
|
||||
if name_col and formula_col:
|
||||
match = self.formula_df[self.formula_df[name_col] == formula_name]
|
||||
if not match.empty:
|
||||
return match[formula_col].iloc[0]
|
||||
|
||||
# 如果未找到,返回原始名称
|
||||
return formula_name
|
||||
|
||||
def select_best_models(self) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
为每个水质参数选择r_squared最高的回归模型
|
||||
|
||||
Returns:
|
||||
Dict[str, pd.Series]: 参数名 -> 最佳模型信息
|
||||
"""
|
||||
if not self.regression_models:
|
||||
self.load_regression_models()
|
||||
|
||||
for param_name, models_df in self.regression_models.items():
|
||||
# 根据r_squared排序,选择最高的
|
||||
best_model = models_df.loc[models_df['r_squared'].idxmax()]
|
||||
self.best_models[param_name] = best_model
|
||||
|
||||
self.logger.info(f"参数 {param_name} 最佳模型:")
|
||||
self.logger.info(f" 指数公式: {best_model['index_formula']}")
|
||||
self.logger.info(f" 反演公式: {best_model['inversion_formula']}")
|
||||
self.logger.info(f" R²: {best_model['r_squared']:.4f}")
|
||||
|
||||
return self.best_models
|
||||
|
||||
def calculate_spectral_index(self,
|
||||
spectral_data: pd.DataFrame,
|
||||
index_formula: str) -> pd.Series:
|
||||
"""
|
||||
根据指数公式计算光谱指数值
|
||||
|
||||
Args:
|
||||
spectral_data: 光谱数据,列名为波长
|
||||
index_formula: 指数公式字符串
|
||||
|
||||
Returns:
|
||||
pd.Series: 计算得到的指数值
|
||||
"""
|
||||
try:
|
||||
# 检查公式是否为水质指数计算器中的标准函数
|
||||
if hasattr(self.index_calculator, index_formula):
|
||||
# 使用标准水质指数计算器
|
||||
calculator_func = getattr(self.index_calculator, index_formula)
|
||||
index_values = calculator_func(spectral_data)
|
||||
self.logger.debug(f"使用标准指数函数 {index_formula} 计算完成")
|
||||
return index_values
|
||||
|
||||
# 否则尝试解析自定义公式
|
||||
# 支持常见的指数公式格式,如 (R865-R644)/(R458+R529) 或 w681 - w665
|
||||
formula = index_formula.strip()
|
||||
|
||||
# 替换波长变量为对应的列值
|
||||
# 查找所有R或w开头的波长变量(支持 R681, w681, R_681 等格式)
|
||||
wavelength_pattern = r'[Rw](\d+)'
|
||||
wavelengths = re.findall(wavelength_pattern, formula)
|
||||
|
||||
# 构建可执行的表达式
|
||||
expression = formula
|
||||
for wl in set(wavelengths):
|
||||
wl_int = int(wl)
|
||||
# 找到最接近的波长列
|
||||
closest_col = self.index_calculator.find_closest_wavelength(
|
||||
spectral_data.columns.tolist(), wl_int
|
||||
)
|
||||
# 替换公式中的变量(支持 R数字 和 w数字 格式)
|
||||
expression = expression.replace(f'R{wl}', f'spectral_data["{closest_col}"]')
|
||||
expression = expression.replace(f'w{wl}', f'spectral_data["{closest_col}"]')
|
||||
|
||||
# 评估表达式
|
||||
index_values = eval(expression)
|
||||
self.logger.debug(f"使用自定义公式 {index_formula} 计算完成")
|
||||
|
||||
return pd.Series(index_values, index=spectral_data.index)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"计算光谱指数失败,公式: {index_formula}, 错误: {e}")
|
||||
# 返回NaN值
|
||||
return pd.Series(np.nan, index=spectral_data.index)
|
||||
|
||||
def apply_inversion_formula(self,
|
||||
index_values: pd.Series,
|
||||
inversion_formula: str) -> pd.Series:
|
||||
"""
|
||||
根据反演公式将指数值转换为水质参数值
|
||||
|
||||
Args:
|
||||
index_values: 光谱指数值
|
||||
inversion_formula: 反演公式字符串,如 "y = 11.27 + 0.82*x" 或 "11.27 + 0.82*x"
|
||||
支持的函数:ln(x)自然对数, log10(x)常用对数, exp(x)指数, sqrt(x)平方根
|
||||
|
||||
Returns:
|
||||
pd.Series: 水质参数预测值
|
||||
"""
|
||||
try:
|
||||
# 处理公式中的常见数学符号
|
||||
formula = inversion_formula.strip()
|
||||
|
||||
# 去除 "y = " 或 "y=" 前缀
|
||||
if formula.lower().startswith('y='):
|
||||
formula = formula[2:].strip()
|
||||
elif formula.lower().startswith('y = '):
|
||||
formula = formula[4:].strip()
|
||||
|
||||
# 替换数学符号
|
||||
formula = formula.replace('^', '**') # 幂运算
|
||||
|
||||
# 定义数学函数环境
|
||||
import math
|
||||
math_env = {
|
||||
'ln': math.log, # 自然对数
|
||||
'log': math.log, # 自然对数(别名)
|
||||
'log10': math.log10, # 常用对数
|
||||
'exp': math.exp, # 指数函数
|
||||
'sqrt': math.sqrt, # 平方根
|
||||
'abs': abs, # 绝对值
|
||||
'pow': pow, # 幂函数
|
||||
}
|
||||
|
||||
self.logger.debug(f"处理后的反演公式: {formula}")
|
||||
|
||||
# 评估每个指数值
|
||||
predicted_values = []
|
||||
for x_val in index_values:
|
||||
try:
|
||||
if pd.isna(x_val):
|
||||
predicted_values.append(np.nan)
|
||||
else:
|
||||
# 评估公式,将 x 作为局部变量传递,同时提供数学函数
|
||||
local_vars = {'x': x_val}
|
||||
local_vars.update(math_env)
|
||||
result = eval(formula, {"__builtins__": {}}, local_vars)
|
||||
|
||||
# 如果是复数,取实部
|
||||
if isinstance(result, complex):
|
||||
result = result.real
|
||||
|
||||
predicted_values.append(result)
|
||||
except Exception as eval_err:
|
||||
self.logger.debug(f"评估公式失败,x={x_val}: {eval_err}")
|
||||
predicted_values.append(np.nan)
|
||||
|
||||
self.logger.debug(f"使用反演公式 {inversion_formula} 计算完成")
|
||||
return pd.Series(predicted_values, index=index_values.index)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"应用反演公式失败,公式: {inversion_formula}, 错误: {e}")
|
||||
return pd.Series(np.nan, index=index_values.index)
|
||||
|
||||
def predict_single_parameter(self,
|
||||
spectral_data: pd.DataFrame,
|
||||
coordinate_data: pd.DataFrame,
|
||||
param_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
预测单个水质参数
|
||||
|
||||
Args:
|
||||
spectral_data: 光谱数据(不包含坐标列)
|
||||
coordinate_data: 坐标数据(包含经纬度或投影坐标)
|
||||
param_name: 水质参数名
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 包含坐标和预测值的结果
|
||||
"""
|
||||
if param_name not in self.best_models:
|
||||
raise ValueError(f"未找到参数 {param_name} 的最佳模型")
|
||||
|
||||
best_model = self.best_models[param_name]
|
||||
index_formula = best_model['index_formula']
|
||||
inversion_formula = best_model['inversion_formula']
|
||||
|
||||
self.logger.info(f"开始预测参数: {param_name}")
|
||||
|
||||
# 步骤1: 计算光谱指数
|
||||
self.logger.info("步骤1: 计算光谱指数")
|
||||
index_values = self.calculate_spectral_index(spectral_data, index_formula)
|
||||
|
||||
# 步骤2: 应用反演公式
|
||||
self.logger.info("步骤2: 应用反演公式")
|
||||
predicted_values = self.apply_inversion_formula(index_values, inversion_formula)
|
||||
|
||||
# 步骤3: 组合结果
|
||||
result_df = coordinate_data.copy()
|
||||
result_df[f'{param_name}_predicted'] = predicted_values
|
||||
result_df[f'{param_name}_index'] = index_values
|
||||
|
||||
# 添加元数据
|
||||
result_df.attrs = {
|
||||
'parameter': param_name,
|
||||
'index_formula': index_formula,
|
||||
'inversion_formula': inversion_formula,
|
||||
'r_squared': best_model['r_squared'],
|
||||
'prediction_time': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
valid_predictions = result_df[f'{param_name}_predicted'].notna().sum()
|
||||
total_points = len(result_df)
|
||||
self.logger.info(f"参数 {param_name} 预测完成: {valid_predictions}/{total_points} 个有效预测")
|
||||
|
||||
return result_df
|
||||
|
||||
def predict_all_parameters(self,
|
||||
spectral_data: pd.DataFrame,
|
||||
coordinate_data: pd.DataFrame) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
批量预测所有水质参数
|
||||
|
||||
Args:
|
||||
spectral_data: 光谱数据(不包含坐标列)
|
||||
coordinate_data: 坐标数据(包含经纬度或投影坐标)
|
||||
|
||||
Returns:
|
||||
Dict[str, pd.DataFrame]: 参数名 -> 预测结果DataFrame
|
||||
"""
|
||||
if not self.best_models:
|
||||
self.select_best_models()
|
||||
|
||||
prediction_results = {}
|
||||
|
||||
for param_name in self.best_models.keys():
|
||||
try:
|
||||
self.logger.info(f"正在预测参数: {param_name}")
|
||||
result_df = self.predict_single_parameter(
|
||||
spectral_data, coordinate_data, param_name
|
||||
)
|
||||
prediction_results[param_name] = result_df
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"预测参数 {param_name} 失败: {e}")
|
||||
continue
|
||||
|
||||
self.logger.info(f"批量预测完成,成功预测 {len(prediction_results)} 个参数")
|
||||
return prediction_results
|
||||
|
||||
def _convert_complex_to_real(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
将DataFrame中的复数转换为实数(取实部)
|
||||
|
||||
Args:
|
||||
df: 输入DataFrame
|
||||
|
||||
Returns:
|
||||
处理后的DataFrame
|
||||
"""
|
||||
df_clean = df.copy()
|
||||
for col in df_clean.columns:
|
||||
if df_clean[col].dtype == 'object' or 'complex' in str(df_clean[col].dtype):
|
||||
# 检查是否包含复数
|
||||
try:
|
||||
df_clean[col] = df_clean[col].apply(
|
||||
lambda x: x.real if isinstance(x, complex) else x
|
||||
)
|
||||
except:
|
||||
pass
|
||||
return df_clean
|
||||
|
||||
def save_prediction_results(self,
|
||||
prediction_results: Dict[str, pd.DataFrame],
|
||||
filename_prefix: str = "custom_regression_prediction") -> Dict[str, str]:
|
||||
"""
|
||||
保存预测结果为CSV文件
|
||||
|
||||
Args:
|
||||
prediction_results: 预测结果字典
|
||||
filename_prefix: 文件名前缀
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 参数名 -> 保存的文件路径
|
||||
"""
|
||||
saved_files = {}
|
||||
|
||||
|
||||
for param_name, result_df in prediction_results.items():
|
||||
try:
|
||||
# 构建文件名
|
||||
filename = f"{filename_prefix}_{param_name}.csv"
|
||||
filepath = self.output_dir / filename
|
||||
|
||||
# 处理复数类型,转换为实数
|
||||
result_df_clean = self._convert_complex_to_real(result_df)
|
||||
|
||||
# 保存CSV文件
|
||||
result_df_clean.to_csv(filepath, index=False, encoding='utf-8-sig')
|
||||
saved_files[param_name] = str(filepath)
|
||||
|
||||
self.logger.info(f"参数 {param_name} 预测结果已保存: {filepath}")
|
||||
|
||||
# 打印统计信息
|
||||
predicted_col = f'{param_name}_predicted'
|
||||
if predicted_col in result_df.columns:
|
||||
valid_count = result_df[predicted_col].notna().sum()
|
||||
mean_value = result_df[predicted_col].mean()
|
||||
std_value = result_df[predicted_col].std()
|
||||
|
||||
self.logger.info(f" 有效预测: {valid_count} 个")
|
||||
self.logger.info(f" 平均值: {mean_value:.4f}")
|
||||
self.logger.info(f" 标准差: {std_value:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存参数 {param_name} 结果失败: {e}")
|
||||
continue
|
||||
|
||||
return saved_files
|
||||
|
||||
def run_batch_prediction(self,
|
||||
input_csv_path: str,
|
||||
coordinate_columns: List[str] = None,
|
||||
filename_prefix: str = "custom_regression_prediction") -> Dict[str, str]:
|
||||
"""
|
||||
运行完整的批量预测流程
|
||||
|
||||
Args:
|
||||
input_csv_path: 输入的光谱采样CSV文件路径
|
||||
coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']
|
||||
filename_prefix: 输出文件名前缀
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 参数名 -> 保存的文件路径
|
||||
"""
|
||||
if coordinate_columns is None:
|
||||
coordinate_columns = ['longitude', 'latitude']
|
||||
|
||||
self.logger.info("="*50)
|
||||
self.logger.info("开始自定义回归批量预测")
|
||||
self.logger.info("="*50)
|
||||
|
||||
# 1. 加载回归模型
|
||||
self.logger.info("步骤1: 加载回归模型")
|
||||
self.load_regression_models()
|
||||
|
||||
# 2. 选择最佳模型
|
||||
self.logger.info("步骤2: 选择最佳模型")
|
||||
self.select_best_models()
|
||||
|
||||
# 3. 读取输入数据
|
||||
self.logger.info("步骤3: 读取输入数据")
|
||||
input_df = pd.read_csv(input_csv_path)
|
||||
self.logger.info(f"读取输入数据: {len(input_df)} 行, {len(input_df.columns)} 列")
|
||||
|
||||
# 4. 分离坐标和光谱数据
|
||||
self.logger.info("步骤4: 分离坐标和光谱数据")
|
||||
|
||||
# 检查坐标列是否存在
|
||||
missing_coords = [col for col in coordinate_columns if col not in input_df.columns]
|
||||
if missing_coords:
|
||||
# 尝试自动识别坐标列
|
||||
potential_coord_cols = []
|
||||
for col in input_df.columns:
|
||||
if any(coord_name in col.lower() for coord_name in ['lon', 'lat', 'x', 'y']):
|
||||
potential_coord_cols.append(col)
|
||||
|
||||
if len(potential_coord_cols) >= 2:
|
||||
coordinate_columns = potential_coord_cols[:2]
|
||||
self.logger.info(f"自动识别坐标列: {coordinate_columns}")
|
||||
else:
|
||||
raise ValueError(f"未找到坐标列: {missing_coords}")
|
||||
|
||||
coordinate_data = input_df[coordinate_columns].copy()
|
||||
spectral_data = input_df.drop(columns=coordinate_columns)
|
||||
|
||||
self.logger.info(f"坐标数据: {len(coordinate_data)} 行, {len(coordinate_data.columns)} 列")
|
||||
self.logger.info(f"光谱数据: {len(spectral_data)} 行, {len(spectral_data.columns)} 列")
|
||||
|
||||
# 5. 批量预测
|
||||
self.logger.info("步骤5: 执行批量预测")
|
||||
prediction_results = self.predict_all_parameters(spectral_data, coordinate_data)
|
||||
|
||||
# 6. 保存结果
|
||||
self.logger.info("步骤6: 保存预测结果")
|
||||
saved_files = self.save_prediction_results(prediction_results, filename_prefix)
|
||||
|
||||
self.logger.info("="*50)
|
||||
self.logger.info(f"批量预测完成!共生成 {len(saved_files)} 个结果文件")
|
||||
self.logger.info("="*50)
|
||||
|
||||
return saved_files
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,用于命令行调用"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='自定义回归预测模块')
|
||||
parser.add_argument('--input_csv', required=True, help='输入的光谱采样CSV文件路径')
|
||||
parser.add_argument('--models_dir', default='9_Custom_Regression_Modeling',
|
||||
help='回归模型CSV文件目录')
|
||||
parser.add_argument('--output_dir', default='prediction_results',
|
||||
help='预测结果输出目录')
|
||||
parser.add_argument('--coord_cols', nargs='+', default=['longitude', 'latitude'],
|
||||
help='坐标列名')
|
||||
parser.add_argument('--prefix', default='custom_regression_prediction',
|
||||
help='输出文件名前缀')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建预测器
|
||||
predictor = CustomRegressionPredictor(
|
||||
regression_models_dir=args.models_dir,
|
||||
output_dir=args.output_dir
|
||||
)
|
||||
|
||||
# 运行预测
|
||||
saved_files = predictor.run_batch_prediction(
|
||||
input_csv_path=args.input_csv,
|
||||
coordinate_columns=args.coord_cols,
|
||||
filename_prefix=args.prefix
|
||||
)
|
||||
|
||||
print("\n预测结果文件:")
|
||||
for param_name, filepath in saved_files.items():
|
||||
print(f" {param_name}: {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -609,6 +609,11 @@ class WaterQualityScatterBatch:
|
||||
split_method = best_row['划分方法']
|
||||
preprocess_method = best_row['预处理方法']
|
||||
model_name = best_row['建模方法']
|
||||
|
||||
# 处理 nan/NaN/None 值,转换为 "None" 字符串
|
||||
if pd.isna(preprocess_method) or str(preprocess_method).lower() in ['nan', 'none', '']:
|
||||
preprocess_method = "None"
|
||||
|
||||
best_combination = f"{split_method}_{preprocess_method}_{model_name}"
|
||||
else:
|
||||
# 简化结果文件格式(英文列名)
|
||||
@ -617,11 +622,15 @@ class WaterQualityScatterBatch:
|
||||
parts = best_combination.split('_')
|
||||
if len(parts) < 3:
|
||||
raise ValueError(f"无效的模型组合名称格式: {best_combination}")
|
||||
|
||||
|
||||
split_method = parts[0]
|
||||
preprocess_method = parts[1]
|
||||
model_name = '_'.join(parts[2:])
|
||||
|
||||
# 处理 nan/NaN/None 值,转换为 "None" 字符串
|
||||
if pd.isna(preprocess_method) or str(preprocess_method).lower() in ['nan', 'none', '']:
|
||||
preprocess_method = "None"
|
||||
|
||||
print(f"最佳模型组合: {best_combination}")
|
||||
print(f" 划分方法: {split_method}")
|
||||
print(f" 预处理方法: {preprocess_method}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -58,6 +58,8 @@ from src.core.glint_removal.Hedley import Hedley
|
||||
from src.core.glint_removal.SUGAR import SUGAR, correction_iterative
|
||||
from src.utils.water_index import WaterQualityIndexCalculator
|
||||
from src.core.modeling.regression import SingleVariableRegressionAnalysis
|
||||
# 导入新的自定义回归预测模块
|
||||
from src.core.prediction.custom_regression_prediction import CustomRegressionPredictor
|
||||
# 导入hdr文件处理函数
|
||||
try:
|
||||
from src.utils.util import write_fields_to_hdrfile, get_hdr_file_path, find_band_number
|
||||
@ -106,9 +108,9 @@ class WaterQualityInversionPipeline:
|
||||
self.processed_data_dir = self.work_dir / "4_processed_data"
|
||||
self.training_spectra_dir = self.work_dir / "5_training_spectra"
|
||||
self.indices_dir = self.work_dir / "6_water_quality_indices"
|
||||
self.models_dir = self.work_dir / "7_models"
|
||||
self.non_empirical_models_dir = self.work_dir / "8_non_empirical_models"
|
||||
self.custom_regression_dir = self.work_dir / "9_custom_regression"
|
||||
self.models_dir = self.work_dir / "7_Supervised_Model_Training"
|
||||
self.non_empirical_models_dir = self.work_dir / "8_Regression_Modeling"
|
||||
self.custom_regression_dir = self.work_dir / "9_Custom_Regression_Modeling"
|
||||
self.sampling_dir = self.work_dir / "10_sampling"
|
||||
self.prediction_dir = self.work_dir / "11_12_13_predictions"
|
||||
self.visualization_dir = self.work_dir / "14_visualization"
|
||||
@ -3379,10 +3381,10 @@ class WaterQualityInversionPipeline:
|
||||
else:
|
||||
# 如果output_dir为空,使用工作目录
|
||||
if hasattr(self, 'work_dir') and self.work_dir is not None:
|
||||
non_empirical_dir = Path(self.work_dir) / "8_non_empirical_models"
|
||||
non_empirical_dir = Path(self.work_dir) / "8_Regression_Modeling"
|
||||
else:
|
||||
# 如果没有工作目录,使用当前目录
|
||||
non_empirical_dir = Path.cwd() / "8_non_empirical_models"
|
||||
non_empirical_dir = Path.cwd() / "8_Regression_Modeling"
|
||||
non_empirical_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 设置默认参数
|
||||
@ -3592,9 +3594,17 @@ class WaterQualityInversionPipeline:
|
||||
|
||||
# 应用预处理 - 使用spectral_Preprocessing模块
|
||||
from src.preprocessing.spectral_Preprocessing import Preprocessing
|
||||
|
||||
# 调用预处理函数
|
||||
processed_spectral = Preprocessing(preprocess_method, spectral_data)
|
||||
|
||||
# 为SS预处理提供scaler保存路径,保存在工作目录的7_Supervised_Model_Training中
|
||||
save_path = None
|
||||
if preprocess_method == 'SS':
|
||||
models_dir = output_dir.parent.parent / "7_Supervised_Model_Training" # 向上两级到工作目录
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_path = str(models_dir / "scaler_params.pkl")
|
||||
print(f"SS预处理: scaler模型将保存到 {save_path}")
|
||||
|
||||
# 调用预处理函数(为SS方法传递save_path)
|
||||
processed_spectral = Preprocessing(preprocess_method, spectral_data, save_path=save_path)
|
||||
|
||||
# 重新组合数据
|
||||
if isinstance(processed_spectral, pd.DataFrame):
|
||||
@ -3712,7 +3722,7 @@ class WaterQualityInversionPipeline:
|
||||
if non_empirical_models_dir is not None:
|
||||
final_models_dir = non_empirical_models_dir
|
||||
else:
|
||||
default_models_dir = str(self.work_dir / "8_non_empirical_models")
|
||||
default_models_dir = str(self.work_dir / "8_Regression_Modeling")
|
||||
if Path(default_models_dir).exists():
|
||||
final_models_dir = default_models_dir
|
||||
else:
|
||||
@ -3812,28 +3822,31 @@ class WaterQualityInversionPipeline:
|
||||
raise
|
||||
|
||||
def step8_75_predict_with_custom_regression(self, sampling_csv_path: str,
|
||||
formula_csv_file: str,
|
||||
custom_regression_dir: Optional[str] = None,
|
||||
formula_names: Optional[List[str]] = None,
|
||||
prediction_column: str = 'prediction',
|
||||
formula_csv_path: Optional[str] = None,
|
||||
coordinate_columns: Optional[List[str]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
filename_prefix: str = "custom_regression_prediction",
|
||||
enabled: bool = True,
|
||||
skip_dependency_check: bool = False) -> Dict[str, str]:
|
||||
"""
|
||||
步骤8.75: 使用自定义回归模型进行参数预测
|
||||
|
||||
使用步骤6.75的自定义回归结果中的all_regression_results.csv文件,
|
||||
选择每个y_variable中r_squared最高的equation,
|
||||
使用采样点光谱数据计算水质指数,然后进行预测
|
||||
使用新的CustomRegressionPredictor模块,基于9_Custom_Regression_Modeling文件夹中的CSV,
|
||||
根据r_squared选择最佳模型,批量预测水质参数
|
||||
|
||||
Args:
|
||||
sampling_csv_path: 采样点光谱数据CSV路径(来自步骤7)
|
||||
formula_csv_file: 公式CSV文件路径,包含水质指数计算公式
|
||||
custom_regression_dir: 自定义回归结果目录(如果为None,使用步骤6.75的结果)
|
||||
formula_names: 要计算的公式名称列表,如果为None则计算所有公式
|
||||
prediction_column: 预测结果列名
|
||||
custom_regression_dir: 自定义回归模型目录(9_Custom_Regression_Modeling)
|
||||
formula_csv_path: 公式CSV文件路径,用于查找index_formula
|
||||
coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']或自动识别
|
||||
output_dir: 输出目录,默认为prediction_dir
|
||||
filename_prefix: 输出文件名前缀
|
||||
enabled: 是否启用
|
||||
skip_dependency_check: 是否跳过依赖检查
|
||||
|
||||
Returns:
|
||||
预测结果文件路径字典(键为y_variable名)
|
||||
预测结果文件路径字典(键为参数名)
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("步骤8.75: 使用自定义回归模型进行参数预测")
|
||||
@ -3848,135 +3861,54 @@ class WaterQualityInversionPipeline:
|
||||
self._record_step_time("步骤8.75: 自定义回归模型预测", step_start_time, step_end_time, status="skipped")
|
||||
return {}
|
||||
|
||||
# 检查公式CSV文件是否存在
|
||||
if not Path(formula_csv_file).exists():
|
||||
raise FileNotFoundError(f"公式CSV文件不存在: {formula_csv_file}")
|
||||
# 检查采样点CSV文件是否存在
|
||||
if not Path(sampling_csv_path).exists():
|
||||
raise FileNotFoundError(f"采样点CSV文件不存在: {sampling_csv_path}")
|
||||
|
||||
# 确定自定义回归结果目录
|
||||
# 确定自定义回归模型目录
|
||||
if custom_regression_dir is not None:
|
||||
final_regression_dir = custom_regression_dir
|
||||
else:
|
||||
default_regression_dir = str(self.custom_regression_dir)
|
||||
if Path(default_regression_dir).exists():
|
||||
final_regression_dir = default_regression_dir
|
||||
else:
|
||||
final_regression_dir = str(self.custom_regression_dir)
|
||||
if not Path(final_regression_dir).exists():
|
||||
if skip_dependency_check:
|
||||
raise ValueError("必须提供custom_regression_dir参数才能独立运行步骤8.75")
|
||||
else:
|
||||
raise ValueError("请先执行步骤6.75: 自定义回归分析,或提供custom_regression_dir参数")
|
||||
|
||||
# 读取all_regression_results.csv文件
|
||||
regression_results_path = Path(final_regression_dir) / "all_regression_results.csv"
|
||||
if not regression_results_path.exists():
|
||||
raise FileNotFoundError(f"未找到自定义回归结果文件: {regression_results_path}")
|
||||
|
||||
# 读取回归结果
|
||||
regression_df = pd.read_csv(regression_results_path)
|
||||
# 确定输出目录
|
||||
if output_dir is None:
|
||||
prediction_output_dir = str(self.prediction_dir)
|
||||
else:
|
||||
prediction_output_dir = output_dir
|
||||
|
||||
# 使用band_math.py计算水质指数
|
||||
print("正在使用采样点光谱数据计算水质指数...")
|
||||
from src.utils.band_math import BandMathCalculator
|
||||
|
||||
# 创建计算器实例
|
||||
calculator = BandMathCalculator(sampling_csv_path)
|
||||
|
||||
# 计算所有公式
|
||||
indices_df = calculator.process_formulas_from_csv(
|
||||
formula_csv_file,
|
||||
formula_names=formula_names,
|
||||
output_file=str(self.prediction_dir / "water_quality_indices.csv")
|
||||
# 创建CustomRegressionPredictor实例
|
||||
predictor = CustomRegressionPredictor(
|
||||
regression_models_dir=final_regression_dir,
|
||||
formula_csv_path=formula_csv_path,
|
||||
output_dir=prediction_output_dir
|
||||
)
|
||||
|
||||
if indices_df is None:
|
||||
raise ValueError("水质指数计算失败")
|
||||
# 运行批量预测
|
||||
print(f"开始使用自定义回归模块进行批量预测...")
|
||||
print(f" 采样点数据: {sampling_csv_path}")
|
||||
print(f" 回归模型目录: {final_regression_dir}")
|
||||
print(f" 输出目录: {prediction_output_dir}")
|
||||
|
||||
# 读取采样点数据(包含坐标信息)
|
||||
sampling_df = pd.read_csv(sampling_csv_path)
|
||||
|
||||
# 获取所有唯一的y_variable
|
||||
y_variables = regression_df['y_variable'].unique()
|
||||
|
||||
prediction_files = {}
|
||||
|
||||
for y_var in y_variables:
|
||||
try:
|
||||
# 筛选当前y_variable的所有回归结果
|
||||
y_var_results = regression_df[regression_df['y_variable'] == y_var]
|
||||
|
||||
# 找到r_squared最高的记录
|
||||
best_result = y_var_results.loc[y_var_results['r_squared'].idxmax()]
|
||||
|
||||
# 解析equation
|
||||
equation = best_result['equation']
|
||||
x_variable = best_result['x_variable']
|
||||
|
||||
print(f"为 {y_var} 选择最佳方程: {equation} (R² = {best_result['r_squared']:.4f})")
|
||||
|
||||
# 检查x_variable是否在水质指数数据中存在
|
||||
if x_variable not in indices_df.columns:
|
||||
print(f"警告: x_variable '{x_variable}' 不在水质指数数据中,跳过 {y_var}")
|
||||
continue
|
||||
|
||||
# 合并采样点坐标和水质指数数据
|
||||
# 假设采样点数据和水质指数数据有相同的行数和顺序
|
||||
if len(sampling_df) != len(indices_df):
|
||||
print(f"警告: 采样点数据({len(sampling_df)}行)和水质指数数据({len(indices_df)}行)行数不一致")
|
||||
# 只取前min(len(sampling_df), len(indices_df))行
|
||||
min_rows = min(len(sampling_df), len(indices_df))
|
||||
merged_df = pd.concat([
|
||||
sampling_df.iloc[:min_rows].reset_index(drop=True),
|
||||
indices_df.iloc[:min_rows].reset_index(drop=True)
|
||||
], axis=1)
|
||||
else:
|
||||
merged_df = pd.concat([sampling_df, indices_df], axis=1)
|
||||
|
||||
# 应用回归方程进行预测
|
||||
# 使用eval函数安全地计算表达式
|
||||
try:
|
||||
# 创建局部命名空间,包含需要的变量和数学函数
|
||||
local_vars = {x_variable: merged_df[x_variable].values}
|
||||
# 添加数学函数到命名空间
|
||||
import math
|
||||
local_vars.update({
|
||||
'exp': math.exp,
|
||||
'log': math.log,
|
||||
'log10': math.log10,
|
||||
'sqrt': math.sqrt,
|
||||
'sin': math.sin,
|
||||
'cos': math.cos,
|
||||
'tan': math.tan,
|
||||
'pi': math.pi,
|
||||
'e': math.e
|
||||
})
|
||||
|
||||
# 替换方程中的变量名为实际值
|
||||
# 这里需要确保方程格式正确,例如: "a*x + b"
|
||||
prediction_values = eval(equation, {"__builtins__": {}}, local_vars)
|
||||
|
||||
# 创建预测结果DataFrame
|
||||
result_df = merged_df[['UTM_X', 'UTM_Y']].copy()
|
||||
result_df[prediction_column] = prediction_values
|
||||
|
||||
# 保存预测结果
|
||||
output_filename = f"custom_regression_{y_var}.csv"
|
||||
output_path = str(self.prediction_dir / output_filename)
|
||||
result_df.to_csv(output_path, index=False)
|
||||
|
||||
prediction_files[y_var] = output_path
|
||||
print(f"成功为 {y_var} 生成预测结果: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"应用方程 {equation} 进行预测时出错: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 y_variable {y_var} 时出错: {e}")
|
||||
continue
|
||||
saved_files = predictor.run_batch_prediction(
|
||||
input_csv_path=sampling_csv_path,
|
||||
coordinate_columns=coordinate_columns,
|
||||
filename_prefix=filename_prefix
|
||||
)
|
||||
|
||||
step_end_time = time.time()
|
||||
self._record_step_time("步骤8.75: 自定义回归模型预测", step_start_time, step_end_time)
|
||||
|
||||
return prediction_files
|
||||
print(f"自定义回归预测完成,生成 {len(saved_files)} 个预测文件:")
|
||||
for param_name, filepath in saved_files.items():
|
||||
print(f" {param_name}: {filepath}")
|
||||
|
||||
return saved_files
|
||||
|
||||
except Exception as e:
|
||||
step_end_time = time.time()
|
||||
@ -4113,6 +4045,14 @@ def main():
|
||||
'prediction_column': 'prediction',
|
||||
'enabled': True # 是否启用非经验模型预测
|
||||
},
|
||||
'step8_75': {
|
||||
'custom_regression_dir': None, # 自定义回归模型目录(None表示使用9_Custom_Regression_Modeling)
|
||||
'formula_csv_path': None, # 公式CSV文件路径,用于查找index_formula(如water_quality_formulas.csv)
|
||||
'coordinate_columns': None, # 坐标列名(None表示自动识别)
|
||||
'output_dir': None, # 输出目录(None表示使用prediction_dir)
|
||||
'filename_prefix': 'custom_regression_prediction', # 输出文件名前缀
|
||||
'enabled': True # 是否启用自定义回归预测
|
||||
},
|
||||
'step9': {
|
||||
'boundary_shp_path': r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp" ,
|
||||
'resolution': 30,
|
||||
|
||||
@ -660,7 +660,7 @@ class VisualizationWorkerThread(QThread):
|
||||
self.failed.emit("训练光谱 CSV 无效或不存在,请确认已选择步骤5输出的文件。")
|
||||
return
|
||||
if not models_dir or not Path(models_dir).is_dir():
|
||||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_models 下的参数子文件夹。")
|
||||
self.failed.emit("模型目录无效或不存在,请确认步骤6已生成 7_Supervised_Model_Training 下的参数子文件夹。")
|
||||
return
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
@ -672,12 +672,111 @@ class VisualizationWorkerThread(QThread):
|
||||
from src.postprocessing.visualization_reports import WaterQualityVisualization
|
||||
viz = WaterQualityVisualization(output_dir=str(wp / "14_visualization"))
|
||||
parts = []
|
||||
|
||||
# 获取训练数据CSV路径(多个图表类型共用)
|
||||
training_csv = wp / "5_training_spectra" / "training_spectra.csv"
|
||||
|
||||
# 生成散点图
|
||||
if self.extra.get("gen_scatter"):
|
||||
if training_csv.is_file():
|
||||
models_dir = wp / "7_Supervised_Model_Training"
|
||||
if models_dir.is_dir() and any(d.is_dir() for d in models_dir.iterdir()):
|
||||
from src.core.water_quality_inversion_pipeline_GUI import WaterQualityInversionPipeline
|
||||
pipeline = WaterQualityInversionPipeline(work_dir=str(wp))
|
||||
scatter_paths = pipeline.generate_model_scatter_plots(
|
||||
training_csv_path=str(training_csv),
|
||||
models_dir=str(models_dir),
|
||||
)
|
||||
count = len(scatter_paths) if scatter_paths else 0
|
||||
parts.append(f"散点图: {count} 个")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无模型目录)")
|
||||
else:
|
||||
parts.append("散点图: 跳过(无训练数据)")
|
||||
|
||||
# 生成光谱图
|
||||
if self.extra.get("gen_spectrum"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
# 推断水质参数列(光谱波段列之前的数值型列)
|
||||
wl_col = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl_col, str):
|
||||
idx = int(df.columns.get_loc(wl_col)) + 1
|
||||
else:
|
||||
idx = int(wl_col)
|
||||
param_cols = []
|
||||
if idx > 0 and idx < len(df.columns):
|
||||
param_cols = [
|
||||
c for c in df.columns[:idx]
|
||||
if df[c].dtype.kind in 'iuf' and df[c].notna().sum() > 0
|
||||
]
|
||||
if param_cols:
|
||||
# plot_spectrum_by_parameter 接受单个参数列,逐个调用
|
||||
spectrum_paths = []
|
||||
for param_col in param_cols:
|
||||
try:
|
||||
path = viz.plot_spectrum_by_parameter(
|
||||
csv_path=str(training_csv),
|
||||
parameter_column=param_col,
|
||||
wavelength_start_column=wl_col,
|
||||
n_groups=5,
|
||||
)
|
||||
if path:
|
||||
spectrum_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"生成光谱图失败 ({param_col}): {e}")
|
||||
count = len(spectrum_paths)
|
||||
parts.append(f"光谱图: {count} 个")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无可用参数列)")
|
||||
else:
|
||||
parts.append("光谱图: 跳过(无训练数据)")
|
||||
|
||||
# 生成统计图
|
||||
if self.extra.get("gen_boxplots"):
|
||||
if training_csv.is_file():
|
||||
import pandas as pd
|
||||
df = pd.read_csv(training_csv)
|
||||
# **只统计水质参数列(数值型),排除波长列**
|
||||
# 获取水质参数列(数值型且不是波长、不是坐标列)
|
||||
exclude_cols = ['longitude', 'latitude', 'lon', 'lat', 'x', 'y', 'coord', 'coordinate']
|
||||
param_cols = [
|
||||
c for c in df.select_dtypes(include=[np.number]).columns
|
||||
if not any(exc in c.lower() for exc in exclude_cols)
|
||||
]
|
||||
# 排除光谱波长列:找到波长开始位置,只取之前的数值列
|
||||
wl = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl, str):
|
||||
idx = int(df.columns.get_loc(wl)) + 1
|
||||
else:
|
||||
idx = int(wl)
|
||||
if 0 < idx < len(df.columns):
|
||||
meta_set = set(df.columns[:idx])
|
||||
param_cols = [c for c in param_cols if c in meta_set]
|
||||
|
||||
if param_cols:
|
||||
output_dict = viz.plot_statistical_charts(
|
||||
csv_path=str(training_csv),
|
||||
parameter_columns=param_cols,
|
||||
)
|
||||
# plot_statistical_charts 返回字典,统计值非空
|
||||
count = len([v for v in output_dict.values() if v]) if output_dict else 0
|
||||
parts.append(f"统计图: {count} 个")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无可用水质参数列)")
|
||||
else:
|
||||
parts.append("统计图: 跳过(无训练数据)")
|
||||
|
||||
# 生成掩膜/耀斑预览图
|
||||
if self.extra.get("gen_mask_glint"):
|
||||
preview_paths = viz.generate_glint_deglint_previews(
|
||||
work_dir=str(wp),
|
||||
output_subdir="glint_deglint_previews",
|
||||
)
|
||||
parts.append(f"掩膜/耀斑预览 {len(preview_paths) if preview_paths else 0} 个")
|
||||
parts.append(f"掩膜/耀斑预览: {len(preview_paths) if preview_paths else 0} 个")
|
||||
|
||||
# 生成采样点地图
|
||||
if self.extra.get("gen_sampling_map"):
|
||||
hyperspectral_files = []
|
||||
deglint_dir = wp / "3_deglint"
|
||||
@ -3103,35 +3202,37 @@ class Step8_75Panel(QWidget):
|
||||
)
|
||||
layout.addWidget(self.sampling_csv_file)
|
||||
|
||||
# 公式CSV文件选择
|
||||
# 自定义回归模型目录选择(9_Custom_Regression_Modeling)
|
||||
self.regression_models_dir = FileSelectWidget(
|
||||
"回归模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.regression_models_dir.label.setText("回归模型目录:")
|
||||
# 修改浏览按钮为选择目录
|
||||
self.regression_models_dir.browse_btn.clicked.disconnect()
|
||||
self.regression_models_dir.browse_btn.clicked.connect(self.browse_regression_models_dir)
|
||||
self.regression_models_dir.set_path("9_Custom_Regression_Modeling") # 设置默认值
|
||||
layout.addWidget(self.regression_models_dir)
|
||||
|
||||
# 公式CSV文件选择(用于查找index_formula)
|
||||
self.formula_csv_file = FileSelectWidget(
|
||||
"公式CSV文件:",
|
||||
"CSV Files (*.csv);;All Files (*.*)"
|
||||
)
|
||||
self.formula_csv_file.label.setText("公式CSV文件:")
|
||||
layout.addWidget(self.formula_csv_file)
|
||||
|
||||
# 模型目录选择
|
||||
self.models_dir_file = FileSelectWidget(
|
||||
"模型目录:",
|
||||
# 输出目录选择
|
||||
self.output_dir_widget = FileSelectWidget(
|
||||
"输出目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.models_dir_file.label.setText("模型目录:")
|
||||
self.output_dir_widget.label.setText("输出目录:")
|
||||
# 修改浏览按钮为选择目录
|
||||
self.models_dir_file.browse_btn.clicked.disconnect()
|
||||
self.models_dir_file.browse_btn.clicked.connect(self.browse_models_dir)
|
||||
layout.addWidget(self.models_dir_file)
|
||||
|
||||
# 参数设置
|
||||
params_group = QGroupBox("预测参数")
|
||||
params_layout = QFormLayout()
|
||||
|
||||
# 预测列名
|
||||
self.prediction_column = QLineEdit()
|
||||
self.prediction_column.setText("prediction")
|
||||
params_layout.addRow("预测列名:", self.prediction_column)
|
||||
|
||||
params_group.setLayout(params_layout)
|
||||
layout.addWidget(params_group)
|
||||
self.output_dir_widget.browse_btn.clicked.disconnect()
|
||||
self.output_dir_widget.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
self.output_dir_widget.line_edit.setPlaceholderText("留空使用默认prediction目录")
|
||||
layout.addWidget(self.output_dir_widget)
|
||||
|
||||
# 启用步骤
|
||||
self.enable_checkbox = QCheckBox("启用此步骤")
|
||||
@ -3162,45 +3263,59 @@ class Step8_75Panel(QWidget):
|
||||
layout.addStretch()
|
||||
self.setLayout(layout)
|
||||
|
||||
def browse_models_dir(self):
|
||||
"""浏览模型目录"""
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择模型目录", "")
|
||||
def browse_regression_models_dir(self):
|
||||
"""浏览回归模型目录"""
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择回归模型目录", "")
|
||||
if dir_path:
|
||||
self.models_dir_file.set_path(dir_path)
|
||||
self.regression_models_dir.set_path(dir_path)
|
||||
|
||||
def browse_output_dir(self):
|
||||
"""浏览输出目录"""
|
||||
dir_path = QFileDialog.getExistingDirectory(self, "选择输出目录", "")
|
||||
if dir_path:
|
||||
self.output_dir_widget.set_path(dir_path)
|
||||
|
||||
def get_config(self):
|
||||
"""获取配置"""
|
||||
config = {
|
||||
'prediction_column': self.prediction_column.text(),
|
||||
'enabled': self.enable_checkbox.isChecked()
|
||||
}
|
||||
|
||||
# 添加采样光谱CSV路径
|
||||
sampling_csv_path = self.sampling_csv_file.get_path()
|
||||
if sampling_csv_path:
|
||||
config['sampling_csv_path'] = sampling_csv_path
|
||||
|
||||
# 添加回归模型目录路径
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if regression_models_dir:
|
||||
config['custom_regression_dir'] = regression_models_dir
|
||||
|
||||
# 添加公式CSV文件路径
|
||||
formula_csv_path = self.formula_csv_file.get_path()
|
||||
if formula_csv_path:
|
||||
config['formula_csv_file'] = formula_csv_path
|
||||
# 添加模型目录路径
|
||||
models_dir = self.models_dir_file.get_path()
|
||||
if models_dir:
|
||||
config['custom_regression_dir'] = models_dir
|
||||
config['formula_csv_path'] = formula_csv_path
|
||||
|
||||
# 添加输出目录路径
|
||||
output_dir = self.output_dir_widget.get_path()
|
||||
if output_dir:
|
||||
config['output_dir'] = output_dir
|
||||
|
||||
return config
|
||||
|
||||
def set_config(self, config):
|
||||
"""设置配置"""
|
||||
if 'prediction_column' in config:
|
||||
self.prediction_column.setText(config['prediction_column'])
|
||||
|
||||
if 'sampling_csv_path' in config:
|
||||
self.sampling_csv_file.set_path(config['sampling_csv_path'])
|
||||
|
||||
if 'formula_csv_file' in config:
|
||||
self.formula_csv_file.set_path(config['formula_csv_file'])
|
||||
|
||||
if 'custom_regression_dir' in config:
|
||||
self.models_dir_file.set_path(config['custom_regression_dir'])
|
||||
self.regression_models_dir.set_path(config['custom_regression_dir'])
|
||||
|
||||
if 'formula_csv_path' in config:
|
||||
self.formula_csv_file.set_path(config['formula_csv_path'])
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir_widget.set_path(config['output_dir'])
|
||||
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
@ -3213,9 +3328,9 @@ class Step8_75Panel(QWidget):
|
||||
QMessageBox.warning(self, "输入错误", "请选择采样光谱CSV文件!")
|
||||
return
|
||||
|
||||
formula_csv_path = self.formula_csv_file.get_path()
|
||||
if not formula_csv_path:
|
||||
QMessageBox.warning(self, "输入错误", "请选择公式CSV文件!")
|
||||
regression_models_dir = self.regression_models_dir.get_path()
|
||||
if not regression_models_dir:
|
||||
QMessageBox.warning(self, "输入错误", "请选择回归模型目录!")
|
||||
return
|
||||
|
||||
# 获取配置
|
||||
@ -3323,8 +3438,7 @@ class ImageCategoryTree(QTreeWidget):
|
||||
("光谱分析", ["spectrum", "spectral", "band", "wavelength"], "📈"),
|
||||
("统计图表", ["boxplot", "histogram", "heatmap", "statistics", "stats"], "📉"),
|
||||
("处理结果", ["mask", "glint", "deglint", "preview", "overlay", "water_mask"], "🖼️"),
|
||||
("采样分析", ["sampling", "flight_path", "point_map", "trajectory"], "📍"),
|
||||
("其他图表", [], "📁"),
|
||||
("含量分布图", [], "📁"),
|
||||
]
|
||||
|
||||
def __init__(self, parent=None):
|
||||
@ -3376,7 +3490,7 @@ class ImageCategoryTree(QTreeWidget):
|
||||
|
||||
# 根据文件名关键词确定类别
|
||||
category = self._determine_category(file_path.name)
|
||||
category_item = self.category_items.get(category, self.category_items["其他图表"])
|
||||
category_item = self.category_items.get(category, self.category_items["含量分布图"])
|
||||
|
||||
# 创建图像项
|
||||
image_item = QTreeWidgetItem(category_item)
|
||||
@ -3394,7 +3508,7 @@ class ImageCategoryTree(QTreeWidget):
|
||||
if any(keyword in filename_lower for keyword in keywords):
|
||||
return category_name
|
||||
|
||||
return "其他图表"
|
||||
return "含量分布图"
|
||||
|
||||
def scan_directory(self, work_dir: str):
|
||||
"""扫描目录中的所有图像文件"""
|
||||
@ -3682,11 +3796,7 @@ class VisualizationPanel(QWidget):
|
||||
def _viz_set_busy(self, busy: bool):
|
||||
for w in (
|
||||
getattr(self, "gen_all_btn", None),
|
||||
getattr(self, "gen_scatter_btn", None),
|
||||
getattr(self, "gen_spectrum_btn", None),
|
||||
getattr(self, "gen_stats_btn", None),
|
||||
getattr(self, "gen_mask_glint_btn", None),
|
||||
getattr(self, "gen_sampling_map_btn", None),
|
||||
getattr(self, "scan_btn", None),
|
||||
):
|
||||
if w is not None:
|
||||
w.setEnabled(not busy)
|
||||
@ -3728,7 +3838,13 @@ class VisualizationPanel(QWidget):
|
||||
return [c for c in meta if pd.api.types.is_numeric_dtype(df[c])]
|
||||
|
||||
def _statistics_param_columns(self, df: pd.DataFrame) -> List[str]:
|
||||
"""统计图用的参数列;若存在光谱波段,则只统计波段前的字段。"""
|
||||
"""统计图用的参数列:**只统计水质参数列(数值型),排除波长列**。
|
||||
- 包括:数值型的水质参数(浓度、含量等)
|
||||
- 排除:光谱波长列(虽然也是数值型,但不是水质参数)
|
||||
- 排除:坐标列(UTM_X, UTM_Y, lat, lon等)
|
||||
若存在光谱波段,则只统计波段前的数值字段。
|
||||
"""
|
||||
# 选择数值类型列
|
||||
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
|
||||
wl = _viz_infer_wavelength_start_column(df)
|
||||
if isinstance(wl, str):
|
||||
@ -3737,6 +3853,7 @@ class VisualizationPanel(QWidget):
|
||||
idx = int(wl)
|
||||
coord_kw = ("utm", "lat", "lon")
|
||||
if 0 < idx < len(df.columns):
|
||||
# 只取波长开始之前的列(水质参数区域)
|
||||
meta_set = set(df.columns[:idx])
|
||||
return [
|
||||
col
|
||||
@ -3744,6 +3861,7 @@ class VisualizationPanel(QWidget):
|
||||
if col in meta_set and not any(x in str(col).lower() for x in coord_kw)
|
||||
]
|
||||
return [
|
||||
# 如果没有找到波长列,排除坐标相关列
|
||||
col
|
||||
for col in numeric_cols
|
||||
if not any(x in str(col).lower() for x in coord_kw + ("x", "y"))
|
||||
@ -3825,7 +3943,7 @@ class VisualizationPanel(QWidget):
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"提示",
|
||||
"未生成任何散点图。请确认 7_models 下已有各参数子目录及模型文件,"
|
||||
"未生成任何散点图。请确认 7_Supervised_Model_Training 下已有各参数子目录及模型文件,"
|
||||
"且训练 CSV 与建模时一致。",
|
||||
)
|
||||
elif t == "generate_all_selected":
|
||||
@ -3870,36 +3988,22 @@ class VisualizationPanel(QWidget):
|
||||
self.image_tree = ImageCategoryTree()
|
||||
self.image_tree.itemClicked.connect(self.on_tree_item_clicked)
|
||||
tree_layout.addWidget(self.image_tree)
|
||||
|
||||
# 生成按钮组
|
||||
gen_btn_layout = QHBoxLayout()
|
||||
self.gen_all_btn = QPushButton("🚀 生成全部")
|
||||
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
|
||||
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
|
||||
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
|
||||
gen_btn_layout.addWidget(self.gen_all_btn)
|
||||
|
||||
self.scan_btn = QPushButton("📁 扫描")
|
||||
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
|
||||
self.scan_btn.clicked.connect(self.scan_work_directory)
|
||||
gen_btn_layout.addWidget(self.scan_btn)
|
||||
|
||||
tree_layout.addLayout(gen_btn_layout)
|
||||
|
||||
tree_group.setLayout(tree_layout)
|
||||
left_layout.addWidget(tree_group, 1)
|
||||
|
||||
|
||||
# 可视化配置
|
||||
config_group = QGroupBox("可视化配置")
|
||||
config_layout = QVBoxLayout()
|
||||
|
||||
|
||||
self.gen_scatter = QCheckBox("模型评估散点图")
|
||||
self.gen_scatter.setChecked(True)
|
||||
config_layout.addWidget(self.gen_scatter)
|
||||
|
||||
|
||||
self.gen_spectrum = QCheckBox("光谱曲线图")
|
||||
self.gen_spectrum.setChecked(True)
|
||||
config_layout.addWidget(self.gen_spectrum)
|
||||
|
||||
|
||||
self.gen_boxplots = QCheckBox("统计图表")
|
||||
self.gen_boxplots.setChecked(True)
|
||||
config_layout.addWidget(self.gen_boxplots)
|
||||
@ -3912,6 +4016,27 @@ class VisualizationPanel(QWidget):
|
||||
self.gen_sampling_map.setChecked(True)
|
||||
config_layout.addWidget(self.gen_sampling_map)
|
||||
|
||||
# 添加分隔线
|
||||
config_layout.addSpacing(10)
|
||||
line = QFrame()
|
||||
line.setFrameShape(QFrame.HLine)
|
||||
line.setStyleSheet("color: #ddd;")
|
||||
config_layout.addWidget(line)
|
||||
config_layout.addSpacing(10)
|
||||
|
||||
# 生成全部按钮
|
||||
self.gen_all_btn = QPushButton("🚀 生成全部")
|
||||
self.gen_all_btn.setToolTip("生成所有类型的可视化图表")
|
||||
self.gen_all_btn.setStyleSheet("background-color: #4CAF50; color: white; font-weight: bold;")
|
||||
self.gen_all_btn.clicked.connect(self.generate_all_visualizations)
|
||||
config_layout.addWidget(self.gen_all_btn)
|
||||
|
||||
# 扫描按钮
|
||||
self.scan_btn = QPushButton("📁 扫描目录")
|
||||
self.scan_btn.setToolTip("扫描工作目录中的图像文件")
|
||||
self.scan_btn.clicked.connect(self.scan_work_directory)
|
||||
config_layout.addWidget(self.scan_btn)
|
||||
|
||||
config_group.setLayout(config_layout)
|
||||
left_layout.addWidget(config_group)
|
||||
|
||||
@ -3928,43 +4053,7 @@ class VisualizationPanel(QWidget):
|
||||
self.image_viewer = ImageViewerWidget()
|
||||
self.image_viewer.refresh_btn.clicked.connect(self.scan_work_directory)
|
||||
right_layout.addWidget(self.image_viewer, 1)
|
||||
|
||||
# 生成特定图表按钮组
|
||||
specific_group = QGroupBox("生成特定图表")
|
||||
specific_layout = QHBoxLayout()
|
||||
|
||||
self.gen_scatter_btn = QPushButton("📊 散点图")
|
||||
self.gen_scatter_btn.setToolTip(
|
||||
"基于工作目录下 5_training_spectra/training_spectra.csv 与 7_models 生成模型评估散点图"
|
||||
)
|
||||
self.gen_scatter_btn.clicked.connect(lambda: self.generate_chart('scatter'))
|
||||
specific_layout.addWidget(self.gen_scatter_btn)
|
||||
|
||||
self.gen_spectrum_btn = QPushButton("📈 光谱图")
|
||||
self.gen_spectrum_btn.setToolTip(
|
||||
"基于 5_training_spectra/training_spectra.csv,为每个数值型水质参数各生成一张光谱对比图(无需选择)"
|
||||
)
|
||||
self.gen_spectrum_btn.clicked.connect(lambda: self.generate_chart('spectrum'))
|
||||
specific_layout.addWidget(self.gen_spectrum_btn)
|
||||
|
||||
self.gen_stats_btn = QPushButton("📉 统计图")
|
||||
self.gen_stats_btn.setToolTip(
|
||||
"基于工作目录下 5_training_spectra/training_spectra.csv 生成箱线图、直方图与相关性热力图"
|
||||
)
|
||||
self.gen_stats_btn.clicked.connect(lambda: self.generate_chart('statistics'))
|
||||
specific_layout.addWidget(self.gen_stats_btn)
|
||||
|
||||
self.gen_mask_glint_btn = QPushButton("🖼️ 掩膜图")
|
||||
self.gen_mask_glint_btn.clicked.connect(lambda: self.generate_mask_glint_previews())
|
||||
specific_layout.addWidget(self.gen_mask_glint_btn)
|
||||
|
||||
self.gen_sampling_map_btn = QPushButton("📍 采样点图")
|
||||
self.gen_sampling_map_btn.clicked.connect(lambda: self.generate_sampling_point_map())
|
||||
specific_layout.addWidget(self.gen_sampling_map_btn)
|
||||
|
||||
specific_group.setLayout(specific_layout)
|
||||
right_layout.addWidget(specific_group)
|
||||
|
||||
right_panel.setLayout(right_layout)
|
||||
main_layout.addWidget(right_panel, 1)
|
||||
|
||||
@ -3999,6 +4088,9 @@ class VisualizationPanel(QWidget):
|
||||
print(f"扫描工作目录: {work_path}")
|
||||
self.image_tree.scan_directory(str(work_path))
|
||||
|
||||
# 设置三个预测步骤的默认输出路径
|
||||
self._setup_prediction_output_dirs(work_path)
|
||||
|
||||
# 如果有图像,自动选择第一个
|
||||
viz_dir = work_path / "14_visualization"
|
||||
if viz_dir.exists():
|
||||
@ -4006,6 +4098,41 @@ class VisualizationPanel(QWidget):
|
||||
if image_files:
|
||||
self.image_viewer.load_image(str(image_files[0]))
|
||||
|
||||
def _setup_prediction_output_dirs(self, work_path: Path):
|
||||
"""
|
||||
设置三个预测步骤的默认输出目录
|
||||
在11_12_13_predictions下创建三个子文件夹
|
||||
"""
|
||||
try:
|
||||
# 基础预测目录
|
||||
base_prediction_dir = work_path / "11_12_13_predictions"
|
||||
|
||||
# 三个子文件夹路径
|
||||
ml_dir = base_prediction_dir / "Machine_Learning_Prediction"
|
||||
reg_dir = base_prediction_dir / "Regression_Model_Prediction"
|
||||
custom_dir = base_prediction_dir / "Custom_Regression_Prediction"
|
||||
|
||||
# 创建目录(如果不存在)
|
||||
ml_dir.mkdir(parents=True, exist_ok=True)
|
||||
reg_dir.mkdir(parents=True, exist_ok=True)
|
||||
custom_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 设置Step8Panel(机器学习预测)的默认输出路径
|
||||
if hasattr(self, 'step8_panel') and hasattr(self.step8_panel, 'output_file'):
|
||||
self.step8_panel.output_file.set_path(str(ml_dir))
|
||||
|
||||
# 设置Step8_5Panel(回归模型预测)的默认输出路径
|
||||
if hasattr(self, 'step8_5_panel') and hasattr(self.step8_5_panel, 'output_file'):
|
||||
self.step8_5_panel.output_file.set_path(str(reg_dir))
|
||||
|
||||
# 设置Step8_75Panel(自定义回归预测)的默认输出路径
|
||||
if hasattr(self, 'step8_75_panel') and hasattr(self.step8_75_panel, 'output_dir_widget'):
|
||||
self.step8_75_panel.output_dir_widget.set_path(str(custom_dir))
|
||||
|
||||
print(f"预测输出目录已设置:\n ML: {ml_dir}\n Reg: {reg_dir}\n Custom: {custom_dir}")
|
||||
except Exception as e:
|
||||
print(f"设置预测输出目录失败: {e}")
|
||||
|
||||
def on_tree_item_clicked(self, item, column):
|
||||
"""目录树项点击事件"""
|
||||
data = item.data(0, Qt.UserRole)
|
||||
@ -4022,37 +4149,37 @@ class VisualizationPanel(QWidget):
|
||||
if not self.work_dir:
|
||||
QMessageBox.warning(self, "警告", "请先选择工作目录!")
|
||||
return
|
||||
|
||||
|
||||
work_path = Path(self.work_dir)
|
||||
if not work_path.exists():
|
||||
QMessageBox.warning(self, "警告", "工作目录不存在!")
|
||||
return
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self, "确认生成",
|
||||
"将按左侧勾选项在后台生成可视化(掩膜/耀斑预览、采样点图等),可能需要较长时间。\n是否继续?",
|
||||
QMessageBox.Yes | QMessageBox.No
|
||||
)
|
||||
|
||||
if reply != QMessageBox.Yes:
|
||||
return
|
||||
|
||||
if self.gen_scatter.isChecked():
|
||||
print("生成散点图...(占位,请用建模/可视化流程生成)")
|
||||
if self.gen_spectrum.isChecked():
|
||||
print("生成光谱图...(占位,请用下方「光谱图」按钮)")
|
||||
if self.gen_boxplots.isChecked():
|
||||
print("生成统计图...(占位,请用下方「统计图」按钮)")
|
||||
|
||||
if not self.gen_mask_glint.isChecked() and not self.gen_sampling_map.isChecked():
|
||||
# 检查是否有任何选项被勾选
|
||||
if not (self.gen_scatter.isChecked() or self.gen_spectrum.isChecked() or
|
||||
self.gen_boxplots.isChecked() or self.gen_mask_glint.isChecked() or
|
||||
self.gen_sampling_map.isChecked()):
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"提示",
|
||||
"请至少勾选「掩膜和耀斑缩略图」或「采样点地图」以执行后台批量任务。",
|
||||
"请至少勾选一项可视化配置选项以生成图表。",
|
||||
)
|
||||
return
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self, "确认生成",
|
||||
"将根据左侧勾选项在后台生成可视化图表,可能需要较长时间。\n是否继续?",
|
||||
QMessageBox.Yes | QMessageBox.No
|
||||
)
|
||||
|
||||
if reply != QMessageBox.Yes:
|
||||
return
|
||||
|
||||
# 收集所有选中的任务
|
||||
extra = {
|
||||
"gen_scatter": self.gen_scatter.isChecked(),
|
||||
"gen_spectrum": self.gen_spectrum.isChecked(),
|
||||
"gen_boxplots": self.gen_boxplots.isChecked(),
|
||||
"gen_mask_glint": self.gen_mask_glint.isChecked(),
|
||||
"gen_sampling_map": self.gen_sampling_map.isChecked(),
|
||||
}
|
||||
@ -4082,7 +4209,7 @@ class VisualizationPanel(QWidget):
|
||||
)
|
||||
return
|
||||
training_csv = training_spectra_csv
|
||||
models_dir = work_path / "7_models"
|
||||
models_dir = work_path / "7_Supervised_Model_Training"
|
||||
if not models_dir.is_dir() or not any(
|
||||
d.is_dir() for d in models_dir.iterdir()
|
||||
):
|
||||
@ -4284,7 +4411,6 @@ class VisualizationPanel(QWidget):
|
||||
'generate_scatter': self.gen_scatter.isChecked(),
|
||||
'generate_boxplots': self.gen_boxplots.isChecked(),
|
||||
'generate_spectrum': self.gen_spectrum.isChecked(),
|
||||
'generate_statistics': self.gen_stats_btn.isChecked(),
|
||||
'generate_glint_previews': self.gen_mask_glint.isChecked(),
|
||||
'generate_sampling_maps': self.gen_sampling_map.isChecked(),
|
||||
'scatter_config': {
|
||||
@ -4314,8 +4440,6 @@ class VisualizationPanel(QWidget):
|
||||
self.gen_boxplots.setChecked(config['generate_boxplots'])
|
||||
if 'generate_spectrum' in config:
|
||||
self.gen_spectrum.setChecked(config['generate_spectrum'])
|
||||
if 'generate_statistics' in config:
|
||||
self.gen_stats_btn.setChecked(config['generate_statistics'])
|
||||
if 'generate_glint_previews' in config:
|
||||
self.gen_mask_glint.setChecked(config['generate_glint_previews'])
|
||||
if 'generate_sampling_maps' in config:
|
||||
@ -4755,7 +4879,7 @@ class Step6_5Panel(QWidget):
|
||||
"输出模型目录:",
|
||||
"Directories;;All Files (*.*)"
|
||||
)
|
||||
self.output_dir.line_edit.setPlaceholderText("8_non_empirical_models")
|
||||
self.output_dir.line_edit.setPlaceholderText("8_Regression_Modeling")
|
||||
# 修改浏览按钮为选择目录
|
||||
self.output_dir.browse_btn.clicked.disconnect()
|
||||
self.output_dir.browse_btn.clicked.connect(self.browse_output_dir)
|
||||
@ -4825,9 +4949,9 @@ class Step6_5Panel(QWidget):
|
||||
# 如果output_dir为空,使用工作目录或当前目录
|
||||
main_window = self.parent().window()
|
||||
if hasattr(main_window, 'work_dir') and main_window.work_dir:
|
||||
output_dir = str(Path(main_window.work_dir) / "8_non_empirical_models")
|
||||
output_dir = str(Path(main_window.work_dir) / "8_Regression_Modeling")
|
||||
else:
|
||||
output_dir = str(Path.cwd() / "8_non_empirical_models")
|
||||
output_dir = str(Path.cwd() / "8_Regression_Modeling")
|
||||
config['output_dir'] = output_dir
|
||||
|
||||
# 添加训练数据路径(用于独立运行)
|
||||
@ -4952,7 +5076,8 @@ class Step6_75Panel(QWidget):
|
||||
# 创建滚动区域来容纳自变量选择
|
||||
x_scroll = QScrollArea()
|
||||
x_scroll.setWidgetResizable(True)
|
||||
x_scroll.setMaximumHeight(200)
|
||||
x_scroll.setMinimumHeight(250)
|
||||
x_scroll.setMaximumHeight(350)
|
||||
|
||||
x_widget = QWidget()
|
||||
self.x_columns_layout = QGridLayout()
|
||||
@ -4982,7 +5107,8 @@ class Step6_75Panel(QWidget):
|
||||
# 创建滚动区域来容纳因变量选择
|
||||
y_scroll = QScrollArea()
|
||||
y_scroll.setWidgetResizable(True)
|
||||
y_scroll.setMaximumHeight(150)
|
||||
y_scroll.setMinimumHeight(200)
|
||||
y_scroll.setMaximumHeight(300)
|
||||
|
||||
y_widget = QWidget()
|
||||
self.y_columns_layout = QGridLayout()
|
||||
@ -5044,7 +5170,7 @@ class Step6_75Panel(QWidget):
|
||||
output_layout = QFormLayout()
|
||||
|
||||
self.output_dir = QLineEdit()
|
||||
self.output_dir.setText("9_custom_regression")
|
||||
self.output_dir.setText("9_Custom_Regression_Modeling")
|
||||
output_layout.addRow("输出目录名:", self.output_dir)
|
||||
|
||||
output_group.setLayout(output_layout)
|
||||
@ -5202,7 +5328,7 @@ class Step6_75Panel(QWidget):
|
||||
checkbox.setChecked(method in selected_methods)
|
||||
|
||||
if 'output_dir' in config:
|
||||
self.output_dir.setText(config['output_dir'] or "9_custom_regression")
|
||||
self.output_dir.setText(config['output_dir'] or "9_Custom_Regression_Modeling")
|
||||
if 'enabled' in config:
|
||||
self.enable_checkbox.setChecked(config['enabled'])
|
||||
|
||||
@ -5574,26 +5700,28 @@ class WaterQualityGUI(QMainWindow):
|
||||
self.step_list = QListWidget()
|
||||
self.step_list.setStyleSheet(ModernStylesheet.get_sidebar_stylesheet())
|
||||
|
||||
# 定义三阶段结构
|
||||
# 定义四阶段结构
|
||||
self.process_stages = {
|
||||
"阶段一:数据预处理": [
|
||||
"阶段一:影像预处理": [
|
||||
("step1", "1. 水域掩膜生成"),
|
||||
("step2", "2. 耀斑区域识别"),
|
||||
("step3", "3. 耀斑去除与修复"),
|
||||
("step4", "4. 数据标准化处理"),
|
||||
],
|
||||
"阶段二:特征提取与建模": [
|
||||
"阶段二:样本数据准备 ": [
|
||||
("step4", "4. 数据标准化处理"),
|
||||
("step5", "5. 光谱特征提取"),
|
||||
("step5_5", "6. 水质参数指数计算"),
|
||||
("step6", "7. 监督学习模型训练"),
|
||||
("step6_5", "8. 经验统计回归"),
|
||||
("step6_75", "9. 自定义回归模型"),
|
||||
],
|
||||
"阶段三:应用与可视化": [
|
||||
"阶段三:模型构建与训练": [
|
||||
("step6", "7. 机器学习模型训练"),
|
||||
("step6_5", "8. 回归模型训练"),
|
||||
("step6_75", "9. 自定义回归模型训练"),
|
||||
],
|
||||
"阶段四:预测与成果输出 ": [
|
||||
("step7", "10. 采样点布设"),
|
||||
("step8", "11. 基于监督学习预测"),
|
||||
("step8_5", "12. 基于统计回归预测"),
|
||||
("step8_75", "13. 基于自定义回归预测"),
|
||||
("step8", "11. 机器学习学习预测"),
|
||||
("step8_5", "12. 回归预测"),
|
||||
("step8_75", "13. 自定义回归预测"),
|
||||
("step9", "14. 专题图生成"),
|
||||
("step9_viz", "15. 可视化分析"),
|
||||
("step_report", "16. 分析报告生成"),
|
||||
|
||||
11311
src/gui/work_dir/10_sampling/sampling_spectra.csv
Normal file
11311
src/gui/work_dir/10_sampling/sampling_spectra.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
src/gui/work_dir/10_sampling/sampling_spectra_valid_area.bsq
Normal file
BIN
src/gui/work_dir/10_sampling/sampling_spectra_valid_area.bsq
Normal file
Binary file not shown.
15
src/gui/work_dir/10_sampling/sampling_spectra_valid_area.hdr
Normal file
15
src/gui/work_dir/10_sampling/sampling_spectra_valid_area.hdr
Normal file
@ -0,0 +1,15 @@
|
||||
ENVI
|
||||
description = {
|
||||
work_dir\10_sampling\sampling_spectra_valid_area.bsq}
|
||||
samples = 11363
|
||||
lines = 10408
|
||||
bands = 1
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 4
|
||||
interleave = bsq
|
||||
byte order = 0
|
||||
map info = {UTM, 1, 1, 600742.055, 4613386.65, 0.2, 0.2, 51, North,WGS-84}
|
||||
coordinate system string = {PROJCS["unnamed",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",123.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]}
|
||||
band names = {
|
||||
Band 1}
|
||||
BIN
src/gui/work_dir/1_water_mask/hsi_preview.png
Normal file
BIN
src/gui/work_dir/1_water_mask/hsi_preview.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.6 MiB |
BIN
src/gui/work_dir/1_water_mask/water_mask_from_shp.dat
Normal file
BIN
src/gui/work_dir/1_water_mask/water_mask_from_shp.dat
Normal file
Binary file not shown.
15
src/gui/work_dir/1_water_mask/water_mask_from_shp.hdr
Normal file
15
src/gui/work_dir/1_water_mask/water_mask_from_shp.hdr
Normal file
@ -0,0 +1,15 @@
|
||||
ENVI
|
||||
description = {
|
||||
work_dir\1_water_mask\water_mask_from_shp.dat}
|
||||
samples = 11363
|
||||
lines = 10408
|
||||
bands = 1
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 4
|
||||
interleave = bsq
|
||||
byte order = 0
|
||||
map info = {UTM, 1, 1, 600742.055, 4613386.65, 0.2, 0.2, 51, North,WGS-84}
|
||||
coordinate system string = {PROJCS["unnamed",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",123.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]}
|
||||
band names = {
|
||||
Band 1}
|
||||
@ -0,0 +1,15 @@
|
||||
ENVI
|
||||
description = {
|
||||
work_dir\1_water_mask\water_mask_from_shp__tmp_delete.dat}
|
||||
samples = 7411
|
||||
lines = 5368
|
||||
bands = 1
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 1
|
||||
interleave = bsq
|
||||
byte order = 0
|
||||
map info = {UTM, 1, 1, 600947.955, 4612951.75, 0.2, 0.2, 51, North,WGS-84}
|
||||
coordinate system string = {PROJCS["WGS_1984_UTM_Zone_51N",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",123.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]}
|
||||
band names = {
|
||||
Band 1}
|
||||
BIN
src/gui/work_dir/1_water_mask/water_mask_overlay.png
Normal file
BIN
src/gui/work_dir/1_water_mask/water_mask_overlay.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.6 MiB |
@ -281,11 +281,13 @@ class WaterQualityVisualization:
|
||||
def plot_statistical_charts(self, csv_path: str, parameter_columns: List[str],
|
||||
output_dir: Optional[str] = None) -> Dict[str, str]:
|
||||
"""
|
||||
绘制统计图表:箱线图、直方图、相关性热力图
|
||||
绘制统计图表:**只针对水质参数列**(数值型,排除波长列)
|
||||
- 水质参数列(如浓度、含量等数值型参数)使用箱线图/直方图/相关性热力图
|
||||
- 排除光谱波长列(虽然也是数值型,但不是水质参数)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径
|
||||
parameter_columns: 参数列名列表
|
||||
parameter_columns: **水质参数**列名列表(数值型,已排除波长列)
|
||||
output_dir: 输出目录
|
||||
|
||||
Returns:
|
||||
@ -301,12 +303,16 @@ class WaterQualityVisualization:
|
||||
|
||||
output_paths = {}
|
||||
|
||||
# 水质参数统计图表(针对数值型参数,排除波长列)
|
||||
# 假设传入的 parameter_columns 已经是过滤后的水质参数列
|
||||
numeric_cols = [col for col in parameter_columns if col in df.columns and pd.api.types.is_numeric_dtype(df[col])]
|
||||
|
||||
# 1. 箱线图
|
||||
if len(parameter_columns) > 0:
|
||||
if len(numeric_cols) > 0:
|
||||
fig, ax = plt.subplots(figsize=(12, 6))
|
||||
data_for_boxplot = [df[col].dropna() for col in parameter_columns if col in df.columns]
|
||||
data_for_boxplot = [df[col].dropna() for col in numeric_cols]
|
||||
if data_for_boxplot:
|
||||
ax.boxplot(data_for_boxplot, labels=[col for col in parameter_columns if col in df.columns])
|
||||
ax.boxplot(data_for_boxplot, labels=numeric_cols)
|
||||
ax.set_ylabel('数值', fontsize=12, fontweight='bold')
|
||||
ax.set_title('水质参数箱线图', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
@ -318,51 +324,51 @@ class WaterQualityVisualization:
|
||||
plt.close()
|
||||
output_paths['boxplot'] = str(boxplot_path)
|
||||
|
||||
# 2. 直方图
|
||||
for col in parameter_columns:
|
||||
if col not in df.columns:
|
||||
continue
|
||||
# 2. 直方图 (每个水质参数列)
|
||||
for col in numeric_cols:
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
data = df[col].dropna()
|
||||
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
|
||||
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
|
||||
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 添加统计信息
|
||||
mean_val = data.mean()
|
||||
std_val = data.std()
|
||||
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
|
||||
ax.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
|
||||
hist_path = output_dir / f"{safe_name}_histogram.png"
|
||||
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
output_paths[f'histogram_{col}'] = str(hist_path)
|
||||
|
||||
# 3. 相关性热力图
|
||||
if len(parameter_columns) >= 2:
|
||||
valid_cols = [col for col in parameter_columns if col in df.columns]
|
||||
if len(valid_cols) >= 2:
|
||||
corr_matrix = df[valid_cols].corr()
|
||||
if len(data) > 1:
|
||||
ax.hist(data, bins=30, edgecolor='black', alpha=0.7, color='skyblue')
|
||||
ax.set_xlabel(f'{col} 数值', fontsize=12, fontweight='bold')
|
||||
ax.set_ylabel('频数', fontsize=12, fontweight='bold')
|
||||
ax.set_title(f'{col} 分布直方图', fontsize=14, fontweight='bold')
|
||||
ax.grid(True, alpha=0.3, axis='y')
|
||||
|
||||
# 添加统计信息
|
||||
mean_val = data.mean()
|
||||
std_val = data.std() if len(data) > 1 else 0
|
||||
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'均值: {mean_val:.4f}')
|
||||
ax.legend()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
|
||||
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
|
||||
ax=ax, vmin=-1, vmax=1)
|
||||
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
heatmap_path = output_dir / "correlation_heatmap.png"
|
||||
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
|
||||
safe_name = "".join(c for c in col if c.isalnum() or c in ('-', '_', '.'))
|
||||
hist_path = output_dir / f"{safe_name}_histogram.png"
|
||||
plt.savefig(hist_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
output_paths['heatmap'] = str(heatmap_path)
|
||||
output_paths[f'histogram_{col}'] = str(hist_path)
|
||||
|
||||
print(f"统计图表已保存到: {output_dir}")
|
||||
# 3. 相关性热力图
|
||||
if len(numeric_cols) >= 2:
|
||||
corr_matrix = df[numeric_cols].corr()
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 8))
|
||||
sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='coolwarm',
|
||||
center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8},
|
||||
ax=ax, vmin=-1, vmax=1)
|
||||
ax.set_title('水质参数相关性热力图', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
|
||||
heatmap_path = output_dir / "correlation_heatmap.png"
|
||||
plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
output_paths['heatmap'] = str(heatmap_path)
|
||||
|
||||
if not output_paths:
|
||||
print("警告: 没有生成任何统计图表(可能无合适的水质参数列)")
|
||||
else:
|
||||
print(f"统计图表已保存到: {output_dir},共 {len(output_paths)} 个文件")
|
||||
return output_paths
|
||||
|
||||
def plot_distribution_map_enhanced(self, prediction_csv_path: str,
|
||||
|
||||
@ -11,16 +11,24 @@ def MMS(input_spectrum):
|
||||
output_spectrum = MinMaxScaler().fit_transform(input_spectrum)
|
||||
return output_spectrum
|
||||
|
||||
# 标准化
|
||||
# 标准化 (StandardScaler)
|
||||
def SS(input_spectrum, save_path=None):
|
||||
"""标准化预处理,使用 StandardScaler 拟合并可选保存模型参数。
|
||||
|
||||
Args:
|
||||
input_spectrum: 输入光谱数据 (numpy array or DataFrame)
|
||||
save_path: scaler模型保存路径。如果提供则保存到该路径(推荐保存到7_Supervised_Model_Training目录)
|
||||
"""
|
||||
# 初始化 StandardScaler 并拟合数据
|
||||
scaler = StandardScaler()
|
||||
output_spectrum = scaler.fit_transform(input_spectrum)
|
||||
|
||||
# 如果指定了保存路径,保存 scaler 对象
|
||||
# 如果指定了保存路径,保存 scaler 对象(用于后续预测时加载)
|
||||
if save_path:
|
||||
import os
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
joblib.dump(scaler, save_path)
|
||||
print(f"Scaler parameters saved to {save_path}")
|
||||
print(f"SS Scaler parameters saved to: {save_path}")
|
||||
|
||||
return output_spectrum
|
||||
|
||||
@ -124,7 +132,14 @@ def wave(input_spectrum):
|
||||
return output_spectrum
|
||||
|
||||
# 通用预处理函数
|
||||
def Preprocessing(method, input_spectrum):
|
||||
def Preprocessing(method, input_spectrum, save_path=None):
|
||||
"""通用预处理函数
|
||||
|
||||
Args:
|
||||
method: 预处理方法名称
|
||||
input_spectrum: 输入光谱数据
|
||||
save_path: 可选的模型保存路径(仅SS方法使用,保存到7_Supervised_Model_Training目录)
|
||||
"""
|
||||
if isinstance(input_spectrum, np.ndarray):
|
||||
input_spectrum = pd.DataFrame(input_spectrum)
|
||||
if method == "None":
|
||||
@ -132,7 +147,11 @@ def Preprocessing(method, input_spectrum):
|
||||
elif method == 'MMS':
|
||||
output_spectrum = MMS(input_spectrum.values)
|
||||
elif method == 'SS':
|
||||
output_spectrum = SS(input_spectrum.values, r'E:\code\WQ\models/scaler_params.pkl')
|
||||
# SS预处理模型保存到工作目录的7_Supervised_Model_Training/scaler_params.pkl
|
||||
# 如果调用者没有提供save_path,则使用默认路径
|
||||
if not save_path:
|
||||
save_path = r'E:\code\WQ\models\scaler_params.pkl'
|
||||
output_spectrum = SS(input_spectrum.values, save_path)
|
||||
elif method == 'CT':
|
||||
output_spectrum = CT(input_spectrum.values)
|
||||
elif method == 'SNV':
|
||||
|
||||
@ -16,78 +16,46 @@ except ImportError:
|
||||
|
||||
def get_wavelengths_from_bil_header(bil_file):
|
||||
"""
|
||||
从BIL文件的头文件中读取波长信息
|
||||
从BIL文件的头文件中读取波长信息(使用spectral库)
|
||||
|
||||
参数:
|
||||
bil_file: str - BIL文件路径
|
||||
bil_file: str - BIL文件路径
|
||||
|
||||
返回:
|
||||
list - 波长列表,如果无法获取则返回None
|
||||
list - 波长列表,如果无法获取则返回None
|
||||
"""
|
||||
try:
|
||||
# 获取头文件路径(通常与BIL文件同目录,后缀为.hdr)
|
||||
# 获取头文件路径
|
||||
header_file = os.path.splitext(bil_file)[0] + ".hdr"
|
||||
|
||||
if not os.path.exists(header_file):
|
||||
print(f"警告: 找不到头文件 {header_file}")
|
||||
return None
|
||||
|
||||
wavelengths = []
|
||||
# 使用spectral库读取头文件
|
||||
import spectral.io.envi as envi
|
||||
header = envi.read_envi_header(header_file)
|
||||
|
||||
with open(header_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
# 获取波长信息
|
||||
wavelengths = header.get('wavelength', None)
|
||||
|
||||
if wavelengths is not None:
|
||||
# 确保是列表形式
|
||||
if isinstance(wavelengths, str):
|
||||
# 如果是字符串,解析为列表
|
||||
wavelengths = wavelengths.strip('{}').replace(',', ' ').split()
|
||||
wavelengths = [float(w.strip()) for w in wavelengths if w.strip()]
|
||||
|
||||
# 查找包含波长信息的行
|
||||
wavelength_lines = []
|
||||
in_wavelength_block = False
|
||||
# 过滤掉0值和无效值
|
||||
wavelengths = [float(w) for w in wavelengths if float(w) > 0]
|
||||
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
|
||||
# 检测波长块的开始(精确匹配 wavelength = )
|
||||
if stripped_line.startswith('wavelength ='):
|
||||
in_wavelength_block = True
|
||||
# 提取第一行的波长信息
|
||||
wavelength_str = stripped_line.replace('wavelength =', '').strip()
|
||||
if wavelength_str.startswith('{'):
|
||||
wavelength_str = wavelength_str[1:].strip()
|
||||
wavelength_lines.append(wavelength_str)
|
||||
# 检测波长块的中间行
|
||||
elif in_wavelength_block:
|
||||
if '}' in stripped_line:
|
||||
# 波长块结束
|
||||
end_str = stripped_line.replace('}', '').strip()
|
||||
if end_str:
|
||||
wavelength_lines.append(end_str)
|
||||
in_wavelength_block = False
|
||||
else:
|
||||
wavelength_lines.append(stripped_line)
|
||||
print(f"从头文件读取到 {len(wavelengths)} 个波长值")
|
||||
print(f"波长范围: {min(wavelengths):.2f} ~ {max(wavelengths):.2f}")
|
||||
return wavelengths
|
||||
else:
|
||||
print("警告: 头文件中未找到波长信息")
|
||||
return None
|
||||
|
||||
if wavelength_lines:
|
||||
# 合并所有波长行
|
||||
combined_wavelengths = ' '.join(wavelength_lines)
|
||||
# 移除所有花括号和逗号
|
||||
combined_wavelengths = combined_wavelengths.replace('{', '').replace('}', '').strip()
|
||||
|
||||
# 分割波长值(支持逗号和空格分隔)
|
||||
wavelength_values = []
|
||||
for part in combined_wavelengths.split(','):
|
||||
part = part.strip()
|
||||
if part:
|
||||
# 处理可能的多值情况(空格分隔)
|
||||
for value in part.split():
|
||||
if value.strip():
|
||||
try:
|
||||
wavelength_values.append(float(value.strip()))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
print(f"从头文件读取到 {len(wavelength_values)} 个波长值")
|
||||
return wavelength_values
|
||||
else:
|
||||
print("警告: 头文件中未找到波长信息")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"读取头文件波长信息时发生错误: {str(e)}")
|
||||
return None
|
||||
@ -1021,10 +989,10 @@ def get_coor_base_interval(water_mask, severe_glint=None, output_csvpath=None, i
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 新功能使用示例
|
||||
bil_file = r"D:\BaiduNetdiskDownload\yaobao\result3.bsq"
|
||||
water_mask_shp = r"D:\BaiduNetdiskDownload\yaobao\roi\roi.shp"
|
||||
severe_glint = r"D:\BaiduNetdiskDownload\yaobao\find_glint\result3_glint_otsu"
|
||||
output_csvpath = r"D:\BaiduNetdiskDownload\yaobao\csv\spectral_sampling_results.csv"
|
||||
bil_file = r"E:\wq_gui_test\3_deglint\deglint_goodman.bsq"
|
||||
water_mask_shp = r"E:\wq_gui_test\1_water_mask\water_mask_from_shp.dat"
|
||||
severe_glint = r"E:\wq_gui_test\2_glint\severe_glint_area.dat"
|
||||
output_csvpath = r"E:\wq_gui_test\10_sampling\sampling_spectra.csv"
|
||||
|
||||
# 设置参数
|
||||
interval = 50 # 基础采样点间隔(像元数),当use_adaptive_sampling=False时使用
|
||||
|
||||
Reference in New Issue
Block a user