初始提交
This commit is contained in:
377
extact_shape.py
Normal file
377
extact_shape.py
Normal file
@ -0,0 +1,377 @@
|
||||
from bil2rgb import process_bil_files
|
||||
from shape_spectral import process_images
|
||||
import cv2
|
||||
from classification_model.Parallel.predict_plastic import predict_and_save
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib
|
||||
import pandas as pd
|
||||
from mask import detect_microplastic_mask_from_array
|
||||
|
||||
matplotlib.use('TkAgg')
|
||||
|
||||
|
||||
def extract_features(df_correct, mask):
|
||||
# 提取df_correct内mask的特征,包括平均值、50%分位数、90%分位数、标准差、变异系数、径向灰度,返回为pandas的df
|
||||
|
||||
# 确保mask是正确的数据类型
|
||||
if mask.dtype != np.uint16:
|
||||
mask = mask.astype(np.uint16)
|
||||
|
||||
# 获取所有颗粒的标签
|
||||
unique_labels = np.unique(mask)
|
||||
unique_labels = unique_labels[unique_labels != 0] # 移除背景
|
||||
|
||||
features_list = []
|
||||
|
||||
for label_id in unique_labels:
|
||||
# 创建当前颗粒的掩膜
|
||||
particle_mask = (mask == label_id)
|
||||
|
||||
# 获取颗粒区域的数据
|
||||
particle_data = df_correct[particle_mask]
|
||||
|
||||
if len(particle_data) == 0:
|
||||
continue
|
||||
|
||||
# 基本统计特征
|
||||
mean_value = np.mean(particle_data)
|
||||
median_value = np.median(particle_data) # 50%分位数
|
||||
percentile_90 = np.percentile(particle_data, 90) # 90%分位数
|
||||
std_value = np.std(particle_data)
|
||||
cv_value = std_value / mean_value if mean_value != 0 else 0 # 变异系数
|
||||
|
||||
# 计算颗粒重心和轮廓
|
||||
y_coords, x_coords = np.where(particle_mask)
|
||||
if len(y_coords) == 0:
|
||||
continue
|
||||
|
||||
# 重心(质心)
|
||||
center_y = np.mean(y_coords)
|
||||
center_x = np.mean(x_coords)
|
||||
|
||||
# 计算轮廓
|
||||
contours, _ = cv2.findContours(
|
||||
particle_mask.astype(np.uint8),
|
||||
cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_SIMPLE
|
||||
)
|
||||
if len(contours) == 0:
|
||||
continue
|
||||
contour = contours[0] # 取最大轮廓
|
||||
|
||||
# 面积
|
||||
area = len(y_coords)
|
||||
|
||||
# 径向灰度特征
|
||||
# 计算最大半径R(从重心到轮廓边缘的最远距离)
|
||||
distances = []
|
||||
for point in contour:
|
||||
px, py = point[0]
|
||||
dist = np.sqrt((px - center_x) ** 2 + (py - center_y) ** 2)
|
||||
distances.append(dist)
|
||||
max_radius = np.max(distances) if len(distances) > 0 else 1.0
|
||||
|
||||
# 如果半径太小,跳过径向特征计算
|
||||
if max_radius < 3:
|
||||
I_inner = mean_value
|
||||
I_mid = mean_value
|
||||
I_outer = mean_value
|
||||
I_center = mean_value
|
||||
I_edge = mean_value
|
||||
R1 = 1.0
|
||||
R2 = 0.0
|
||||
else:
|
||||
# 按半径分圈:0-0.3R(inner)、0.3-0.7R(mid)、0.7-R(outer)
|
||||
# 对mask内的每个像素计算距离
|
||||
inner_values = []
|
||||
mid_values = []
|
||||
outer_values = []
|
||||
center_values = []
|
||||
edge_values = []
|
||||
|
||||
for y, x in zip(y_coords, x_coords):
|
||||
dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
|
||||
normalized_dist = dist / max_radius if max_radius > 0 else 0
|
||||
pixel_value = df_correct[y, x]
|
||||
|
||||
if normalized_dist <= 0.3:
|
||||
inner_values.append(pixel_value)
|
||||
center_values.append(pixel_value)
|
||||
elif normalized_dist <= 0.7:
|
||||
mid_values.append(pixel_value)
|
||||
else:
|
||||
outer_values.append(pixel_value)
|
||||
edge_values.append(pixel_value)
|
||||
|
||||
# 计算各圈的平均灰度
|
||||
I_inner = np.mean(inner_values) if len(inner_values) > 0 else mean_value
|
||||
I_mid = np.mean(mid_values) if len(mid_values) > 0 else mean_value
|
||||
I_outer = np.mean(outer_values) if len(outer_values) > 0 else mean_value
|
||||
I_center = np.mean(center_values) if len(center_values) > 0 else mean_value
|
||||
I_edge = np.mean(edge_values) if len(edge_values) > 0 else mean_value
|
||||
|
||||
# 构造特征
|
||||
R1 = I_center / I_edge if I_edge != 0 else 1.0 # R1 = I_center / I_edge
|
||||
R2 = I_edge - I_center # R2 = I_edge - I_center
|
||||
|
||||
# 将轮廓转换为列表格式(与process_images输出格式一致)
|
||||
# cv2.findContours返回的contour格式是 (n, 1, 2),需要转换
|
||||
if len(contour.shape) == 3:
|
||||
contour_list = contour.reshape(-1, 2).tolist()
|
||||
else:
|
||||
contour_list = contour.tolist()
|
||||
|
||||
# 创建特征字典
|
||||
feature_dict = {
|
||||
'ID':label_id,
|
||||
'mean': mean_value,
|
||||
'median': median_value, # 50%分位数
|
||||
'percentile_90': percentile_90,
|
||||
'std': std_value,
|
||||
'cv': cv_value, # 变异系数
|
||||
'I_inner': I_inner,
|
||||
'I_mid': I_mid,
|
||||
'I_outer': I_outer,
|
||||
'I_center': I_center,
|
||||
'I_edge': I_edge,
|
||||
'R1': R1,
|
||||
'R2': R2,
|
||||
'area': area,
|
||||
'contour': contour_list,
|
||||
'center_of_mass': (center_x, center_y)
|
||||
}
|
||||
|
||||
features_list.append(feature_dict)
|
||||
|
||||
# 转换为DataFrame
|
||||
df = pd.DataFrame(features_list)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def shape_correct_background(bil_path, mask, filter_mask_original):
|
||||
# 读取bil文件的第160波段数据
|
||||
import numpy as np
|
||||
from spectral.io import envi
|
||||
|
||||
# 读取BIL文件
|
||||
img = envi.open(bil_path.replace('.bil', '.hdr'), bil_path)
|
||||
band_160 = img.read_band(159) # 第160波段
|
||||
|
||||
# filter_mask_original减去mask掩膜得到只包含滤纸的掩膜
|
||||
paper_mask = filter_mask_original.copy()
|
||||
paper_mask[mask > 0] = 0 # 减去塑料掩膜区域
|
||||
|
||||
# 求只包含滤纸掩膜的第160波段数据的平均值
|
||||
paper_band_160 = band_160[paper_mask > 0]
|
||||
if len(paper_band_160) == 0:
|
||||
print("Warning: 滤纸掩膜区域内无数据,使用全局平均值")
|
||||
paper_mean = np.mean(band_160)
|
||||
else:
|
||||
paper_mean = np.mean(paper_band_160)
|
||||
|
||||
# 将原第160波段数据除以只包含滤纸掩膜的第160波段数据的平均值
|
||||
corrected_band_160 = band_160 / paper_mean
|
||||
|
||||
# 返回背景校正后的第160波段数据
|
||||
return corrected_band_160
|
||||
|
||||
|
||||
def read_hdr_file(bil_path):
|
||||
hdr_path = bil_path.replace('.bil', '.hdr')
|
||||
with open(hdr_path, 'r') as f:
|
||||
header = f.readlines()
|
||||
|
||||
samples, lines = None, None
|
||||
|
||||
for line in header:
|
||||
if line.startswith('samples'):
|
||||
samples = int(line.split('=')[-1].strip())
|
||||
if line.startswith('lines'):
|
||||
lines = int(line.split('=')[-1].strip())
|
||||
|
||||
return samples, lines
|
||||
|
||||
|
||||
def save_envi_classification(bil_path, df, savepath):
|
||||
samples, lines = read_hdr_file(bil_path)
|
||||
classification_result = np.zeros((lines, samples), dtype=np.uint16)
|
||||
|
||||
for _, row in df.iterrows():
|
||||
contour = row['contour']
|
||||
prediction = int(row['Predictions']) + 1
|
||||
contour = np.array(contour, dtype=np.int32)
|
||||
# 先将 classification_result 中的 10 和 11 替换为 0
|
||||
classification_result[(classification_result == 10)] = 0
|
||||
|
||||
cv2.fillPoly(classification_result, [contour], prediction)
|
||||
|
||||
output_path = savepath
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
classification_result.tofile(f)
|
||||
|
||||
header_content = f"""ENVI
|
||||
description = {{
|
||||
Classification Result.}}
|
||||
samples = {samples}
|
||||
lines = {lines}
|
||||
bands = 1
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 2
|
||||
interleave = bil
|
||||
classes = 11
|
||||
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC,background2 }}
|
||||
single pixel area = 0.000036
|
||||
unit = mm2
|
||||
byte order = 0
|
||||
wavelength units = nm
|
||||
"""
|
||||
filename, ext = os.path.splitext(savepath)
|
||||
# 替换扩展名为 '.hdr'
|
||||
header_filename = filename + '.hdr'
|
||||
|
||||
with open(header_filename, 'w') as header_file:
|
||||
header_file.write(header_content)
|
||||
|
||||
|
||||
def change_hdr_file(bil_path):
|
||||
# 定义要追加的波长信息
|
||||
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
|
||||
|
||||
# 将.bil路径转换为.hdr路径
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
|
||||
# 检查.hdr文件是否存在
|
||||
if not os.path.exists(hdr_path):
|
||||
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||||
return
|
||||
|
||||
# 读取文件内容
|
||||
with open(hdr_path, 'r') as file:
|
||||
content = file.read()
|
||||
|
||||
# 检查是否已包含波长信息
|
||||
if 'wavelength' in content:
|
||||
print(f"文件 {os.path.basename(hdr_path)} 已包含波长信息,无需修改。")
|
||||
return
|
||||
|
||||
# 检查文件是否以换行符结尾
|
||||
needs_newline = not content.endswith('\n')
|
||||
|
||||
# 追加波长信息
|
||||
with open(hdr_path, 'a') as file:
|
||||
if needs_newline:
|
||||
file.write('\n') # 确保新内容从新行开始
|
||||
file.write(wavelength_info + '\n')
|
||||
|
||||
print(f"已成功添加波长信息到文件: {os.path.basename(hdr_path)}")
|
||||
|
||||
|
||||
def process_single_bil(bil_path):
|
||||
"""
|
||||
处理单个BIL文件
|
||||
"""
|
||||
try:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Processing: {os.path.basename(bil_path)}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 处理BIL文件生成RGB图像
|
||||
print("Processing BIL file to generate RGB image...")
|
||||
rgb_img = process_bil_files(bil_path)
|
||||
|
||||
# 修改hdr
|
||||
change_hdr_file(bil_path)
|
||||
|
||||
# 生成掩膜,mask为16位的塑料标签掩膜
|
||||
print("Generating mask...")
|
||||
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||||
image=rgb_img,
|
||||
filter_method='threshold',
|
||||
diameter=None,
|
||||
flow_threshold=0.4,
|
||||
cellprob_threshold=0.0
|
||||
)
|
||||
|
||||
# 返回背景校正后的第160波段数据
|
||||
df_correct = shape_correct_background(bil_path, mask, filter_mask_original)
|
||||
|
||||
# 提取特征
|
||||
df = extract_features(df_correct, mask)
|
||||
|
||||
# 数据清理
|
||||
print("Cleaning data...")
|
||||
df = df.dropna()
|
||||
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||
df = df[df['area'] >= 400]
|
||||
|
||||
# 添加文件名列(不含扩展名)
|
||||
filename = os.path.splitext(os.path.basename(bil_path))[0]
|
||||
df.insert(0, 'filename', filename)
|
||||
|
||||
print(f"Extracted {len(df)} objects from {os.path.basename(bil_path)}")
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {bil_path}: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
# 单个文件或文件夹路径
|
||||
bil_path_or_folder = r"D:\Data\Traindata-11\LDPE7.bil"
|
||||
output_csv_path = r"D:\Data\Traindata-05\HDPELDPE\LDPE7.csv"
|
||||
|
||||
# 确保输出目录存在
|
||||
output_dir = os.path.dirname(output_csv_path)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 判断是文件还是文件夹
|
||||
if os.path.isfile(bil_path_or_folder):
|
||||
bil_files = [bil_path_or_folder]
|
||||
elif os.path.isdir(bil_path_or_folder):
|
||||
# 搜索所有.bil文件
|
||||
bil_files = [os.path.join(bil_path_or_folder, f) for f in os.listdir(bil_path_or_folder) if f.endswith('.bil')]
|
||||
print(f"Found {len(bil_files)} BIL files to process")
|
||||
else:
|
||||
print(f"Error: {bil_path_or_folder} is not a valid file or directory")
|
||||
return
|
||||
|
||||
# 初始化CSV文件(写入表头)
|
||||
is_first_row = True
|
||||
total_objects = 0
|
||||
|
||||
for i, bil_path in enumerate(bil_files, 1):
|
||||
print(f"\n[{i}/{len(bil_files)}] Processing file...")
|
||||
df = process_single_bil(bil_path)
|
||||
|
||||
if df is not None and len(df) > 0:
|
||||
# 边处理边写入CSV
|
||||
df.to_csv(
|
||||
output_csv_path,
|
||||
mode='a' if not is_first_row else 'w', # 第一行写入模式为'w',后续追加'w'
|
||||
index=False,
|
||||
header=is_first_row # 只在第一行写入表头
|
||||
)
|
||||
total_objects += len(df)
|
||||
is_first_row = False
|
||||
print(f" -> {len(df)} objects appended to CSV file")
|
||||
|
||||
# 显示统计信息
|
||||
if total_objects > 0:
|
||||
print(f"\nSummary:")
|
||||
print(f" Total files processed: {len(bil_files)}")
|
||||
print(f" Total objects detected: {total_objects}")
|
||||
print(f" Output file: {output_csv_path}")
|
||||
else:
|
||||
print("\nNo results to save.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user