import os import cv2 import matplotlib import numpy as np import argparse import pandas as pd from bil2rgb import process_bil_files from mask import detect_microplastic_mask_from_array from shape_spectral import process_images from shape_spectral_background import process_images_background import time from tqdm import tqdm matplotlib.use('TkAgg') # 训练相机波长(237通道) TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] # 微塑料类型映射 PLASTIC_TYPE_MAPPING = { 'ABS': 0, 'HDPE': 1, 'LDPE': 2, 'PA6': 3, 'PET': 4, 'PP': 5, 'PS': 6, 'PTFE': 7, 'PVC': 8, } def get_plastic_label_from_filename(filename): """ 从文件名中提取微塑料类型并返回对应的数字标签 只要文件名中包含微塑料名称就返回对应的数字 """ filename_upper = filename.upper() for plastic_type, label in PLASTIC_TYPE_MAPPING.items(): if plastic_type in filename_upper: return label return None def read_wavelengths_from_hdr(bil_path): hdr_path = os.path.splitext(bil_path)[0] + '.hdr' if not os.path.exists(hdr_path): return np.array([], dtype=np.float64) with open(hdr_path, 'r') as f: txt = f.read() if 'wavelength' not in txt: return np.array([], dtype=np.float64) seg = txt.split('wavelength', 1)[1] seg = seg[seg.find('{')+1: seg.find('}')] vals = [v.strip() for v in seg.split(',') if v.strip()] try: waves = np.array([float(v) for v in vals], dtype=np.float64) except Exception: waves = np.array([], dtype=np.float64) return waves def resample_spectra_matrix(X, src_waves, dst_waves): src = np.asarray(src_waves, dtype=np.float64) dst = np.asarray(dst_waves, dtype=np.float64) X = np.asarray(X, dtype=np.float64) if src.size == 0 or dst.size == 0: return X # 线性插值,越界用端点外推,避免维度缺失 out = np.empty((X.shape[0], dst.size), dtype=np.float64) for i in range(X.shape[0]): row = X[i] out[i] = np.interp(dst, src, row, left=row[0], right=row[-1]) return out def read_hdr_file(bil_path): hdr_path = bil_path.replace('.bil', '.hdr') with open(hdr_path, 'r') as f: header = f.readlines() samples, lines = None, None for line in header: if line.startswith('samples'): samples = int(line.split('=')[-1].strip()) if line.startswith('lines'): lines = int(line.split('=')[-1].strip()) return samples, lines def change_hdr_file(bil_path, wavelengths=None): hdr_path = os.path.splitext(bil_path)[0] + '.hdr' if not os.path.exists(hdr_path): print(f"错误: 找不到对应的HDR文件: {hdr_path}") return # 仅在缺少 wavelength 字段时才尝试写入 with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file: content = file.read() if 'wavelength' in content: print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。") return if wavelengths is None or len(wavelengths) == 0: print("HDR 缺少 wavelength,但未提供 wavelengths,跳过写入以避免错误。") return needs_newline = not content.endswith('\n') wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n" with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file: if needs_newline: file.write('\n') file.write(wavelength_info) print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。") def validate_inputs(bil_path, output_dir): """验证输入文件和参数""" # 检查BIL和HDR文件存在 if not os.path.exists(bil_path): raise FileNotFoundError(f"BIL文件不存在: {bil_path}") hdr_path = os.path.splitext(bil_path)[0] + '.hdr' if not os.path.exists(hdr_path): raise FileNotFoundError(f"HDR文件不存在: {hdr_path}") # 检查输出目录可写 if output_dir and not os.path.exists(output_dir): try: os.makedirs(output_dir, exist_ok=True) except Exception as e: raise RuntimeError(f"无法创建输出目录: {output_dir}") from e # 检查BIL文件波段数是否足够 try: from spectral.io import envi img = envi.open(hdr_path, bil_path) n_bands = img.nbands # bil2rgb需要波段索引9, 59, 159 if n_bands <= 159: raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个") except Exception as e: raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e def generate_rgb(bil_path): """处理BIL文件生成RGB图像""" try: rgb_img = process_bil_files(bil_path) return rgb_img except Exception as e: raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e def run_segmentation(rgb_img, segmentation_model_path=None): """运行分割获取掩膜""" try: mask, filter_mask_original = detect_microplastic_mask_from_array( image=rgb_img, filter_method='threshold', diameter=None, flow_threshold=0.4, cellprob_threshold=-1, model_path=segmentation_model_path, detect_filter=True ) return mask, filter_mask_original except Exception as e: raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e def extract_primary_features(bil_path, mask): """提取主要特征""" try: df = process_images(bil_path, mask) return df except Exception as e: raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e def compute_background_spectrum(bil_path, mask): """计算背景光谱""" try: bg_spectrum = process_images_background(bil_path, mask) return bg_spectrum except Exception as e: raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e def apply_background_correction(df, bg_spectrum): """应用背景校正,不进行重采样""" # 识别光谱列:所有以wavelength_开头的列 spec_cols = [c for c in df.columns if isinstance(c, str) and c.startswith('wavelength_')] if not spec_cols: raise ValueError("未找到光谱列(以wavelength_开头的列)") # 创建原始光谱数据的副本 df_original = df.copy() # 背景校正:用背景光谱逐列相除 bg = np.asarray(bg_spectrum, dtype=np.float64).ravel() # 尾部对齐,取最小长度,避免维度不一致 n = min(len(spec_cols), bg.shape[0]) use_cols = spec_cols[-n:] df_corrected = df.copy() df_corrected.loc[:, use_cols] = df_corrected.loc[:, use_cols].div(bg[-n:], axis=1) return df_corrected, df_original, bg def clean_and_select_columns(df): """数据清理和列选择""" # 移除NaN值 df = df.dropna() # 过滤轮廓点数不足的样本 df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] # 过滤面积过小的样本 df = df[df['area'] >= 500] return df def rename_wavelength_columns(df, prefix=''): """ 将列名中的 'wavelength_' 前缀移除,替换为指定前缀或直接波长数值 """ new_columns = {} for col in df.columns: if isinstance(col, str) and col.startswith('wavelength_'): # 提取波长数值 wavelength_value = col.replace('wavelength_', '') new_columns[col] = wavelength_value if new_columns: df = df.rename(columns=new_columns) return df def save_spectra_to_csv(df_corrected, df_original, bg_spectrum, bil_path, output_dir, plastic_label, all_corrected_data, all_original_data, all_background_data): """ 保存三种光谱数据为CSV文件 - 背景校正后的光谱 - 原始光谱 - 背景光谱 同时收集数据用于统一合并 """ base_name = os.path.splitext(os.path.basename(bil_path))[0] # 创建输出子目录 corrected_dir = os.path.join(output_dir, 'corrected_spectra') original_dir = os.path.join(output_dir, 'original_spectra') background_dir = os.path.join(output_dir, 'background_spectra') os.makedirs(corrected_dir, exist_ok=True) os.makedirs(original_dir, exist_ok=True) os.makedirs(background_dir, exist_ok=True) # 获取光谱列 spec_cols = [c for c in df_corrected.columns if isinstance(c, str) and c.startswith('wavelength_')] # 移除波长列名中的 'wavelength_' 前缀 df_corrected_renamed = rename_wavelength_columns(df_corrected.copy()) df_original_renamed = rename_wavelength_columns(df_original.copy()) # 获取新的波长列名(已移除前缀) wavelength_cols = [c for c in df_corrected_renamed.columns if c not in [col for col in df_corrected.columns if isinstance(col, str) and not col.startswith('wavelength_')]] # 保存背景校正后的光谱 df_corrected_out = df_corrected_renamed.copy() if plastic_label is not None: if len(df_corrected_out) > 0: non_spec_cols = [c for c in df_corrected_out.columns if c not in wavelength_cols] if non_spec_cols: first_col = non_spec_cols[0] df_corrected_out[first_col] = plastic_label # 添加文件名列用于区分来源 df_corrected_out.insert(0, 'source_file', base_name) corrected_path = os.path.join(corrected_dir, f"{base_name}_corrected.csv") df_corrected_out.to_csv(corrected_path, index=False) print(f" 背景校正光谱已保存: {corrected_path}") # 收集到合并列表 all_corrected_data.append(df_corrected_out) # 保存原始光谱 df_original_out = df_original_renamed.copy() if plastic_label is not None: if len(df_original_out) > 0: non_spec_cols = [c for c in df_original_out.columns if c not in wavelength_cols] if non_spec_cols: first_col = non_spec_cols[0] df_original_out[first_col] = plastic_label # 添加文件名列用于区分来源 df_original_out.insert(0, 'source_file', base_name) original_path = os.path.join(original_dir, f"{base_name}_original.csv") df_original_out.to_csv(original_path, index=False) print(f" 原始光谱已保存: {original_path}") # 收集到合并列表 all_original_data.append(df_original_out) # 保存背景光谱 # 移除 'wavelength_' 前缀 wavelength_names = [col.replace('wavelength_', '') for col in spec_cols[-len(bg_spectrum):]] if len(spec_cols) >= len(bg_spectrum) else [col.replace('wavelength_', '') for col in spec_cols] bg_df = pd.DataFrame({ 'wavelength': wavelength_names, 'background_value': bg_spectrum }) bg_df.insert(0, 'source_file', base_name) bg_df.insert(1, 'plastic_type', plastic_label if plastic_label is not None else 'unknown') background_path = os.path.join(background_dir, f"{base_name}_background.csv") bg_df.to_csv(background_path, index=False) print(f" 背景光谱已保存: {background_path}") # 收集到合并列表 all_background_data.append(bg_df) def save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir): """ 将所有收集的数据合并保存为统一的CSV文件 """ combined_dir = os.path.join(output_dir, 'combined') os.makedirs(combined_dir, exist_ok=True) # 合并背景校正光谱 if all_corrected_data: combined_corrected = pd.concat(all_corrected_data, ignore_index=True) corrected_combined_path = os.path.join(combined_dir, 'all_corrected_spectra.csv') combined_corrected.to_csv(corrected_combined_path, index=False) print(f"\n 合并背景校正光谱已保存: {corrected_combined_path}") print(f" 总行数: {len(combined_corrected)}") # 合并原始光谱 if all_original_data: combined_original = pd.concat(all_original_data, ignore_index=True) original_combined_path = os.path.join(combined_dir, 'all_original_spectra.csv') combined_original.to_csv(original_combined_path, index=False) print(f" 合并原始光谱已保存: {original_combined_path}") print(f" 总行数: {len(combined_original)}") # 合并背景光谱 if all_background_data: combined_background = pd.concat(all_background_data, ignore_index=True) background_combined_path = os.path.join(combined_dir, 'all_background_spectra.csv') combined_background.to_csv(background_combined_path, index=False) print(f" 合并背景光谱已保存: {background_combined_path}") print(f" 总行数: {len(combined_background)}") def process_single_bil(bil_path, output_dir, segmentation_model_path=None, all_corrected_data=None, all_original_data=None, all_background_data=None): """处理单个BIL文件""" try: print(f"\n处理文件: {bil_path}") # 从文件名获取微塑料标签 filename = os.path.basename(bil_path) plastic_label = get_plastic_label_from_filename(filename) if plastic_label is not None: print(f" 检测到微塑料类型: {list(PLASTIC_TYPE_MAPPING.keys())[list(PLASTIC_TYPE_MAPPING.values()).index(plastic_label)]} -> {plastic_label}") else: print(f" 警告: 无法从文件名识别微塑料类型") # 验证输入 validate_inputs(bil_path, output_dir) bands = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] # 修改HDR文件 change_hdr_file(bil_path, bands) # 处理BIL文件生成RGB图像 print(" 生成RGB图像...") rgb_img = generate_rgb(bil_path) # 分割阶段 print(" 生成掩膜...") mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path) # 提取特征 print(" 提取光谱特征...") df = extract_primary_features(bil_path, mask) # 背景校正 print(" 计算背景光谱并应用校正...") bg_spectrum = compute_background_spectrum(bil_path, mask) # 应用背景校正(不进行重采样) df_corrected, df_original, bg_spectrum_aligned = apply_background_correction(df, bg_spectrum) # 数据清理 print(" 清理数据...") df_corrected = clean_and_select_columns(df_corrected) df_original = clean_and_select_columns(df_original) # 保存三种光谱数据(同时收集到合并列表) save_spectra_to_csv(df_corrected, df_original, bg_spectrum_aligned, bil_path, output_dir, plastic_label, all_corrected_data, all_original_data, all_background_data) print(f" 处理完成: {filename}") return True except Exception as e: print(f" 处理失败: {bil_path} - {e}") return False def parse_arguments(): """解析命令行参数""" parser = argparse.ArgumentParser( description='批量处理高光谱图像,提取并保存背景校正光谱、原始光谱和背景光谱' ) # 输入文件夹(包含BIL文件) parser.add_argument('--input_dir', required=True, help='包含输入BIL文件的文件夹路径') # 输出文件夹 parser.add_argument('--output_dir', required=True, help='保存CSV输出结果的文件夹路径') return parser.parse_args() def main(): """主函数""" args = parse_arguments() input_dir = args.input_dir output_dir = args.output_dir # 记录总开始时间 total_start_time = time.time() # 检查输入目录 if not os.path.exists(input_dir): print(f"错误: 输入目录不存在: {input_dir}") return # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 获取所有BIL文件 bil_files = [f for f in os.listdir(input_dir) if f.endswith('.bil')] bil_files.sort() if not bil_files: print(f"警告: 在 {input_dir} 中未找到BIL文件") return print(f"\n{'=' * 60}") print(f"找到 {len(bil_files)} 个BIL文件需要处理") print(f"{'=' * 60}") # 用于收集所有数据的列表 all_corrected_data = [] all_original_data = [] all_background_data = [] # 使用tqdm显示进度条 success_count = 0 fail_count = 0 for bil_file in tqdm(bil_files, desc="处理进度", unit="文件"): bil_path = os.path.join(input_dir, bil_file) if process_single_bil(bil_path, output_dir, all_corrected_data=all_corrected_data, all_original_data=all_original_data, all_background_data=all_background_data): success_count += 1 else: fail_count += 1 # 保存合并的CSV文件 if all_corrected_data or all_original_data or all_background_data: print(f"\n{'=' * 60}") print("正在生成合并的CSV文件...") save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir) # 计算总耗时 total_time = time.time() - total_start_time # 打印总结 print(f"\n{'=' * 60}") print("处理完成总结") print(f"{'=' * 60}") print(f"成功处理: {success_count} 个文件") print(f"失败: {fail_count} 个文件") print(f"总耗时: {total_time:.2f} 秒") print(f"平均每个文件: {total_time / len(bil_files):.2f} 秒") print(f"{'=' * 60}") print(f"结果已保存至: {output_dir}") print(f" - 单独文件:") print(f" - 背景校正光谱: {os.path.join(output_dir, 'corrected_spectra')}") print(f" - 原始光谱: {os.path.join(output_dir, 'original_spectra')}") print(f" - 背景光谱: {os.path.join(output_dir, 'background_spectra')}") print(f" - 合并文件:") print(f" - 合并后的光谱数据: {os.path.join(output_dir, 'combined')}") if __name__ == "__main__": main()