Files
micro_plastic/spectral_shape_class.py
2026-02-25 09:42:51 +08:00

392 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import matplotlib
import numpy as np
from bil2rgb import process_bil_files
from classification_model.Parallel.predict_plastic import load_model, predict_with_model
from mask import detect_microplastic_mask_from_array
from shape_spectral import process_images
from shape_spectral_background import process_images_background
from extact_shape import shape_correct_background, extract_features
import time
matplotlib.use('TkAgg')
# ----------------------------
# 配置参数:直接在此修改
# ----------------------------
BIL_PATH = r"D:/Data/MPData5.bil"
OUTPUT_PATH = r'D:\Data\MPData\pridict\5.bil'
PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m"
PRIMARY_MODEL_TYPE = 'SVM'
PRIMARY_PROCESS_METHODS1 = 'SS'
PRIMARY_PROCESS_METHODS2 = 'None'
SECONDARY_MODEL_PATH ="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None
SECONDARY_MODEL_TYPE = 'SVM'
SECONDARY_PROCESS_METHODS1 = 'None'
SECONDARY_PROCESS_METHODS2 = 'None'
SECONDARY_TARGET_CLASSES = [1, 2]
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def shrink_contours(bil_path, df, shrink_pixels=1):
"""
对DataFrame中的所有轮廓进行收缩操作避免塑料之间的相连
Args:
bil_path: BIL文件路径用于获取图像尺寸
df: 包含contour列的DataFrame
shrink_pixels: 收缩的像素数默认1像素
Returns:
更新后的DataFramecontour列已被收缩
"""
samples, lines = read_hdr_file(bil_path)
# 创建腐蚀核
kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8)
# 创建临时掩膜用于处理
temp_mask = np.zeros((lines, samples), dtype=np.uint8)
# 创建DataFrame副本
df = df.copy()
# 遍历每一行,更新轮廓
for idx, row in df.iterrows():
contour = row['contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
try:
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
# 清空临时掩膜
temp_mask.fill(0)
# 填充轮廓
cv2.fillPoly(temp_mask, [contour_array], 255)
# 对掩膜进行腐蚀操作
eroded_mask = cv2.erode(temp_mask, kernel, iterations=1)
# 重新提取轮廓
contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
# 选择最大的轮廓(如果有多个)
largest_contour = max(contours, key=cv2.contourArea)
# 转换为列表格式,保持与原始格式一致
if len(largest_contour) >= 3:
updated_contour = largest_contour.reshape(-1, 2).tolist()
df.at[idx, 'contour'] = updated_contour
except Exception as e:
# 如果处理失败,保留原始轮廓
continue
return df
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
# 先将 classification_result 中的 10 和 11 替换为 0
# classification_result[(classification_result == 10)] = 0
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 11
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC,background2 }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path):
# 定义要追加的波长信息
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
# 将.bil路径转换为.hdr路径
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
# 检查.hdr文件是否存在
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 读取文件内容
with open(hdr_path, 'r') as file:
content = file.read()
# 检查是否已包含波长信息
if 'wavelength' in content:
print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.")
return
# 检查文件是否以换行符结尾
needs_newline = not content.endswith('\n')
# 追加波长信息
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n') # 确保新内容从新行开始
file.write(wavelength_info + '\n')
print(f"Successfully added wavelength information to file: {os.path.basename(hdr_path)}")
def main():
bil_path = BIL_PATH
output_path = OUTPUT_PATH
model_path = PRIMARY_MODEL_PATH
# 记录总开始时间
total_start_time = time.time()
# 处理BIL文件生成RGB图像
print("Processing BIL file to generate RGB image...\n")
rgb_img = process_bil_files(bil_path)
# 修改hdr
change_hdr_file(bil_path)
segmentation_start_time = time.time()
# 生成掩膜mask为16位的塑料标签掩膜
print("Generating mask ...\n")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img, # 直接传入cv2.imread的结果
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=-1
)
# 提取特征
print("Extracting features from BIL file...\n")
df = process_images(bil_path, mask)
# 背景校正
print("Applying background correction...\n")
df_correct = process_images_background(bil_path, mask)
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
# 数据清理
print("Cleaning data...\n")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 500]
# 使用pandas列选择获取要删除的列名从第 94 列到第 118 列索引从0开始
cols_to_remove = df.columns[np.r_[87:110, -10:-1]]
# cols_to_remove = df.columns[87:110]
# 删除指定列保持DataFrame结构
df = df.drop(columns=cols_to_remove)
segmentation_time = time.time() - segmentation_start_time
# 使用pandas列选择选择从第二列开始的所有列跳过第一列通常是'Sample ID'或'filename'
# 保持DataFrame结构不转换为numpy数组.values会丢失列名和DataFrame结构
df = df.iloc[:, :]
# 预测分类(分类阶段)
classification_start_time = time.time()
# 预测分类
print("Predicting classes...\n")
loaded_model = load_model(model_path)
df_pre = predict_with_model(
df,
model_path,
model_type=PRIMARY_MODEL_TYPE,
ProcessMethods1=PRIMARY_PROCESS_METHODS1,
ProcessMethods2=PRIMARY_PROCESS_METHODS2
)
# 对HDPE和LDPE进行二次分类
# 图像信息的背景矫正
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
# 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓
target_classes = set(SECONDARY_TARGET_CLASSES or [])
mask_secondary = df_pre['Predictions'].isin(target_classes)
if mask_secondary.any():
# 创建新的掩膜mask_second只包含目标类别的轮廓
mask_second = np.zeros_like(mask, dtype=np.uint16)
for idx in df_pre[mask_secondary].index:
contour = df_pre.loc[idx, 'contour']
if isinstance(contour, list) and len(contour) > 0:
contour_array = np.array(contour, dtype=np.int32)
cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签
# 提取特征
df_shape = extract_features(df_correct, mask_second)
# 确保使用第2到13列作为模型输入特征
if len(df_shape.columns) >= 13:
df_shape = df_shape.iloc[:, :13]
else:
print("No samples from target classes found; skipping secondary classification.\n")
return
# 二次分类:使用第二个模型预测并更新分类结果
if SECONDARY_MODEL_PATH:
print(f"Running secondary classification for classes: {sorted(target_classes)}")
df_secondary = predict_with_model(
df_shape,
SECONDARY_MODEL_PATH,
model_type=SECONDARY_MODEL_TYPE,
ProcessMethods1=SECONDARY_PROCESS_METHODS1,
ProcessMethods2=SECONDARY_PROCESS_METHODS2
)
df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1
else:
print("Secondary model path not provided; skipping secondary classification.\n")
# 识别类别7中的背景阴影误判通过边界清晰度特征
# 真正的类别7边界清晰背景阴影边界模糊
class_7_mask = df_pre['Predictions'] == 7
if class_7_mask.any():
print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n")
# 将PIL Image转换为numpy数组
if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image
rgb_img_array = np.array(rgb_img)
else:
rgb_img_array = rgb_img
# 转换为灰度图(用于计算梯度)
if len(rgb_img_array.shape) == 3:
gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY)
else:
gray_img = rgb_img_array
# 计算梯度图使用Sobel算子
grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3)
gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2)
# 先收集所有类别7样本的边缘梯度值用于确定阈值
all_class7_gradients = []
valid_indices = []
for idx in df_pre[class_7_mask].index:
try:
contour = df_pre.loc[idx, 'contour']
if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3:
continue
contour_array = np.array(contour, dtype=np.int32)
if len(contour_array.shape) == 1:
continue
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) > 0:
all_class7_gradients.extend(edge_gradients)
valid_indices.append(idx)
except:
continue
# 基于类别7样本的梯度分布确定阈值
# 使用类别7样本梯度值的中位数作为基准低于某个分位数如30%)的认为是背景阴影
if len(all_class7_gradients) > 0:
gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数
else:
gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本使用整张图的30%分位数
print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n")
# 处理每个类别7的样本判断是否为背景阴影
indices_to_update = []
for idx in valid_indices:
try:
contour = df_pre.loc[idx, 'contour']
contour_array = np.array(contour, dtype=np.int32)
# 创建轮廓掩膜线宽为2像素用于提取边缘
mask_img = np.zeros(gray_img.shape, dtype=np.uint8)
cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2)
# 提取轮廓边缘的梯度值
edge_gradients = gradient_magnitude[mask_img > 0]
if len(edge_gradients) == 0:
continue
# 计算轮廓边缘的平均梯度强度
mean_gradient = np.mean(edge_gradients)
# 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊)
if mean_gradient < gradient_threshold:
indices_to_update.append(idx)
print(
f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow")
except Exception as e:
print(f"Error processing sample at index {idx}: {str(e)}")
continue
# 将背景阴影的类别7改为类别9
if indices_to_update:
df_pre.loc[indices_to_update, 'Predictions'] = 9
print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n")
else:
print("No samples needed to be updated from class 7\n")
classification_time = time.time() - classification_start_time
df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1)
# 保存ENVI分类结果
print("Saving ENVI classification results...\n")
save_envi_classification(bil_path, df_pre, output_path)
print(f"ENVI classification results saved to: {output_path}")
# 计算总耗时
total_time = time.time() - total_start_time
# 打印耗时统计
print(f"\n{'=' * 60}")
print(f"处理完成")
print(f"{'=' * 60}")
print(f"分割耗时: {segmentation_time:.2f}")
print(f"分类耗时: {classification_time:.2f}")
print(f"总耗时: {total_time:.2f}")
print(f"{'=' * 60}")
print(f"结果已保存至: {output_path}")
if __name__ == "__main__":
main()