diff --git a/__pycache__/bil2rgb.cpython-312.pyc b/__pycache__/bil2rgb.cpython-312.pyc new file mode 100644 index 0000000..92feb4e Binary files /dev/null and b/__pycache__/bil2rgb.cpython-312.pyc differ diff --git a/__pycache__/mask.cpython-312.pyc b/__pycache__/mask.cpython-312.pyc new file mode 100644 index 0000000..a62dcd5 Binary files /dev/null and b/__pycache__/mask.cpython-312.pyc differ diff --git a/fliter_sample_spectral.py b/fliter_sample_spectral.py index b844b06..20f4453 100644 --- a/fliter_sample_spectral.py +++ b/fliter_sample_spectral.py @@ -20,7 +20,7 @@ from main import ( ) matplotlib.use('TkAgg') - +#####用于提取背景滤纸的样本,目的是在训练时加入滤纸的光谱,以减少滤纸与ftpe的误判 def apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path): # 先做背景校正(自动识别以 wavelength_ 或 band_ 开头的光谱列,且长度不一致时尾部对齐) diff --git a/only_mask.py b/only_mask.py index 9e4a339..90d5538 100644 --- a/only_mask.py +++ b/only_mask.py @@ -107,8 +107,7 @@ def change_hdr_file(bil_path): def main(): bil_path = "D:\Data\MPData7.bil" - output_path = r'E:\plastic\plastic\output\20251113\数据增强\mpdata7.bil' - model_path = r"E:\plastic\plastic\output\20251113\一阶导数\catboost.m" + # 处理BIL文件生成RGB图像 print("Processing BIL file to generate RGB image...\n") diff --git a/shape_spectral.py b/shape_spectral.py index 49315e1..bd65240 100644 --- a/shape_spectral.py +++ b/shape_spectral.py @@ -390,9 +390,11 @@ def process_images(full_bil_path, mask, outdir='None', debug="None"): return combined_data -# # # 示例:批量处理指定路径下的光谱图像和掩膜文件 -# bil_path = r'D:\WQ\test\Traindata-05' -# mask_path = r'D:\WQ\test\mask' -# outdir = r"D:\WQ\test" # 输出文件夹路径 -# -# process_images(bil_path, mask_path, outdir) \ No newline at end of file +if __name__ == "__main__": + + # # 示例:批量处理指定路径下的光谱图像和掩膜文件 + bil_path = r'D:\WQ\test\Traindata-05' + mask_path = r'D:\WQ\test\mask' + outdir = r"D:\WQ\test" # 输出文件夹路径 + + process_images(bil_path, mask_path, outdir) \ No newline at end of file diff --git a/train_sample.py b/train_sample.py new file mode 100644 index 0000000..642d3d7 --- /dev/null +++ b/train_sample.py @@ -0,0 +1,517 @@ +import os +import cv2 +import matplotlib +import numpy as np +import argparse +import pandas as pd +from bil2rgb import process_bil_files +from mask import detect_microplastic_mask_from_array +from shape_spectral import process_images +from shape_spectral_background import process_images_background +import time +from tqdm import tqdm + +matplotlib.use('TkAgg') + +# 训练相机波长(237通道) +TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] + +# 微塑料类型映射 +PLASTIC_TYPE_MAPPING = { + 'ABS': 0, + 'HDPE': 1, + 'LDPE': 2, + 'PA6': 3, + 'PET': 4, + 'PP': 5, + 'PS': 6, + 'PTFE': 7, + 'PVC': 8, +} + + +def get_plastic_label_from_filename(filename): + """ + 从文件名中提取微塑料类型并返回对应的数字标签 + 只要文件名中包含微塑料名称就返回对应的数字 + """ + filename_upper = filename.upper() + for plastic_type, label in PLASTIC_TYPE_MAPPING.items(): + if plastic_type in filename_upper: + return label + return None + + +def read_wavelengths_from_hdr(bil_path): + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if not os.path.exists(hdr_path): + return np.array([], dtype=np.float64) + with open(hdr_path, 'r') as f: + txt = f.read() + if 'wavelength' not in txt: + return np.array([], dtype=np.float64) + seg = txt.split('wavelength', 1)[1] + seg = seg[seg.find('{')+1: seg.find('}')] + vals = [v.strip() for v in seg.split(',') if v.strip()] + try: + waves = np.array([float(v) for v in vals], dtype=np.float64) + except Exception: + waves = np.array([], dtype=np.float64) + return waves + + +def resample_spectra_matrix(X, src_waves, dst_waves): + src = np.asarray(src_waves, dtype=np.float64) + dst = np.asarray(dst_waves, dtype=np.float64) + X = np.asarray(X, dtype=np.float64) + if src.size == 0 or dst.size == 0: + return X + # 线性插值,越界用端点外推,避免维度缺失 + out = np.empty((X.shape[0], dst.size), dtype=np.float64) + for i in range(X.shape[0]): + row = X[i] + out[i] = np.interp(dst, src, row, left=row[0], right=row[-1]) + return out + + +def read_hdr_file(bil_path): + hdr_path = bil_path.replace('.bil', '.hdr') + with open(hdr_path, 'r') as f: + header = f.readlines() + + samples, lines = None, None + + for line in header: + if line.startswith('samples'): + samples = int(line.split('=')[-1].strip()) + if line.startswith('lines'): + lines = int(line.split('=')[-1].strip()) + + return samples, lines + + +def change_hdr_file(bil_path, wavelengths=None): + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if not os.path.exists(hdr_path): + print(f"错误: 找不到对应的HDR文件: {hdr_path}") + return + + # 仅在缺少 wavelength 字段时才尝试写入 + with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file: + content = file.read() + + if 'wavelength' in content: + print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。") + return + + if wavelengths is None or len(wavelengths) == 0: + print("HDR 缺少 wavelength,但未提供 wavelengths,跳过写入以避免错误。") + return + + needs_newline = not content.endswith('\n') + wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n" + + with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file: + if needs_newline: + file.write('\n') + file.write(wavelength_info) + + print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。") + + +def validate_inputs(bil_path, output_dir): + """验证输入文件和参数""" + # 检查BIL和HDR文件存在 + if not os.path.exists(bil_path): + raise FileNotFoundError(f"BIL文件不存在: {bil_path}") + + hdr_path = os.path.splitext(bil_path)[0] + '.hdr' + if not os.path.exists(hdr_path): + raise FileNotFoundError(f"HDR文件不存在: {hdr_path}") + + # 检查输出目录可写 + if output_dir and not os.path.exists(output_dir): + try: + os.makedirs(output_dir, exist_ok=True) + except Exception as e: + raise RuntimeError(f"无法创建输出目录: {output_dir}") from e + + # 检查BIL文件波段数是否足够 + try: + from spectral.io import envi + img = envi.open(hdr_path, bil_path) + n_bands = img.nbands + # bil2rgb需要波段索引9, 59, 159 + if n_bands <= 159: + raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个") + except Exception as e: + raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e + + +def generate_rgb(bil_path): + """处理BIL文件生成RGB图像""" + try: + rgb_img = process_bil_files(bil_path) + return rgb_img + except Exception as e: + raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e + + +def run_segmentation(rgb_img, segmentation_model_path=None): + """运行分割获取掩膜""" + try: + mask, filter_mask_original = detect_microplastic_mask_from_array( + image=rgb_img, + filter_method='threshold', + diameter=None, + flow_threshold=0.4, + cellprob_threshold=-1, + model_path=segmentation_model_path, + detect_filter=True + ) + return mask, filter_mask_original + except Exception as e: + raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e + + +def extract_primary_features(bil_path, mask): + """提取主要特征""" + try: + df = process_images(bil_path, mask) + return df + except Exception as e: + raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e + + +def compute_background_spectrum(bil_path, mask): + """计算背景光谱""" + try: + bg_spectrum = process_images_background(bil_path, mask) + return bg_spectrum + except Exception as e: + raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e + + +def apply_background_correction(df, bg_spectrum): + """应用背景校正,不进行重采样""" + # 识别光谱列:所有以wavelength_开头的列 + spec_cols = [c for c in df.columns if isinstance(c, str) and c.startswith('wavelength_')] + + if not spec_cols: + raise ValueError("未找到光谱列(以wavelength_开头的列)") + + # 创建原始光谱数据的副本 + df_original = df.copy() + + # 背景校正:用背景光谱逐列相除 + bg = np.asarray(bg_spectrum, dtype=np.float64).ravel() + + # 尾部对齐,取最小长度,避免维度不一致 + n = min(len(spec_cols), bg.shape[0]) + use_cols = spec_cols[-n:] + + df_corrected = df.copy() + df_corrected.loc[:, use_cols] = df_corrected.loc[:, use_cols].div(bg[-n:], axis=1) + + return df_corrected, df_original, bg + + +def clean_and_select_columns(df): + """数据清理和列选择""" + # 移除NaN值 + df = df.dropna() + + # 过滤轮廓点数不足的样本 + df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] + + # 过滤面积过小的样本 + df = df[df['area'] >= 500] + + return df + + +def rename_wavelength_columns(df, prefix=''): + """ + 将列名中的 'wavelength_' 前缀移除,替换为指定前缀或直接波长数值 + """ + new_columns = {} + for col in df.columns: + if isinstance(col, str) and col.startswith('wavelength_'): + # 提取波长数值 + wavelength_value = col.replace('wavelength_', '') + new_columns[col] = wavelength_value + if new_columns: + df = df.rename(columns=new_columns) + return df + + +def save_spectra_to_csv(df_corrected, df_original, bg_spectrum, bil_path, output_dir, plastic_label, + all_corrected_data, all_original_data, all_background_data): + """ + 保存三种光谱数据为CSV文件 + - 背景校正后的光谱 + - 原始光谱 + - 背景光谱 + 同时收集数据用于统一合并 + """ + base_name = os.path.splitext(os.path.basename(bil_path))[0] + + # 创建输出子目录 + corrected_dir = os.path.join(output_dir, 'corrected_spectra') + original_dir = os.path.join(output_dir, 'original_spectra') + background_dir = os.path.join(output_dir, 'background_spectra') + + os.makedirs(corrected_dir, exist_ok=True) + os.makedirs(original_dir, exist_ok=True) + os.makedirs(background_dir, exist_ok=True) + + # 获取光谱列 + spec_cols = [c for c in df_corrected.columns if isinstance(c, str) and c.startswith('wavelength_')] + + # 移除波长列名中的 'wavelength_' 前缀 + df_corrected_renamed = rename_wavelength_columns(df_corrected.copy()) + df_original_renamed = rename_wavelength_columns(df_original.copy()) + + # 获取新的波长列名(已移除前缀) + wavelength_cols = [c for c in df_corrected_renamed.columns if c not in + [col for col in df_corrected.columns if isinstance(col, str) and not col.startswith('wavelength_')]] + + # 保存背景校正后的光谱 + df_corrected_out = df_corrected_renamed.copy() + if plastic_label is not None: + if len(df_corrected_out) > 0: + non_spec_cols = [c for c in df_corrected_out.columns if c not in wavelength_cols] + if non_spec_cols: + first_col = non_spec_cols[0] + df_corrected_out[first_col] = plastic_label + + # 添加文件名列用于区分来源 + df_corrected_out.insert(0, 'source_file', base_name) + + corrected_path = os.path.join(corrected_dir, f"{base_name}_corrected.csv") + df_corrected_out.to_csv(corrected_path, index=False) + print(f" 背景校正光谱已保存: {corrected_path}") + + # 收集到合并列表 + all_corrected_data.append(df_corrected_out) + + # 保存原始光谱 + df_original_out = df_original_renamed.copy() + if plastic_label is not None: + if len(df_original_out) > 0: + non_spec_cols = [c for c in df_original_out.columns if c not in wavelength_cols] + if non_spec_cols: + first_col = non_spec_cols[0] + df_original_out[first_col] = plastic_label + + # 添加文件名列用于区分来源 + df_original_out.insert(0, 'source_file', base_name) + + original_path = os.path.join(original_dir, f"{base_name}_original.csv") + df_original_out.to_csv(original_path, index=False) + print(f" 原始光谱已保存: {original_path}") + + # 收集到合并列表 + all_original_data.append(df_original_out) + + # 保存背景光谱 + # 移除 'wavelength_' 前缀 + wavelength_names = [col.replace('wavelength_', '') for col in spec_cols[-len(bg_spectrum):]] if len(spec_cols) >= len(bg_spectrum) else [col.replace('wavelength_', '') for col in spec_cols] + + bg_df = pd.DataFrame({ + 'wavelength': wavelength_names, + 'background_value': bg_spectrum + }) + bg_df.insert(0, 'source_file', base_name) + bg_df.insert(1, 'plastic_type', plastic_label if plastic_label is not None else 'unknown') + + background_path = os.path.join(background_dir, f"{base_name}_background.csv") + bg_df.to_csv(background_path, index=False) + print(f" 背景光谱已保存: {background_path}") + + # 收集到合并列表 + all_background_data.append(bg_df) + + +def save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir): + """ + 将所有收集的数据合并保存为统一的CSV文件 + """ + combined_dir = os.path.join(output_dir, 'combined') + os.makedirs(combined_dir, exist_ok=True) + + # 合并背景校正光谱 + if all_corrected_data: + combined_corrected = pd.concat(all_corrected_data, ignore_index=True) + corrected_combined_path = os.path.join(combined_dir, 'all_corrected_spectra.csv') + combined_corrected.to_csv(corrected_combined_path, index=False) + print(f"\n 合并背景校正光谱已保存: {corrected_combined_path}") + print(f" 总行数: {len(combined_corrected)}") + + # 合并原始光谱 + if all_original_data: + combined_original = pd.concat(all_original_data, ignore_index=True) + original_combined_path = os.path.join(combined_dir, 'all_original_spectra.csv') + combined_original.to_csv(original_combined_path, index=False) + print(f" 合并原始光谱已保存: {original_combined_path}") + print(f" 总行数: {len(combined_original)}") + + # 合并背景光谱 + if all_background_data: + combined_background = pd.concat(all_background_data, ignore_index=True) + background_combined_path = os.path.join(combined_dir, 'all_background_spectra.csv') + combined_background.to_csv(background_combined_path, index=False) + print(f" 合并背景光谱已保存: {background_combined_path}") + print(f" 总行数: {len(combined_background)}") + + +def process_single_bil(bil_path, output_dir, segmentation_model_path=None, + all_corrected_data=None, all_original_data=None, all_background_data=None): + """处理单个BIL文件""" + try: + print(f"\n处理文件: {bil_path}") + + # 从文件名获取微塑料标签 + filename = os.path.basename(bil_path) + plastic_label = get_plastic_label_from_filename(filename) + if plastic_label is not None: + print(f" 检测到微塑料类型: {list(PLASTIC_TYPE_MAPPING.keys())[list(PLASTIC_TYPE_MAPPING.values()).index(plastic_label)]} -> {plastic_label}") + else: + print(f" 警告: 无法从文件名识别微塑料类型") + + # 验证输入 + validate_inputs(bil_path, output_dir) + bands = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6] + + # 修改HDR文件 + change_hdr_file(bil_path, bands) + + # 处理BIL文件生成RGB图像 + print(" 生成RGB图像...") + rgb_img = generate_rgb(bil_path) + + # 分割阶段 + print(" 生成掩膜...") + mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path) + + # 提取特征 + print(" 提取光谱特征...") + df = extract_primary_features(bil_path, mask) + + # 背景校正 + print(" 计算背景光谱并应用校正...") + bg_spectrum = compute_background_spectrum(bil_path, mask) + + # 应用背景校正(不进行重采样) + df_corrected, df_original, bg_spectrum_aligned = apply_background_correction(df, bg_spectrum) + + # 数据清理 + print(" 清理数据...") + df_corrected = clean_and_select_columns(df_corrected) + df_original = clean_and_select_columns(df_original) + + # 保存三种光谱数据(同时收集到合并列表) + save_spectra_to_csv(df_corrected, df_original, bg_spectrum_aligned, bil_path, output_dir, plastic_label, + all_corrected_data, all_original_data, all_background_data) + + print(f" 处理完成: {filename}") + return True + + except Exception as e: + print(f" 处理失败: {bil_path} - {e}") + return False + + +def parse_arguments(): + """解析命令行参数""" + parser = argparse.ArgumentParser( + description='批量处理高光谱图像,提取并保存背景校正光谱、原始光谱和背景光谱' + ) + + # 输入文件夹(包含BIL文件) + parser.add_argument('--input_dir', required=True, help='包含输入BIL文件的文件夹路径') + + # 输出文件夹 + parser.add_argument('--output_dir', required=True, help='保存CSV输出结果的文件夹路径') + + return parser.parse_args() + + +def main(): + """主函数""" + args = parse_arguments() + + input_dir = args.input_dir + output_dir = args.output_dir + + # 记录总开始时间 + total_start_time = time.time() + + # 检查输入目录 + if not os.path.exists(input_dir): + print(f"错误: 输入目录不存在: {input_dir}") + return + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + # 获取所有BIL文件 + bil_files = [f for f in os.listdir(input_dir) if f.endswith('.bil')] + bil_files.sort() + + if not bil_files: + print(f"警告: 在 {input_dir} 中未找到BIL文件") + return + + print(f"\n{'=' * 60}") + print(f"找到 {len(bil_files)} 个BIL文件需要处理") + print(f"{'=' * 60}") + + # 用于收集所有数据的列表 + all_corrected_data = [] + all_original_data = [] + all_background_data = [] + + # 使用tqdm显示进度条 + success_count = 0 + fail_count = 0 + + for bil_file in tqdm(bil_files, desc="处理进度", unit="文件"): + bil_path = os.path.join(input_dir, bil_file) + if process_single_bil(bil_path, output_dir, + all_corrected_data=all_corrected_data, + all_original_data=all_original_data, + all_background_data=all_background_data): + success_count += 1 + else: + fail_count += 1 + + # 保存合并的CSV文件 + if all_corrected_data or all_original_data or all_background_data: + print(f"\n{'=' * 60}") + print("正在生成合并的CSV文件...") + save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir) + + # 计算总耗时 + total_time = time.time() - total_start_time + + # 打印总结 + print(f"\n{'=' * 60}") + print("处理完成总结") + print(f"{'=' * 60}") + print(f"成功处理: {success_count} 个文件") + print(f"失败: {fail_count} 个文件") + print(f"总耗时: {total_time:.2f} 秒") + print(f"平均每个文件: {total_time / len(bil_files):.2f} 秒") + print(f"{'=' * 60}") + print(f"结果已保存至: {output_dir}") + print(f" - 单独文件:") + print(f" - 背景校正光谱: {os.path.join(output_dir, 'corrected_spectra')}") + print(f" - 原始光谱: {os.path.join(output_dir, 'original_spectra')}") + print(f" - 背景光谱: {os.path.join(output_dir, 'background_spectra')}") + print(f" - 合并文件:") + print(f" - 合并后的光谱数据: {os.path.join(output_dir, 'combined')}") + + +if __name__ == "__main__": + main()