update:更新readme
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from bil2rgb import process_bil_files
|
||||
from shape_spectral import process_images
|
||||
import cv2
|
||||
from classification_model.Parallel.predict_plastic import predict_and_save
|
||||
# from classification_model.Parallel.predict_plastic import predict_and_save # 本脚本不分类,可移除
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib
|
||||
@ -10,9 +10,51 @@ from shape_spectral_background import process_images_background
|
||||
from mask import detect_microplastic_mask_from_array
|
||||
import plantcv as pcv
|
||||
|
||||
# 直接复用 main.py 中的成熟实现,避免重复逻辑和不一致
|
||||
from main import (
|
||||
TRAIN_WAVELENGTHS,
|
||||
read_wavelengths_from_hdr,
|
||||
resample_spectra_matrix,
|
||||
apply_background_no_resample,
|
||||
change_hdr_file,
|
||||
)
|
||||
|
||||
matplotlib.use('TkAgg')
|
||||
|
||||
|
||||
def apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path):
|
||||
# 先做背景校正(自动识别以 wavelength_ 或 band_ 开头的光谱列,且长度不一致时尾部对齐)
|
||||
df = apply_background_no_resample(df, bg_spectrum)
|
||||
|
||||
# 再判断是否需要重采样到训练波长
|
||||
src_waves = read_wavelengths_from_hdr(bil_path)
|
||||
need_resample = (
|
||||
src_waves.size > 0 and (
|
||||
src_waves.size != len(TRAIN_WAVELENGTHS) or
|
||||
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)
|
||||
)
|
||||
)
|
||||
|
||||
if not need_resample:
|
||||
return df
|
||||
|
||||
# 识别光谱列并重采样
|
||||
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
|
||||
if not spec_cols:
|
||||
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
|
||||
|
||||
X_src = df[spec_cols].to_numpy(dtype=np.float64)
|
||||
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
|
||||
|
||||
# 用 band_{i} 替换光谱列,保持与 main.py 一致
|
||||
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
|
||||
df = pd.concat([
|
||||
df.drop(columns=spec_cols),
|
||||
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
|
||||
], axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def read_hdr_file(bil_path):
|
||||
hdr_path = bil_path.replace('.bil', '.hdr')
|
||||
with open(hdr_path, 'r') as f:
|
||||
@ -199,9 +241,9 @@ def generate_new_mask(filter_mask_original, mask, num_masks=50, bil_path=None):
|
||||
return new_mask_array
|
||||
|
||||
|
||||
def process_single_bil(bil_path):
|
||||
def process_single_bil(bil_path, num_masks=50, rng_seed=None):
|
||||
"""
|
||||
处理单个BIL文件
|
||||
处理单个BIL文件,生成滤纸背景样本的光谱特征行,并返回DataFrame
|
||||
"""
|
||||
try:
|
||||
print(f"\n{'=' * 60}")
|
||||
@ -212,45 +254,47 @@ def process_single_bil(bil_path):
|
||||
print("Processing BIL file to generate RGB image...")
|
||||
rgb_img = process_bil_files(bil_path)
|
||||
|
||||
# 修改hdr
|
||||
change_hdr_file(bil_path)
|
||||
# HDR:仅在缺失 wavelength 时补齐;并尽量对齐训练相机波长(与 main.py 一致)
|
||||
change_hdr_file(bil_path, TRAIN_WAVELENGTHS)
|
||||
|
||||
# 生成掩膜,mask为16位的塑料标签掩膜
|
||||
# 生成掩膜:返回塑料掩膜 + 滤纸掩膜
|
||||
print("Generating mask...")
|
||||
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||||
image=rgb_img,
|
||||
filter_method='threshold',
|
||||
diameter=None,
|
||||
flow_threshold=0.4,
|
||||
cellprob_threshold=0.0
|
||||
cellprob_threshold=-1,
|
||||
model_path=None,
|
||||
detect_filter=True
|
||||
)
|
||||
|
||||
# 根据滤纸掩膜和微塑料掩膜生成新的掩膜,在滤纸掩膜内塑料掩膜外,随机位置生成大小为35*35大小的掩膜,数量为50个
|
||||
new_mask_array = generate_new_mask(filter_mask_original, mask)
|
||||
# 生成新的随机背景小块掩膜(在滤纸内且不与塑料重叠)
|
||||
if rng_seed is not None:
|
||||
np.random.seed(int(rng_seed))
|
||||
new_mask_array = generate_new_mask(filter_mask_original, mask, num_masks=num_masks, bil_path=bil_path)
|
||||
|
||||
# 提取特征
|
||||
# 提取光谱与形状特征(仅限新背景小块)
|
||||
print("Extracting features from BIL file...")
|
||||
# 清理plantcv的observations,确保只包含当前处理的塑料掩膜数据
|
||||
pcv.observations = {}
|
||||
pcv.observations = {} # 清理plantcv状态
|
||||
df = process_images(bil_path, new_mask_array)
|
||||
|
||||
# 背景校正
|
||||
print("Applying background correction...")
|
||||
df_correct = process_images_background(bil_path, mask)
|
||||
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
|
||||
# 背景校正(用整图的滤纸背景光谱作为除数)+ 可选重采样到训练相机波长
|
||||
print("Applying background correction (+ optional resample)...")
|
||||
bg_spectrum = process_images_background(bil_path, mask)
|
||||
df = apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path)
|
||||
|
||||
# 数据清理
|
||||
# 数据清理:去NA、轮廓点数不足、面积过小过滤(与 main.py 对齐)
|
||||
print("Cleaning data...")
|
||||
df = df.dropna()
|
||||
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||
df = df[df['area'] >= 400]
|
||||
df = df[df['area'] >= 500]
|
||||
|
||||
# 添加文件名列(不含扩展名)
|
||||
filename = os.path.splitext(os.path.basename(bil_path))[0]
|
||||
df.insert(0, 'filename', filename)
|
||||
|
||||
print(f"Extracted {len(df)} objects from {os.path.basename(bil_path)}")
|
||||
|
||||
print(f"Extracted {len(df)} background objects from {os.path.basename(bil_path)}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
@ -261,50 +305,45 @@ def process_single_bil(bil_path):
|
||||
|
||||
|
||||
def main():
|
||||
# 单个文件或文件夹路径
|
||||
# 支持文件或目录;新增可调样本数量与随机种子,便于复现
|
||||
bil_path_or_folder = r"D:\Data\Traindata-11"
|
||||
output_csv_path = r"E:\plastic\plastic\output\滤纸样本光谱\11.csv"
|
||||
num_masks = 50
|
||||
rng_seed = 42
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_csv_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
|
||||
|
||||
# 判断是文件还是文件夹
|
||||
if os.path.isfile(bil_path_or_folder):
|
||||
bil_files = [bil_path_or_folder]
|
||||
elif os.path.isdir(bil_path_or_folder):
|
||||
# 搜索所有.bil文件
|
||||
bil_files = [os.path.join(bil_path_or_folder, f) for f in os.listdir(bil_path_or_folder) if f.endswith('.bil')]
|
||||
print(f"Found {len(bil_files)} BIL files to process")
|
||||
else:
|
||||
print(f"Error: {bil_path_or_folder} is not a valid file or directory")
|
||||
return
|
||||
|
||||
# 初始化CSV文件(写入表头)
|
||||
is_first_row = True
|
||||
total_objects = 0
|
||||
|
||||
for i, bil_path in enumerate(bil_files, 1):
|
||||
print(f"\n[{i}/{len(bil_files)}] Processing file...")
|
||||
df = process_single_bil(bil_path)
|
||||
df = process_single_bil(bil_path, num_masks=num_masks, rng_seed=rng_seed)
|
||||
|
||||
if df is not None and len(df) > 0:
|
||||
# 边处理边写入CSV
|
||||
df.to_csv(
|
||||
output_csv_path,
|
||||
mode='a' if not is_first_row else 'w', # 第一行写入模式为'w',后续追加'w'
|
||||
mode='a' if not is_first_row else 'w',
|
||||
index=False,
|
||||
header=is_first_row # 只在第一行写入表头
|
||||
header=is_first_row
|
||||
)
|
||||
total_objects += len(df)
|
||||
is_first_row = False
|
||||
print(f" -> {len(df)} objects appended to CSV file")
|
||||
|
||||
# 显示统计信息
|
||||
if total_objects > 0:
|
||||
print(f"\nSummary:")
|
||||
print(f" Total files processed: {len(bil_files)}")
|
||||
print(f" Total objects detected: {total_objects}")
|
||||
print(f" Total background objects collected: {total_objects}")
|
||||
print(f" Output file: {output_csv_path}")
|
||||
else:
|
||||
print("\nNo results to save.")
|
||||
|
||||
Reference in New Issue
Block a user