Files
micro_plastic/extact_shape.py
2026-02-25 09:42:51 +08:00

377 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from bil2rgb import process_bil_files
from shape_spectral import process_images
import cv2
from classification_model.Parallel.predict_plastic import predict_and_save
import numpy as np
import os
import matplotlib
import pandas as pd
from mask import detect_microplastic_mask_from_array
matplotlib.use('TkAgg')
def extract_features(df_correct, mask):
# 提取df_correct内mask的特征包括平均值、50%分位数、90%分位数、标准差、变异系数、径向灰度返回为pandas的df
# 确保mask是正确的数据类型
if mask.dtype != np.uint16:
mask = mask.astype(np.uint16)
# 获取所有颗粒的标签
unique_labels = np.unique(mask)
unique_labels = unique_labels[unique_labels != 0] # 移除背景
features_list = []
for label_id in unique_labels:
# 创建当前颗粒的掩膜
particle_mask = (mask == label_id)
# 获取颗粒区域的数据
particle_data = df_correct[particle_mask]
if len(particle_data) == 0:
continue
# 基本统计特征
mean_value = np.mean(particle_data)
median_value = np.median(particle_data) # 50%分位数
percentile_90 = np.percentile(particle_data, 90) # 90%分位数
std_value = np.std(particle_data)
cv_value = std_value / mean_value if mean_value != 0 else 0 # 变异系数
# 计算颗粒重心和轮廓
y_coords, x_coords = np.where(particle_mask)
if len(y_coords) == 0:
continue
# 重心(质心)
center_y = np.mean(y_coords)
center_x = np.mean(x_coords)
# 计算轮廓
contours, _ = cv2.findContours(
particle_mask.astype(np.uint8),
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
if len(contours) == 0:
continue
contour = contours[0] # 取最大轮廓
# 面积
area = len(y_coords)
# 径向灰度特征
# 计算最大半径R从重心到轮廓边缘的最远距离
distances = []
for point in contour:
px, py = point[0]
dist = np.sqrt((px - center_x) ** 2 + (py - center_y) ** 2)
distances.append(dist)
max_radius = np.max(distances) if len(distances) > 0 else 1.0
# 如果半径太小,跳过径向特征计算
if max_radius < 3:
I_inner = mean_value
I_mid = mean_value
I_outer = mean_value
I_center = mean_value
I_edge = mean_value
R1 = 1.0
R2 = 0.0
else:
# 按半径分圈0-0.3Rinner、0.3-0.7Rmid、0.7-Router
# 对mask内的每个像素计算距离
inner_values = []
mid_values = []
outer_values = []
center_values = []
edge_values = []
for y, x in zip(y_coords, x_coords):
dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
normalized_dist = dist / max_radius if max_radius > 0 else 0
pixel_value = df_correct[y, x]
if normalized_dist <= 0.3:
inner_values.append(pixel_value)
center_values.append(pixel_value)
elif normalized_dist <= 0.7:
mid_values.append(pixel_value)
else:
outer_values.append(pixel_value)
edge_values.append(pixel_value)
# 计算各圈的平均灰度
I_inner = np.mean(inner_values) if len(inner_values) > 0 else mean_value
I_mid = np.mean(mid_values) if len(mid_values) > 0 else mean_value
I_outer = np.mean(outer_values) if len(outer_values) > 0 else mean_value
I_center = np.mean(center_values) if len(center_values) > 0 else mean_value
I_edge = np.mean(edge_values) if len(edge_values) > 0 else mean_value
# 构造特征
R1 = I_center / I_edge if I_edge != 0 else 1.0 # R1 = I_center / I_edge
R2 = I_edge - I_center # R2 = I_edge - I_center
# 将轮廓转换为列表格式与process_images输出格式一致
# cv2.findContours返回的contour格式是 (n, 1, 2),需要转换
if len(contour.shape) == 3:
contour_list = contour.reshape(-1, 2).tolist()
else:
contour_list = contour.tolist()
# 创建特征字典
feature_dict = {
'ID':label_id,
'mean': mean_value,
'median': median_value, # 50%分位数
'percentile_90': percentile_90,
'std': std_value,
'cv': cv_value, # 变异系数
'I_inner': I_inner,
'I_mid': I_mid,
'I_outer': I_outer,
'I_center': I_center,
'I_edge': I_edge,
'R1': R1,
'R2': R2,
'area': area,
'contour': contour_list,
'center_of_mass': (center_x, center_y)
}
features_list.append(feature_dict)
# 转换为DataFrame
df = pd.DataFrame(features_list)
return df
def shape_correct_background(bil_path, mask, filter_mask_original):
# 读取bil文件的第160波段数据
import numpy as np
from spectral.io import envi
# 读取BIL文件
img = envi.open(bil_path.replace('.bil', '.hdr'), bil_path)
band_160 = img.read_band(159) # 第160波段
# filter_mask_original减去mask掩膜得到只包含滤纸的掩膜
paper_mask = filter_mask_original.copy()
paper_mask[mask > 0] = 0 # 减去塑料掩膜区域
# 求只包含滤纸掩膜的第160波段数据的平均值
paper_band_160 = band_160[paper_mask > 0]
if len(paper_band_160) == 0:
print("Warning: 滤纸掩膜区域内无数据,使用全局平均值")
paper_mean = np.mean(band_160)
else:
paper_mean = np.mean(paper_band_160)
# 将原第160波段数据除以只包含滤纸掩膜的第160波段数据的平均值
corrected_band_160 = band_160 / paper_mean
# 返回背景校正后的第160波段数据
return corrected_band_160
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
# 先将 classification_result 中的 10 和 11 替换为 0
classification_result[(classification_result == 10)] = 0
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 11
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC,background2 }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path):
# 定义要追加的波长信息
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
# 将.bil路径转换为.hdr路径
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
# 检查.hdr文件是否存在
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 读取文件内容
with open(hdr_path, 'r') as file:
content = file.read()
# 检查是否已包含波长信息
if 'wavelength' in content:
print(f"文件 {os.path.basename(hdr_path)} 已包含波长信息,无需修改。")
return
# 检查文件是否以换行符结尾
needs_newline = not content.endswith('\n')
# 追加波长信息
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n') # 确保新内容从新行开始
file.write(wavelength_info + '\n')
print(f"已成功添加波长信息到文件: {os.path.basename(hdr_path)}")
def process_single_bil(bil_path):
"""
处理单个BIL文件
"""
try:
print(f"\n{'=' * 60}")
print(f"Processing: {os.path.basename(bil_path)}")
print(f"{'=' * 60}")
# 处理BIL文件生成RGB图像
print("Processing BIL file to generate RGB image...")
rgb_img = process_bil_files(bil_path)
# 修改hdr
change_hdr_file(bil_path)
# 生成掩膜mask为16位的塑料标签掩膜
print("Generating mask...")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=0.0
)
# 返回背景校正后的第160波段数据
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
# 提取特征
df = extract_features(df_correct, mask)
# 数据清理
print("Cleaning data...")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 400]
# 添加文件名列(不含扩展名)
filename = os.path.splitext(os.path.basename(bil_path))[0]
df.insert(0, 'filename', filename)
print(f"Extracted {len(df)} objects from {os.path.basename(bil_path)}")
return df
except Exception as e:
print(f"Error processing {bil_path}: {str(e)}")
import traceback
traceback.print_exc()
return None
def main():
# 单个文件或文件夹路径
bil_path_or_folder = r"D:\Data\Traindata-11\LDPE7.bil"
output_csv_path = r"D:\Data\Traindata-05\HDPELDPE\LDPE7.csv"
# 确保输出目录存在
output_dir = os.path.dirname(output_csv_path)
os.makedirs(output_dir, exist_ok=True)
# 判断是文件还是文件夹
if os.path.isfile(bil_path_or_folder):
bil_files = [bil_path_or_folder]
elif os.path.isdir(bil_path_or_folder):
# 搜索所有.bil文件
bil_files = [os.path.join(bil_path_or_folder, f) for f in os.listdir(bil_path_or_folder) if f.endswith('.bil')]
print(f"Found {len(bil_files)} BIL files to process")
else:
print(f"Error: {bil_path_or_folder} is not a valid file or directory")
return
# 初始化CSV文件写入表头
is_first_row = True
total_objects = 0
for i, bil_path in enumerate(bil_files, 1):
print(f"\n[{i}/{len(bil_files)}] Processing file...")
df = process_single_bil(bil_path)
if df is not None and len(df) > 0:
# 边处理边写入CSV
df.to_csv(
output_csv_path,
mode='a' if not is_first_row else 'w', # 第一行写入模式为'w',后续追加'w'
index=False,
header=is_first_row # 只在第一行写入表头
)
total_objects += len(df)
is_first_row = False
print(f" -> {len(df)} objects appended to CSV file")
# 显示统计信息
if total_objects > 0:
print(f"\nSummary:")
print(f" Total files processed: {len(bil_files)}")
print(f" Total objects detected: {total_objects}")
print(f" Output file: {output_csv_path}")
else:
print("\nNo results to save.")
if __name__ == "__main__":
main()