Files
WQ_GUI/src/utils/band_math.py

254 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
import re
class BandMathCalculator:
def __init__(self, csv_file):
"""
初始化计算器
csv_file: 包含光谱反射率的CSV文件路径
"""
self.df = pd.read_csv(csv_file)
self.wavelengths = self._extract_wavelengths()
def _extract_wavelengths(self):
"""从列名中提取波长信息"""
wavelengths = []
for col in self.df.columns:
# 尝试从列名中提取数字(波长)
numbers = re.findall(r'\d+\.?\d*', str(col))
if numbers:
wavelengths.append(float(numbers[0]))
else:
wavelengths.append(None)
return wavelengths
def _find_closest_wavelength(self, target_wavelength):
"""找到最接近目标波长的列索引"""
valid_indices = [i for i, wl in enumerate(self.wavelengths) if wl is not None]
if not valid_indices:
raise ValueError("未找到有效的波长列")
# 计算与目标波长的差值
differences = [abs(self.wavelengths[i] - target_wavelength) for i in valid_indices]
min_diff_index = np.argmin(differences)
closest_index = valid_indices[min_diff_index]
closest_wavelength = self.wavelengths[closest_index]
print(
f"目标波长 {target_wavelength}nm -> 最接近波长 {closest_wavelength}nm (列: {self.df.columns[closest_index]})")
return closest_index
def _parse_expression(self, expression):
"""解析表达式,提取所有波段变量 - 支持大小写"""
# 匹配 w或W后面跟着数字的格式的变量
pattern = r'[wW](\d+\.?\d*)'
matches = re.findall(pattern, expression)
return matches # 返回字符串列表,如 ['686', '672', '715', '672']
def _create_substitution_dict(self, variables, row_index=0):
"""创建变量替换字典 - 支持大小写"""
substitution_dict = {}
for var in variables:
wavelength = float(var) # 将字符串转换为浮点数
col_index = self._find_closest_wavelength(wavelength)
value = self.df.iloc[row_index, col_index]
# 同时添加小写和大写版本的变量
substitution_dict[f'w{var}'] = value
substitution_dict[f'W{var}'] = value
return substitution_dict
def calculate(self, expression, row_index=0):
"""
计算自定义波段表达式
参数:
expression: 波段计算表达式,如 'chl=w560/w760'
row_index: 要计算的数据行索引默认为第0行
返回:
计算结果
"""
try:
# 提取表达式中的计算部分
if '=' in expression:
# 如果包含赋值,只取等号右边的计算部分
calc_part = expression.split('=')[1].strip()
var_name = expression.split('=')[0].strip()
else:
calc_part = expression.strip()
var_name = None
# 解析变量
variables = self._parse_expression(calc_part)
print(f"解析到的波长变量: {variables}")
# 创建替换字典
sub_dict = self._create_substitution_dict(variables, row_index)
print(f"变量值: {sub_dict}")
# 替换表达式中的变量 - 使用安全的字符串替换
calc_expression = calc_part
for var_pattern, value in sub_dict.items():
# 确保替换完整的变量名,避免部分匹配
calc_expression = re.sub(r'\b' + re.escape(var_pattern) + r'\b', f"({value})", calc_expression)
print(f"计算表达式: {calc_expression}")
# 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf
import numpy as np
try:
# 即使 calc_expression 含有纯字符 nan也能被 np.nan 安全接管
result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np})
except Exception as e:
print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan")
result = np.nan
# 返回结果
if var_name:
return {var_name: result}
else:
return result
except Exception as e:
print(f"计算错误: {e}")
import traceback
traceback.print_exc()
return None
def calculate_all_rows(self, expression):
"""为所有行计算表达式"""
results = []
for i in range(len(self.df)):
print(f"\n--- 计算第 {i} 行 ---")
result = self.calculate(expression, i)
if result is not None:
if isinstance(result, dict):
results.append(list(result.values())[0])
else:
results.append(result)
else:
# 如果计算失败添加NaN值以保持结果数量一致
results.append(np.nan)
print(f"{i} 行计算失败使用NaN填充")
return results
def _parse_coeff(self, coeff_str: str) -> np.ndarray:
"""
将 Coefficient 字符串解析为 numpy 多项式系数数组。
- "1.0" → [1.0]
- "a,b,c" → [a, b, c] (多项式,从高次到低次,供 np.polyval 直接使用)
- "1.0,2.0" → [1.0, 2.0] (线性 y = a*x + b
"""
s = str(coeff_str).strip() if not pd.isna(coeff_str) else ""
if s in ("", "nan", "None"):
return np.array([1.0])
parts = [float(x.strip()) for x in s.split(",")]
return np.array(parts)
def process_formulas_from_csv(self, formula_csv_file, formula_names=None, output_file=None):
"""
从公式CSV文件中批量计算并添加到数据文件中。
支持两种 CSV 格式:
- 新版6列Formula_Name, Category, Formula_Type, Formula, Coefficient, Reference
- 旧版≥3列第一列=公式名,第三列=表达式(兼容处理)
Formula_Type 为 'concentration' 时,计算比值后额外通过 np.polyval 应用 Coefficient。
参数:
formula_csv_file: 公式CSV文件路径
formula_names: 要计算的公式名称列表如果为None则计算所有公式
output_file: 输出文件路径如果为None则自动生成
返回:
包含计算结果的新DataFrame
"""
try:
formulas_df = pd.read_csv(formula_csv_file)
print(f"读取到 {len(formulas_df)} 个公式")
has_new_format = set(["Formula_Name", "Formula_Type", "Formula", "Coefficient"]).issubset(
set(formulas_df.columns)
)
if has_new_format:
name_col = "Formula_Name"
type_col = "Formula_Type"
expr_col = "Formula"
coeff_col = "Coefficient"
else:
name_col = formulas_df.columns[0]
type_col = None
expr_col = formulas_df.columns[2]
coeff_col = None
result_df = self.df.copy()
if formula_names is not None:
if isinstance(formula_names, str):
formula_names = [formula_names]
selected = formulas_df[formulas_df[name_col].isin(formula_names)]
print(f"找到 {len(selected)} 个指定公式")
if len(selected) == 0:
print(f"警告: 未找到指定公式: {formula_names}")
return result_df
formulas_to_process = selected
else:
formulas_to_process = formulas_df
for _, row in formulas_to_process.iterrows():
formula_name = row[name_col]
formula_expr = row[expr_col]
if pd.isna(formula_name) or pd.isna(formula_expr):
print(f"跳过空公式: {row.to_dict()}")
continue
ftype = str(row[type_col]).strip().lower() if type_col and not pd.isna(row.get(type_col)) else "ratio"
coeff_str = str(row[coeff_col]).strip() if coeff_col and not pd.isna(row.get(coeff_col)) else "1.0"
print(f"\n计算公式: {formula_name} = {formula_expr} [type={ftype}, coeff={coeff_str}]")
results = self.calculate_all_rows(formula_expr)
if ftype == "concentration":
coeff = self._parse_coeff(coeff_str)
results = np.polyval(coeff, np.array(results))
result_df[formula_name] = results
print(f"公式 '{formula_name}' 计算完成")
if output_file is None:
import os
base_name = os.path.splitext(os.path.basename(formula_csv_file))[0]
output_file = f"band_math_results_{base_name}.csv"
result_df.to_csv(output_file, index=False)
print(f"结果已保存到: {output_file}")
return result_df
except Exception as e:
print(f"处理公式CSV文件时出错: {e}")
import traceback
traceback.print_exc()
return None
# 更新使用示例
if __name__ == "__main__":
# 创建计算器实例
calculator = BandMathCalculator(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\training_spectra.csv")
# 示例1: 计算所有公式
# result_df = calculator.process_formulas_from_csv(r"E:\code\WQ\封装\sub\水质参数.csv", "enhanced_data.csv")
# 示例2: 计算指定公式
result_df = calculator.process_formulas_from_csv(
r"E:\code\WQ\封装\sub\水质参数.csv",
formula_names=["BGA_Am09KBBI", "BGA_Be162B643sub629"],
output_file=r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\enhanced_data.csv"
)