Files
micro_plastic/train_sample.py
2026-04-16 13:11:05 +08:00

518 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import matplotlib
import numpy as np
import argparse
import pandas as pd
from bil2rgb import process_bil_files
from mask import detect_microplastic_mask_from_array
from shape_spectral import process_images
from shape_spectral_background import process_images_background
import time
from tqdm import tqdm
matplotlib.use('TkAgg')
# 训练相机波长237通道
TRAIN_WAVELENGTHS = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
# 微塑料类型映射
PLASTIC_TYPE_MAPPING = {
'ABS': 0,
'HDPE': 1,
'LDPE': 2,
'PA6': 3,
'PET': 4,
'PP': 5,
'PS': 6,
'PTFE': 7,
'PVC': 8,
}
def get_plastic_label_from_filename(filename):
"""
从文件名中提取微塑料类型并返回对应的数字标签
只要文件名中包含微塑料名称就返回对应的数字
"""
filename_upper = filename.upper()
for plastic_type, label in PLASTIC_TYPE_MAPPING.items():
if plastic_type in filename_upper:
return label
return None
def read_wavelengths_from_hdr(bil_path):
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
return np.array([], dtype=np.float64)
with open(hdr_path, 'r') as f:
txt = f.read()
if 'wavelength' not in txt:
return np.array([], dtype=np.float64)
seg = txt.split('wavelength', 1)[1]
seg = seg[seg.find('{')+1: seg.find('}')]
vals = [v.strip() for v in seg.split(',') if v.strip()]
try:
waves = np.array([float(v) for v in vals], dtype=np.float64)
except Exception:
waves = np.array([], dtype=np.float64)
return waves
def resample_spectra_matrix(X, src_waves, dst_waves):
src = np.asarray(src_waves, dtype=np.float64)
dst = np.asarray(dst_waves, dtype=np.float64)
X = np.asarray(X, dtype=np.float64)
if src.size == 0 or dst.size == 0:
return X
# 线性插值,越界用端点外推,避免维度缺失
out = np.empty((X.shape[0], dst.size), dtype=np.float64)
for i in range(X.shape[0]):
row = X[i]
out[i] = np.interp(dst, src, row, left=row[0], right=row[-1])
return out
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def change_hdr_file(bil_path, wavelengths=None):
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 仅在缺少 wavelength 字段时才尝试写入
with open(hdr_path, 'r', encoding='utf-8', errors='ignore') as file:
content = file.read()
if 'wavelength' in content:
print(f"{os.path.basename(hdr_path)} 已包含 wavelength 字段,跳过追加。")
return
if wavelengths is None or len(wavelengths) == 0:
print("HDR 缺少 wavelength但未提供 wavelengths跳过写入以避免错误。")
return
needs_newline = not content.endswith('\n')
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
with open(hdr_path, 'a', encoding='utf-8', errors='ignore') as file:
if needs_newline:
file.write('\n')
file.write(wavelength_info)
print(f"已在 {os.path.basename(hdr_path)} 末尾追加 wavelength 字段。")
def validate_inputs(bil_path, output_dir):
"""验证输入文件和参数"""
# 检查BIL和HDR文件存在
if not os.path.exists(bil_path):
raise FileNotFoundError(f"BIL文件不存在: {bil_path}")
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
raise FileNotFoundError(f"HDR文件不存在: {hdr_path}")
# 检查输出目录可写
if output_dir and not os.path.exists(output_dir):
try:
os.makedirs(output_dir, exist_ok=True)
except Exception as e:
raise RuntimeError(f"无法创建输出目录: {output_dir}") from e
# 检查BIL文件波段数是否足够
try:
from spectral.io import envi
img = envi.open(hdr_path, bil_path)
n_bands = img.nbands
# bil2rgb需要波段索引9, 59, 159
if n_bands <= 159:
raise ValueError(f"BIL文件波段数不足: 需要至少160个波段但只有{n_bands}")
except Exception as e:
raise RuntimeError(f"无法读取BIL文件头信息: {bil_path}") from e
def generate_rgb(bil_path):
"""处理BIL文件生成RGB图像"""
try:
rgb_img = process_bil_files(bil_path)
return rgb_img
except Exception as e:
raise RuntimeError(f"生成RGB图像失败: bil_path={bil_path}") from e
def run_segmentation(rgb_img, segmentation_model_path=None):
"""运行分割获取掩膜"""
try:
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1,
model_path=segmentation_model_path,
detect_filter=True
)
return mask, filter_mask_original
except Exception as e:
raise RuntimeError("分割失败: 无法检测微塑料颗粒") from e
def extract_primary_features(bil_path, mask):
"""提取主要特征"""
try:
df = process_images(bil_path, mask)
return df
except Exception as e:
raise RuntimeError(f"特征提取失败: bil_path={bil_path}") from e
def compute_background_spectrum(bil_path, mask):
"""计算背景光谱"""
try:
bg_spectrum = process_images_background(bil_path, mask)
return bg_spectrum
except Exception as e:
raise RuntimeError(f"背景光谱计算失败: bil_path={bil_path}") from e
def apply_background_correction(df, bg_spectrum):
"""应用背景校正,不进行重采样"""
# 识别光谱列所有以wavelength_开头的列
spec_cols = [c for c in df.columns if isinstance(c, str) and c.startswith('wavelength_')]
if not spec_cols:
raise ValueError("未找到光谱列以wavelength_开头的列")
# 创建原始光谱数据的副本
df_original = df.copy()
# 背景校正:用背景光谱逐列相除
bg = np.asarray(bg_spectrum, dtype=np.float64).ravel()
# 尾部对齐,取最小长度,避免维度不一致
n = min(len(spec_cols), bg.shape[0])
use_cols = spec_cols[-n:]
df_corrected = df.copy()
df_corrected.loc[:, use_cols] = df_corrected.loc[:, use_cols].div(bg[-n:], axis=1)
return df_corrected, df_original, bg
def clean_and_select_columns(df):
"""数据清理和列选择"""
# 移除NaN值
df = df.dropna()
# 过滤轮廓点数不足的样本
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
# 过滤面积过小的样本
df = df[df['area'] >= 500]
return df
def rename_wavelength_columns(df, prefix=''):
"""
将列名中的 'wavelength_' 前缀移除,替换为指定前缀或直接波长数值
"""
new_columns = {}
for col in df.columns:
if isinstance(col, str) and col.startswith('wavelength_'):
# 提取波长数值
wavelength_value = col.replace('wavelength_', '')
new_columns[col] = wavelength_value
if new_columns:
df = df.rename(columns=new_columns)
return df
def save_spectra_to_csv(df_corrected, df_original, bg_spectrum, bil_path, output_dir, plastic_label,
all_corrected_data, all_original_data, all_background_data):
"""
保存三种光谱数据为CSV文件
- 背景校正后的光谱
- 原始光谱
- 背景光谱
同时收集数据用于统一合并
"""
base_name = os.path.splitext(os.path.basename(bil_path))[0]
# 创建输出子目录
corrected_dir = os.path.join(output_dir, 'corrected_spectra')
original_dir = os.path.join(output_dir, 'original_spectra')
background_dir = os.path.join(output_dir, 'background_spectra')
os.makedirs(corrected_dir, exist_ok=True)
os.makedirs(original_dir, exist_ok=True)
os.makedirs(background_dir, exist_ok=True)
# 获取光谱列
spec_cols = [c for c in df_corrected.columns if isinstance(c, str) and c.startswith('wavelength_')]
# 移除波长列名中的 'wavelength_' 前缀
df_corrected_renamed = rename_wavelength_columns(df_corrected.copy())
df_original_renamed = rename_wavelength_columns(df_original.copy())
# 获取新的波长列名(已移除前缀)
wavelength_cols = [c for c in df_corrected_renamed.columns if c not in
[col for col in df_corrected.columns if isinstance(col, str) and not col.startswith('wavelength_')]]
# 保存背景校正后的光谱
df_corrected_out = df_corrected_renamed.copy()
if plastic_label is not None:
if len(df_corrected_out) > 0:
non_spec_cols = [c for c in df_corrected_out.columns if c not in wavelength_cols]
if non_spec_cols:
first_col = non_spec_cols[0]
df_corrected_out[first_col] = plastic_label
# 添加文件名列用于区分来源
df_corrected_out.insert(0, 'source_file', base_name)
corrected_path = os.path.join(corrected_dir, f"{base_name}_corrected.csv")
df_corrected_out.to_csv(corrected_path, index=False)
print(f" 背景校正光谱已保存: {corrected_path}")
# 收集到合并列表
all_corrected_data.append(df_corrected_out)
# 保存原始光谱
df_original_out = df_original_renamed.copy()
if plastic_label is not None:
if len(df_original_out) > 0:
non_spec_cols = [c for c in df_original_out.columns if c not in wavelength_cols]
if non_spec_cols:
first_col = non_spec_cols[0]
df_original_out[first_col] = plastic_label
# 添加文件名列用于区分来源
df_original_out.insert(0, 'source_file', base_name)
original_path = os.path.join(original_dir, f"{base_name}_original.csv")
df_original_out.to_csv(original_path, index=False)
print(f" 原始光谱已保存: {original_path}")
# 收集到合并列表
all_original_data.append(df_original_out)
# 保存背景光谱
# 移除 'wavelength_' 前缀
wavelength_names = [col.replace('wavelength_', '') for col in spec_cols[-len(bg_spectrum):]] if len(spec_cols) >= len(bg_spectrum) else [col.replace('wavelength_', '') for col in spec_cols]
bg_df = pd.DataFrame({
'wavelength': wavelength_names,
'background_value': bg_spectrum
})
bg_df.insert(0, 'source_file', base_name)
bg_df.insert(1, 'plastic_type', plastic_label if plastic_label is not None else 'unknown')
background_path = os.path.join(background_dir, f"{base_name}_background.csv")
bg_df.to_csv(background_path, index=False)
print(f" 背景光谱已保存: {background_path}")
# 收集到合并列表
all_background_data.append(bg_df)
def save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir):
"""
将所有收集的数据合并保存为统一的CSV文件
"""
combined_dir = os.path.join(output_dir, 'combined')
os.makedirs(combined_dir, exist_ok=True)
# 合并背景校正光谱
if all_corrected_data:
combined_corrected = pd.concat(all_corrected_data, ignore_index=True)
corrected_combined_path = os.path.join(combined_dir, 'all_corrected_spectra.csv')
combined_corrected.to_csv(corrected_combined_path, index=False)
print(f"\n 合并背景校正光谱已保存: {corrected_combined_path}")
print(f" 总行数: {len(combined_corrected)}")
# 合并原始光谱
if all_original_data:
combined_original = pd.concat(all_original_data, ignore_index=True)
original_combined_path = os.path.join(combined_dir, 'all_original_spectra.csv')
combined_original.to_csv(original_combined_path, index=False)
print(f" 合并原始光谱已保存: {original_combined_path}")
print(f" 总行数: {len(combined_original)}")
# 合并背景光谱
if all_background_data:
combined_background = pd.concat(all_background_data, ignore_index=True)
background_combined_path = os.path.join(combined_dir, 'all_background_spectra.csv')
combined_background.to_csv(background_combined_path, index=False)
print(f" 合并背景光谱已保存: {background_combined_path}")
print(f" 总行数: {len(combined_background)}")
def process_single_bil(bil_path, output_dir, segmentation_model_path=None,
all_corrected_data=None, all_original_data=None, all_background_data=None):
"""处理单个BIL文件"""
try:
print(f"\n处理文件: {bil_path}")
# 从文件名获取微塑料标签
filename = os.path.basename(bil_path)
plastic_label = get_plastic_label_from_filename(filename)
if plastic_label is not None:
print(f" 检测到微塑料类型: {list(PLASTIC_TYPE_MAPPING.keys())[list(PLASTIC_TYPE_MAPPING.values()).index(plastic_label)]} -> {plastic_label}")
else:
print(f" 警告: 无法从文件名识别微塑料类型")
# 验证输入
validate_inputs(bil_path, output_dir)
bands = [912.36, 915.68, 919, 922.31, 925.63, 928.95, 932.27, 935.59, 938.91, 942.23, 945.55, 948.87, 952.18, 955.5, 958.82, 962.14, 965.46, 968.78, 972.1, 975.42, 978.74, 982.06, 985.38, 988.7, 992.02, 995.34, 998.65, 1002, 1005.3, 1008.6, 1011.9, 1015.3, 1018.6, 1021.9, 1025.2, 1028.5, 1031.9, 1035.2, 1038.5, 1041.8, 1045.1, 1048.5, 1051.8, 1055.1, 1058.4, 1061.7, 1065.1, 1068.4, 1071.7, 1075, 1078.3, 1081.7, 1085, 1088.3, 1091.6, 1094.9, 1098.3, 1101.6, 1104.9, 1108.2, 1111.5, 1114.9, 1118.2, 1121.5, 1124.8, 1128.1, 1131.5, 1134.8, 1138.1, 1141.4, 1144.8, 1148.1, 1151.4, 1154.7, 1158, 1161.4, 1164.7, 1168, 1171.3, 1174.6, 1178, 1181.3, 1184.6, 1187.9, 1191.3, 1194.6, 1197.9, 1201.2, 1204.5, 1207.9, 1211.2, 1214.5, 1217.8, 1221.2, 1224.5, 1227.8, 1231.1, 1234.4, 1237.8, 1241.1, 1244.4, 1247.7, 1251.1, 1254.4, 1257.7, 1261, 1264.3, 1267.7, 1271, 1274.3, 1277.6, 1281, 1284.3, 1287.6, 1290.9, 1294.2, 1297.6, 1300.9, 1304.2, 1307.5, 1310.9, 1314.2, 1317.5, 1320.8, 1324.2, 1327.5, 1330.8, 1334.1, 1337.4, 1340.8, 1344.1, 1347.4, 1350.7, 1354.1, 1357.4, 1360.7, 1364, 1367.4, 1370.7, 1374, 1377.3, 1380.7, 1384, 1387.3, 1390.6, 1394, 1397.3, 1400.6, 1403.9, 1407.2, 1410.6, 1413.9, 1417.2, 1420.5, 1423.9, 1427.2, 1430.5, 1433.8, 1437.2, 1440.5, 1443.8, 1447.1, 1450.5, 1453.8, 1457.1, 1460.4, 1463.8, 1467.1, 1470.4, 1473.7, 1477.1, 1480.4, 1483.7, 1487, 1490.4, 1493.7, 1497, 1500.3, 1503.7, 1507, 1510.3, 1513.6, 1517, 1520.3, 1523.6, 1526.9, 1530.3, 1533.6, 1536.9, 1540.2, 1543.6, 1546.9, 1550.2, 1553.6, 1556.9, 1560.2, 1563.5, 1566.9, 1570.2, 1573.5, 1576.8, 1580.2, 1583.5, 1586.8, 1590.1, 1593.5, 1596.8, 1600.1, 1603.4, 1606.8, 1610.1, 1613.4, 1616.7, 1620.1, 1623.4, 1626.7, 1630.1, 1633.4, 1636.7, 1640, 1643.4, 1646.7, 1650, 1653.3, 1656.7, 1660, 1663.3, 1666.7, 1670, 1673.3, 1676.6, 1680, 1683.3, 1686.6, 1689.9, 1693.3, 1696.6, 1699.9, 1703.3, 1706.6]
# 修改HDR文件
change_hdr_file(bil_path, bands)
# 处理BIL文件生成RGB图像
print(" 生成RGB图像...")
rgb_img = generate_rgb(bil_path)
# 分割阶段
print(" 生成掩膜...")
mask, filter_mask_original = run_segmentation(rgb_img, segmentation_model_path)
# 提取特征
print(" 提取光谱特征...")
df = extract_primary_features(bil_path, mask)
# 背景校正
print(" 计算背景光谱并应用校正...")
bg_spectrum = compute_background_spectrum(bil_path, mask)
# 应用背景校正(不进行重采样)
df_corrected, df_original, bg_spectrum_aligned = apply_background_correction(df, bg_spectrum)
# 数据清理
print(" 清理数据...")
df_corrected = clean_and_select_columns(df_corrected)
df_original = clean_and_select_columns(df_original)
# 保存三种光谱数据(同时收集到合并列表)
save_spectra_to_csv(df_corrected, df_original, bg_spectrum_aligned, bil_path, output_dir, plastic_label,
all_corrected_data, all_original_data, all_background_data)
print(f" 处理完成: {filename}")
return True
except Exception as e:
print(f" 处理失败: {bil_path} - {e}")
return False
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description='批量处理高光谱图像,提取并保存背景校正光谱、原始光谱和背景光谱'
)
# 输入文件夹包含BIL文件
parser.add_argument('--input_dir', required=True, help='包含输入BIL文件的文件夹路径')
# 输出文件夹
parser.add_argument('--output_dir', required=True, help='保存CSV输出结果的文件夹路径')
return parser.parse_args()
def main():
"""主函数"""
args = parse_arguments()
input_dir = args.input_dir
output_dir = args.output_dir
# 记录总开始时间
total_start_time = time.time()
# 检查输入目录
if not os.path.exists(input_dir):
print(f"错误: 输入目录不存在: {input_dir}")
return
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取所有BIL文件
bil_files = [f for f in os.listdir(input_dir) if f.endswith('.bil')]
bil_files.sort()
if not bil_files:
print(f"警告: 在 {input_dir} 中未找到BIL文件")
return
print(f"\n{'=' * 60}")
print(f"找到 {len(bil_files)} 个BIL文件需要处理")
print(f"{'=' * 60}")
# 用于收集所有数据的列表
all_corrected_data = []
all_original_data = []
all_background_data = []
# 使用tqdm显示进度条
success_count = 0
fail_count = 0
for bil_file in tqdm(bil_files, desc="处理进度", unit="文件"):
bil_path = os.path.join(input_dir, bil_file)
if process_single_bil(bil_path, output_dir,
all_corrected_data=all_corrected_data,
all_original_data=all_original_data,
all_background_data=all_background_data):
success_count += 1
else:
fail_count += 1
# 保存合并的CSV文件
if all_corrected_data or all_original_data or all_background_data:
print(f"\n{'=' * 60}")
print("正在生成合并的CSV文件...")
save_combined_csv(all_corrected_data, all_original_data, all_background_data, output_dir)
# 计算总耗时
total_time = time.time() - total_start_time
# 打印总结
print(f"\n{'=' * 60}")
print("处理完成总结")
print(f"{'=' * 60}")
print(f"成功处理: {success_count} 个文件")
print(f"失败: {fail_count} 个文件")
print(f"总耗时: {total_time:.2f}")
print(f"平均每个文件: {total_time / len(bil_files):.2f}")
print(f"{'=' * 60}")
print(f"结果已保存至: {output_dir}")
print(f" - 单独文件:")
print(f" - 背景校正光谱: {os.path.join(output_dir, 'corrected_spectra')}")
print(f" - 原始光谱: {os.path.join(output_dir, 'original_spectra')}")
print(f" - 背景光谱: {os.path.join(output_dir, 'background_spectra')}")
print(f" - 合并文件:")
print(f" - 合并后的光谱数据: {os.path.join(output_dir, 'combined')}")
if __name__ == "__main__":
main()