update:更新readme

This commit is contained in:
2026-04-14 08:57:29 +08:00
parent 10fd2b00d4
commit 987be5ad9b
4 changed files with 718 additions and 119 deletions

View File

@ -1,7 +1,7 @@
from bil2rgb import process_bil_files
from shape_spectral import process_images
import cv2
from classification_model.Parallel.predict_plastic import predict_and_save
# from classification_model.Parallel.predict_plastic import predict_and_save # 本脚本不分类,可移除
import numpy as np
import os
import matplotlib
@ -10,9 +10,51 @@ from shape_spectral_background import process_images_background
from mask import detect_microplastic_mask_from_array
import plantcv as pcv
# 直接复用 main.py 中的成熟实现,避免重复逻辑和不一致
from main import (
TRAIN_WAVELENGTHS,
read_wavelengths_from_hdr,
resample_spectra_matrix,
apply_background_no_resample,
change_hdr_file,
)
matplotlib.use('TkAgg')
def apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path):
# 先做背景校正(自动识别以 wavelength_ 或 band_ 开头的光谱列,且长度不一致时尾部对齐)
df = apply_background_no_resample(df, bg_spectrum)
# 再判断是否需要重采样到训练波长
src_waves = read_wavelengths_from_hdr(bil_path)
need_resample = (
src_waves.size > 0 and (
src_waves.size != len(TRAIN_WAVELENGTHS) or
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)
)
)
if not need_resample:
return df
# 识别光谱列并重采样
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
if not spec_cols:
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
X_src = df[spec_cols].to_numpy(dtype=np.float64)
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
# 用 band_{i} 替换光谱列,保持与 main.py 一致
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
df = pd.concat([
df.drop(columns=spec_cols),
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
], axis=1)
return df
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
@ -199,9 +241,9 @@ def generate_new_mask(filter_mask_original, mask, num_masks=50, bil_path=None):
return new_mask_array
def process_single_bil(bil_path):
def process_single_bil(bil_path, num_masks=50, rng_seed=None):
"""
处理单个BIL文件
处理单个BIL文件生成滤纸背景样本的光谱特征行并返回DataFrame
"""
try:
print(f"\n{'=' * 60}")
@ -212,45 +254,47 @@ def process_single_bil(bil_path):
print("Processing BIL file to generate RGB image...")
rgb_img = process_bil_files(bil_path)
# 修改hdr
change_hdr_file(bil_path)
# HDR仅在缺失 wavelength 时补齐;并尽量对齐训练相机波长(与 main.py 一致)
change_hdr_file(bil_path, TRAIN_WAVELENGTHS)
# 生成掩膜mask为16位的塑料标签掩膜
# 生成掩膜:返回塑料掩膜 + 滤纸掩膜
print("Generating mask...")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=0.0
cellprob_threshold=-1,
model_path=None,
detect_filter=True
)
# 根据滤纸掩膜和微塑料掩膜生成新的掩膜在滤纸掩膜内塑料掩膜外随机位置生成大小为35*35大小的掩膜数量为50个
new_mask_array = generate_new_mask(filter_mask_original, mask)
# 生成新的随机背景小块掩膜在滤纸内且不与塑料重叠)
if rng_seed is not None:
np.random.seed(int(rng_seed))
new_mask_array = generate_new_mask(filter_mask_original, mask, num_masks=num_masks, bil_path=bil_path)
# 提取特征
# 提取光谱与形状特征(仅限新背景小块)
print("Extracting features from BIL file...")
# 清理plantcv的observations确保只包含当前处理的塑料掩膜数据
pcv.observations = {}
pcv.observations = {} # 清理plantcv状态
df = process_images(bil_path, new_mask_array)
# 背景校正
print("Applying background correction...")
df_correct = process_images_background(bil_path, mask)
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
# 背景校正(用整图的滤纸背景光谱作为除数)+ 可选重采样到训练相机波长
print("Applying background correction (+ optional resample)...")
bg_spectrum = process_images_background(bil_path, mask)
df = apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path)
# 数据清理
# 数据清理去NA、轮廓点数不足、面积过小过滤与 main.py 对齐)
print("Cleaning data...")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 400]
df = df[df['area'] >= 500]
# 添加文件名列(不含扩展名)
filename = os.path.splitext(os.path.basename(bil_path))[0]
df.insert(0, 'filename', filename)
print(f"Extracted {len(df)} objects from {os.path.basename(bil_path)}")
print(f"Extracted {len(df)} background objects from {os.path.basename(bil_path)}")
return df
except Exception as e:
@ -261,50 +305,45 @@ def process_single_bil(bil_path):
def main():
# 单个文件或文件夹路径
# 支持文件或目录;新增可调样本数量与随机种子,便于复现
bil_path_or_folder = r"D:\Data\Traindata-11"
output_csv_path = r"E:\plastic\plastic\output\滤纸样本光谱\11.csv"
num_masks = 50
rng_seed = 42
# 确保输出目录存在
output_dir = os.path.dirname(output_csv_path)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
# 判断是文件还是文件夹
if os.path.isfile(bil_path_or_folder):
bil_files = [bil_path_or_folder]
elif os.path.isdir(bil_path_or_folder):
# 搜索所有.bil文件
bil_files = [os.path.join(bil_path_or_folder, f) for f in os.listdir(bil_path_or_folder) if f.endswith('.bil')]
print(f"Found {len(bil_files)} BIL files to process")
else:
print(f"Error: {bil_path_or_folder} is not a valid file or directory")
return
# 初始化CSV文件写入表头
is_first_row = True
total_objects = 0
for i, bil_path in enumerate(bil_files, 1):
print(f"\n[{i}/{len(bil_files)}] Processing file...")
df = process_single_bil(bil_path)
df = process_single_bil(bil_path, num_masks=num_masks, rng_seed=rng_seed)
if df is not None and len(df) > 0:
# 边处理边写入CSV
df.to_csv(
output_csv_path,
mode='a' if not is_first_row else 'w', # 第一行写入模式为'w',后续追加'w'
mode='a' if not is_first_row else 'w',
index=False,
header=is_first_row # 只在第一行写入表头
header=is_first_row
)
total_objects += len(df)
is_first_row = False
print(f" -> {len(df)} objects appended to CSV file")
# 显示统计信息
if total_objects > 0:
print(f"\nSummary:")
print(f" Total files processed: {len(bil_files)}")
print(f" Total objects detected: {total_objects}")
print(f" Total background objects collected: {total_objects}")
print(f" Output file: {output_csv_path}")
else:
print("\nNo results to save.")