import pandas as pd import numpy as np import re class BandMathCalculator: def __init__(self, csv_file): """ 初始化计算器 csv_file: 包含光谱反射率的CSV文件路径 """ self.df = pd.read_csv(csv_file) self.wavelengths = self._extract_wavelengths() def _extract_wavelengths(self): """从列名中提取波长信息""" wavelengths = [] for col in self.df.columns: # 尝试从列名中提取数字(波长) numbers = re.findall(r'\d+\.?\d*', str(col)) if numbers: wavelengths.append(float(numbers[0])) else: wavelengths.append(None) return wavelengths def _find_closest_wavelength(self, target_wavelength): """找到最接近目标波长的列索引""" valid_indices = [i for i, wl in enumerate(self.wavelengths) if wl is not None] if not valid_indices: raise ValueError("未找到有效的波长列") # 计算与目标波长的差值 differences = [abs(self.wavelengths[i] - target_wavelength) for i in valid_indices] min_diff_index = np.argmin(differences) closest_index = valid_indices[min_diff_index] closest_wavelength = self.wavelengths[closest_index] print( f"目标波长 {target_wavelength}nm -> 最接近波长 {closest_wavelength}nm (列: {self.df.columns[closest_index]})") return closest_index def _parse_expression(self, expression): """解析表达式,提取所有波段变量 - 支持大小写""" # 匹配 w或W后面跟着数字的格式的变量 pattern = r'[wW](\d+\.?\d*)' matches = re.findall(pattern, expression) return matches # 返回字符串列表,如 ['686', '672', '715', '672'] def _create_substitution_dict(self, variables, row_index=0): """创建变量替换字典 - 支持大小写""" substitution_dict = {} for var in variables: wavelength = float(var) # 将字符串转换为浮点数 col_index = self._find_closest_wavelength(wavelength) value = self.df.iloc[row_index, col_index] # 同时添加小写和大写版本的变量 substitution_dict[f'w{var}'] = value substitution_dict[f'W{var}'] = value return substitution_dict def calculate(self, expression, row_index=0): """ 计算自定义波段表达式 参数: expression: 波段计算表达式,如 'chl=w560/w760' row_index: 要计算的数据行索引,默认为第0行 返回: 计算结果 """ try: # 提取表达式中的计算部分 if '=' in expression: # 如果包含赋值,只取等号右边的计算部分 calc_part = expression.split('=')[1].strip() var_name = expression.split('=')[0].strip() else: calc_part = expression.strip() var_name = None # 解析变量 variables = self._parse_expression(calc_part) print(f"解析到的波长变量: {variables}") # 创建替换字典 sub_dict = self._create_substitution_dict(variables, row_index) print(f"变量值: {sub_dict}") # 替换表达式中的变量 - 使用安全的字符串替换 calc_expression = calc_part for var_pattern, value in sub_dict.items(): # 确保替换完整的变量名,避免部分匹配 calc_expression = re.sub(r'\b' + re.escape(var_pattern) + r'\b', f"({value})", calc_expression) print(f"计算表达式: {calc_expression}") # 【新增安全防护】引入 numpy 命名空间,让 eval 引擎安全识别 nan 与 inf import numpy as np try: # 即使 calc_expression 含有纯字符 nan,也能被 np.nan 安全接管 result = eval(calc_expression, {"__builtins__": None}, {"nan": np.nan, "inf": np.inf, "np": np}) except Exception as e: print(f"⚠️ 警告:公式计算异常 ({e}),该点赋值为 nan") result = np.nan # 返回结果 if var_name: return {var_name: result} else: return result except Exception as e: print(f"计算错误: {e}") import traceback traceback.print_exc() return None def calculate_all_rows(self, expression): """为所有行计算表达式""" results = [] for i in range(len(self.df)): print(f"\n--- 计算第 {i} 行 ---") result = self.calculate(expression, i) if result is not None: if isinstance(result, dict): results.append(list(result.values())[0]) else: results.append(result) else: # 如果计算失败,添加NaN值以保持结果数量一致 results.append(np.nan) print(f"第 {i} 行计算失败,使用NaN填充") return results def _parse_coeff(self, coeff_str: str) -> np.ndarray: """ 将 Coefficient 字符串解析为 numpy 多项式系数数组。 - "1.0" → [1.0] - "a,b,c" → [a, b, c] (多项式,从高次到低次,供 np.polyval 直接使用) - "1.0,2.0" → [1.0, 2.0] (线性 y = a*x + b) """ s = str(coeff_str).strip() if not pd.isna(coeff_str) else "" if s in ("", "nan", "None"): return np.array([1.0]) parts = [float(x.strip()) for x in s.split(",")] return np.array(parts) def process_formulas_from_csv(self, formula_csv_file, formula_names=None, output_file=None): """ 从公式CSV文件中批量计算并添加到数据文件中。 支持两种 CSV 格式: - 新版(6列):Formula_Name, Category, Formula_Type, Formula, Coefficient, Reference - 旧版(≥3列):第一列=公式名,第三列=表达式(兼容处理) Formula_Type 为 'concentration' 时,计算比值后额外通过 np.polyval 应用 Coefficient。 参数: formula_csv_file: 公式CSV文件路径 formula_names: 要计算的公式名称列表,如果为None则计算所有公式 output_file: 输出文件路径,如果为None则自动生成 返回: 包含计算结果的新DataFrame """ try: formulas_df = pd.read_csv(formula_csv_file) print(f"读取到 {len(formulas_df)} 个公式") has_new_format = set(["Formula_Name", "Formula_Type", "Formula", "Coefficient"]).issubset( set(formulas_df.columns) ) if has_new_format: name_col = "Formula_Name" type_col = "Formula_Type" expr_col = "Formula" coeff_col = "Coefficient" else: name_col = formulas_df.columns[0] type_col = None expr_col = formulas_df.columns[2] coeff_col = None result_df = self.df.copy() if formula_names is not None: if isinstance(formula_names, str): formula_names = [formula_names] selected = formulas_df[formulas_df[name_col].isin(formula_names)] print(f"找到 {len(selected)} 个指定公式") if len(selected) == 0: print(f"警告: 未找到指定公式: {formula_names}") return result_df formulas_to_process = selected else: formulas_to_process = formulas_df for _, row in formulas_to_process.iterrows(): formula_name = row[name_col] formula_expr = row[expr_col] if pd.isna(formula_name) or pd.isna(formula_expr): print(f"跳过空公式: {row.to_dict()}") continue ftype = str(row[type_col]).strip().lower() if type_col and not pd.isna(row.get(type_col)) else "ratio" coeff_str = str(row[coeff_col]).strip() if coeff_col and not pd.isna(row.get(coeff_col)) else "1.0" print(f"\n计算公式: {formula_name} = {formula_expr} [type={ftype}, coeff={coeff_str}]") results = self.calculate_all_rows(formula_expr) if ftype == "concentration": coeff = self._parse_coeff(coeff_str) results = np.polyval(coeff, np.array(results)) result_df[formula_name] = results print(f"公式 '{formula_name}' 计算完成") if output_file is None: import os base_name = os.path.splitext(os.path.basename(formula_csv_file))[0] output_file = f"band_math_results_{base_name}.csv" result_df.to_csv(output_file, index=False) print(f"结果已保存到: {output_file}") return result_df except Exception as e: print(f"处理公式CSV文件时出错: {e}") import traceback traceback.print_exc() return None # 更新使用示例 if __name__ == "__main__": # 创建计算器实例 calculator = BandMathCalculator(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\training_spectra.csv") # 示例1: 计算所有公式 # result_df = calculator.process_formulas_from_csv(r"E:\code\WQ\封装\sub\水质参数.csv", "enhanced_data.csv") # 示例2: 计算指定公式 result_df = calculator.process_formulas_from_csv( r"E:\code\WQ\封装\sub\水质参数.csv", formula_names=["BGA_Am09KBBI", "BGA_Be162B643sub629"], output_file=r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\enhanced_data.csv" )