import os import cv2 import matplotlib import numpy as np import argparse from bil2rgb import process_bil_files from classification_model.Parallel.predict_plastic import load_model, predict_with_model from mask import detect_microplastic_mask_from_array from shape_spectral import process_images from shape_spectral_background import process_images_background from extact_shape import shape_correct_background, extract_features import time matplotlib.use('TkAgg') # def parse_arguments(): # """解析命令行参数""" # parser = argparse.ArgumentParser(description='Microplastic spectral shape classification') # # # 必需参数 # parser.add_argument('--bil_path', required=True, help='Path to input BIL file') # parser.add_argument('--output_path', required=True, help='Path to output classification result') # parser.add_argument('--model_path', required=True, help='Path to primary classification model') # # # 可选参数 # # parser.add_argument('--primary_model_type', default='SVM', help='Type of primary model (default: SVM)') # # parser.add_argument('--primary_process_methods1', default='SS', help='Primary process method 1 (default: SS)') # # parser.add_argument('--primary_process_methods2', default='None', help='Primary process method 2 (default: None)') # # # parser.add_argument('--secondary_model', default="D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m", help='Path to secondary classification model') # # parser.add_argument('--secondary_model_type', default='SVM', help='Type of secondary model (default: SVM)') # # parser.add_argument('--secondary_process_methods1', default='None', # # help='Secondary process method 1 (default: None)') # # parser.add_argument('--secondary_process_methods2', default='None', # # help='Secondary process method 2 (default: None)') # # parser.add_argument('--secondary_target_classes', nargs='+', type=int, default=[1,2], # # help='Target classes for secondary classification (space separated)') # # return parser.parse_args() # ---------------------------- # 配置参数:直接在此修改 # ---------------------------- # BIL_PATH = r"D:/Data/Test/PET_bottle2.bil" # OUTPUT_PATH = r'D:/Data/PET_bottle2_class.bil' # # PRIMARY_MODEL_PATH = r"D:\plastic\plastic\modelsave\svm.m" # PRIMARY_MODEL_TYPE = 'SVM' # PRIMARY_PROCESS_METHODS1 = 'SS' # PRIMARY_PROCESS_METHODS2 = 'None' # # SECONDARY_MODEL_PATH = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" # 若不需要二次分类,则保持为 None # SECONDARY_MODEL_TYPE = 'SVM' # SECONDARY_PROCESS_METHODS1 = 'None' # SECONDARY_PROCESS_METHODS2 = 'None' # SECONDARY_TARGET_CLASSES = [1, 2] def read_hdr_file(bil_path): hdr_path = bil_path.replace('.bil', '.hdr') with open(hdr_path, 'r') as f: header = f.readlines() samples, lines = None, None for line in header: if line.startswith('samples'): samples = int(line.split('=')[-1].strip()) if line.startswith('lines'): lines = int(line.split('=')[-1].strip()) return samples, lines def shrink_contours(bil_path, df, shrink_pixels=1): """ 对DataFrame中的所有轮廓进行收缩操作,避免塑料之间的相连 Args: bil_path: BIL文件路径,用于获取图像尺寸 df: 包含contour列的DataFrame shrink_pixels: 收缩的像素数,默认1像素 Returns: 更新后的DataFrame,contour列已被收缩 """ samples, lines = read_hdr_file(bil_path) # 创建腐蚀核 kernel = np.ones((2 * shrink_pixels + 1, 2 * shrink_pixels + 1), np.uint8) # 创建临时掩膜用于处理 temp_mask = np.zeros((lines, samples), dtype=np.uint8) # 创建DataFrame副本 df = df.copy() # 遍历每一行,更新轮廓 for idx, row in df.iterrows(): contour = row['contour'] if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: continue try: contour_array = np.array(contour, dtype=np.int32) if len(contour_array.shape) == 1: continue # 清空临时掩膜 temp_mask.fill(0) # 填充轮廓 cv2.fillPoly(temp_mask, [contour_array], 255) # 对掩膜进行腐蚀操作 eroded_mask = cv2.erode(temp_mask, kernel, iterations=1) # 重新提取轮廓 contours, _ = cv2.findContours(eroded_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(contours) > 0: # 选择最大的轮廓(如果有多个) largest_contour = max(contours, key=cv2.contourArea) # 转换为列表格式,保持与原始格式一致 if len(largest_contour) >= 3: updated_contour = largest_contour.reshape(-1, 2).tolist() df.at[idx, 'contour'] = updated_contour except Exception as e: # 如果处理失败,保留原始轮廓 continue return df def save_envi_classification(bil_path, df, savepath): samples, lines = read_hdr_file(bil_path) classification_result = np.zeros((lines, samples), dtype=np.uint16) for _, row in df.iterrows(): contour = row['contour'] prediction = int(row['Predictions']) + 1 contour = np.array(contour, dtype=np.int32) # 先将 classification_result 中的 10 和 11 替换为 0 classification_result[(classification_result == 10)] = 0 cv2.fillPoly(classification_result, [contour], prediction) output_path = savepath with open(output_path, 'wb') as f: classification_result.tofile(f) header_content = f"""ENVI description = {{ Classification Result.}} samples = {samples} lines = {lines} bands = 1 header offset = 0 file type = ENVI Standard data type = 2 interleave = bil classes = 10 class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }} single pixel area = 0.000036 unit = mm2 byte order = 0 wavelength units = nm """ filename, ext = os.path.splitext(savepath) # 替换扩展名为 '.hdr' header_filename = filename + '.hdr' with open(header_filename, 'w') as header_file: header_file.write(header_content) def change_hdr_file(bil_path): # 定义要追加的波长信息 wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}""" # 将.bil路径转换为.hdr路径 hdr_path = os.path.splitext(bil_path)[0] + '.hdr' # 检查.hdr文件是否存在 if not os.path.exists(hdr_path): print(f"错误: 找不到对应的HDR文件: {hdr_path}") return # 读取文件内容 with open(hdr_path, 'r') as file: content = file.read() # 检查是否已包含波长信息 if 'wavelength' in content: print(f"File {os.path.basename(hdr_path)} already contains wavelength information; no changes needed.") return # 检查文件是否以换行符结尾 needs_newline = not content.endswith('\n') # 追加波长信息 with open(hdr_path, 'a') as file: if needs_newline: file.write('\n') # 确保新内容从新行开始 file.write(wavelength_info + '\n') print(f"Successfully added wavelength information to file: {os.path.basename(hdr_path)}") def main(): # args = parse_arguments() bil_path = 'D:/Data/Test/PET_bottle2.bil' output_path = 'D:/Data/PET_bottle2_class_test.bil' primary_model_path = "D:\plastic\plastic\modelsave\svm.m" primary_model_type = 'SVM' primary_process_methods1 = 'SS' primary_process_methods2 = "None" # secondary_model_path = args.secondary_model # secondary_model_type = args.secondary_model_type # secondary_process_methods1 = args.secondary_process_methods1 # secondary_process_methods2 = args.secondary_process_methods2 # secondary_target_classes = args.secondary_target_classes secondary_model_path = "D:\plastic\plastic\modelsave\HDPELDPE_model\svm.m" secondary_model_type = 'SVM' secondary_process_methods1 = 'None' secondary_process_methods2 = 'None' secondary_target_classes = [1,2] # 记录总开始时间 total_start_time = time.time() # 处理BIL文件生成RGB图像 print("Processing BIL file to generate RGB image...\n") rgb_img = process_bil_files(bil_path) # 修改hdr change_hdr_file(bil_path) segmentation_start_time = time.time() # 生成掩膜,mask为16位的塑料标签掩膜 print("Generating mask ...\n") mask, filter_mask_original = detect_microplastic_mask_from_array( image=rgb_img, # 直接传入cv2.imread的结果 filter_method='threshold', diameter=None, flow_threshold=0.4, cellprob_threshold=-1 ) # 提取特征 print("Extracting features from BIL file...\n") df = process_images(bil_path, mask) # 背景校正 print("Applying background correction...\n") df_correct = process_images_background(bil_path, mask) df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1) # 数据清理 print("Cleaning data...\n") df = df.dropna() df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)] df = df[df['area'] >= 500] # 使用pandas列选择:获取要删除的列名(从第 94 列到第 118 列,索引从0开始) cols_to_remove = df.columns[np.r_[87:110, -10:-1]] # cols_to_remove = df.columns[87:110] # 删除指定列,保持DataFrame结构 df = df.drop(columns=cols_to_remove) segmentation_time = time.time() - segmentation_start_time # 使用pandas列选择:选择从第二列开始的所有列(跳过第一列,通常是'Sample ID'或'filename') # 保持DataFrame结构,不转换为numpy数组(.values会丢失列名和DataFrame结构) df = df.iloc[:, :] # 预测分类(分类阶段) classification_start_time = time.time() # 预测分类 print("Predicting classes...\n") loaded_model = load_model(primary_model_path) df_pre = predict_with_model( df, primary_model_path, model_type=primary_model_type, ProcessMethods1=primary_process_methods1, ProcessMethods2=primary_process_methods2 ) # 对HDPE和LDPE进行二次分类 # 图像信息的背景矫正 df_correct = shape_correct_background(bil_path, mask, filter_mask_original) # 从第一次分类结果中提取SECONDARY_TARGET_CLASSES类别的掩膜轮廓 target_classes = set(secondary_target_classes or []) mask_secondary = df_pre['Predictions'].isin(target_classes) if mask_secondary.any(): # 创建新的掩膜mask_second,只包含目标类别的轮廓 mask_second = np.zeros_like(mask, dtype=np.uint16) for idx in df_pre[mask_secondary].index: contour = df_pre.loc[idx, 'contour'] if isinstance(contour, list) and len(contour) > 0: contour_array = np.array(contour, dtype=np.int32) cv2.fillPoly(mask_second, [contour_array], idx + 1) # 使用索引+1作为标签 # 提取特征 df_shape = extract_features(df_correct, mask_second) # 确保使用第2到13列作为模型输入特征 if len(df_shape.columns) >= 13: df_shape = df_shape.iloc[:, :13] else: print("No samples from target classes found; skipping secondary classification.\n") # 二次分类:使用第二个模型预测并更新分类结果 if secondary_model_path: print(f"Running secondary classification for classes: {sorted(target_classes)}") df_secondary = predict_with_model( df_shape, secondary_model_path, model_type=secondary_model_type, ProcessMethods1=secondary_process_methods1, ProcessMethods2=secondary_process_methods2 ) df_pre.loc[mask_secondary, 'Predictions'] = df_secondary['Predictions'].values + 1 else: print("Secondary model path not provided; skipping secondary classification.\n") # 识别类别7中的背景阴影误判:通过边界清晰度特征 # 真正的类别7边界清晰,背景阴影边界模糊 class_7_mask = df_pre['Predictions'] == 7 if class_7_mask.any(): print(f"Processing {class_7_mask.sum()} samples with class 7 to identify background shadows...\n") # 将PIL Image转换为numpy数组 if hasattr(rgb_img, 'mode'): # 检查是否是PIL Image rgb_img_array = np.array(rgb_img) else: rgb_img_array = rgb_img # 转换为灰度图(用于计算梯度) if len(rgb_img_array.shape) == 3: gray_img = cv2.cvtColor(rgb_img_array, cv2.COLOR_RGB2GRAY) else: gray_img = rgb_img_array # 计算梯度图(使用Sobel算子) grad_x = cv2.Sobel(gray_img, cv2.CV_64F, 1, 0, ksize=3) grad_y = cv2.Sobel(gray_img, cv2.CV_64F, 0, 1, ksize=3) gradient_magnitude = np.sqrt(grad_x ** 2 + grad_y ** 2) # 先收集所有类别7样本的边缘梯度值,用于确定阈值 all_class7_gradients = [] valid_indices = [] for idx in df_pre[class_7_mask].index: try: contour = df_pre.loc[idx, 'contour'] if not isinstance(contour, (list, np.ndarray)) or len(contour) < 3: continue contour_array = np.array(contour, dtype=np.int32) if len(contour_array.shape) == 1: continue mask_img = np.zeros(gray_img.shape, dtype=np.uint8) cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) edge_gradients = gradient_magnitude[mask_img > 0] if len(edge_gradients) > 0: all_class7_gradients.extend(edge_gradients) valid_indices.append(idx) except: continue # 基于类别7样本的梯度分布确定阈值 # 使用类别7样本梯度值的中位数作为基准,低于某个分位数(如30%)的认为是背景阴影 if len(all_class7_gradients) > 0: gradient_threshold = np.percentile(all_class7_gradients, 30) # 使用类别7样本梯度值的30%分位数 else: gradient_threshold = np.percentile(gradient_magnitude, 30) # 如果没有有效样本,使用整张图的30%分位数 print(f"Gradient threshold for class 7: {gradient_threshold:.2f}\n") # 处理每个类别7的样本,判断是否为背景阴影 indices_to_update = [] for idx in valid_indices: try: contour = df_pre.loc[idx, 'contour'] contour_array = np.array(contour, dtype=np.int32) # 创建轮廓掩膜(线宽为2像素,用于提取边缘) mask_img = np.zeros(gray_img.shape, dtype=np.uint8) cv2.drawContours(mask_img, [contour_array], -1, 255, thickness=2) # 提取轮廓边缘的梯度值 edge_gradients = gradient_magnitude[mask_img > 0] if len(edge_gradients) == 0: continue # 计算轮廓边缘的平均梯度强度 mean_gradient = np.mean(edge_gradients) # 如果平均梯度强度低于阈值,认为是背景阴影(边界模糊) if mean_gradient < gradient_threshold: indices_to_update.append(idx) print( f"Sample {idx}: mean_gradient={mean_gradient:.2f}, threshold={gradient_threshold:.2f} -> identified as background shadow") except Exception as e: print(f"Error processing sample at index {idx}: {str(e)}") continue # 将背景阴影的类别7改为类别9 if indices_to_update: df_pre.loc[indices_to_update, 'Predictions'] = 9 print(f"Updated {len(indices_to_update)} samples from class 7 to class 9 (background shadows)\n") else: print("No samples needed to be updated from class 7\n") classification_time = time.time() - classification_start_time df_pre = shrink_contours(bil_path, df_pre, shrink_pixels=1) # 保存ENVI分类结果 print("Saving ENVI classification results...\n") save_envi_classification(bil_path, df_pre, output_path) print(f"ENVI classification results saved to: {output_path}") # 计算总耗时 total_time = time.time() - total_start_time # 打印耗时统计 print(f"\n{'=' * 60}") print(f"处理完成") print(f"{'=' * 60}") print(f"分割耗时: {segmentation_time:.2f} 秒") print(f"分类耗时: {classification_time:.2f} 秒") print(f"总耗时: {total_time:.2f} 秒") print(f"{'=' * 60}") print(f"结果已保存至: {output_path}") if __name__ == "__main__": main()