254 lines
9.9 KiB
Python
254 lines
9.9 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
import re
|
||
|
||
|
||
class BandMathCalculator:
|
||
def __init__(self, csv_file):
|
||
"""
|
||
初始化计算器
|
||
csv_file: 包含光谱反射率的CSV文件路径
|
||
"""
|
||
self.df = pd.read_csv(csv_file)
|
||
self.wavelengths = self._extract_wavelengths()
|
||
|
||
def _extract_wavelengths(self):
|
||
"""从列名中提取波长信息"""
|
||
wavelengths = []
|
||
for col in self.df.columns:
|
||
# 尝试从列名中提取数字(波长)
|
||
numbers = re.findall(r'\d+\.?\d*', str(col))
|
||
if numbers:
|
||
wavelengths.append(float(numbers[0]))
|
||
else:
|
||
wavelengths.append(None)
|
||
return wavelengths
|
||
|
||
def _find_closest_wavelength(self, target_wavelength):
|
||
"""找到最接近目标波长的列索引"""
|
||
valid_indices = [i for i, wl in enumerate(self.wavelengths) if wl is not None]
|
||
if not valid_indices:
|
||
raise ValueError("未找到有效的波长列")
|
||
|
||
# 计算与目标波长的差值
|
||
differences = [abs(self.wavelengths[i] - target_wavelength) for i in valid_indices]
|
||
min_diff_index = np.argmin(differences)
|
||
closest_index = valid_indices[min_diff_index]
|
||
closest_wavelength = self.wavelengths[closest_index]
|
||
|
||
print(
|
||
f"目标波长 {target_wavelength}nm -> 最接近波长 {closest_wavelength}nm (列: {self.df.columns[closest_index]})")
|
||
return closest_index
|
||
|
||
def _parse_expression(self, expression):
|
||
"""解析表达式,提取所有波段变量 - 支持大小写"""
|
||
# 匹配 w或W后面跟着数字的格式的变量
|
||
pattern = r'[wW](\d+\.?\d*)'
|
||
matches = re.findall(pattern, expression)
|
||
return matches # 返回字符串列表,如 ['686', '672', '715', '672']
|
||
|
||
def _create_substitution_dict(self, variables, row_index=0):
|
||
"""创建变量替换字典 - 支持大小写"""
|
||
substitution_dict = {}
|
||
for var in variables:
|
||
wavelength = float(var) # 将字符串转换为浮点数
|
||
col_index = self._find_closest_wavelength(wavelength)
|
||
value = self.df.iloc[row_index, col_index]
|
||
# 同时添加小写和大写版本的变量
|
||
substitution_dict[f'w{var}'] = value
|
||
substitution_dict[f'W{var}'] = value
|
||
return substitution_dict
|
||
|
||
def calculate(self, expression, row_index=0):
|
||
"""
|
||
计算自定义波段表达式
|
||
|
||
参数:
|
||
expression: 波段计算表达式,如 'chl=w560/w760'
|
||
row_index: 要计算的数据行索引,默认为第0行
|
||
|
||
返回:
|
||
计算结果
|
||
"""
|
||
try:
|
||
# 提取表达式中的计算部分
|
||
if '=' in expression:
|
||
# 如果包含赋值,只取等号右边的计算部分
|
||
calc_part = expression.split('=')[1].strip()
|
||
var_name = expression.split('=')[0].strip()
|
||
else:
|
||
calc_part = expression.strip()
|
||
var_name = None
|
||
|
||
# 解析变量
|
||
variables = self._parse_expression(calc_part)
|
||
print(f"解析到的波长变量: {variables}")
|
||
|
||
# 创建替换字典
|
||
sub_dict = self._create_substitution_dict(variables, row_index)
|
||
print(f"变量值: {sub_dict}")
|
||
|
||
# 替换表达式中的变量 - 使用安全的字符串替换
|
||
calc_expression = calc_part
|
||
for var_pattern, value in sub_dict.items():
|
||
# 确保替换完整的变量名,避免部分匹配
|
||
calc_expression = re.sub(r'\b' + re.escape(var_pattern) + r'\b', f"({value})", calc_expression)
|
||
|
||
print(f"计算表达式: {calc_expression}")
|
||
|
||
# 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf
|
||
import numpy as np
|
||
try:
|
||
# 即使 calc_expression 含有纯字符 nan,也能被 np.nan 安全接管
|
||
result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np})
|
||
except Exception as e:
|
||
print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan")
|
||
result = np.nan
|
||
|
||
# 返回结果
|
||
if var_name:
|
||
return {var_name: result}
|
||
else:
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"计算错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
def calculate_all_rows(self, expression):
|
||
"""为所有行计算表达式"""
|
||
results = []
|
||
for i in range(len(self.df)):
|
||
print(f"\n--- 计算第 {i} 行 ---")
|
||
result = self.calculate(expression, i)
|
||
if result is not None:
|
||
if isinstance(result, dict):
|
||
results.append(list(result.values())[0])
|
||
else:
|
||
results.append(result)
|
||
else:
|
||
# 如果计算失败,添加NaN值以保持结果数量一致
|
||
results.append(np.nan)
|
||
print(f"第 {i} 行计算失败,使用NaN填充")
|
||
return results
|
||
|
||
def _parse_coeff(self, coeff_str: str) -> np.ndarray:
|
||
"""
|
||
将 Coefficient 字符串解析为 numpy 多项式系数数组。
|
||
|
||
- "1.0" → [1.0]
|
||
- "a,b,c" → [a, b, c] (多项式,从高次到低次,供 np.polyval 直接使用)
|
||
- "1.0,2.0" → [1.0, 2.0] (线性 y = a*x + b)
|
||
"""
|
||
s = str(coeff_str).strip() if not pd.isna(coeff_str) else ""
|
||
if s in ("", "nan", "None"):
|
||
return np.array([1.0])
|
||
parts = [float(x.strip()) for x in s.split(",")]
|
||
return np.array(parts)
|
||
|
||
def process_formulas_from_csv(self, formula_csv_file, formula_names=None, output_file=None):
|
||
"""
|
||
从公式CSV文件中批量计算并添加到数据文件中。
|
||
|
||
支持两种 CSV 格式:
|
||
- 新版(6列):Formula_Name, Category, Formula_Type, Formula, Coefficient, Reference
|
||
- 旧版(≥3列):第一列=公式名,第三列=表达式(兼容处理)
|
||
|
||
Formula_Type 为 'concentration' 时,计算比值后额外通过 np.polyval 应用 Coefficient。
|
||
|
||
参数:
|
||
formula_csv_file: 公式CSV文件路径
|
||
formula_names: 要计算的公式名称列表,如果为None则计算所有公式
|
||
output_file: 输出文件路径,如果为None则自动生成
|
||
|
||
返回:
|
||
包含计算结果的新DataFrame
|
||
"""
|
||
try:
|
||
formulas_df = pd.read_csv(formula_csv_file)
|
||
print(f"读取到 {len(formulas_df)} 个公式")
|
||
|
||
has_new_format = set(["Formula_Name", "Formula_Type", "Formula", "Coefficient"]).issubset(
|
||
set(formulas_df.columns)
|
||
)
|
||
|
||
if has_new_format:
|
||
name_col = "Formula_Name"
|
||
type_col = "Formula_Type"
|
||
expr_col = "Formula"
|
||
coeff_col = "Coefficient"
|
||
else:
|
||
name_col = formulas_df.columns[0]
|
||
type_col = None
|
||
expr_col = formulas_df.columns[2]
|
||
coeff_col = None
|
||
|
||
result_df = self.df.copy()
|
||
|
||
if formula_names is not None:
|
||
if isinstance(formula_names, str):
|
||
formula_names = [formula_names]
|
||
selected = formulas_df[formulas_df[name_col].isin(formula_names)]
|
||
print(f"找到 {len(selected)} 个指定公式")
|
||
if len(selected) == 0:
|
||
print(f"警告: 未找到指定公式: {formula_names}")
|
||
return result_df
|
||
formulas_to_process = selected
|
||
else:
|
||
formulas_to_process = formulas_df
|
||
|
||
for _, row in formulas_to_process.iterrows():
|
||
formula_name = row[name_col]
|
||
formula_expr = row[expr_col]
|
||
|
||
if pd.isna(formula_name) or pd.isna(formula_expr):
|
||
print(f"跳过空公式: {row.to_dict()}")
|
||
continue
|
||
|
||
ftype = str(row[type_col]).strip().lower() if type_col and not pd.isna(row.get(type_col)) else "ratio"
|
||
coeff_str = str(row[coeff_col]).strip() if coeff_col and not pd.isna(row.get(coeff_col)) else "1.0"
|
||
|
||
print(f"\n计算公式: {formula_name} = {formula_expr} [type={ftype}, coeff={coeff_str}]")
|
||
|
||
results = self.calculate_all_rows(formula_expr)
|
||
|
||
if ftype == "concentration":
|
||
coeff = self._parse_coeff(coeff_str)
|
||
results = np.polyval(coeff, np.array(results))
|
||
|
||
result_df[formula_name] = results
|
||
print(f"公式 '{formula_name}' 计算完成")
|
||
|
||
if output_file is None:
|
||
import os
|
||
base_name = os.path.splitext(os.path.basename(formula_csv_file))[0]
|
||
output_file = f"band_math_results_{base_name}.csv"
|
||
|
||
result_df.to_csv(output_file, index=False)
|
||
print(f"结果已保存到: {output_file}")
|
||
|
||
return result_df
|
||
|
||
except Exception as e:
|
||
print(f"处理公式CSV文件时出错: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
# 更新使用示例
|
||
if __name__ == "__main__":
|
||
# 创建计算器实例
|
||
calculator = BandMathCalculator(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\training_spectra.csv")
|
||
|
||
# 示例1: 计算所有公式
|
||
# result_df = calculator.process_formulas_from_csv(r"E:\code\WQ\封装\sub\水质参数.csv", "enhanced_data.csv")
|
||
|
||
# 示例2: 计算指定公式
|
||
result_df = calculator.process_formulas_from_csv(
|
||
r"E:\code\WQ\封装\sub\水质参数.csv",
|
||
formula_names=["BGA_Am09KBBI", "BGA_Be162B643sub629"],
|
||
output_file=r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\enhanced_data.csv"
|
||
) |