Files
micro_plastic/time.py
2026-02-25 09:42:51 +08:00

184 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from bil2rgb import process_bil_files
from unet_pytorch import predict_rgb
from shape_spectral import process_images
import cv2
from classification_model.Parallel.predict_plastic import predict_and_save
import numpy as np
import os
import matplotlib
from shape_spectral_background import process_images_background
matplotlib.use('TkAgg')
def read_hdr_file(bil_path):
# 读取 .bil 文件对应的 .hdr 文件
hdr_path = bil_path.replace('.bil', '.hdr') # 假设 .hdr 文件与 .bil 文件同名
with open(hdr_path, 'r') as f:
header = f.readlines()
samples, lines = None, None
# 从 .hdr 文件中提取样本和行数
for line in header:
if line.startswith('samples'):
samples = int(line.split('=')[-1].strip())
if line.startswith('lines'):
lines = int(line.split('=')[-1].strip())
return samples, lines
def save_envi_classification(bil_path, df, savepath):
if not os.path.exists(savepath):
os.makedirs(savepath)
print(f"创建目录: {savepath}")
# 获取 .bil 文件的图像尺寸
samples, lines = read_hdr_file(bil_path)
# 初始化一个大小为 (lines, samples) 的空白图像
classification_result = np.zeros((lines, samples), dtype=np.uint8)
# 循环遍历 df 中的每个轮廓和对应的类别
for _, row in df.iterrows():
contour = row['contour'] # 轮廓为一个列表
prediction = int(row['Predictions']) + 1 # 预测类别
# 将预测类别10和11转换为0
if prediction == 10 or prediction == 11:
prediction = 0
# 将该轮廓的所有点设置为预测类别
contour = np.array(contour, dtype=np.int32)
cv2.fillPoly(classification_result, [contour], prediction)
# 根据原 .bil 文件名生成输出路径
base_name = os.path.splitext(os.path.basename(bil_path))[0] # 获取文件名,不带扩展名
output_path = os.path.join(savepath, f"{base_name}_predict.dat")
# 保存 ENVI 格式的文件
with open(output_path, 'wb') as f:
classification_result.tofile(f)
# 保存头文件(关键修正点)
header_content = f"""ENVI
description = {{
Classification Result.}}
samples = {samples}
lines = {lines}
bands = 1
header offset = 0
file type = ENVI Standard
data type = 1 # 修正为8-bit byte类型
interleave = bil
classes = 10
class = {{ back, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
single pixel area = 0.000036
unit = mm2
byte order = 0
wavelength units = nm
"""
header_filename = os.path.join(savepath, f"{base_name}_predict.hdr")
with open(header_filename, 'w') as header_file:
header_file.write(header_content)
def change_hdr_file(bil_path):
# 定义要追加的波长信息
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
# 将.bil路径转换为.hdr路径
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
# 检查.hdr文件是否存在
if not os.path.exists(hdr_path):
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
return
# 读取文件内容
with open(hdr_path, 'r') as file:
content = file.read()
# 检查是否已包含波长信息
if 'wavelength' in content:
print(f"文件 {os.path.basename(hdr_path)} 已包含波长信息,无需修改。")
return
# 检查文件是否以换行符结尾
needs_newline = not content.endswith('\n')
# 追加波长信息
with open(hdr_path, 'a') as file:
if needs_newline:
file.write('\n') # 确保新内容从新行开始
file.write(wavelength_info + '\n')
print(f"已成功添加波长信息到文件: {os.path.basename(hdr_path)}")
# 记录总开始时间
total_start_time = time.time()
#--------------------------------------------输入参数------------------------------------
# 读取bil文件生成rgb,假彩色显示,拉伸
bil_path = r"E:\Data\Traindata-06\MPData4.bil"
output_path = r'C:\Users\HyperSpec\test\1+2'
unet_path = r"./unet_pytorch/logs/best_epoch_weights.pth"
model_path = r"D:\plastic\plastic\classification_model\modelsave\svm.m"
# 步骤1: 读取BIL文件并生成RGB图像
print("步骤1: 读取BIL文件并生成RGB图像...")
start_time = time.time()
rgb_img = process_bil_files(bil_path)
print(f"步骤1完成耗时: {time.time() - start_time:.2f}")
#修改hdr文件
change_hdr_file(bil_path)
# 步骤2: UNet生成掩膜
print("步骤2: UNet生成掩膜...")
start_time = time.time()
mask = predict_rgb.rgb2mask(rgb_img, unet_path)
mask = np.array(mask)
mask = mask[:, :, 0]
print(f"步骤2完成耗时: {time.time() - start_time:.2f}")
# 步骤3: 处理BIL文件提取特征
print("步骤3: 处理BIL文件提取特征...")
start_time = time.time()
df = process_images(bil_path, mask)
print(f"步骤3完成耗时: {time.time() - start_time:.2f}")
# 步骤4: 获取背景滤纸的平均反射率
print("步骤4: 获取背景滤纸的平均反射率...")
start_time = time.time()
df_correct = process_images_background(bil_path, mask)
print(f"步骤4完成耗时: {time.time() - start_time:.2f}")
# 步骤5: 数据预处理
print("步骤5: 数据预处理...")
start_time = time.time()
# 对 df 的第 2-169 列(索引 1:169进行除法
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
# 去除nan行
df = df.dropna()
# 去除轮廓只有一个点的数据行
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
# 去除面积太小的数据行
df = df[df['area'] >= 400]
print(f"步骤5完成耗时: {time.time() - start_time:.2f}")
# 步骤6: 预测分类结果
print("步骤6: 预测分类结果...")
start_time = time.time()
df_pre = predict_and_save(df, model_path)
print(f"步骤6完成耗时: {time.time() - start_time:.2f}")
# 步骤7: 保存输出结果
print("步骤7: 保存输出结果...")
start_time = time.time()
save_envi_classification(bil_path, df, output_path)
print(f"步骤7完成耗时: {time.time() - start_time:.2f}")
# 输出总耗时
total_time = time.time() - total_start_time
print(f"所有步骤完成,总耗时: {total_time:.2f}")
print(f"详细时间报告:")
print(f"- 读取BIL并生成RGB: {time.time() - total_start_time:.2f}")
# 这里可以添加更多详细时间报告...