内容部分修改

This commit is contained in:
DXC
2026-05-11 17:38:29 +08:00
parent bf4237b160
commit 170d347e21
8 changed files with 284 additions and 47 deletions

View File

@ -40,16 +40,19 @@ class WaterQualityInference:
self.best_model_info = None
self.loaded_model_data = None
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
def load_sampling_data(self, csv_path: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""
加载sampling生成的CSV数据
加载sampling生成的CSV数据(兼容 WQI 增强版 CSV
Args:
csv_path: CSV文件路径,前两列为经纬度,其余列为光谱数据
csv_path: CSV文件路径
旧版x_coord,y_coord,pixel_x,pixel_y,波长...
新版x_coord,y_coord,WQI_...,波长...
Returns:
coords: 经纬度数据 (DataFrame)
spectra: 光谱数据 (DataFrame)
coords: 经纬度数据 (DataFrame, 2列)
spectra: 光谱数据 (DataFrame, 跳过 WQI 列)
wqi_df: WQI 指数列 (DataFrame, 0或45列)
"""
print(f"正在加载采样数据: {csv_path}")
@ -71,15 +74,35 @@ class WaterQualityInference:
coords = data.iloc[:, :2].copy()
coords.columns = ['longitude', 'latitude']
# 从第5列开始为光谱数据跳过第2、3、4列的其他信息
spectra = data.iloc[:, 4:].copy()
# 动态识别光谱列(兼容 sampling_spectra.csv 列顺序变更
# 列名约定:波长为纯数字字符串如 "374.285004"WQI 为 "WQI_xxx" 前缀
# 旧版 CSV无WQIx_coord,y_coord,pixel_x,pixel_y,波长... → 取 [4:]
# 新版 CSV有WQIx_coord,y_coord,WQI_...,波长... → 过滤 WQI 列后取光谱
all_cols = list(data.columns)
spectral_col_indices = []
wqi_col_indices = []
for i, col in enumerate(all_cols):
col_str = str(col)
if col_str.startswith('WQI_'):
wqi_col_indices.append(i)
elif col_str.replace('.', '').lstrip('-').isdigit():
# 波长列:纯数字字符串
spectral_col_indices.append(i)
else:
# 其他元数据列x_coord/y_coord/pixel_x/pixel_y由 coords 接收
pass
# 光谱列 = 纯数字列WQI 已被排除)
spectra = data.iloc[:, spectral_col_indices].copy() if spectral_col_indices else data.iloc[:, 4:].copy()
# WQI 列(用于追加到预测结果输出)
wqi_df = data.iloc[:, wqi_col_indices].copy() if wqi_col_indices else pd.DataFrame()
print(f" 经纬度数据形状: {coords.shape}")
print(f" 光谱数据形状: {spectra.shape}")
print(f" 光谱数据形状: {spectra.shape} (自动识别波长列,排除 {len(wqi_col_indices)} 个WQI列)")
print(f" 经纬度范围: 经度[{coords['longitude'].min():.6f}, {coords['longitude'].max():.6f}], "
f"纬度[{coords['latitude'].min():.6f}, {coords['latitude'].max():.6f}]")
return coords, spectra
return coords, spectra, wqi_df
def random(self, data, label, test_ratio=0.2, random_state=123):
"""
@ -519,6 +542,69 @@ class WaterQualityInference:
print(f"正在应用预处理方法: {actual_preprocess_method}")
print(f"原始光谱数据形状: {spectra.shape}")
# ---- 自动特征补全50 光谱 → 补全至模型训练时的 95 维WQI 指数) ----
# 触发条件:模型期望 n_features_in_ 个特征,但当前 spectra 列数不足
# 原因training_spectra.csv 含 50 光谱 + 45 WQIsampling_spectra.csv 只有 50 光谱
# 做法与训练端calculate_all_indices完全一致的算法列表实时补全缺失的 45 个 WQI 列
model = self.loaded_model_data['model']
expected_features = getattr(model, 'n_features_in_', None)
# ---- 自动特征补全50 光谱 → 补全至模型训练时的 n_features_in_ 维WQI 指数) ----
if expected_features is not None and spectra.shape[1] < expected_features:
print(f"[特征补全] 检测到特征缺口:当前 {spectra.shape[1]} 列 < 模型期望 {expected_features} 列,"
f"正在从光谱数据实时计算 WQI 指数...")
try:
from src.utils.water_index import WaterQualityIndexCalculator
calc = WaterQualityIndexCalculator()
# 提取纯计算方法(排除 find_closest_wavelength 和 calculate_all_indices
# 以及不返回 Series 的辅助方法)
algorithm_methods = []
for m in dir(calc):
if m.startswith('_'):
continue
if m in ['find_closest_wavelength', 'calculate_all_indices']:
continue
attr = getattr(calc, m)
if callable(attr):
algorithm_methods.append(m)
original_col_count = spectra.shape[1]
for algo_name in algorithm_methods:
try:
algo_func = getattr(calc, algo_name)
result = algo_func(spectra)
# 只追加返回 Series 且长度为样本数的合法结果
if isinstance(result, pd.Series) and len(result) == len(spectra):
spectra[algo_name] = result.values
else:
spectra[algo_name] = np.nan
except Exception:
spectra[algo_name] = np.nan
print(f"[特征补全] 完成!光谱列已扩充至 {spectra.shape[1]}"
f"(追加了 {spectra.shape[1] - original_col_count} 个 WQI 指数)")
except Exception as e:
print(f"[特征补全] 失败,将使用原始光谱特征: {e}")
# ---- 防线 1强制维度对齐物理截断----
if expected_features is not None and spectra.shape[1] > expected_features:
print(f"[精准对齐] 正在将 {spectra.shape[1]} 维特征截断为模型要求的 {expected_features}")
spectra = spectra.iloc[:, :expected_features]
elif expected_features is not None and spectra.shape[1] < expected_features:
# 维度不足时填充 0
padding_cols = expected_features - spectra.shape[1]
for i in range(padding_cols):
spectra[f'_padding_{i}'] = 0.0
print(f"[精准对齐] 特征不足,填充 {padding_cols} 列 0")
# ---- 防线 2彻底清洗无穷大数值----
# 防止 WQI 计算中除零/溢出产生 np.inf / -np.inf 导致预处理崩溃
spectra = spectra.replace([np.inf, -np.inf], np.nan)
spectra = spectra.fillna(0)
print(f"[特征对齐] 最终输入维度: {spectra.shape}")
try:
# 应用预处理
spectra_processed = Preprocessing(actual_preprocess_method, spectra)
@ -573,7 +659,8 @@ class WaterQualityInference:
raise
def save_predictions(self, coords: pd.DataFrame, predictions: np.ndarray,
output_path: str, prediction_column: str = 'prediction'):
output_path: str, prediction_column: str = 'prediction',
wqi_columns: Optional[pd.DataFrame] = None):
"""
保存预测结果
@ -582,11 +669,15 @@ class WaterQualityInference:
predictions: 预测结果
output_path: 输出文件路径
prediction_column: 预测列名称
wqi_columns: Optional[pd.DataFrame] = None
"""
print(f"正在保存预测结果到: {output_path}")
# 创建结果DataFrame
result_df = coords.copy()
# 追加 WQI 水质指数列(如 sampling_spectra.csv 注入了 45 列指数)
if wqi_columns is not None and not wqi_columns.empty:
result_df = pd.concat([result_df, wqi_columns.reset_index(drop=True)], axis=1)
result_df[prediction_column] = predictions
# 确保输出目录存在
@ -659,10 +750,10 @@ class WaterQualityInference:
else:
self.load_best_model(metric=metric)
# 2. 加载采样数据
# 2. 加载采样数据coords=坐标, spectra=纯光谱, wqi_df=45个WQI指数列
print("\n步骤2: 加载采样数据")
print("-" * 40)
coords, spectra = self.load_sampling_data(sampling_csv_path)
coords, spectra, wqi_df = self.load_sampling_data(sampling_csv_path)
# 3. 数据预处理
print("\n步骤3: 数据预处理")
@ -674,10 +765,11 @@ class WaterQualityInference:
print("-" * 40)
predictions = self.predict(spectra_processed)
# 5. 保存预测结果
# 5. 保存预测结果(透传 WQI 列至最终输出文件)
print("\n步骤5: 保存预测结果")
print("-" * 40)
result_df = self.save_predictions(coords, predictions, output_csv_path, prediction_column)
result_df = self.save_predictions(coords, predictions, output_csv_path,
prediction_column, wqi_df)
print("\n" + "=" * 80)
print("推理流程完成!")
@ -747,10 +839,11 @@ class WaterQualityInference:
output_file = output_path / f"prediction_{csv_file.name}"
# 执行推理
coords, spectra = self.load_sampling_data(str(csv_file))
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
spectra_processed = self.preprocess_spectra(spectra)
predictions = self.predict(spectra_processed)
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
result_df = self.save_predictions(coords, predictions, str(output_file),
prediction_column, wqi_df)
results[csv_file.name] = {
'output_file': str(output_file),
@ -908,10 +1001,11 @@ class WaterQualityInference:
output_file = output_path / f"{file_stem}{file_ext}"
# 执行推理
coords, spectra = self.load_sampling_data(str(csv_file))
coords, spectra, wqi_df = self.load_sampling_data(str(csv_file))
spectra_processed = self.preprocess_spectra(spectra)
predictions = self.predict(spectra_processed)
result_df = self.save_predictions(coords, predictions, str(output_file), prediction_column)
result_df = self.save_predictions(coords, predictions, str(output_file),
prediction_column, wqi_df)
results[file_stem] = {
'input_file': str(csv_file),