Files
micro_plastic/fliter_sample_spectral.py
2026-02-25 09:42:51 +08:00

314 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from bil2rgb import process_bil_files
from shape_spectral import process_images
import cv2
from classification_model.Parallel.predict_plastic import predict_and_save
import numpy as np
import os
import matplotlib
import pandas as pd
from shape_spectral_background import process_images_background
from mask import detect_microplastic_mask_from_array
import plantcv as pcv
matplotlib.use('TkAgg')
def read_hdr_file(bil_path):
hdr_path = bil_path.replace('.bil', '.hdr')
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def save_envi_classification(bil_path, df, savepath):
samples, lines = read_hdr_file(bil_path)
classification_result = np.zeros((lines, samples), dtype=np.uint16)
for _, row in df.iterrows():
contour = row['contour']
prediction = int(row['Predictions']) + 1
contour = np.array(contour, dtype=np.int32)
# 先将 classification_result 中的 10 和 11 替换为 0
classification_result[(classification_result == 10)] = 0
cv2.fillPoly(classification_result, [contour], prediction)
output_path = savepath
with open(output_path, 'wb') as f:
classification_result.tofile(f)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 2
interleave = bil
classes = 11
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC,background2 }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
filename, ext = os.path.splitext(savepath)
# 替换扩展名为 '.hdr'
header_filename = filename + '.hdr'
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path):
# 定义要追加的波长信息
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
# 将.bil路径转换为.hdr路径
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
# 检查.hdr文件是否存在
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 读取文件内容
with open(hdr_path, 'r') as file:
content = file.read()
# 检查是否已包含波长信息
if 'wavelength' in content:
print(f"文件 {os.path.basename(hdr_path)} 已包含波长信息,无需修改。")
return
# 检查文件是否以换行符结尾
needs_newline = not content.endswith('\n')
# 追加波长信息
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n') # 确保新内容从新行开始
file.write(wavelength_info + '\n')
print(f"已成功添加波长信息到文件: {os.path.basename(hdr_path)}")
def generate_new_mask(filter_mask_original, mask, num_masks=50, bil_path=None):
# 根据滤纸掩膜和微塑料掩膜生成新的掩膜在滤纸掩膜内塑料掩膜外随机位置生成大小为35*35大小的掩膜数量为50个
# filter_mask_original为滤纸掩膜mask为微塑料掩膜
# filter_mask_original为二值图像mask为16位的塑料标签掩膜
# 生成新的掩膜在滤纸掩膜内塑料掩膜外随机位置生成大小为35*35大小的掩膜数量为50个
# 确保mask是二值图像非零值表示塑料区域
mask_binary = (mask > 0).astype(np.uint8)
# 找到滤纸掩膜内且塑料掩膜外的区域(滤纸为真,塑料为假)
# filter_mask_original应该是二值图像非零表示滤纸区域
filter_mask_binary = (filter_mask_original > 0).astype(np.uint8)
valid_region = filter_mask_binary & (~mask_binary)
# 获取图像尺寸
height, width = valid_region.shape
mask_size = 35
half_size = mask_size // 2
# 初始化新的掩膜数组
new_mask_array = np.zeros((height, width), dtype=np.uint16)
# 找到所有可以放置35x35掩膜的有效中心点
# 确保掩膜完全在有效区域内
valid_centers = []
for y in range(half_size, height - half_size):
for x in range(half_size, width - half_size):
# 检查以(x, y)为中心大小为35x35的区域是否完全在有效区域内
y_start, y_end = y - half_size, y + half_size + 1
x_start, x_end = x - half_size, x + half_size + 1
region = valid_region[y_start:y_end, x_start:x_end]
if np.all(region > 0): # 整个35x35区域都在有效区域内
valid_centers.append((y, x))
# 如果有效中心点不足50个则使用所有可用的中心点
num_masks = min(num_masks, len(valid_centers))
if num_masks == 0:
print("Warning: No valid positions found for generating masks.")
return new_mask_array
# 随机选择50个或更少中心点
if len(valid_centers) > num_masks:
selected_centers = np.random.choice(len(valid_centers), size=num_masks, replace=False)
selected_centers = [valid_centers[i] for i in selected_centers]
else:
selected_centers = valid_centers
# 在每个选定的中心点生成35x35的掩膜
mask_value = 1 # 可以设置为不同的值来区分不同的掩膜
for y, x in selected_centers:
y_start, y_end = y - half_size, y + half_size + 1
x_start, x_end = x - half_size, x + half_size + 1
new_mask_array[y_start:y_end, x_start:x_end] = mask_value
mask_value += 1
print(f"Generated {len(selected_centers)} masks of size {mask_size}x{mask_size}")
# 保存滤纸掩膜塑料掩膜以及新生成的掩膜为不同颜色保存至同一个图片上保存至bil_path的masks文件夹下
if bil_path is not None:
# 确保输出目录存在
os.makedirs(os.path.join(os.path.dirname(bil_path), 'masks'), exist_ok=True)
# 创建RGB图像用于可视化黑色背景
height, width = filter_mask_original.shape
combined_visualization = np.zeros((height, width, 3), dtype=np.uint8)
# 滤纸掩膜用蓝色表示
combined_visualization[:, :, 2] = filter_mask_binary * 100 # B通道
# 塑料掩膜用红色表示
combined_visualization[:, :, 0] = mask_binary * 255 # R通道
# 新生成的掩膜用绿色表示
new_mask_binary = (new_mask_array > 0).astype(np.uint8)
combined_visualization[:, :, 1] = np.maximum(combined_visualization[:, :, 1], new_mask_binary * 255) # G通道
# 获取文件名(不含扩展名)
filename = os.path.splitext(os.path.basename(bil_path))[0]
output_path = os.path.join(os.path.join(os.path.dirname(bil_path), 'masks'),
f"{filename}_mask_visualization.png")
# 保存图像
cv2.imwrite(output_path, combined_visualization)
print(f"Saved mask visualization to: {output_path}")
# 合并掩膜:将新生成的掩膜和原塑料掩膜合并
# 新掩膜使用不同的标签值,避免与原掩膜冲突
combined_mask = mask.copy().astype(np.uint16)
# 将新掩膜添加到合并掩膜中使用较大的标签值如1000+
combined_mask[new_mask_array > 0] = new_mask_array[new_mask_array > 0] + 1000
return new_mask_array
def process_single_bil(bil_path):
"""
处理单个BIL文件
"""
try:
print(f"\n{'=' * 60}")
print(f"Processing: {os.path.basename(bil_path)}")
print(f"{'=' * 60}")
# 处理BIL文件生成RGB图像
print("Processing BIL file to generate RGB image...")
rgb_img = process_bil_files(bil_path)
# 修改hdr
change_hdr_file(bil_path)
# 生成掩膜mask为16位的塑料标签掩膜
print("Generating mask...")
mask, filter_mask_original = detect_microplastic_mask_from_array(
image=rgb_img,
filter_method='threshold',
diameter=None,
flow_threshold=0.4,
cellprob_threshold=0.0
)
# 根据滤纸掩膜和微塑料掩膜生成新的掩膜在滤纸掩膜内塑料掩膜外随机位置生成大小为35*35大小的掩膜数量为50个
new_mask_array = generate_new_mask(filter_mask_original, mask)
# 提取特征
print("Extracting features from BIL file...")
# 清理plantcv的observations确保只包含当前处理的塑料掩膜数据
pcv.observations = {}
df = process_images(bil_path, new_mask_array)
# 背景校正
print("Applying background correction...")
df_correct = process_images_background(bil_path, mask)
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
# 数据清理
print("Cleaning data...")
df = df.dropna()
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
df = df[df['area'] >= 400]
# 添加文件名列(不含扩展名)
filename = os.path.splitext(os.path.basename(bil_path))[0]
df.insert(0, 'filename', filename)
print(f"Extracted {len(df)} objects from {os.path.basename(bil_path)}")
return df
except Exception as e:
print(f"Error processing {bil_path}: {str(e)}")
import traceback
traceback.print_exc()
return None
def main():
# 单个文件或文件夹路径
bil_path_or_folder = r"D:\Data\Traindata-11"
output_csv_path = r"E:\plastic\plastic\output\滤纸样本光谱\11.csv"
# 确保输出目录存在
output_dir = os.path.dirname(output_csv_path)
os.makedirs(output_dir, exist_ok=True)
# 判断是文件还是文件夹
if os.path.isfile(bil_path_or_folder):
bil_files = [bil_path_or_folder]
elif os.path.isdir(bil_path_or_folder):
# 搜索所有.bil文件
bil_files = [os.path.join(bil_path_or_folder, f) for f in os.listdir(bil_path_or_folder) if f.endswith('.bil')]
print(f"Found {len(bil_files)} BIL files to process")
else:
print(f"Error: {bil_path_or_folder} is not a valid file or directory")
return
# 初始化CSV文件写入表头
is_first_row = True
total_objects = 0
for i, bil_path in enumerate(bil_files, 1):
print(f"\n[{i}/{len(bil_files)}] Processing file...")
df = process_single_bil(bil_path)
if df is not None and len(df) > 0:
# 边处理边写入CSV
df.to_csv(
output_csv_path,
mode='a' if not is_first_row else 'w', # 第一行写入模式为'w',后续追加'w'
index=False,
header=is_first_row # 只在第一行写入表头
)
total_objects += len(df)
is_first_row = False
print(f" -> {len(df)} objects appended to CSV file")
# 显示统计信息
if total_objects > 0:
print(f"\nSummary:")
print(f" Total files processed: {len(bil_files)}")
print(f" Total objects detected: {total_objects}")
print(f" Output file: {output_csv_path}")
else:
print("\nNo results to save.")
if __name__ == "__main__":
main()