内容部分修改
This commit is contained in:
@ -40,16 +40,19 @@ class WaterQualityInference:
|
||||
self.best_model_info = None
|
||||
self.loaded_model_data = None
|
||||
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
加载sampling生成的CSV数据
|
||||
加载sampling生成的CSV数据(兼容 WQI 增强版 CSV)
|
||||
|
||||
Args:
|
||||
csv_path: CSV文件路径,前两列为经纬度,其余列为光谱数据
|
||||
csv_path: CSV文件路径
|
||||
旧版:x_coord,y_coord,pixel_x,pixel_y,波长...
|
||||
新版:x_coord,y_coord,WQI_...,波长...
|
||||
|
||||
Returns:
|
||||
coords: 经纬度数据 (DataFrame)
|
||||
spectra: 光谱数据 (DataFrame)
|
||||
coords: 经纬度数据 (DataFrame, 2列)
|
||||
spectra: 纯光谱数据 (DataFrame, 跳过 WQI 列)
|
||||
wqi_df: WQI 指数列 (DataFrame, 0或45列)
|
||||
"""
|
||||
print(f"正在加载采样数据: {csv_path}")
|
||||
|
||||
@ -71,15 +74,35 @@ class WaterQualityInference:
|
||||
coords = data.iloc[:, :2].copy()
|
||||
coords.columns = ['longitude', 'latitude']
|
||||
|
||||
# 从第5列开始为光谱数据(跳过第2、3、4列的其他信息)
|
||||
spectra = data.iloc[:, 4:].copy()
|
||||
# 动态识别光谱列(兼容 sampling_spectra.csv 列顺序变更)
|
||||
# 列名约定:波长为纯数字字符串如 "374.285004";WQI 为 "WQI_xxx" 前缀
|
||||
# 旧版 CSV(无WQI):x_coord,y_coord,pixel_x,pixel_y,波长... → 取 [4:]
|
||||
# 新版 CSV(有WQI):x_coord,y_coord,WQI_...,波长... → 过滤 WQI 列后取光谱
|
||||
all_cols = list(data.columns)
|
||||
spectral_col_indices = []
|
||||
wqi_col_indices = []
|
||||
for i, col in enumerate(all_cols):
|
||||
col_str = str(col)
|
||||
if col_str.startswith('WQI_'):
|
||||
wqi_col_indices.append(i)
|
||||
elif col_str.replace('.', '').lstrip('-').isdigit():
|
||||
# 波长列:纯数字字符串
|
||||
spectral_col_indices.append(i)
|
||||
else:
|
||||
# 其他元数据列(x_coord/y_coord/pixel_x/pixel_y),由 coords 接收
|
||||
pass
|
||||
|
||||
# 光谱列 = 纯数字列(WQI 已被排除)
|
||||
spectra = data.iloc[:, spectral_col_indices].copy() if spectral_col_indices else data.iloc[:, 4:].copy()
|
||||
# WQI 列(用于追加到预测结果输出)
|
||||
wqi_df = data.iloc[:, wqi_col_indices].copy() if wqi_col_indices else pd.DataFrame()
|
||||
|
||||
print(f" 经纬度数据形状: {coords.shape}")
|
||||
print(f" 光谱数据形状: {spectra.shape}")
|
||||
print(f" 光谱数据形状: {spectra.shape} (自动识别波长列,排除 {len(wqi_col_indices)} 个WQI列)")
|
||||
print(f" 经纬度范围: 经度[{coords['longitude'].min():.6f}, {coords['longitude'].max():.6f}], "
|
||||
f"纬度[{coords['latitude'].min():.6f}, {coords['latitude'].max():.6f}]")
|
||||
|
||||
return coords, spectra
|
||||
return coords, spectra, wqi_df
|
||||
|
||||
def random(self, data, label, test_ratio=0.2, random_state=123):
|
||||
"""
|
||||
@ -519,6 +542,69 @@ class WaterQualityInference:
|
||||
print(f"正在应用预处理方法: {actual_preprocess_method}")
|
||||
print(f"原始光谱数据形状: {spectra.shape}")
|
||||
|
||||
# ---- 自动特征补全:50 光谱 → 补全至模型训练时的 95 维(WQI 指数) ----
|
||||
# 触发条件:模型期望 n_features_in_ 个特征,但当前 spectra 列数不足
|
||||
# 原因:training_spectra.csv 含 50 光谱 + 45 WQI;sampling_spectra.csv 只有 50 光谱
|
||||
# 做法:与训练端(calculate_all_indices)完全一致的算法列表,实时补全缺失的 45 个 WQI 列
|
||||
model = self.loaded_model_data['model']
|
||||
expected_features = getattr(model, 'n_features_in_', None)
|
||||
|
||||
# ---- 自动特征补全:50 光谱 → 补全至模型训练时的 n_features_in_ 维(WQI 指数) ----
|
||||
if expected_features is not None and spectra.shape[1] < expected_features:
|
||||
print(f"[特征补全] 检测到特征缺口:当前 {spectra.shape[1]} 列 < 模型期望 {expected_features} 列,"
|
||||
f"正在从光谱数据实时计算 WQI 指数...")
|
||||
try:
|
||||
from src.utils.water_index import WaterQualityIndexCalculator
|
||||
calc = WaterQualityIndexCalculator()
|
||||
|
||||
# 提取纯计算方法(排除 find_closest_wavelength 和 calculate_all_indices,
|
||||
# 以及不返回 Series 的辅助方法)
|
||||
algorithm_methods = []
|
||||
for m in dir(calc):
|
||||
if m.startswith('_'):
|
||||
continue
|
||||
if m in ['find_closest_wavelength', 'calculate_all_indices']:
|
||||
continue
|
||||
attr = getattr(calc, m)
|
||||
if callable(attr):
|
||||
algorithm_methods.append(m)
|
||||
|
||||
original_col_count = spectra.shape[1]
|
||||
for algo_name in algorithm_methods:
|
||||
try:
|
||||
algo_func = getattr(calc, algo_name)
|
||||
result = algo_func(spectra)
|
||||
# 只追加返回 Series 且长度为样本数的合法结果
|
||||
if isinstance(result, pd.Series) and len(result) == len(spectra):
|
||||
spectra[algo_name] = result.values
|
||||
else:
|
||||
spectra[algo_name] = np.nan
|
||||
except Exception:
|
||||
spectra[algo_name] = np.nan
|
||||
|
||||
print(f"[特征补全] 完成!光谱列已扩充至 {spectra.shape[1]} 列"
|
||||
f"(追加了 {spectra.shape[1] - original_col_count} 个 WQI 指数)")
|
||||
except Exception as e:
|
||||
print(f"[特征补全] 失败,将使用原始光谱特征: {e}")
|
||||
|
||||
# ---- 防线 1:强制维度对齐(物理截断)----
|
||||
if expected_features is not None and spectra.shape[1] > expected_features:
|
||||
print(f"[精准对齐] 正在将 {spectra.shape[1]} 维特征截断为模型要求的 {expected_features} 维")
|
||||
spectra = spectra.iloc[:, :expected_features]
|
||||
elif expected_features is not None and spectra.shape[1] < expected_features:
|
||||
# 维度不足时填充 0
|
||||
padding_cols = expected_features - spectra.shape[1]
|
||||
for i in range(padding_cols):
|
||||
spectra[f'_padding_{i}'] = 0.0
|
||||
print(f"[精准对齐] 特征不足,填充 {padding_cols} 列 0")
|
||||
|
||||
# ---- 防线 2:彻底清洗无穷大数值----
|
||||
# 防止 WQI 计算中除零/溢出产生 np.inf / -np.inf 导致预处理崩溃
|
||||
spectra = spectra.replace([np.inf, -np.inf], np.nan)
|
||||
spectra = spectra.fillna(0)
|
||||
|
||||
print(f"[特征对齐] 最终输入维度: {spectra.shape}")
|
||||
|
||||
try:
|
||||
# 应用预处理
|
||||
spectra_processed = Preprocessing(actual_preprocess_method, spectra)
|
||||
@ -573,7 +659,8 @@ class WaterQualityInference:
|
||||
raise
|
||||
|
||||
def save_predictions(self, coords: pd.DataFrame, predictions: np.ndarray,
|
||||
output_path: str, prediction_column: str = 'prediction'):
|
||||
output_path: str, prediction_column: str = 'prediction',
|
||||
wqi_columns: Optional[pd.DataFrame] = None):
|
||||
"""
|
||||
保存预测结果
|
||||
|
||||
@ -582,11 +669,15 @@ class WaterQualityInference:
|
||||
predictions: 预测结果
|
||||
output_path: 输出文件路径
|
||||
prediction_column: 预测列名称
|
||||
wqi_columns: Optional[pd.DataFrame] = None
|
||||
"""
|
||||
print(f"正在保存预测结果到: {output_path}")
|
||||
|
||||
# 创建结果DataFrame
|
||||
result_df = coords.copy()
|
||||
# 追加 WQI 水质指数列(如 sampling_spectra.csv 注入了 45 列指数)
|
||||
if wqi_columns is not None and not wqi_columns.empty:
|
||||
result_df = pd.concat([result_df, wqi_columns.reset_index(drop=True)], axis=1)
|
||||
result_df[prediction_column] = predictions
|
||||
|
||||
# 确保输出目录存在
|
||||
@ -659,10 +750,10 @@ class WaterQualityInference:
|
||||
else:
|
||||
self.load_best_model(metric=metric)
|
||||
|
||||
# 2. 加载采样数据
|
||||
# 2. 加载采样数据(coords=坐标, spectra=纯光谱, wqi_df=45个WQI指数列)
|
||||
print("\n步骤2: 加载采样数据")
|
||||
print("-" * 40)
|
||||
coords, spectra = self.load_sampling_data(sampling_csv_path)
|
||||
coords, spectra, wqi_df = self.load_sampling_data(sampling_csv_path)
|
||||
|
||||
# 3. 数据预处理
|
||||
print("\n步骤3: 数据预处理")
|
||||
@ -674,10 +765,11 @@ class WaterQualityInference:
|
||||
print("-" * 40)
|
||||
predictions = self.predict(spectra_processed)
|
||||
|
||||
# 5. 保存预测结果
|
||||
# 5. 保存预测结果(透传 WQI 列至最终输出文件)
|
||||
print("\n步骤5: 保存预测结果")
|
||||
print("-" * 40)
|
||||
result_df = self.save_predictions(coords, predictions, output_csv_path, prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, output_csv_path,
|
||||
prediction_column, wqi_df)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("推理流程完成!")
|
||||
@ -747,10 +839,11 @@ class WaterQualityInference:
|
||||
output_file = output_path / f"prediction_{csv_file.name}"
|
||||
|
||||
# 执行推理
|
||||
coords, spectra = self.load_sampling_data(str(csv_file))
|
||||
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
|
||||
spectra_processed = self.preprocess_spectra(spectra)
|
||||
predictions = self.predict(spectra_processed)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file),
|
||||
prediction_column, wqi_df)
|
||||
|
||||
results[csv_file.name] = {
|
||||
'output_file': str(output_file),
|
||||
@ -908,10 +1001,11 @@ class WaterQualityInference:
|
||||
output_file = output_path / f"{file_stem}{file_ext}"
|
||||
|
||||
# 执行推理
|
||||
coords, spectra = self.load_sampling_data(str(csv_file))
|
||||
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
|
||||
spectra_processed = self.preprocess_spectra(spectra)
|
||||
predictions = self.predict(spectra_processed)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
|
||||
result_df = self.save_predictions(coords, predictions, str(output_file),
|
||||
prediction_column, wqi_df)
|
||||
|
||||
results[file_stem] = {
|
||||
'input_file': str(csv_file),
|
||||
|
||||
Reference in New Issue
Block a user