Files
micro_plastic/main.py
2026-02-25 09:42:51 +08:00

514 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import matplotlib
import numpy as np
import argparse
import pandas as pd
from bil2rgb import process_bil_files
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
from mask import detect_microplastic_mask_from_array
from shape_spectral import process_images
from shape_spectral_background import process_images_background
from extact_shape import shape_correct_background, extract_features
import time
matplotlib.use('TkAgg')
# 训练相机波长168通道
TRAIN_WAVELENGTHS = [
898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2
]
def read_wavelengths_from_hdr(bil_path):
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
return np.array([], dtype=np.float64)
with open(hdr_path, 'r') as f:
txt = f.read()
if 'wavelength' not in txt:
return np.array([], dtype=np.float64)
seg = txt.split('wavelength', 1)[1]
seg = seg[seg.find('{')+1: seg.find('}')]
vals = [v.strip() for v in seg.split(',') if v.strip()]
try:
waves = np.array([float(v) for v in vals], dtype=np.float64)
except Exception:
waves = np.array([], dtype=np.float64)
return waves
def resample_spectra_matrix(X, src_waves, dst_waves):
src = np.asarray(src_waves, dtype=np.float64)
dst = np.asarray(dst_waves, dtype=np.float64)
X = np.asarray(X, dtype=np.float64)
if src.size == 0 or dst.size == 0:
return X
# 线性插值,越界用端点外推,避免维度缺失
out = np.empty((X.shape[0], dst.size), dtype=np.float64)
for i in range(X.shape[0]):
row = X[i]
out[i] = np.interp(dst, src, row, left=row[0], right=row[-1])
return out
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='Microplastic spectral shape classification')
# 必需参数
parser.add_argument('--bil_path', required=True, help='Path to input BIL file')
parser.add_argument('--output_path', required=True, help='Path to output classification result')
parser.add_argument('--model_path', required=True, help='Path to primary classification model')
# 可选参数
# parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)')
# parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)')
# parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)')
# parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model')
# parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)')
# parser.add_argument('--secondary_process_methods1', default='None',
# help='Secondary process method 1 (default: None)')
# parser.add_argument('--secondary_process_methods2', default='None',
# help='Secondary process method 2 (default: None)')
# parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2],
# help='Target classes for secondary classification (space separated)')
return parser.parse_args()
# ----------------------------
# 配置参数:直接在此修改
# ----------------------------
# BIL_PATH = r"D:/Data/Test/PET_bottle2.bil"
# OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil'
#
# PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
# PRIMARY_MODEL_TYPE = 'SVM'
# PRIMARY_PROCESS_METHODS1 = 'SS'
# PRIMARY_PROCESS_METHODS2 = 'None'
#
# SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
# SECONDARY_MODEL_TYPE = 'SVM'
# SECONDARY_PROCESS_METHODS1 = 'None'
# SECONDARY_PROCESS_METHODS2 = 'None'
# SECONDARY_TARGET_CLASSES = [1, 2]
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def shrink_contours(bil_path, df, shrink_pixels=1):
"""
对DataFrame中的所有轮廓进行收缩操作避免塑料之间的相连
Args:
bil_path: BIL文件路径用于获取图像尺寸
df: 包含contour列的DataFrame
shrink_pixels: 收缩的像素数默认1像素
Returns:
更新后的DataFramecontour列已被收缩
"""
samples, lines = read_hdr_file(bil_path)
# 创建腐蚀核
kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8)
# 创建临时掩膜用于处理
temp_mask = np.zeros((lines, samples), dtype=np.uint8)
# 创建DataFrame副本
df = df.copy()
# 遍历每一行,更新轮廓
for idx, row in df.iterrows():
contour = row['contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
try:
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
# 清空临时掩膜
temp_mask.fill(0)
# 填充轮廓
cv2.fillPoly(temp_mask, [contour_array], 255)
# 对掩膜进行腐蚀操作
eroded_mask = cv2.erode(temp_mask, kernel, iterations=1)
# 重新提取轮廓
contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
# 选择最大的轮廓(如果有多个)
largest_contour = max(contours, key=cv2.contourArea)
# 转换为列表格式,保持与原始格式一致
if len(largest_contour) >= 3:
updated_contour = largest_contour.reshape(-1, 2).tolist()
df.at[idx, 'contour'] = updated_contour
except Exception as e:
# 如果处理失败,保留原始轮廓
continue
return df
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
# 先将 classification_result 中的 10 和 11 替换为 0
classification_result[(classification_result == 10)] = 0
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 10
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path, wavelengths=None):
# wavelengths=None 时仅在HDR缺失wavelength字段才写入若提供则按提供内容写入
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
with open(hdr_path, 'r') as file:
content = file.read()
if 'wavelength' in content and wavelengths is None:
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
return
if wavelengths is None:
print(f"No wavelengths provided and HDR lacks wavelength; skipping write to avoid wrong bands.")
return
needs_newline = not content.endswith('\n')
wavelength_info = "wavelength = {" + ", ".join(str(float(w)) for w in wavelengths) + "}\n"
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n')
file.write(wavelength_info)
print(f"Successfully ensured wavelength information in file: {os.path.basename(hdr_path)}")
def main():
args = parse_arguments()
bil_path = args.bil_path
output_path = args.output_path
primary_model_path = args.model_path
primary_model_type = 'SVM'
primary_process_methods1 = 'SS'
primary_process_methods2 = "None"
# secondary_model_path = args.secondary_model
# secondary_model_type = args.secondary_model_type
# secondary_process_methods1 = args.secondary_process_methods1
# secondary_process_methods2 = args.secondary_process_methods2
# secondary_target_classes = args.secondary_target_classes
secondary_model_path = "E:\code\plastic\plastic20260224\plastic\plastic\modelsave\HDPELDPE_model\svm.m"
secondary_model_type = 'SVM'
secondary_process_methods1 = 'None'
secondary_process_methods2 = 'None'
secondary_target_classes = [1, 2]
# 记录总开始时间
total_start_time = time.time()
# 处理BIL文件生成RGB图像
print("Processing BIL file to generate RGB image...\n")
rgb_img = process_bil_files(bil_path)
# 修改hdr
change_hdr_file(bil_path)
segmentation_start_time = time.time()
# 生成掩膜mask为16位的塑料标签掩膜
print("Generating mask ...\n")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img, # 直接传入cv2.imread的结果
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1,
detect_filter=True
)
# 提取特征
print("Extracting features from BIL file...\n")
df = process_images(bil_path, mask)
# 背景校正(保持现有逻辑)
print("Applying background correction...\n")
df_correct = process_images_background(bil_path, mask)
# 自动识别光谱列范围假设从第2列开始连续为光谱列长度=背景校正矩阵列数
spec_start = 1
spec_len_src = len(df_correct)
spec_end_src = spec_start + spec_len_src
# 归一化(按通道逐列相除)
df.iloc[:, spec_start:spec_end_src] = df.iloc[:, spec_start:spec_end_src].div(df_correct, axis=1)
# 读取当前BIL的波长并重采样到训练用波长
src_waves = read_wavelengths_from_hdr(bil_path)
need_resample = (src_waves.size > 0) and (src_waves.size != len(TRAIN_WAVELENGTHS) or not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2))
if need_resample:
print(f"Resampling spectra: src {src_waves.size} bands -> dst {len(TRAIN_WAVELENGTHS)} bands\n")
X_src = df.iloc[:, spec_start:spec_end_src].to_numpy(dtype=np.float64)
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
# 用训练维度(168)替换原光谱列;其余形状/统计特征保持不变
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
df = pd.concat(
[
df.iloc[:, :spec_start],
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index),
df.iloc[:, spec_end_src:]
],
axis=1
)
# 更新光谱列的新范围
spec_end_src = spec_start + len(TRAIN_WAVELENGTHS)
# 数据清理(保持原有规则)
print("Cleaning data...\n")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 500]
# 列筛选此时光谱列已与训练对齐为168维因此区间索引仍可复用
cols_to_remove = df.columns[np.r_[1:5, 87:110, 166:169, -10:-1]]
df = df.drop(columns=cols_to_remove)
# 继续后续流程
segmentation_time = time.time() - segmentation_start_time
df = df.iloc[:, :]
# 预测分类(分类阶段)
classification_start_time = time.time()
# 预测分类
print("Predicting classes...\n")
loaded_model = load_model(primary_model_path)
# 断言数值特征维度应等于训练时的scaler输入维度
try:
import joblib
scaler = joblib.load(os.path.join(os.path.dirname(primary_model_path), 'scaler_params.pkl'))
numeric_cols = [c for c in df.columns[1:] if np.issubdtype(df[c].dtype, np.number) and c != 'contour']
if hasattr(scaler, 'mean_') and len(numeric_cols) != scaler.mean_.shape[0]:
raise ValueError(f"Feature dimension mismatch: current {len(numeric_cols)} != scaler {scaler.mean_.shape[0]}. Check resampling and column selection.")
except Exception:
pass
df_pre = predict_with_model(
df,
primary_model_path,
model_type=primary_model_type,
ProcessMethods1=primary_process_methods1,
ProcessMethods2=primary_process_methods2
)
# 对HDPE和LDPE进行二次分类
# 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓
target_classes = set(secondary_target_classes or [])
mask_secondary = df_pre['Predictions'].isin(target_classes)
if mask_secondary.any():
# 只有在找到目标类别时才进行背景校正和二次分类
print(f"Running secondary classification for classes: {sorted(target_classes)}")
# 图像信息的背景矫正
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
# 创建新的掩膜mask_second只包含目标类别的轮廓
mask_second = np.zeros_like(mask, dtype=np.uint16)
for idx in df_pre[mask_secondary].index:
contour = df_pre.loc[idx, 'contour']
if isinstance(contour, list) and len(contour) > 0:
contour_array = np.array(contour, dtype=np.int32)
cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签
# 提取特征
df_shape = extract_features(df_correct, mask_second)
# 确保使用第2到13列作为模型输入特征
if len(df_shape.columns) >= 13:
df_shape = df_shape.iloc[:, :13]
# 二次分类:使用第二个模型预测并更新分类结果
if secondary_model_path:
df_secondary = predict_with_model(
df_shape,
secondary_model_path,
model_type=secondary_model_type,
ProcessMethods1=secondary_process_methods1,
ProcessMethods2=secondary_process_methods2
)
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
else:
print("Secondary model path not provided; skipping secondary classification.\n")
else:
print("No samples from target classes found; skipping secondary classification.\n")
# 识别类别7中的背景阴影误判通过边界清晰度特征
# 真正的类别7边界清晰背景阴影边界模糊
class_7_mask = df_pre['Predictions'] == 7
if class_7_mask.any():
print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n")
# 将PIL Image转换为numpy数组
if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
# 转换为灰度图(用于计算梯度)
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 计算梯度图使用Sobel算子
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
# 先收集所有类别7样本的边缘梯度值用于确定阈值
all_class7_gradients = []
valid_indices = []
for idx in df_pre[class_7_mask].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) > 0:
all_class7_gradients.extend(edge_gradients)
valid_indices.append(idx)
except:
continue
# 基于类别7样本的梯度分布确定阈值
# 使用类别7样本梯度值的中位数作为基准低于某个分位数如30%)的认为是背景阴影
if len(all_class7_gradients) > 0:
gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数
else:
gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本使用整张图的30%分位数
print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n")
# 处理每个类别7的样本判断是否为背景阴影
indices_to_update = []
for idx in valid_indices:
try:
contour = df_pre.loc[idx, 'contour']
contour_array = np.array(contour, dtype=np.int32)
# 创建轮廓掩膜线宽为2像素用于提取边缘
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
# 提取轮廓边缘的梯度值
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) == 0:
continue
# 计算轮廓边缘的平均梯度强度
mean_gradient = np.mean(edge_gradients)
# 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊)
if mean_gradient < gradient_threshold:
indices_to_update.append(idx)
print(
f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow")
except Exception as e:
print(f"Error processing sample at index {idx}: {str(e)}")
continue
# 将背景阴影的类别7改为类别9
if indices_to_update:
df_pre.loc[indices_to_update, 'Predictions'] = 9
print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n")
else:
print("No samples needed to be updated from class 7\n")
classification_time = time.time() - classification_start_time
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
# 保存ENVI分类结果
print("Saving ENVI classification results...\n")
save_envi_classification(bil_path, df_pre, output_path)
print(f"ENVI classification results saved to: {output_path}")
# 计算总耗时
total_time = time.time() - total_start_time
# 打印耗时统计
print(f"\n{'=' * 60}")
print(f"处理完成")
print(f"{'=' * 60}")
print(f"分割耗时: {segmentation_time:.2f}")
print(f"分类耗时: {classification_time:.2f}")
print(f"总耗时: {total_time:.2f}")
print(f"{'=' * 60}")
print(f"结果已保存至: {output_path}")
if __name__ == "__main__":
main()