增加生成训练数据
This commit is contained in:
517
train_sample.py
Normal file
517
train_sample.py
Normal file
@ -0,0 +1,517 @@
|
||||
import os
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import argparse
|
||||
import pandas as pd
|
||||
from bil2rgb import process_bil_files
|
||||
from mask import detect_microplastic_mask_from_array
|
||||
from shape_spectral import process_images
|
||||
from shape_spectral_background import process_images_background
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
|
||||
matplotlib.use('TkAgg')
|
||||
|
||||
# 训练相机波长(237通道)
|
||||
TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
|
||||
|
||||
# 微塑料类型映射
|
||||
PLASTIC_TYPE_MAPPING = {
|
||||
'ABS': 0,
|
||||
'HDPE': 1,
|
||||
'LDPE': 2,
|
||||
'PA6': 3,
|
||||
'PET': 4,
|
||||
'PP': 5,
|
||||
'PS': 6,
|
||||
'PTFE': 7,
|
||||
'PVC': 8,
|
||||
}
|
||||
|
||||
|
||||
def get_plastic_label_from_filename(filename):
|
||||
"""
|
||||
从文件名中提取微塑料类型并返回对应的数字标签
|
||||
只要文件名中包含微塑料名称就返回对应的数字
|
||||
"""
|
||||
filename_upper = filename.upper()
|
||||
for plastic_type, label in PLASTIC_TYPE_MAPPING.items():
|
||||
if plastic_type in filename_upper:
|
||||
return label
|
||||
return None
|
||||
|
||||
|
||||
def read_wavelengths_from_hdr(bil_path):
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
if not os.path.exists(hdr_path):
|
||||
return np.array([], dtype=np.float64)
|
||||
with open(hdr_path, 'r') as f:
|
||||
txt = f.read()
|
||||
if 'wavelength' not in txt:
|
||||
return np.array([], dtype=np.float64)
|
||||
seg = txt.split('wavelength', 1)[1]
|
||||
seg = seg[seg.find('{')+1: seg.find('}')]
|
||||
vals = [v.strip() for v in seg.split(',') if v.strip()]
|
||||
try:
|
||||
waves = np.array([float(v) for v in vals], dtype=np.float64)
|
||||
except Exception:
|
||||
waves = np.array([], dtype=np.float64)
|
||||
return waves
|
||||
|
||||
|
||||
def resample_spectra_matrix(X, src_waves, dst_waves):
|
||||
src = np.asarray(src_waves, dtype=np.float64)
|
||||
dst = np.asarray(dst_waves, dtype=np.float64)
|
||||
X = np.asarray(X, dtype=np.float64)
|
||||
if src.size == 0 or dst.size == 0:
|
||||
return X
|
||||
# 线性插值,越界用端点外推,避免维度缺失
|
||||
out = np.empty((X.shape[0], dst.size), dtype=np.float64)
|
||||
for i in range(X.shape[0]):
|
||||
row = X[i]
|
||||
out[i] = np.interp(dst, src, row, left=row[0], right=row[-1])
|
||||
return out
|
||||
|
||||
|
||||
def read_hdr_file(bil_path):
|
||||
hdr_path = bil_path.replace('.bil', '.hdr')
|
||||
with open(hdr_path, 'r') as f:
|
||||
header = f.readlines()
|
||||
|
||||
samples, lines = None, None
|
||||
|
||||
for line in header:
|
||||
if line.startswith('samples'):
|
||||
samples = int(line.split('=')[-1].strip())
|
||||
if line.startswith('lines'):
|
||||
lines = int(line.split('=')[-1].strip())
|
||||
|
||||
return samples, lines
|
||||
|
||||
|
||||
def change_hdr_file(bil_path, wavelengths=None):
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
if not os.path.exists(hdr_path):
|
||||
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||||
return
|
||||
|
||||
# 仅在缺少 wavelength 字段时才尝试写入
|
||||
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file:
|
||||
content = file.read()
|
||||
|
||||
if 'wavelength' in content:
|
||||
print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。")
|
||||
return
|
||||
|
||||
if wavelengths is None or len(wavelengths) == 0:
|
||||
print("HDR 缺少 wavelength,但未提供 wavelengths,跳过写入以避免错误。")
|
||||
return
|
||||
|
||||
needs_newline = not content.endswith('\n')
|
||||
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
|
||||
|
||||
with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file:
|
||||
if needs_newline:
|
||||
file.write('\n')
|
||||
file.write(wavelength_info)
|
||||
|
||||
print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。")
|
||||
|
||||
|
||||
def validate_inputs(bil_path, output_dir):
|
||||
"""验证输入文件和参数"""
|
||||
# 检查BIL和HDR文件存在
|
||||
if not os.path.exists(bil_path):
|
||||
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
|
||||
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
if not os.path.exists(hdr_path):
|
||||
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
|
||||
|
||||
# 检查输出目录可写
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
|
||||
|
||||
# 检查BIL文件波段数是否足够
|
||||
try:
|
||||
from spectral.io import envi
|
||||
img = envi.open(hdr_path, bil_path)
|
||||
n_bands = img.nbands
|
||||
# bil2rgb需要波段索引9, 59, 159
|
||||
if n_bands <= 159:
|
||||
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段,但只有{n_bands}个")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
|
||||
|
||||
|
||||
def generate_rgb(bil_path):
|
||||
"""处理BIL文件生成RGB图像"""
|
||||
try:
|
||||
rgb_img = process_bil_files(bil_path)
|
||||
return rgb_img
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
|
||||
|
||||
|
||||
def run_segmentation(rgb_img, segmentation_model_path=None):
|
||||
"""运行分割获取掩膜"""
|
||||
try:
|
||||
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||||
image=rgb_img,
|
||||
filter_method='threshold',
|
||||
diameter=None,
|
||||
flow_threshold=0.4,
|
||||
cellprob_threshold=-1,
|
||||
model_path=segmentation_model_path,
|
||||
detect_filter=True
|
||||
)
|
||||
return mask, filter_mask_original
|
||||
except Exception as e:
|
||||
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
|
||||
|
||||
|
||||
def extract_primary_features(bil_path, mask):
|
||||
"""提取主要特征"""
|
||||
try:
|
||||
df = process_images(bil_path, mask)
|
||||
return df
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
|
||||
|
||||
|
||||
def compute_background_spectrum(bil_path, mask):
|
||||
"""计算背景光谱"""
|
||||
try:
|
||||
bg_spectrum = process_images_background(bil_path, mask)
|
||||
return bg_spectrum
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
|
||||
|
||||
|
||||
def apply_background_correction(df, bg_spectrum):
|
||||
"""应用背景校正,不进行重采样"""
|
||||
# 识别光谱列:所有以wavelength_开头的列
|
||||
spec_cols = [c for c in df.columns if isinstance(c, str) and c.startswith('wavelength_')]
|
||||
|
||||
if not spec_cols:
|
||||
raise ValueError("未找到光谱列(以wavelength_开头的列)")
|
||||
|
||||
# 创建原始光谱数据的副本
|
||||
df_original = df.copy()
|
||||
|
||||
# 背景校正:用背景光谱逐列相除
|
||||
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
|
||||
|
||||
# 尾部对齐,取最小长度,避免维度不一致
|
||||
n = min(len(spec_cols), bg.shape[0])
|
||||
use_cols = spec_cols[-n:]
|
||||
|
||||
df_corrected = df.copy()
|
||||
df_corrected.loc[:, use_cols] = df_corrected.loc[:, use_cols].div(bg[-n:], axis=1)
|
||||
|
||||
return df_corrected, df_original, bg
|
||||
|
||||
|
||||
def clean_and_select_columns(df):
|
||||
"""数据清理和列选择"""
|
||||
# 移除NaN值
|
||||
df = df.dropna()
|
||||
|
||||
# 过滤轮廓点数不足的样本
|
||||
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||
|
||||
# 过滤面积过小的样本
|
||||
df = df[df['area'] >= 500]
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def rename_wavelength_columns(df, prefix=''):
|
||||
"""
|
||||
将列名中的 'wavelength_' 前缀移除,替换为指定前缀或直接波长数值
|
||||
"""
|
||||
new_columns = {}
|
||||
for col in df.columns:
|
||||
if isinstance(col, str) and col.startswith('wavelength_'):
|
||||
# 提取波长数值
|
||||
wavelength_value = col.replace('wavelength_', '')
|
||||
new_columns[col] = wavelength_value
|
||||
if new_columns:
|
||||
df = df.rename(columns=new_columns)
|
||||
return df
|
||||
|
||||
|
||||
def save_spectra_to_csv(df_corrected, df_original, bg_spectrum, bil_path, output_dir, plastic_label,
|
||||
all_corrected_data, all_original_data, all_background_data):
|
||||
"""
|
||||
保存三种光谱数据为CSV文件
|
||||
- 背景校正后的光谱
|
||||
- 原始光谱
|
||||
- 背景光谱
|
||||
同时收集数据用于统一合并
|
||||
"""
|
||||
base_name = os.path.splitext(os.path.basename(bil_path))[0]
|
||||
|
||||
# 创建输出子目录
|
||||
corrected_dir = os.path.join(output_dir, 'corrected_spectra')
|
||||
original_dir = os.path.join(output_dir, 'original_spectra')
|
||||
background_dir = os.path.join(output_dir, 'background_spectra')
|
||||
|
||||
os.makedirs(corrected_dir, exist_ok=True)
|
||||
os.makedirs(original_dir, exist_ok=True)
|
||||
os.makedirs(background_dir, exist_ok=True)
|
||||
|
||||
# 获取光谱列
|
||||
spec_cols = [c for c in df_corrected.columns if isinstance(c, str) and c.startswith('wavelength_')]
|
||||
|
||||
# 移除波长列名中的 'wavelength_' 前缀
|
||||
df_corrected_renamed = rename_wavelength_columns(df_corrected.copy())
|
||||
df_original_renamed = rename_wavelength_columns(df_original.copy())
|
||||
|
||||
# 获取新的波长列名(已移除前缀)
|
||||
wavelength_cols = [c for c in df_corrected_renamed.columns if c not in
|
||||
[col for col in df_corrected.columns if isinstance(col, str) and not col.startswith('wavelength_')]]
|
||||
|
||||
# 保存背景校正后的光谱
|
||||
df_corrected_out = df_corrected_renamed.copy()
|
||||
if plastic_label is not None:
|
||||
if len(df_corrected_out) > 0:
|
||||
non_spec_cols = [c for c in df_corrected_out.columns if c not in wavelength_cols]
|
||||
if non_spec_cols:
|
||||
first_col = non_spec_cols[0]
|
||||
df_corrected_out[first_col] = plastic_label
|
||||
|
||||
# 添加文件名列用于区分来源
|
||||
df_corrected_out.insert(0, 'source_file', base_name)
|
||||
|
||||
corrected_path = os.path.join(corrected_dir, f"{base_name}_corrected.csv")
|
||||
df_corrected_out.to_csv(corrected_path, index=False)
|
||||
print(f" 背景校正光谱已保存: {corrected_path}")
|
||||
|
||||
# 收集到合并列表
|
||||
all_corrected_data.append(df_corrected_out)
|
||||
|
||||
# 保存原始光谱
|
||||
df_original_out = df_original_renamed.copy()
|
||||
if plastic_label is not None:
|
||||
if len(df_original_out) > 0:
|
||||
non_spec_cols = [c for c in df_original_out.columns if c not in wavelength_cols]
|
||||
if non_spec_cols:
|
||||
first_col = non_spec_cols[0]
|
||||
df_original_out[first_col] = plastic_label
|
||||
|
||||
# 添加文件名列用于区分来源
|
||||
df_original_out.insert(0, 'source_file', base_name)
|
||||
|
||||
original_path = os.path.join(original_dir, f"{base_name}_original.csv")
|
||||
df_original_out.to_csv(original_path, index=False)
|
||||
print(f" 原始光谱已保存: {original_path}")
|
||||
|
||||
# 收集到合并列表
|
||||
all_original_data.append(df_original_out)
|
||||
|
||||
# 保存背景光谱
|
||||
# 移除 'wavelength_' 前缀
|
||||
wavelength_names = [col.replace('wavelength_', '') for col in spec_cols[-len(bg_spectrum):]] if len(spec_cols) >= len(bg_spectrum) else [col.replace('wavelength_', '') for col in spec_cols]
|
||||
|
||||
bg_df = pd.DataFrame({
|
||||
'wavelength': wavelength_names,
|
||||
'background_value': bg_spectrum
|
||||
})
|
||||
bg_df.insert(0, 'source_file', base_name)
|
||||
bg_df.insert(1, 'plastic_type', plastic_label if plastic_label is not None else 'unknown')
|
||||
|
||||
background_path = os.path.join(background_dir, f"{base_name}_background.csv")
|
||||
bg_df.to_csv(background_path, index=False)
|
||||
print(f" 背景光谱已保存: {background_path}")
|
||||
|
||||
# 收集到合并列表
|
||||
all_background_data.append(bg_df)
|
||||
|
||||
|
||||
def save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir):
|
||||
"""
|
||||
将所有收集的数据合并保存为统一的CSV文件
|
||||
"""
|
||||
combined_dir = os.path.join(output_dir, 'combined')
|
||||
os.makedirs(combined_dir, exist_ok=True)
|
||||
|
||||
# 合并背景校正光谱
|
||||
if all_corrected_data:
|
||||
combined_corrected = pd.concat(all_corrected_data, ignore_index=True)
|
||||
corrected_combined_path = os.path.join(combined_dir, 'all_corrected_spectra.csv')
|
||||
combined_corrected.to_csv(corrected_combined_path, index=False)
|
||||
print(f"\n 合并背景校正光谱已保存: {corrected_combined_path}")
|
||||
print(f" 总行数: {len(combined_corrected)}")
|
||||
|
||||
# 合并原始光谱
|
||||
if all_original_data:
|
||||
combined_original = pd.concat(all_original_data, ignore_index=True)
|
||||
original_combined_path = os.path.join(combined_dir, 'all_original_spectra.csv')
|
||||
combined_original.to_csv(original_combined_path, index=False)
|
||||
print(f" 合并原始光谱已保存: {original_combined_path}")
|
||||
print(f" 总行数: {len(combined_original)}")
|
||||
|
||||
# 合并背景光谱
|
||||
if all_background_data:
|
||||
combined_background = pd.concat(all_background_data, ignore_index=True)
|
||||
background_combined_path = os.path.join(combined_dir, 'all_background_spectra.csv')
|
||||
combined_background.to_csv(background_combined_path, index=False)
|
||||
print(f" 合并背景光谱已保存: {background_combined_path}")
|
||||
print(f" 总行数: {len(combined_background)}")
|
||||
|
||||
|
||||
def process_single_bil(bil_path, output_dir, segmentation_model_path=None,
|
||||
all_corrected_data=None, all_original_data=None, all_background_data=None):
|
||||
"""处理单个BIL文件"""
|
||||
try:
|
||||
print(f"\n处理文件: {bil_path}")
|
||||
|
||||
# 从文件名获取微塑料标签
|
||||
filename = os.path.basename(bil_path)
|
||||
plastic_label = get_plastic_label_from_filename(filename)
|
||||
if plastic_label is not None:
|
||||
print(f" 检测到微塑料类型: {list(PLASTIC_TYPE_MAPPING.keys())[list(PLASTIC_TYPE_MAPPING.values()).index(plastic_label)]} -> {plastic_label}")
|
||||
else:
|
||||
print(f" 警告: 无法从文件名识别微塑料类型")
|
||||
|
||||
# 验证输入
|
||||
validate_inputs(bil_path, output_dir)
|
||||
bands = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
|
||||
|
||||
# 修改HDR文件
|
||||
change_hdr_file(bil_path, bands)
|
||||
|
||||
# 处理BIL文件生成RGB图像
|
||||
print(" 生成RGB图像...")
|
||||
rgb_img = generate_rgb(bil_path)
|
||||
|
||||
# 分割阶段
|
||||
print(" 生成掩膜...")
|
||||
mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path)
|
||||
|
||||
# 提取特征
|
||||
print(" 提取光谱特征...")
|
||||
df = extract_primary_features(bil_path, mask)
|
||||
|
||||
# 背景校正
|
||||
print(" 计算背景光谱并应用校正...")
|
||||
bg_spectrum = compute_background_spectrum(bil_path, mask)
|
||||
|
||||
# 应用背景校正(不进行重采样)
|
||||
df_corrected, df_original, bg_spectrum_aligned = apply_background_correction(df, bg_spectrum)
|
||||
|
||||
# 数据清理
|
||||
print(" 清理数据...")
|
||||
df_corrected = clean_and_select_columns(df_corrected)
|
||||
df_original = clean_and_select_columns(df_original)
|
||||
|
||||
# 保存三种光谱数据(同时收集到合并列表)
|
||||
save_spectra_to_csv(df_corrected, df_original, bg_spectrum_aligned, bil_path, output_dir, plastic_label,
|
||||
all_corrected_data, all_original_data, all_background_data)
|
||||
|
||||
print(f" 处理完成: {filename}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f" 处理失败: {bil_path} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='批量处理高光谱图像,提取并保存背景校正光谱、原始光谱和背景光谱'
|
||||
)
|
||||
|
||||
# 输入文件夹(包含BIL文件)
|
||||
parser.add_argument('--input_dir', required=True, help='包含输入BIL文件的文件夹路径')
|
||||
|
||||
# 输出文件夹
|
||||
parser.add_argument('--output_dir', required=True, help='保存CSV输出结果的文件夹路径')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
args = parse_arguments()
|
||||
|
||||
input_dir = args.input_dir
|
||||
output_dir = args.output_dir
|
||||
|
||||
# 记录总开始时间
|
||||
total_start_time = time.time()
|
||||
|
||||
# 检查输入目录
|
||||
if not os.path.exists(input_dir):
|
||||
print(f"错误: 输入目录不存在: {input_dir}")
|
||||
return
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 获取所有BIL文件
|
||||
bil_files = [f for f in os.listdir(input_dir) if f.endswith('.bil')]
|
||||
bil_files.sort()
|
||||
|
||||
if not bil_files:
|
||||
print(f"警告: 在 {input_dir} 中未找到BIL文件")
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"找到 {len(bil_files)} 个BIL文件需要处理")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 用于收集所有数据的列表
|
||||
all_corrected_data = []
|
||||
all_original_data = []
|
||||
all_background_data = []
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for bil_file in tqdm(bil_files, desc="处理进度", unit="文件"):
|
||||
bil_path = os.path.join(input_dir, bil_file)
|
||||
if process_single_bil(bil_path, output_dir,
|
||||
all_corrected_data=all_corrected_data,
|
||||
all_original_data=all_original_data,
|
||||
all_background_data=all_background_data):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
# 保存合并的CSV文件
|
||||
if all_corrected_data or all_original_data or all_background_data:
|
||||
print(f"\n{'=' * 60}")
|
||||
print("正在生成合并的CSV文件...")
|
||||
save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir)
|
||||
|
||||
# 计算总耗时
|
||||
total_time = time.time() - total_start_time
|
||||
|
||||
# 打印总结
|
||||
print(f"\n{'=' * 60}")
|
||||
print("处理完成总结")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"成功处理: {success_count} 个文件")
|
||||
print(f"失败: {fail_count} 个文件")
|
||||
print(f"总耗时: {total_time:.2f} 秒")
|
||||
print(f"平均每个文件: {total_time / len(bil_files):.2f} 秒")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"结果已保存至: {output_dir}")
|
||||
print(f" - 单独文件:")
|
||||
print(f" - 背景校正光谱: {os.path.join(output_dir, 'corrected_spectra')}")
|
||||
print(f" - 原始光谱: {os.path.join(output_dir, 'original_spectra')}")
|
||||
print(f" - 背景光谱: {os.path.join(output_dir, 'background_spectra')}")
|
||||
print(f" - 合并文件:")
|
||||
print(f" - 合并后的光谱数据: {os.path.join(output_dir, 'combined')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user