Initial commit of WQ_GUI
This commit is contained in:
226
src/utils/band_math.py
Normal file
226
src/utils/band_math.py
Normal file
@ -0,0 +1,226 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
|
||||
class BandMathCalculator:
|
||||
def __init__(self, csv_file):
|
||||
"""
|
||||
初始化计算器
|
||||
csv_file: 包含光谱反射率的CSV文件路径
|
||||
"""
|
||||
self.df = pd.read_csv(csv_file)
|
||||
self.wavelengths = self._extract_wavelengths()
|
||||
|
||||
def _extract_wavelengths(self):
|
||||
"""从列名中提取波长信息"""
|
||||
wavelengths = []
|
||||
for col in self.df.columns:
|
||||
# 尝试从列名中提取数字(波长)
|
||||
numbers = re.findall(r'\d+\.?\d*', str(col))
|
||||
if numbers:
|
||||
wavelengths.append(float(numbers[0]))
|
||||
else:
|
||||
wavelengths.append(None)
|
||||
return wavelengths
|
||||
|
||||
def _find_closest_wavelength(self, target_wavelength):
|
||||
"""找到最接近目标波长的列索引"""
|
||||
valid_indices = [i for i, wl in enumerate(self.wavelengths) if wl is not None]
|
||||
if not valid_indices:
|
||||
raise ValueError("未找到有效的波长列")
|
||||
|
||||
# 计算与目标波长的差值
|
||||
differences = [abs(self.wavelengths[i] - target_wavelength) for i in valid_indices]
|
||||
min_diff_index = np.argmin(differences)
|
||||
closest_index = valid_indices[min_diff_index]
|
||||
closest_wavelength = self.wavelengths[closest_index]
|
||||
|
||||
print(
|
||||
f"目标波长 {target_wavelength}nm -> 最接近波长 {closest_wavelength}nm (列: {self.df.columns[closest_index]})")
|
||||
return closest_index
|
||||
|
||||
def _parse_expression(self, expression):
|
||||
"""解析表达式,提取所有波段变量 - 支持大小写"""
|
||||
# 匹配 w或W后面跟着数字的格式的变量
|
||||
pattern = r'[wW](\d+\.?\d*)'
|
||||
matches = re.findall(pattern, expression)
|
||||
return matches # 返回字符串列表,如 ['686', '672', '715', '672']
|
||||
|
||||
def _create_substitution_dict(self, variables, row_index=0):
|
||||
"""创建变量替换字典 - 支持大小写"""
|
||||
substitution_dict = {}
|
||||
for var in variables:
|
||||
wavelength = float(var) # 将字符串转换为浮点数
|
||||
col_index = self._find_closest_wavelength(wavelength)
|
||||
value = self.df.iloc[row_index, col_index]
|
||||
# 同时添加小写和大写版本的变量
|
||||
substitution_dict[f'w{var}'] = value
|
||||
substitution_dict[f'W{var}'] = value
|
||||
return substitution_dict
|
||||
|
||||
def calculate(self, expression, row_index=0):
|
||||
"""
|
||||
计算自定义波段表达式
|
||||
|
||||
参数:
|
||||
expression: 波段计算表达式,如 'chl=w560/w760'
|
||||
row_index: 要计算的数据行索引,默认为第0行
|
||||
|
||||
返回:
|
||||
计算结果
|
||||
"""
|
||||
try:
|
||||
# 提取表达式中的计算部分
|
||||
if '=' in expression:
|
||||
# 如果包含赋值,只取等号右边的计算部分
|
||||
calc_part = expression.split('=')[1].strip()
|
||||
var_name = expression.split('=')[0].strip()
|
||||
else:
|
||||
calc_part = expression.strip()
|
||||
var_name = None
|
||||
|
||||
# 解析变量
|
||||
variables = self._parse_expression(calc_part)
|
||||
print(f"解析到的波长变量: {variables}")
|
||||
|
||||
# 创建替换字典
|
||||
sub_dict = self._create_substitution_dict(variables, row_index)
|
||||
print(f"变量值: {sub_dict}")
|
||||
|
||||
# 替换表达式中的变量 - 使用安全的字符串替换
|
||||
calc_expression = calc_part
|
||||
for var_pattern, value in sub_dict.items():
|
||||
# 确保替换完整的变量名,避免部分匹配
|
||||
calc_expression = re.sub(r'\b' + re.escape(var_pattern) + r'\b', f"({value})", calc_expression)
|
||||
|
||||
print(f"计算表达式: {calc_expression}")
|
||||
|
||||
# 安全地计算表达式
|
||||
result = eval(calc_expression)
|
||||
|
||||
# 返回结果
|
||||
if var_name:
|
||||
return {var_name: result}
|
||||
else:
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def calculate_all_rows(self, expression):
|
||||
"""为所有行计算表达式"""
|
||||
results = []
|
||||
for i in range(len(self.df)):
|
||||
print(f"\n--- 计算第 {i} 行 ---")
|
||||
result = self.calculate(expression, i)
|
||||
if result is not None:
|
||||
if isinstance(result, dict):
|
||||
results.append(list(result.values())[0])
|
||||
else:
|
||||
results.append(result)
|
||||
else:
|
||||
# 如果计算失败,添加NaN值以保持结果数量一致
|
||||
results.append(np.nan)
|
||||
print(f"第 {i} 行计算失败,使用NaN填充")
|
||||
return results
|
||||
|
||||
def process_formulas_from_csv(self, formula_csv_file, formula_names=None, output_file=None):
|
||||
"""
|
||||
从公式CSV文件中批量计算并添加到数据文件中
|
||||
|
||||
参数:
|
||||
formula_csv_file: 公式CSV文件路径,第一列为公式名称,第三列为具体公式
|
||||
formula_names: 要计算的公式名称列表,如果为None则计算所有公式
|
||||
output_file: 输出文件路径,如果为None则自动生成
|
||||
|
||||
返回:
|
||||
包含计算结果的新DataFrame
|
||||
"""
|
||||
# 读取公式CSV文件
|
||||
try:
|
||||
formulas_df = pd.read_csv(formula_csv_file)
|
||||
print(f"读取到 {len(formulas_df)} 个公式")
|
||||
|
||||
# 检查CSV格式,假设第一列为公式名称,第三列为具体公式
|
||||
if len(formulas_df.columns) < 3:
|
||||
raise ValueError("公式CSV文件需要至少3列")
|
||||
|
||||
formula_name_col = formulas_df.columns[0] # 第一列:公式名称
|
||||
formula_expr_col = formulas_df.columns[2] # 第三列:具体公式
|
||||
|
||||
# 创建结果DataFrame的副本
|
||||
result_df = self.df.copy()
|
||||
|
||||
# 如果指定了公式名称,则只计算这些公式
|
||||
if formula_names is not None:
|
||||
if isinstance(formula_names, str):
|
||||
formula_names = [formula_names] # 转换为列表
|
||||
|
||||
# 筛选出指定的公式
|
||||
selected_formulas = formulas_df[formulas_df[formula_name_col].isin(formula_names)]
|
||||
print(f"找到 {len(selected_formulas)} 个指定的公式")
|
||||
|
||||
if len(selected_formulas) == 0:
|
||||
print(f"警告: 未找到指定的公式: {formula_names}")
|
||||
return result_df
|
||||
|
||||
formulas_to_process = selected_formulas
|
||||
else:
|
||||
# 计算所有公式
|
||||
formulas_to_process = formulas_df
|
||||
|
||||
# 为每个公式计算所有行
|
||||
for _, row in formulas_to_process.iterrows():
|
||||
formula_name = row[formula_name_col]
|
||||
formula_expr = row[formula_expr_col]
|
||||
|
||||
if pd.isna(formula_name) or pd.isna(formula_expr):
|
||||
print(f"跳过空公式: {row}")
|
||||
continue
|
||||
|
||||
print(f"\n计算公式: {formula_name} = {formula_expr}")
|
||||
|
||||
# 计算所有行的结果
|
||||
results = self.calculate_all_rows(formula_expr)
|
||||
|
||||
# 将结果添加到DataFrame
|
||||
result_df[formula_name] = results
|
||||
print(f"公式 '{formula_name}' 计算完成,添加到数据中")
|
||||
|
||||
# 保存结果
|
||||
if output_file is None:
|
||||
# 自动生成输出文件名
|
||||
import os
|
||||
base_name = os.path.splitext(os.path.basename(formula_csv_file))[0]
|
||||
output_file = f"band_math_results_{base_name}.csv"
|
||||
|
||||
result_df.to_csv(output_file, index=False)
|
||||
print(f"结果已保存到: {output_file}")
|
||||
|
||||
return result_df
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理公式CSV文件时出错: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# 更新使用示例
|
||||
if __name__ == "__main__":
|
||||
# 创建计算器实例
|
||||
calculator = BandMathCalculator(r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\training_spectra.csv")
|
||||
|
||||
# 示例1: 计算所有公式
|
||||
# result_df = calculator.process_formulas_from_csv(r"E:\code\WQ\封装\sub\水质参数.csv", "enhanced_data.csv")
|
||||
|
||||
# 示例2: 计算指定公式
|
||||
result_df = calculator.process_formulas_from_csv(
|
||||
r"E:\code\WQ\封装\sub\水质参数.csv",
|
||||
formula_names=["BGA_Am09KBBI", "BGA_Be162B643sub629"],
|
||||
output_file=r"E:\code\WQ\pipeline_result\work_dir\5_training_spectra\enhanced_data.csv"
|
||||
)
|
||||
Reference in New Issue
Block a user