fix: 修复工作目录与步骤名不对应、回归预测虚数报错、模型加载及预处理名称转换问题,重构可视化并修正勾选联动
This commit is contained in:
654
src/core/prediction/custom_regression_prediction.py
Normal file
654
src/core/prediction/custom_regression_prediction.py
Normal file
@ -0,0 +1,654 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自定义回归预测模块
|
||||
|
||||
该模块根据9_Custom_Regression_Modeling文件夹中的CSV信息,批量预测水质指数。
|
||||
处理流程:
|
||||
1. 读取9_Custom_Regression_Modeling文件夹中的CSV文件
|
||||
2. 根据r_squared选择最佳模型(指数公式+反演公式)
|
||||
3. 使用指数公式计算光谱指数值
|
||||
4. 使用反演公式计算水质参数值
|
||||
5. 输出包含投影经纬度和预测值的CSV文件
|
||||
|
||||
作者: Assistant
|
||||
创建时间: 2026-04-14
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Union, Tuple, Optional
|
||||
import warnings
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 导入水质指数计算器
|
||||
try:
|
||||
from src.utils.water_index import WaterQualityIndexCalculator
|
||||
except ImportError:
|
||||
from ..utils.water_index import WaterQualityIndexCalculator
|
||||
|
||||
|
||||
class CustomRegressionPredictor:
|
||||
"""
|
||||
自定义回归预测器
|
||||
|
||||
基于9_Custom_Regression_Modeling文件夹中的回归模型CSV文件,
|
||||
进行水质参数的批量预测。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
regression_models_dir: str = "9_Custom_Regression_Modeling",
|
||||
formula_csv_path: Optional[str] = None,
|
||||
output_dir: str = "prediction_results",
|
||||
log_level: int = logging.INFO):
|
||||
"""
|
||||
初始化预测器
|
||||
|
||||
Args:
|
||||
regression_models_dir: 回归模型CSV文件所在目录
|
||||
formula_csv_path: 公式CSV文件路径,用于查找index_formula
|
||||
output_dir: 预测结果输出目录
|
||||
log_level: 日志级别
|
||||
"""
|
||||
self.regression_models_dir = Path(regression_models_dir)
|
||||
self.formula_csv_path = formula_csv_path
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 初始化日志
|
||||
self._setup_logging(log_level)
|
||||
|
||||
# 初始化水质指数计算器
|
||||
self.index_calculator = WaterQualityIndexCalculator()
|
||||
|
||||
# 存储加载的回归模型信息
|
||||
self.regression_models = {}
|
||||
self.best_models = {}
|
||||
|
||||
# 加载公式CSV(如果提供)
|
||||
self.formula_df = None
|
||||
if self.formula_csv_path and Path(self.formula_csv_path).exists():
|
||||
self.formula_df = pd.read_csv(self.formula_csv_path)
|
||||
self.logger.info(f"加载公式CSV: {self.formula_csv_path}, 共 {len(self.formula_df)} 个公式")
|
||||
|
||||
self.logger.info(f"CustomRegressionPredictor初始化完成")
|
||||
self.logger.info(f"回归模型目录: {self.regression_models_dir}")
|
||||
self.logger.info(f"输出目录: {self.output_dir}")
|
||||
|
||||
def _setup_logging(self, log_level: int):
|
||||
"""设置日志配置"""
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.logger.setLevel(log_level)
|
||||
|
||||
# 避免重复添加处理器
|
||||
if not self.logger.handlers:
|
||||
# 创建控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(log_level)
|
||||
|
||||
# 创建格式器
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def load_regression_models(self) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
加载9_Custom_Regression_Modeling文件夹中的所有CSV文件
|
||||
|
||||
支持的CSV格式:
|
||||
- 回归结果CSV包含列:y_variable, x_variable, equation, r_squared
|
||||
- 其中 x_variable 是指数名称,需要从 formula_csv 中查找对应的 Formula
|
||||
- equation 是反演公式
|
||||
|
||||
Returns:
|
||||
Dict[str, pd.DataFrame]: 参数名 -> 回归模型数据框
|
||||
"""
|
||||
if not self.regression_models_dir.exists():
|
||||
raise FileNotFoundError(f"回归模型目录不存在: {self.regression_models_dir}")
|
||||
|
||||
csv_files = list(self.regression_models_dir.glob("*.csv"))
|
||||
if not csv_files:
|
||||
raise FileNotFoundError(f"在目录 {self.regression_models_dir} 中未找到CSV文件")
|
||||
|
||||
self.logger.info(f"找到 {len(csv_files)} 个CSV文件")
|
||||
|
||||
for csv_file in csv_files:
|
||||
try:
|
||||
# 跳过all_regression_results.csv,处理单个参数的结果文件
|
||||
if csv_file.name.lower() == "all_regression_results.csv":
|
||||
self.logger.info(f"跳过汇总文件: {csv_file.name}")
|
||||
continue
|
||||
|
||||
# 从文件名推断参数名(移除_regression_results后缀)
|
||||
param_name = csv_file.stem.replace('_regression_results', '').replace('_results', '')
|
||||
self.logger.info(f"加载回归模型文件: {csv_file.name} -> 参数: {param_name}")
|
||||
|
||||
# 读取CSV文件
|
||||
df = pd.read_csv(csv_file)
|
||||
|
||||
# 检查必需的列并进行列名映射
|
||||
# 实际CSV列名: y_variable, x_variable, equation, r_squared
|
||||
# 需要的列名: index_formula, inversion_formula, r_squared
|
||||
|
||||
# 列名映射
|
||||
column_mapping = {}
|
||||
|
||||
# 检查r_squared(可能为R2, r2, R_squared等)
|
||||
r_squared_cols = ['r_squared', 'R_squared', 'R2', 'r2', 'R_squared', 'R²']
|
||||
found_r2 = None
|
||||
for col in r_squared_cols:
|
||||
if col in df.columns:
|
||||
found_r2 = col
|
||||
break
|
||||
if found_r2 and found_r2 != 'r_squared':
|
||||
column_mapping[found_r2] = 'r_squared'
|
||||
|
||||
# 检查equation(反演公式)
|
||||
if 'equation' in df.columns:
|
||||
column_mapping['equation'] = 'inversion_formula'
|
||||
elif 'inversion_formula' not in df.columns:
|
||||
self.logger.warning(f"文件 {csv_file.name} 缺少反演公式列(equation)")
|
||||
continue
|
||||
|
||||
# 检查x_variable(指数名称,需要转换为index_formula)
|
||||
if 'x_variable' in df.columns:
|
||||
# 需要从formula_csv中查找对应的公式
|
||||
if self.formula_df is not None:
|
||||
df['index_formula'] = df['x_variable'].apply(
|
||||
lambda x: self._lookup_formula_from_name(x)
|
||||
)
|
||||
else:
|
||||
# 如果没有formula_csv,直接使用x_variable作为index_formula
|
||||
column_mapping['x_variable'] = 'index_formula'
|
||||
elif 'index_formula' not in df.columns:
|
||||
self.logger.warning(f"文件 {csv_file.name} 缺少指数名称列(x_variable)")
|
||||
continue
|
||||
|
||||
# 应用列名映射
|
||||
if column_mapping:
|
||||
df = df.rename(columns=column_mapping)
|
||||
|
||||
# 验证必需的列
|
||||
required_columns = ['index_formula', 'inversion_formula', 'r_squared']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
if missing_columns:
|
||||
self.logger.warning(f"文件 {csv_file.name} 缺少必需列: {missing_columns}")
|
||||
continue
|
||||
|
||||
self.regression_models[param_name] = df
|
||||
self.logger.info(f"成功加载参数 {param_name} 的 {len(df)} 个回归模型")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载文件 {csv_file.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
if not self.regression_models:
|
||||
raise ValueError("未成功加载任何回归模型")
|
||||
|
||||
return self.regression_models
|
||||
|
||||
def _lookup_formula_from_name(self, formula_name: str) -> str:
|
||||
"""
|
||||
从公式CSV中根据公式名称查找对应的公式
|
||||
|
||||
Args:
|
||||
formula_name: 公式名称(如 BGA_Am09KBBI)
|
||||
|
||||
Returns:
|
||||
公式字符串
|
||||
"""
|
||||
if self.formula_df is None:
|
||||
return formula_name
|
||||
|
||||
# 查找匹配的公式名称
|
||||
# 支持多种列名:Formula_Name, formula_name, name等
|
||||
name_cols = ['Formula_Name', 'formula_name', 'Name', 'name']
|
||||
name_col = None
|
||||
for col in name_cols:
|
||||
if col in self.formula_df.columns:
|
||||
name_col = col
|
||||
break
|
||||
|
||||
if name_col is None:
|
||||
self.logger.warning(f"公式CSV中未找到公式名称列,使用默认列名")
|
||||
name_col = self.formula_df.columns[0] if len(self.formula_df.columns) > 0 else None
|
||||
|
||||
# 查找公式列
|
||||
formula_cols = ['Formula', 'formula', 'Expression', 'expression']
|
||||
formula_col = None
|
||||
for col in formula_cols:
|
||||
if col in self.formula_df.columns:
|
||||
formula_col = col
|
||||
break
|
||||
|
||||
if formula_col is None and len(self.formula_df.columns) > 2:
|
||||
formula_col = self.formula_df.columns[2] # 通常第3列是公式
|
||||
|
||||
if name_col and formula_col:
|
||||
match = self.formula_df[self.formula_df[name_col] == formula_name]
|
||||
if not match.empty:
|
||||
return match[formula_col].iloc[0]
|
||||
|
||||
# 如果未找到,返回原始名称
|
||||
return formula_name
|
||||
|
||||
def select_best_models(self) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
为每个水质参数选择r_squared最高的回归模型
|
||||
|
||||
Returns:
|
||||
Dict[str, pd.Series]: 参数名 -> 最佳模型信息
|
||||
"""
|
||||
if not self.regression_models:
|
||||
self.load_regression_models()
|
||||
|
||||
for param_name, models_df in self.regression_models.items():
|
||||
# 根据r_squared排序,选择最高的
|
||||
best_model = models_df.loc[models_df['r_squared'].idxmax()]
|
||||
self.best_models[param_name] = best_model
|
||||
|
||||
self.logger.info(f"参数 {param_name} 最佳模型:")
|
||||
self.logger.info(f" 指数公式: {best_model['index_formula']}")
|
||||
self.logger.info(f" 反演公式: {best_model['inversion_formula']}")
|
||||
self.logger.info(f" R²: {best_model['r_squared']:.4f}")
|
||||
|
||||
return self.best_models
|
||||
|
||||
def calculate_spectral_index(self,
|
||||
spectral_data: pd.DataFrame,
|
||||
index_formula: str) -> pd.Series:
|
||||
"""
|
||||
根据指数公式计算光谱指数值
|
||||
|
||||
Args:
|
||||
spectral_data: 光谱数据,列名为波长
|
||||
index_formula: 指数公式字符串
|
||||
|
||||
Returns:
|
||||
pd.Series: 计算得到的指数值
|
||||
"""
|
||||
try:
|
||||
# 检查公式是否为水质指数计算器中的标准函数
|
||||
if hasattr(self.index_calculator, index_formula):
|
||||
# 使用标准水质指数计算器
|
||||
calculator_func = getattr(self.index_calculator, index_formula)
|
||||
index_values = calculator_func(spectral_data)
|
||||
self.logger.debug(f"使用标准指数函数 {index_formula} 计算完成")
|
||||
return index_values
|
||||
|
||||
# 否则尝试解析自定义公式
|
||||
# 支持常见的指数公式格式,如 (R865-R644)/(R458+R529) 或 w681 - w665
|
||||
formula = index_formula.strip()
|
||||
|
||||
# 替换波长变量为对应的列值
|
||||
# 查找所有R或w开头的波长变量(支持 R681, w681, R_681 等格式)
|
||||
wavelength_pattern = r'[Rw](\d+)'
|
||||
wavelengths = re.findall(wavelength_pattern, formula)
|
||||
|
||||
# 构建可执行的表达式
|
||||
expression = formula
|
||||
for wl in set(wavelengths):
|
||||
wl_int = int(wl)
|
||||
# 找到最接近的波长列
|
||||
closest_col = self.index_calculator.find_closest_wavelength(
|
||||
spectral_data.columns.tolist(), wl_int
|
||||
)
|
||||
# 替换公式中的变量(支持 R数字 和 w数字 格式)
|
||||
expression = expression.replace(f'R{wl}', f'spectral_data["{closest_col}"]')
|
||||
expression = expression.replace(f'w{wl}', f'spectral_data["{closest_col}"]')
|
||||
|
||||
# 评估表达式
|
||||
index_values = eval(expression)
|
||||
self.logger.debug(f"使用自定义公式 {index_formula} 计算完成")
|
||||
|
||||
return pd.Series(index_values, index=spectral_data.index)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"计算光谱指数失败,公式: {index_formula}, 错误: {e}")
|
||||
# 返回NaN值
|
||||
return pd.Series(np.nan, index=spectral_data.index)
|
||||
|
||||
def apply_inversion_formula(self,
|
||||
index_values: pd.Series,
|
||||
inversion_formula: str) -> pd.Series:
|
||||
"""
|
||||
根据反演公式将指数值转换为水质参数值
|
||||
|
||||
Args:
|
||||
index_values: 光谱指数值
|
||||
inversion_formula: 反演公式字符串,如 "y = 11.27 + 0.82*x" 或 "11.27 + 0.82*x"
|
||||
支持的函数:ln(x)自然对数, log10(x)常用对数, exp(x)指数, sqrt(x)平方根
|
||||
|
||||
Returns:
|
||||
pd.Series: 水质参数预测值
|
||||
"""
|
||||
try:
|
||||
# 处理公式中的常见数学符号
|
||||
formula = inversion_formula.strip()
|
||||
|
||||
# 去除 "y = " 或 "y=" 前缀
|
||||
if formula.lower().startswith('y='):
|
||||
formula = formula[2:].strip()
|
||||
elif formula.lower().startswith('y = '):
|
||||
formula = formula[4:].strip()
|
||||
|
||||
# 替换数学符号
|
||||
formula = formula.replace('^', '**') # 幂运算
|
||||
|
||||
# 定义数学函数环境
|
||||
import math
|
||||
math_env = {
|
||||
'ln': math.log, # 自然对数
|
||||
'log': math.log, # 自然对数(别名)
|
||||
'log10': math.log10, # 常用对数
|
||||
'exp': math.exp, # 指数函数
|
||||
'sqrt': math.sqrt, # 平方根
|
||||
'abs': abs, # 绝对值
|
||||
'pow': pow, # 幂函数
|
||||
}
|
||||
|
||||
self.logger.debug(f"处理后的反演公式: {formula}")
|
||||
|
||||
# 评估每个指数值
|
||||
predicted_values = []
|
||||
for x_val in index_values:
|
||||
try:
|
||||
if pd.isna(x_val):
|
||||
predicted_values.append(np.nan)
|
||||
else:
|
||||
# 评估公式,将 x 作为局部变量传递,同时提供数学函数
|
||||
local_vars = {'x': x_val}
|
||||
local_vars.update(math_env)
|
||||
result = eval(formula, {"__builtins__": {}}, local_vars)
|
||||
|
||||
# 如果是复数,取实部
|
||||
if isinstance(result, complex):
|
||||
result = result.real
|
||||
|
||||
predicted_values.append(result)
|
||||
except Exception as eval_err:
|
||||
self.logger.debug(f"评估公式失败,x={x_val}: {eval_err}")
|
||||
predicted_values.append(np.nan)
|
||||
|
||||
self.logger.debug(f"使用反演公式 {inversion_formula} 计算完成")
|
||||
return pd.Series(predicted_values, index=index_values.index)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"应用反演公式失败,公式: {inversion_formula}, 错误: {e}")
|
||||
return pd.Series(np.nan, index=index_values.index)
|
||||
|
||||
def predict_single_parameter(self,
|
||||
spectral_data: pd.DataFrame,
|
||||
coordinate_data: pd.DataFrame,
|
||||
param_name: str) -> pd.DataFrame:
|
||||
"""
|
||||
预测单个水质参数
|
||||
|
||||
Args:
|
||||
spectral_data: 光谱数据(不包含坐标列)
|
||||
coordinate_data: 坐标数据(包含经纬度或投影坐标)
|
||||
param_name: 水质参数名
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 包含坐标和预测值的结果
|
||||
"""
|
||||
if param_name not in self.best_models:
|
||||
raise ValueError(f"未找到参数 {param_name} 的最佳模型")
|
||||
|
||||
best_model = self.best_models[param_name]
|
||||
index_formula = best_model['index_formula']
|
||||
inversion_formula = best_model['inversion_formula']
|
||||
|
||||
self.logger.info(f"开始预测参数: {param_name}")
|
||||
|
||||
# 步骤1: 计算光谱指数
|
||||
self.logger.info("步骤1: 计算光谱指数")
|
||||
index_values = self.calculate_spectral_index(spectral_data, index_formula)
|
||||
|
||||
# 步骤2: 应用反演公式
|
||||
self.logger.info("步骤2: 应用反演公式")
|
||||
predicted_values = self.apply_inversion_formula(index_values, inversion_formula)
|
||||
|
||||
# 步骤3: 组合结果
|
||||
result_df = coordinate_data.copy()
|
||||
result_df[f'{param_name}_predicted'] = predicted_values
|
||||
result_df[f'{param_name}_index'] = index_values
|
||||
|
||||
# 添加元数据
|
||||
result_df.attrs = {
|
||||
'parameter': param_name,
|
||||
'index_formula': index_formula,
|
||||
'inversion_formula': inversion_formula,
|
||||
'r_squared': best_model['r_squared'],
|
||||
'prediction_time': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
valid_predictions = result_df[f'{param_name}_predicted'].notna().sum()
|
||||
total_points = len(result_df)
|
||||
self.logger.info(f"参数 {param_name} 预测完成: {valid_predictions}/{total_points} 个有效预测")
|
||||
|
||||
return result_df
|
||||
|
||||
def predict_all_parameters(self,
|
||||
spectral_data: pd.DataFrame,
|
||||
coordinate_data: pd.DataFrame) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
批量预测所有水质参数
|
||||
|
||||
Args:
|
||||
spectral_data: 光谱数据(不包含坐标列)
|
||||
coordinate_data: 坐标数据(包含经纬度或投影坐标)
|
||||
|
||||
Returns:
|
||||
Dict[str, pd.DataFrame]: 参数名 -> 预测结果DataFrame
|
||||
"""
|
||||
if not self.best_models:
|
||||
self.select_best_models()
|
||||
|
||||
prediction_results = {}
|
||||
|
||||
for param_name in self.best_models.keys():
|
||||
try:
|
||||
self.logger.info(f"正在预测参数: {param_name}")
|
||||
result_df = self.predict_single_parameter(
|
||||
spectral_data, coordinate_data, param_name
|
||||
)
|
||||
prediction_results[param_name] = result_df
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"预测参数 {param_name} 失败: {e}")
|
||||
continue
|
||||
|
||||
self.logger.info(f"批量预测完成,成功预测 {len(prediction_results)} 个参数")
|
||||
return prediction_results
|
||||
|
||||
def _convert_complex_to_real(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
将DataFrame中的复数转换为实数(取实部)
|
||||
|
||||
Args:
|
||||
df: 输入DataFrame
|
||||
|
||||
Returns:
|
||||
处理后的DataFrame
|
||||
"""
|
||||
df_clean = df.copy()
|
||||
for col in df_clean.columns:
|
||||
if df_clean[col].dtype == 'object' or 'complex' in str(df_clean[col].dtype):
|
||||
# 检查是否包含复数
|
||||
try:
|
||||
df_clean[col] = df_clean[col].apply(
|
||||
lambda x: x.real if isinstance(x, complex) else x
|
||||
)
|
||||
except:
|
||||
pass
|
||||
return df_clean
|
||||
|
||||
def save_prediction_results(self,
|
||||
prediction_results: Dict[str, pd.DataFrame],
|
||||
filename_prefix: str = "custom_regression_prediction") -> Dict[str, str]:
|
||||
"""
|
||||
保存预测结果为CSV文件
|
||||
|
||||
Args:
|
||||
prediction_results: 预测结果字典
|
||||
filename_prefix: 文件名前缀
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 参数名 -> 保存的文件路径
|
||||
"""
|
||||
saved_files = {}
|
||||
|
||||
|
||||
for param_name, result_df in prediction_results.items():
|
||||
try:
|
||||
# 构建文件名
|
||||
filename = f"{filename_prefix}_{param_name}.csv"
|
||||
filepath = self.output_dir / filename
|
||||
|
||||
# 处理复数类型,转换为实数
|
||||
result_df_clean = self._convert_complex_to_real(result_df)
|
||||
|
||||
# 保存CSV文件
|
||||
result_df_clean.to_csv(filepath, index=False, encoding='utf-8-sig')
|
||||
saved_files[param_name] = str(filepath)
|
||||
|
||||
self.logger.info(f"参数 {param_name} 预测结果已保存: {filepath}")
|
||||
|
||||
# 打印统计信息
|
||||
predicted_col = f'{param_name}_predicted'
|
||||
if predicted_col in result_df.columns:
|
||||
valid_count = result_df[predicted_col].notna().sum()
|
||||
mean_value = result_df[predicted_col].mean()
|
||||
std_value = result_df[predicted_col].std()
|
||||
|
||||
self.logger.info(f" 有效预测: {valid_count} 个")
|
||||
self.logger.info(f" 平均值: {mean_value:.4f}")
|
||||
self.logger.info(f" 标准差: {std_value:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存参数 {param_name} 结果失败: {e}")
|
||||
continue
|
||||
|
||||
return saved_files
|
||||
|
||||
def run_batch_prediction(self,
|
||||
input_csv_path: str,
|
||||
coordinate_columns: List[str] = None,
|
||||
filename_prefix: str = "custom_regression_prediction") -> Dict[str, str]:
|
||||
"""
|
||||
运行完整的批量预测流程
|
||||
|
||||
Args:
|
||||
input_csv_path: 输入的光谱采样CSV文件路径
|
||||
coordinate_columns: 坐标列名列表,默认为['longitude', 'latitude']
|
||||
filename_prefix: 输出文件名前缀
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 参数名 -> 保存的文件路径
|
||||
"""
|
||||
if coordinate_columns is None:
|
||||
coordinate_columns = ['longitude', 'latitude']
|
||||
|
||||
self.logger.info("="*50)
|
||||
self.logger.info("开始自定义回归批量预测")
|
||||
self.logger.info("="*50)
|
||||
|
||||
# 1. 加载回归模型
|
||||
self.logger.info("步骤1: 加载回归模型")
|
||||
self.load_regression_models()
|
||||
|
||||
# 2. 选择最佳模型
|
||||
self.logger.info("步骤2: 选择最佳模型")
|
||||
self.select_best_models()
|
||||
|
||||
# 3. 读取输入数据
|
||||
self.logger.info("步骤3: 读取输入数据")
|
||||
input_df = pd.read_csv(input_csv_path)
|
||||
self.logger.info(f"读取输入数据: {len(input_df)} 行, {len(input_df.columns)} 列")
|
||||
|
||||
# 4. 分离坐标和光谱数据
|
||||
self.logger.info("步骤4: 分离坐标和光谱数据")
|
||||
|
||||
# 检查坐标列是否存在
|
||||
missing_coords = [col for col in coordinate_columns if col not in input_df.columns]
|
||||
if missing_coords:
|
||||
# 尝试自动识别坐标列
|
||||
potential_coord_cols = []
|
||||
for col in input_df.columns:
|
||||
if any(coord_name in col.lower() for coord_name in ['lon', 'lat', 'x', 'y']):
|
||||
potential_coord_cols.append(col)
|
||||
|
||||
if len(potential_coord_cols) >= 2:
|
||||
coordinate_columns = potential_coord_cols[:2]
|
||||
self.logger.info(f"自动识别坐标列: {coordinate_columns}")
|
||||
else:
|
||||
raise ValueError(f"未找到坐标列: {missing_coords}")
|
||||
|
||||
coordinate_data = input_df[coordinate_columns].copy()
|
||||
spectral_data = input_df.drop(columns=coordinate_columns)
|
||||
|
||||
self.logger.info(f"坐标数据: {len(coordinate_data)} 行, {len(coordinate_data.columns)} 列")
|
||||
self.logger.info(f"光谱数据: {len(spectral_data)} 行, {len(spectral_data.columns)} 列")
|
||||
|
||||
# 5. 批量预测
|
||||
self.logger.info("步骤5: 执行批量预测")
|
||||
prediction_results = self.predict_all_parameters(spectral_data, coordinate_data)
|
||||
|
||||
# 6. 保存结果
|
||||
self.logger.info("步骤6: 保存预测结果")
|
||||
saved_files = self.save_prediction_results(prediction_results, filename_prefix)
|
||||
|
||||
self.logger.info("="*50)
|
||||
self.logger.info(f"批量预测完成!共生成 {len(saved_files)} 个结果文件")
|
||||
self.logger.info("="*50)
|
||||
|
||||
return saved_files
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,用于命令行调用"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='自定义回归预测模块')
|
||||
parser.add_argument('--input_csv', required=True, help='输入的光谱采样CSV文件路径')
|
||||
parser.add_argument('--models_dir', default='9_Custom_Regression_Modeling',
|
||||
help='回归模型CSV文件目录')
|
||||
parser.add_argument('--output_dir', default='prediction_results',
|
||||
help='预测结果输出目录')
|
||||
parser.add_argument('--coord_cols', nargs='+', default=['longitude', 'latitude'],
|
||||
help='坐标列名')
|
||||
parser.add_argument('--prefix', default='custom_regression_prediction',
|
||||
help='输出文件名前缀')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建预测器
|
||||
predictor = CustomRegressionPredictor(
|
||||
regression_models_dir=args.models_dir,
|
||||
output_dir=args.output_dir
|
||||
)
|
||||
|
||||
# 运行预测
|
||||
saved_files = predictor.run_batch_prediction(
|
||||
input_csv_path=args.input_csv,
|
||||
coordinate_columns=args.coord_cols,
|
||||
filename_prefix=args.prefix
|
||||
)
|
||||
|
||||
print("\n预测结果文件:")
|
||||
for param_name, filepath in saved_files.items():
|
||||
print(f" {param_name}: {filepath}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -609,6 +609,11 @@ class WaterQualityScatterBatch:
|
||||
split_method = best_row['划分方法']
|
||||
preprocess_method = best_row['预处理方法']
|
||||
model_name = best_row['建模方法']
|
||||
|
||||
# 处理 nan/NaN/None 值,转换为 "None" 字符串
|
||||
if pd.isna(preprocess_method) or str(preprocess_method).lower() in ['nan', 'none', '']:
|
||||
preprocess_method = "None"
|
||||
|
||||
best_combination = f"{split_method}_{preprocess_method}_{model_name}"
|
||||
else:
|
||||
# 简化结果文件格式(英文列名)
|
||||
@ -617,11 +622,15 @@ class WaterQualityScatterBatch:
|
||||
parts = best_combination.split('_')
|
||||
if len(parts) < 3:
|
||||
raise ValueError(f"无效的模型组合名称格式: {best_combination}")
|
||||
|
||||
|
||||
split_method = parts[0]
|
||||
preprocess_method = parts[1]
|
||||
model_name = '_'.join(parts[2:])
|
||||
|
||||
# 处理 nan/NaN/None 值,转换为 "None" 字符串
|
||||
if pd.isna(preprocess_method) or str(preprocess_method).lower() in ['nan', 'none', '']:
|
||||
preprocess_method = "None"
|
||||
|
||||
print(f"最佳模型组合: {best_combination}")
|
||||
print(f" 划分方法: {split_method}")
|
||||
print(f" 预处理方法: {preprocess_method}")
|
||||
|
||||
Reference in New Issue
Block a user