import pandas as pd import numpy as np import matplotlib.pyplot as plt def plot_spectra(csv_file_path, chla_file_path): """ 读取CSV文件并根据chl-a指数是否为0将光谱分为两部分绘制 用红线标出650.88nm, 670.41nm, 706.54nm波长位置 :param csv_file_path: 原始光谱数据CSV文件路径 :param chla_file_path: 包含chl-a指数的CSV文件路径 """ # 读取原始光谱数据 df_spectral = pd.read_csv(csv_file_path) wavelengths = df_spectral.columns[1:].astype(float) # 假设从第2列开始是光谱数据 # 读取包含chl-a指数的结果文件 df_chla = pd.read_csv(chla_file_path) # 确保两个文件的行数匹配 if len(df_spectral) != len(df_chla): raise ValueError("光谱数据文件和水质参数文件行数不匹配") # 分离光谱数据 spectral_data = df_spectral.iloc[:, 1:].values.astype(float) # 获取chl-a列 chla_values = df_chla['chl-a'].values # 根据chl-a值创建分组 group1_indices = np.where(chla_values == 0)[0] # chl-a=0的索引 group2_indices = np.where(chla_values != 0)[0] # chl-a≠0的索引 # 创建两个图表 plt.figure(figsize=(15, 10)) # 定义要标记的波长位置 highlight_wavelengths = [650.88, 670.41, 706.54] colors = ['red', 'green', 'blue'] # 为每个波长使用不同的颜色 # 第一组:chl-a=0 ax1 = plt.subplot(2, 1, 1) if len(group1_indices) > 0: group1_data = spectral_data[group1_indices, :] for i in range(group1_data.shape[0]): plt.plot(wavelengths, group1_data[i, :], linewidth=1, alpha=0.7) mean_spectrum = np.mean(group1_data, axis=0) plt.plot(wavelengths, mean_spectrum, 'k-', linewidth=2.5, label='Average Spectrum') # 添加波长标记线 y_min, y_max = plt.ylim() for i, wave in enumerate(highlight_wavelengths): plt.axvline(x=wave, color=colors[i], linestyle='--', alpha=0.7, label=f'{wave}nm' if i == 0 else '') plt.text(wave, y_max * 0.95, f'{wave}nm', fontsize=10, color=colors[i], ha='center', va='top') plt.title(f'Spectral Data with chl-a = 0 (n={len(group1_indices)})') plt.ylabel('Reflectance') plt.legend() plt.grid(True, linestyle='--', alpha=0.3) # 第二组:chl-a≠0 ax2 = plt.subplot(2, 1, 2, sharex=ax1) # 共享x轴范围 if len(group2_indices) > 0: group2_data = spectral_data[group2_indices, :] for i in range(group2_data.shape[0]): plt.plot(wavelengths, group2_data[i, :], linewidth=1, alpha=0.7) mean_spectrum = np.mean(group2_data, axis=0) plt.plot(wavelengths, mean_spectrum, 'k-', linewidth=2.5, label='Average Spectrum') # 添加波长标记线 y_min, y_max = plt.ylim() for i, wave in enumerate(highlight_wavelengths): plt.axvline(x=wave, color=colors[i], linestyle='--', alpha=0.7, label=f'{wave}nm' if i == 0 else '') plt.text(wave, y_max * 0.95, f'{wave}nm', fontsize=10, color=colors[i], ha='center', va='top') plt.title(f'Spectral Data with chl-a ≠ 0 (n={len(group2_indices)})') plt.xlabel('Wavelength (nm)') plt.ylabel('Reflectance') plt.legend() plt.grid(True, linestyle='--', alpha=0.3) # 调整布局并保存 plt.tight_layout() plt.savefig( r'D:\WQ\zhanghuilai\hyperspectral-inversion\data\input\一代高光谱\2025-07-11_2025-07-21\plot\spectral_groups_highlighted.png', dpi=300 ) plt.show() # 使用示例 if __name__ == "__main__": spectral_csv = r"D:\WQ\zhanghuilai\hyperspectral-inversion\data\input\一代高光谱\2025-07-11_2025-07-21\筛选.csv" chla_csv = r"D:\WQ\zhanghuilai\hyperspectral-inversion\data\input\一代高光谱\2025-07-11_2025-07-21\不去耀斑index.csv" plot_spectra(spectral_csv, chla_csv)