Files
micro_plastic/多模型.py
2026-02-25 09:42:51 +08:00

447 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import matplotlib
import numpy as np
from bil2rgb import process_bil_files
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
from mask import detect_microplastic_mask_from_array
from shape_spectral import process_images
from shape_spectral_background import process_images_background
from get_glcm import calcu_glcm, calcu_glcm_variance
matplotlib.use('TkAgg')
# ----------------------------
# 配置参数:直接在此修改
# ----------------------------
BIL_PATH = r"D:/Data/Test/PET_bottle2.bil"
OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil'
PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
PRIMARY_MODEL_TYPE = 'SVM'
PRIMARY_PROCESS_METHODS1 = 'SS'
PRIMARY_PROCESS_METHODS2 = 'None'
SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
SECONDARY_MODEL_TYPE = 'SVM'
SECONDARY_PROCESS_METHODS1 = 'None'
SECONDARY_PROCESS_METHODS2 = 'None'
SECONDARY_TARGET_CLASSES = [1, 2]
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
# 先将 classification_result 中的 10 和 11 替换为 0
# classification_result[(classification_result == 10)] = 0
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 11
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC,background2 }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path):
# 定义要追加的波长信息
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
# 将.bil路径转换为.hdr路径
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
# 检查.hdr文件是否存在
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 读取文件内容
with open(hdr_path, 'r') as file:
content = file.read()
# 检查是否已包含波长信息
if 'wavelength' in content:
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
return
# 检查文件是否以换行符结尾
needs_newline = not content.endswith('\n')
# 追加波长信息
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n') # 确保新内容从新行开始
file.write(wavelength_info + '\n')
print(f"Successfully added wavelength information to file: {os.path.basename(hdr_path)}")
def main():
bil_path = BIL_PATH
output_path = OUTPUT_PATH
model_path = PRIMARY_MODEL_PATH
# 处理BIL文件生成RGB图像
print("Processing BIL file to generate RGB image...\n")
rgb_img = process_bil_files(bil_path)
# 修改hdr
change_hdr_file(bil_path)
# 生成掩膜mask为16位的塑料标签掩膜
print("Generating mask ...\n")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img, # 直接传入cv2.imread的结果
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1
)
# 提取特征
print("Extracting features from BIL file...\n")
df = process_images(bil_path, mask)
# 背景校正
print("Applying background correction...\n")
df_correct = process_images_background(bil_path, mask)
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
# 数据清理
print("Cleaning data...\n")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 500]
# 使用pandas列选择获取要删除的列名从第 94 列到第 118 列索引从0开始
cols_to_remove = df.columns[np.r_[87:110, -10:-1]]
# cols_to_remove = df.columns[87:110]
# 删除指定列保持DataFrame结构
df = df.drop(columns=cols_to_remove)
# 使用pandas列选择选择从第二列开始的所有列跳过第一列通常是'Sample ID'或'filename'
# 保持DataFrame结构不转换为numpy数组.values会丢失列名和DataFrame结构
df = df.iloc[:, :]
# 保存原始特征数据(在第一次预测之前),供第二次模型使用
df_original = df.copy()
# 预测分类
print("Predicting classes...\n")
loaded_model = load_model(model_path)
df_pre = predict_with_model(
df,
model_path,
model_type=PRIMARY_MODEL_TYPE,
ProcessMethods1=PRIMARY_PROCESS_METHODS1,
ProcessMethods2=PRIMARY_PROCESS_METHODS2
)
# 二次分类:针对指定类别重新预测(使用原始特征值)
if SECONDARY_MODEL_PATH:
target_classes = set(SECONDARY_TARGET_CLASSES or [])
if target_classes:
print(f"Running secondary classification for classes: {sorted(target_classes)}\n")
mask_secondary = df_pre['Predictions'].isin(target_classes)
if mask_secondary.any():
# 使用原始特征数据,而不是第一次预测后的数据
df_secondary_input = df_original.loc[mask_secondary].copy()
df_secondary = predict_with_model(
df_secondary_input,
SECONDARY_MODEL_PATH,
model_type=SECONDARY_MODEL_TYPE,
ProcessMethods1=SECONDARY_PROCESS_METHODS1,
ProcessMethods2=SECONDARY_PROCESS_METHODS2
)
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values
else:
print("No samples from target classes found; skipping secondary classification.\n")
else:
print("Secondary target classes not provided; skipping secondary classification.\n")
else:
print("Secondary model path not provided; skipping secondary classification.\n")
# 识别类别7中的背景阴影误判通过边界清晰度特征
# 真正的类别7边界清晰背景阴影边界模糊
class_7_mask = df_pre['Predictions'] == 7
if class_7_mask.any():
print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n")
# 将PIL Image转换为numpy数组
if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
# 转换为灰度图(用于计算梯度)
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 计算梯度图使用Sobel算子
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
# 先收集所有类别7样本的边缘梯度值用于确定阈值
all_class7_gradients = []
valid_indices = []
for idx in df_pre[class_7_mask].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) > 0:
all_class7_gradients.extend(edge_gradients)
valid_indices.append(idx)
except:
continue
# 基于类别7样本的梯度分布确定阈值
# 使用类别7样本梯度值的中位数作为基准低于某个分位数如30%)的认为是背景阴影
if len(all_class7_gradients) > 0:
gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数
else:
gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本使用整张图的30%分位数
print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n")
# 处理每个类别7的样本判断是否为背景阴影
indices_to_update = []
for idx in valid_indices:
try:
contour = df_pre.loc[idx, 'contour']
contour_array = np.array(contour, dtype=np.int32)
# 创建轮廓掩膜线宽为2像素用于提取边缘
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
# 提取轮廓边缘的梯度值
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) == 0:
continue
# 计算轮廓边缘的平均梯度强度
mean_gradient = np.mean(edge_gradients)
# 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊)
if mean_gradient < gradient_threshold:
indices_to_update.append(idx)
print(
f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow")
except Exception as e:
print(f"Error processing sample at index {idx}: {str(e)}")
continue
# 将背景阴影的类别7改为类别9
if indices_to_update:
df_pre.loc[indices_to_update, 'Predictions'] = 9 # 类别9在代码中是索引80-based
print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n")
else:
print("No samples needed to be updated from class 7\n")
# 区分类别1HDPE和类别2LDPE通过亮度和均匀性特征
# 类别2LDPE亮度更亮且不均匀高亮度 + 高标准差)
# 类别1HDPE亮度暗且均匀低亮度 + 低标准差)
class_1_2_mask = df_pre['Predictions'].isin([1, 2]) # 类别1和2在代码中是索引0和10-based
if class_1_2_mask.any():
print(
f"Processing {class_1_2_mask.sum()} samples with class 1 (HDPE) or class 2 (LDPE) to distinguish them...\n")
# 将PIL Image转换为numpy数组如果还没有转换
if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
# 转换为灰度图(用于计算亮度)
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 收集所有类别1和2样本的亮度和标准差用于确定阈值
all_brightnesses = []
all_std_devs = []
valid_indices = []
for idx in df_pre[class_1_2_mask].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
# 创建完整轮廓掩膜
full_mask = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.fillPoly(full_mask, [contour_array], 255)
# 先对轮廓掩膜进行内缩,避免边缘区域包含背景像素
# 使用较小的核进行第一次腐蚀,得到内缩后的掩膜
contour_rect = cv2.boundingRect(contour_array)
inner_kernel_size = max(2, min(min(contour_rect[2], contour_rect[3]) // 20, 5))
inner_kernel = np.ones((inner_kernel_size, inner_kernel_size), np.uint8)
inner_mask = cv2.erode(full_mask, inner_kernel, iterations=1)
# 提取轮廓内部区域的像素值
inner_pixels = gray_img[inner_mask > 0]
if len(inner_pixels) > 0:
mean_brightness = np.mean(inner_pixels)
std_brightness = np.std(inner_pixels)
all_brightnesses.append(mean_brightness)
all_std_devs.append(std_brightness)
valid_indices.append(idx)
except Exception as e:
print(f"Error processing sample at index {idx} for brightness/std: {str(e)}")
continue
# 基于类别1和2样本的亮度和标准差分布确定阈值
# LDPE亮度更亮且不均匀HDPE亮度暗且均匀
if len(all_brightnesses) > 0 and len(all_std_devs) > 0:
# 使用中位数作为阈值
brightness_threshold = np.median(all_brightnesses)
std_threshold = np.median(all_std_devs)
else:
brightness_threshold = 128 # 默认阈值0-255范围的中值
std_threshold = 20 # 默认标准差阈值
print(f"Brightness threshold: {brightness_threshold:.3f}")
print(f"Standard deviation threshold: {std_threshold:.3f}\n")
# 处理每个类别1或2的样本判断是HDPE还是LDPE
indices_to_update_to_ldpe = [] # 需要改为LDPE类别2索引1的样本
indices_to_update_to_hdpe = [] # 需要改为HDPE类别1索引0的样本
for idx in valid_indices:
try:
contour = df_pre.loc[idx, 'contour']
contour_array = np.array(contour, dtype=np.int32)
current_prediction = df_pre.loc[idx, 'Predictions']
# 创建完整轮廓掩膜
full_mask = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.fillPoly(full_mask, [contour_array], 255)
# 先对轮廓掩膜进行内缩,避免边缘区域包含背景像素
# 使用较小的核进行第一次腐蚀,得到内缩后的掩膜
contour_rect = cv2.boundingRect(contour_array)
inner_kernel_size = max(2, min(min(contour_rect[2], contour_rect[3]) // 20, 5))
inner_kernel = np.ones((inner_kernel_size, inner_kernel_size), np.uint8)
inner_mask = cv2.erode(full_mask, inner_kernel, iterations=1)
# 提取轮廓内部区域的像素值
inner_pixels = gray_img[inner_mask > 0]
if len(inner_pixels) > 0:
mean_brightness = np.mean(inner_pixels)
std_brightness = np.std(inner_pixels)
# 判断逻辑:
# 类别2LDPE亮度更亮且不均匀高亮度 + 高标准差)
# 类别1HDPE亮度暗且均匀低亮度 + 低标准差)
is_bright = mean_brightness > brightness_threshold
is_uneven = std_brightness > std_threshold
# 如果亮度高且不均匀更可能是LDPE类别2
# 如果亮度暗且均匀更可能是HDPE类别1
if is_bright and is_uneven:
# 更可能是LDPE类别2索引1
if current_prediction == 1: # 如果当前预测是HDPE改为LDPE
indices_to_update_to_ldpe.append(idx)
print(
f"Sample {idx}: brightness={mean_brightness:.3f} (>{brightness_threshold:.3f}), "
f"std={std_brightness:.3f} (>{std_threshold:.3f}) -> changed from HDPE to LDPE")
elif not is_bright and not is_uneven:
# 更可能是HDPE类别1索引0
if current_prediction == 2: # 如果当前预测是LDPE改为HDPE
indices_to_update_to_hdpe.append(idx)
print(
f"Sample {idx}: brightness={mean_brightness:.3f} (<={brightness_threshold:.3f}), "
f"std={std_brightness:.3f} (<={std_threshold:.3f}) -> changed from LDPE to HDPE")
except Exception as e:
print(f"Error processing sample at index {idx}: {str(e)}")
continue
# 更新分类结果
if indices_to_update_to_ldpe:
df_pre.loc[indices_to_update_to_ldpe, 'Predictions'] = 2 # 改为LDPE类别2索引1
print(f"Updated {len(indices_to_update_to_ldpe)} samples from HDPE to LDPE\n")
if indices_to_update_to_hdpe:
df_pre.loc[indices_to_update_to_hdpe, 'Predictions'] = 1 # 改为HDPE类别1索引0
print(f"Updated {len(indices_to_update_to_hdpe)} samples from LDPE to HDPE\n")
if not indices_to_update_to_ldpe and not indices_to_update_to_hdpe:
print("No samples needed to be updated between HDPE and LDPE\n")
# 保存ENVI分类结果
print("Saving ENVI classification results...\n")
save_envi_classification(bil_path, df_pre, output_path)
print(f"ENVI classification results saved to: {output_path}")
if __name__ == "__main__":
main()