Files
micro_plastic/fliter_sample_spectral.py
2026-04-16 13:11:05 +08:00

353 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from bil2rgb import process_bil_files
from shape_spectral import process_images
import cv2
# from classification_model.Parallel.predict_plastic import predict_and_save # 本脚本不分类,可移除
import numpy as np
import os
import matplotlib
import pandas as pd
from shape_spectral_background import process_images_background
from mask import detect_microplastic_mask_from_array
import plantcv as pcv
# 直接复用 main.py 中的成熟实现,避免重复逻辑和不一致
from main import (
TRAIN_WAVELENGTHS,
read_wavelengths_from_hdr,
resample_spectra_matrix,
apply_background_no_resample,
change_hdr_file,
)
matplotlib.use('TkAgg')
#####用于提取背景滤纸的样本目的是在训练时加入滤纸的光谱以减少滤纸与ftpe的误判
def apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path):
# 先做背景校正(自动识别以 wavelength_ 或 band_ 开头的光谱列,且长度不一致时尾部对齐)
df = apply_background_no_resample(df, bg_spectrum)
# 再判断是否需要重采样到训练波长
src_waves = read_wavelengths_from_hdr(bil_path)
need_resample = (
src_waves.size > 0 and (
src_waves.size != len(TRAIN_WAVELENGTHS) or
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)
)
)
if not need_resample:
return df
# 识别光谱列并重采样
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
if not spec_cols:
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
X_src = df[spec_cols].to_numpy(dtype=np.float64)
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
# 用 band_{i} 替换光谱列,保持与 main.py 一致
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
df = pd.concat([
df.drop(columns=spec_cols),
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
], axis=1)
return df
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
# 先将 classification_result 中的 10 和 11 替换为 0
classification_result[(classification_result == 10)] = 0
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 11
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC,background2 }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path):
# 定义要追加的波长信息
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
# 将.bil路径转换为.hdr路径
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
# 检查.hdr文件是否存在
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 读取文件内容
with open(hdr_path, 'r') as file:
content = file.read()
# 检查是否已包含波长信息
if 'wavelength' in content:
print(f"文件 {os.path.basename(hdr_path)} 已包含波长信息,无需修改。")
return
# 检查文件是否以换行符结尾
needs_newline = not content.endswith('\n')
# 追加波长信息
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n') # 确保新内容从新行开始
file.write(wavelength_info + '\n')
print(f"已成功添加波长信息到文件: {os.path.basename(hdr_path)}")
def generate_new_mask(filter_mask_original, mask, num_masks=50, bil_path=None):
# 根据滤纸掩膜和微塑料掩膜生成新的掩膜在滤纸掩膜内塑料掩膜外随机位置生成大小为35*35大小的掩膜数量为50个
# filter_mask_original为滤纸掩膜mask为微塑料掩膜
# filter_mask_original为二值图像mask为16位的塑料标签掩膜
# 生成新的掩膜在滤纸掩膜内塑料掩膜外随机位置生成大小为35*35大小的掩膜数量为50个
# 确保mask是二值图像非零值表示塑料区域
mask_binary = (mask > 0).astype(np.uint8)
# 找到滤纸掩膜内且塑料掩膜外的区域(滤纸为真,塑料为假)
# filter_mask_original应该是二值图像非零表示滤纸区域
filter_mask_binary = (filter_mask_original > 0).astype(np.uint8)
valid_region = filter_mask_binary & (~mask_binary)
# 获取图像尺寸
height, width = valid_region.shape
mask_size = 35
half_size = mask_size // 2
# 初始化新的掩膜数组
new_mask_array = np.zeros((height, width), dtype=np.uint16)
# 找到所有可以放置35x35掩膜的有效中心点
# 确保掩膜完全在有效区域内
valid_centers = []
for y in range(half_size, height - half_size):
for x in range(half_size, width - half_size):
# 检查以(x, y)为中心大小为35x35的区域是否完全在有效区域内
y_start, y_end = y - half_size, y + half_size + 1
x_start, x_end = x - half_size, x + half_size + 1
region = valid_region[y_start:y_end, x_start:x_end]
if np.all(region > 0): # 整个35x35区域都在有效区域内
valid_centers.append((y, x))
# 如果有效中心点不足50个则使用所有可用的中心点
num_masks = min(num_masks, len(valid_centers))
if num_masks == 0:
print("Warning: No valid positions found for generating masks.")
return new_mask_array
# 随机选择50个或更少中心点
if len(valid_centers) > num_masks:
selected_centers = np.random.choice(len(valid_centers), size=num_masks, replace=False)
selected_centers = [valid_centers[i] for i in selected_centers]
else:
selected_centers = valid_centers
# 在每个选定的中心点生成35x35的掩膜
mask_value = 1 # 可以设置为不同的值来区分不同的掩膜
for y, x in selected_centers:
y_start, y_end = y - half_size, y + half_size + 1
x_start, x_end = x - half_size, x + half_size + 1
new_mask_array[y_start:y_end, x_start:x_end] = mask_value
mask_value += 1
print(f"Generated {len(selected_centers)} masks of size {mask_size}x{mask_size}")
# 保存滤纸掩膜塑料掩膜以及新生成的掩膜为不同颜色保存至同一个图片上保存至bil_path的masks文件夹下
if bil_path is not None:
# 确保输出目录存在
os.makedirs(os.path.join(os.path.dirname(bil_path), 'masks'), exist_ok=True)
# 创建RGB图像用于可视化黑色背景
height, width = filter_mask_original.shape
combined_visualization = np.zeros((height, width, 3), dtype=np.uint8)
# 滤纸掩膜用蓝色表示
combined_visualization[:, :, 2] = filter_mask_binary * 100 # B通道
# 塑料掩膜用红色表示
combined_visualization[:, :, 0] = mask_binary * 255 # R通道
# 新生成的掩膜用绿色表示
new_mask_binary = (new_mask_array > 0).astype(np.uint8)
combined_visualization[:, :, 1] = np.maximum(combined_visualization[:, :, 1], new_mask_binary * 255) # G通道
# 获取文件名(不含扩展名)
filename = os.path.splitext(os.path.basename(bil_path))[0]
output_path = os.path.join(os.path.join(os.path.dirname(bil_path), 'masks'),
f"{filename}_mask_visualization.png")
# 保存图像
cv2.imwrite(output_path, combined_visualization)
print(f"Saved mask visualization to: {output_path}")
# 合并掩膜:将新生成的掩膜和原塑料掩膜合并
# 新掩膜使用不同的标签值,避免与原掩膜冲突
combined_mask = mask.copy().astype(np.uint16)
# 将新掩膜添加到合并掩膜中使用较大的标签值如1000+
combined_mask[new_mask_array > 0] = new_mask_array[new_mask_array > 0] + 1000
return new_mask_array
def process_single_bil(bil_path, num_masks=50, rng_seed=None):
"""
处理单个BIL文件生成滤纸背景样本的光谱特征行并返回DataFrame
"""
try:
print(f"\n{'=' * 60}")
print(f"Processing: {os.path.basename(bil_path)}")
print(f"{'=' * 60}")
# 处理BIL文件生成RGB图像
print("Processing BIL file to generate RGB image...")
rgb_img = process_bil_files(bil_path)
# HDR仅在缺失 wavelength 时补齐;并尽量对齐训练相机波长(与 main.py 一致)
change_hdr_file(bil_path, TRAIN_WAVELENGTHS)
# 生成掩膜:返回塑料掩膜 + 滤纸掩膜
print("Generating mask...")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1,
model_path=None,
detect_filter=True
)
# 生成新的随机背景小块掩膜(在滤纸内且不与塑料重叠)
if rng_seed is not None:
np.random.seed(int(rng_seed))
new_mask_array = generate_new_mask(filter_mask_original, mask, num_masks=num_masks, bil_path=bil_path)
# 提取光谱与形状特征(仅限新背景小块)
print("Extracting features from BIL file...")
pcv.observations = {} # 清理plantcv状态
df = process_images(bil_path, new_mask_array)
# 背景校正(用整图的滤纸背景光谱作为除数)+ 可选重采样到训练相机波长
print("Applying background correction (+ optional resample)...")
bg_spectrum = process_images_background(bil_path, mask)
df = apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path)
# 数据清理去NA、轮廓点数不足、面积过小过滤与 main.py 对齐)
print("Cleaning data...")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 500]
# 添加文件名列(不含扩展名)
filename = os.path.splitext(os.path.basename(bil_path))[0]
df.insert(0, 'filename', filename)
print(f"Extracted {len(df)} background objects from {os.path.basename(bil_path)}")
return df
except Exception as e:
print(f"Error processing {bil_path}: {str(e)}")
import traceback
traceback.print_exc()
return None
def main():
# 支持文件或目录;新增可调样本数量与随机种子,便于复现
bil_path_or_folder = r"D:\Data\Traindata-11"
output_csv_path = r"E:\plastic\plastic\output\滤纸样本光谱\11.csv"
num_masks = 50
rng_seed = 42
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
if os.path.isfile(bil_path_or_folder):
bil_files = [bil_path_or_folder]
elif os.path.isdir(bil_path_or_folder):
bil_files = [os.path.join(bil_path_or_folder, f) for f in os.listdir(bil_path_or_folder) if f.endswith('.bil')]
print(f"Found {len(bil_files)} BIL files to process")
else:
print(f"Error: {bil_path_or_folder} is not a valid file or directory")
return
is_first_row = True
total_objects = 0
for i, bil_path in enumerate(bil_files, 1):
print(f"\n[{i}/{len(bil_files)}] Processing file...")
df = process_single_bil(bil_path, num_masks=num_masks, rng_seed=rng_seed)
if df is not None and len(df) > 0:
df.to_csv(
output_csv_path,
mode='a' if not is_first_row else 'w',
index=False,
header=is_first_row
)
total_objects += len(df)
is_first_row = False
print(f" -> {len(df)} objects appended to CSV file")
if total_objects > 0:
print(f"\nSummary:")
print(f" Total files processed: {len(bil_files)}")
print(f" Total background objects collected: {total_objects}")
print(f" Output file: {output_csv_path}")
else:
print("\nNo results to save.")
if __name__ == "__main__":
main()