初始提交
This commit is contained in:
184
time.py
Normal file
184
time.py
Normal file
@ -0,0 +1,184 @@
|
||||
import time
|
||||
from bil2rgb import process_bil_files
|
||||
from unet_pytorch import predict_rgb
|
||||
from shape_spectral import process_images
|
||||
import cv2
|
||||
from classification_model.Parallel.predict_plastic import predict_and_save
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib
|
||||
from shape_spectral_background import process_images_background
|
||||
matplotlib.use('TkAgg')
|
||||
|
||||
def read_hdr_file(bil_path):
|
||||
# 读取 .bil 文件对应的 .hdr 文件
|
||||
hdr_path = bil_path.replace('.bil', '.hdr') # 假设 .hdr 文件与 .bil 文件同名
|
||||
with open(hdr_path, 'r') as f:
|
||||
header = f.readlines()
|
||||
|
||||
samples, lines = None, None
|
||||
|
||||
# 从 .hdr 文件中提取样本和行数
|
||||
for line in header:
|
||||
if line.startswith('samples'):
|
||||
samples = int(line.split('=')[-1].strip())
|
||||
if line.startswith('lines'):
|
||||
lines = int(line.split('=')[-1].strip())
|
||||
|
||||
return samples, lines
|
||||
|
||||
def save_envi_classification(bil_path, df, savepath):
|
||||
if not os.path.exists(savepath):
|
||||
os.makedirs(savepath)
|
||||
print(f"创建目录: {savepath}")
|
||||
# 获取 .bil 文件的图像尺寸
|
||||
samples, lines = read_hdr_file(bil_path)
|
||||
|
||||
# 初始化一个大小为 (lines, samples) 的空白图像
|
||||
classification_result = np.zeros((lines, samples), dtype=np.uint8)
|
||||
|
||||
# 循环遍历 df 中的每个轮廓和对应的类别
|
||||
for _, row in df.iterrows():
|
||||
contour = row['contour'] # 轮廓为一个列表
|
||||
|
||||
prediction = int(row['Predictions']) + 1 # 预测类别
|
||||
|
||||
# 将预测类别10和11转换为0
|
||||
if prediction == 10 or prediction == 11:
|
||||
prediction = 0
|
||||
|
||||
# 将该轮廓的所有点设置为预测类别
|
||||
contour = np.array(contour, dtype=np.int32)
|
||||
cv2.fillPoly(classification_result, [contour], prediction)
|
||||
|
||||
# 根据原 .bil 文件名生成输出路径
|
||||
base_name = os.path.splitext(os.path.basename(bil_path))[0] # 获取文件名,不带扩展名
|
||||
output_path = os.path.join(savepath, f"{base_name}_predict.dat")
|
||||
|
||||
# 保存 ENVI 格式的文件
|
||||
with open(output_path, 'wb') as f:
|
||||
classification_result.tofile(f)
|
||||
|
||||
# 保存头文件(关键修正点)
|
||||
header_content = f"""ENVI
|
||||
description = {{
|
||||
Classification Result.}}
|
||||
samples = {samples}
|
||||
lines = {lines}
|
||||
bands = 1
|
||||
header offset = 0
|
||||
file type = ENVI Standard
|
||||
data type = 1 # 修正为8-bit byte类型
|
||||
interleave = bil
|
||||
classes = 10
|
||||
class = {{ back, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC }}
|
||||
single pixel area = 0.000036
|
||||
unit = mm2
|
||||
byte order = 0
|
||||
wavelength units = nm
|
||||
"""
|
||||
header_filename = os.path.join(savepath, f"{base_name}_predict.hdr")
|
||||
with open(header_filename, 'w') as header_file:
|
||||
header_file.write(header_content)
|
||||
|
||||
def change_hdr_file(bil_path):
|
||||
# 定义要追加的波长信息
|
||||
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
|
||||
|
||||
# 将.bil路径转换为.hdr路径
|
||||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||||
|
||||
# 检查.hdr文件是否存在
|
||||
if not os.path.exists(hdr_path):
|
||||
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||||
return
|
||||
|
||||
# 读取文件内容
|
||||
with open(hdr_path, 'r') as file:
|
||||
content = file.read()
|
||||
|
||||
# 检查是否已包含波长信息
|
||||
if 'wavelength' in content:
|
||||
print(f"文件 {os.path.basename(hdr_path)} 已包含波长信息,无需修改。")
|
||||
return
|
||||
|
||||
# 检查文件是否以换行符结尾
|
||||
needs_newline = not content.endswith('\n')
|
||||
|
||||
# 追加波长信息
|
||||
with open(hdr_path, 'a') as file:
|
||||
if needs_newline:
|
||||
file.write('\n') # 确保新内容从新行开始
|
||||
file.write(wavelength_info + '\n')
|
||||
|
||||
print(f"已成功添加波长信息到文件: {os.path.basename(hdr_path)}")
|
||||
# 记录总开始时间
|
||||
total_start_time = time.time()
|
||||
|
||||
#--------------------------------------------输入参数------------------------------------
|
||||
# 读取bil文件生成rgb,假彩色显示,拉伸
|
||||
bil_path = r"E:\Data\Traindata-06\MPData4.bil"
|
||||
output_path = r'C:\Users\HyperSpec\test\1+2'
|
||||
unet_path = r"./unet_pytorch/logs/best_epoch_weights.pth"
|
||||
model_path = r"D:\plastic\plastic\classification_model\modelsave\svm.m"
|
||||
|
||||
# 步骤1: 读取BIL文件并生成RGB图像
|
||||
print("步骤1: 读取BIL文件并生成RGB图像...")
|
||||
start_time = time.time()
|
||||
rgb_img = process_bil_files(bil_path)
|
||||
print(f"步骤1完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
#修改hdr文件
|
||||
change_hdr_file(bil_path)
|
||||
|
||||
# 步骤2: UNet生成掩膜
|
||||
print("步骤2: UNet生成掩膜...")
|
||||
start_time = time.time()
|
||||
mask = predict_rgb.rgb2mask(rgb_img, unet_path)
|
||||
mask = np.array(mask)
|
||||
mask = mask[:, :, 0]
|
||||
print(f"步骤2完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 步骤3: 处理BIL文件提取特征
|
||||
print("步骤3: 处理BIL文件提取特征...")
|
||||
start_time = time.time()
|
||||
df = process_images(bil_path, mask)
|
||||
print(f"步骤3完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 步骤4: 获取背景滤纸的平均反射率
|
||||
print("步骤4: 获取背景滤纸的平均反射率...")
|
||||
start_time = time.time()
|
||||
df_correct = process_images_background(bil_path, mask)
|
||||
print(f"步骤4完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 步骤5: 数据预处理
|
||||
print("步骤5: 数据预处理...")
|
||||
start_time = time.time()
|
||||
# 对 df 的第 2-169 列(索引 1:169)进行除法
|
||||
df.iloc[:, 1:169] = df.iloc[:, 1:169].div(df_correct, axis=1)
|
||||
# 去除nan行
|
||||
df = df.dropna()
|
||||
# 去除轮廓只有一个点的数据行
|
||||
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||||
# 去除面积太小的数据行
|
||||
df = df[df['area'] >= 400]
|
||||
print(f"步骤5完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 步骤6: 预测分类结果
|
||||
print("步骤6: 预测分类结果...")
|
||||
start_time = time.time()
|
||||
df_pre = predict_and_save(df, model_path)
|
||||
print(f"步骤6完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 步骤7: 保存输出结果
|
||||
print("步骤7: 保存输出结果...")
|
||||
start_time = time.time()
|
||||
save_envi_classification(bil_path, df, output_path)
|
||||
print(f"步骤7完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 输出总耗时
|
||||
total_time = time.time() - total_start_time
|
||||
print(f"所有步骤完成,总耗时: {total_time:.2f}秒")
|
||||
print(f"详细时间报告:")
|
||||
print(f"- 读取BIL并生成RGB: {time.time() - total_start_time:.2f}秒")
|
||||
# 这里可以添加更多详细时间报告...
|
||||
Reference in New Issue
Block a user