353 lines
14 KiB
Python
353 lines
14 KiB
Python
from bil2rgb import process_bil_files
|
||
from shape_spectral import process_images
|
||
import cv2
|
||
# from classification_model.Parallel.predict_plastic import predict_and_save # 本脚本不分类,可移除
|
||
import numpy as np
|
||
import os
|
||
import matplotlib
|
||
import pandas as pd
|
||
from shape_spectral_background import process_images_background
|
||
from mask import detect_microplastic_mask_from_array
|
||
import plantcv as pcv
|
||
|
||
# 直接复用 main.py 中的成熟实现,避免重复逻辑和不一致
|
||
from main import (
|
||
TRAIN_WAVELENGTHS,
|
||
read_wavelengths_from_hdr,
|
||
resample_spectra_matrix,
|
||
apply_background_no_resample,
|
||
change_hdr_file,
|
||
)
|
||
|
||
matplotlib.use('TkAgg')
|
||
#####用于提取背景滤纸的样本,目的是在训练时加入滤纸的光谱,以减少滤纸与ftpe的误判
|
||
|
||
def apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path):
|
||
# 先做背景校正(自动识别以 wavelength_ 或 band_ 开头的光谱列,且长度不一致时尾部对齐)
|
||
df = apply_background_no_resample(df, bg_spectrum)
|
||
|
||
# 再判断是否需要重采样到训练波长
|
||
src_waves = read_wavelengths_from_hdr(bil_path)
|
||
need_resample = (
|
||
src_waves.size > 0 and (
|
||
src_waves.size != len(TRAIN_WAVELENGTHS) or
|
||
not np.allclose(src_waves, TRAIN_WAVELENGTHS, atol=1e-2)
|
||
)
|
||
)
|
||
|
||
if not need_resample:
|
||
return df
|
||
|
||
# 识别光谱列并重采样
|
||
spec_cols = [c for c in df.columns if isinstance(c, str) and (c.startswith('wavelength_') or c.startswith('band_'))]
|
||
if not spec_cols:
|
||
raise ValueError("未找到光谱列(以 wavelength_ 或 band_ 开头)")
|
||
|
||
X_src = df[spec_cols].to_numpy(dtype=np.float64)
|
||
X_dst = resample_spectra_matrix(X_src, src_waves, TRAIN_WAVELENGTHS)
|
||
|
||
# 用 band_{i} 替换光谱列,保持与 main.py 一致
|
||
spec_col_names = [f"band_{i+1}" for i in range(len(TRAIN_WAVELENGTHS))]
|
||
df = pd.concat([
|
||
df.drop(columns=spec_cols),
|
||
pd.DataFrame(X_dst, columns=spec_col_names, index=df.index)
|
||
], axis=1)
|
||
return df
|
||
|
||
|
||
def read_hdr_file(bil_path):
|
||
hdr_path = bil_path.replace('.bil', '.hdr')
|
||
with open(hdr_path, 'r') as f:
|
||
header = f.readlines()
|
||
|
||
samples, lines = None, None
|
||
|
||
for line in header:
|
||
if line.startswith('samples'):
|
||
samples = int(line.split('=')[-1].strip())
|
||
if line.startswith('lines'):
|
||
lines = int(line.split('=')[-1].strip())
|
||
|
||
return samples, lines
|
||
|
||
|
||
def save_envi_classification(bil_path, df, savepath):
|
||
samples, lines = read_hdr_file(bil_path)
|
||
classification_result = np.zeros((lines, samples), dtype=np.uint16)
|
||
|
||
for _, row in df.iterrows():
|
||
contour = row['contour']
|
||
prediction = int(row['Predictions']) + 1
|
||
contour = np.array(contour, dtype=np.int32)
|
||
# 先将 classification_result 中的 10 和 11 替换为 0
|
||
classification_result[(classification_result == 10)] = 0
|
||
|
||
cv2.fillPoly(classification_result, [contour], prediction)
|
||
|
||
output_path = savepath
|
||
|
||
with open(output_path, 'wb') as f:
|
||
classification_result.tofile(f)
|
||
|
||
header_content = f"""ENVI
|
||
description = {{
|
||
Classification Result.}}
|
||
samples = {samples}
|
||
lines = {lines}
|
||
bands = 1
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = 2
|
||
interleave = bil
|
||
classes = 11
|
||
class = {{ background, ABS, HDPE, LDPE, PA6, PET, PP, PS, PTFE, PVC,background2 }}
|
||
single pixel area = 0.000036
|
||
unit = mm2
|
||
byte order = 0
|
||
wavelength units = nm
|
||
"""
|
||
filename, ext = os.path.splitext(savepath)
|
||
# 替换扩展名为 '.hdr'
|
||
header_filename = filename + '.hdr'
|
||
|
||
with open(header_filename, 'w') as header_file:
|
||
header_file.write(header_content)
|
||
|
||
|
||
def change_hdr_file(bil_path):
|
||
# 定义要追加的波长信息
|
||
wavelength_info = """wavelength = {898.82, 903.64, 908.46, 913.28, 918.1, 922.92, 927.75, 932.57, 937.4, 942.22, 947.05, 951.88, 956.71, 961.54, 966.38, 971.21, 976.05, 980.88, 985.72, 990.56, 995.4, 1000.2, 1005.1, 1009.9, 1014.8, 1019.6, 1024.5, 1029.3, 1034.2, 1039, 1043.9, 1048.7, 1053.6, 1058.4, 1063.3, 1068.2, 1073, 1077.9, 1082.7, 1087.6, 1092.5, 1097.3, 1102.2, 1107.1, 1111.9, 1116.8, 1121.7, 1126.6, 1131.4, 1136.3, 1141.2, 1146.1, 1150.9, 1155.8, 1160.7, 1165.6, 1170.5, 1175.4, 1180.2, 1185.1, 1190, 1194.9, 1199.8, 1204.7, 1209.6, 1214.5, 1219.4, 1224.3, 1229.2, 1234.1, 1239, 1243.9, 1248.8, 1253.7, 1258.6, 1263.5, 1268.4, 1273.3, 1278.2, 1283.1, 1288.1, 1293, 1297.9, 1302.8, 1307.7, 1312.6, 1317.6, 1322.5, 1327.4, 1332.3, 1337.3, 1342.2, 1347.1, 1352, 1357, 1361.9, 1366.8, 1371.8, 1376.7, 1381.6, 1386.6, 1391.5, 1396.5, 1401.4, 1406.3, 1411.3, 1416.2, 1421.2, 1426.1, 1431.1, 1436, 1441, 1445.9, 1450.9, 1455.8, 1460.8, 1465.8, 1470.7, 1475.7, 1480.6, 1485.6, 1490.6, 1495.5, 1500.5, 1505.5, 1510.4, 1515.4, 1520.4, 1525.3, 1530.3, 1535.3, 1540.3, 1545.2, 1550.2, 1555.2, 1560.2, 1565.2, 1570.1, 1575.1, 1580.1, 1585.1, 1590.1, 1595.1, 1600.1, 1605.1, 1610, 1615, 1620, 1625, 1630, 1635, 1640, 1645, 1650, 1655, 1660, 1665, 1670.1, 1675.1, 1680.1, 1685.1, 1690.1, 1695.1, 1700.1, 1705.1, 1710.2, 1715.2, 1720.2}"""
|
||
|
||
# 将.bil路径转换为.hdr路径
|
||
hdr_path = os.path.splitext(bil_path)[0] + '.hdr'
|
||
|
||
# 检查.hdr文件是否存在
|
||
if not os.path.exists(hdr_path):
|
||
print(f"错误: 找不到对应的HDR文件: {hdr_path}")
|
||
return
|
||
|
||
# 读取文件内容
|
||
with open(hdr_path, 'r') as file:
|
||
content = file.read()
|
||
|
||
# 检查是否已包含波长信息
|
||
if 'wavelength' in content:
|
||
print(f"文件 {os.path.basename(hdr_path)} 已包含波长信息,无需修改。")
|
||
return
|
||
|
||
# 检查文件是否以换行符结尾
|
||
needs_newline = not content.endswith('\n')
|
||
|
||
# 追加波长信息
|
||
with open(hdr_path, 'a') as file:
|
||
if needs_newline:
|
||
file.write('\n') # 确保新内容从新行开始
|
||
file.write(wavelength_info + '\n')
|
||
|
||
print(f"已成功添加波长信息到文件: {os.path.basename(hdr_path)}")
|
||
|
||
|
||
def generate_new_mask(filter_mask_original, mask, num_masks=50, bil_path=None):
|
||
# 根据滤纸掩膜和微塑料掩膜生成新的掩膜,在滤纸掩膜内塑料掩膜外,随机位置生成大小为35*35大小的掩膜,数量为50个
|
||
# filter_mask_original为滤纸掩膜,mask为微塑料掩膜
|
||
# filter_mask_original为二值图像,mask为16位的塑料标签掩膜
|
||
# 生成新的掩膜,在滤纸掩膜内,塑料掩膜外,随机位置生成大小为35*35大小的掩膜,数量为50个
|
||
|
||
# 确保mask是二值图像(非零值表示塑料区域)
|
||
mask_binary = (mask > 0).astype(np.uint8)
|
||
|
||
# 找到滤纸掩膜内且塑料掩膜外的区域(滤纸为真,塑料为假)
|
||
# filter_mask_original应该是二值图像,非零表示滤纸区域
|
||
filter_mask_binary = (filter_mask_original > 0).astype(np.uint8)
|
||
valid_region = filter_mask_binary & (~mask_binary)
|
||
|
||
# 获取图像尺寸
|
||
height, width = valid_region.shape
|
||
mask_size = 35
|
||
half_size = mask_size // 2
|
||
|
||
# 初始化新的掩膜数组
|
||
new_mask_array = np.zeros((height, width), dtype=np.uint16)
|
||
|
||
# 找到所有可以放置35x35掩膜的有效中心点
|
||
# 确保掩膜完全在有效区域内
|
||
valid_centers = []
|
||
for y in range(half_size, height - half_size):
|
||
for x in range(half_size, width - half_size):
|
||
# 检查以(x, y)为中心,大小为35x35的区域是否完全在有效区域内
|
||
y_start, y_end = y - half_size, y + half_size + 1
|
||
x_start, x_end = x - half_size, x + half_size + 1
|
||
region = valid_region[y_start:y_end, x_start:x_end]
|
||
if np.all(region > 0): # 整个35x35区域都在有效区域内
|
||
valid_centers.append((y, x))
|
||
|
||
# 如果有效中心点不足50个,则使用所有可用的中心点
|
||
num_masks = min(num_masks, len(valid_centers))
|
||
|
||
if num_masks == 0:
|
||
print("Warning: No valid positions found for generating masks.")
|
||
return new_mask_array
|
||
|
||
# 随机选择50个(或更少)中心点
|
||
if len(valid_centers) > num_masks:
|
||
selected_centers = np.random.choice(len(valid_centers), size=num_masks, replace=False)
|
||
selected_centers = [valid_centers[i] for i in selected_centers]
|
||
else:
|
||
selected_centers = valid_centers
|
||
|
||
# 在每个选定的中心点生成35x35的掩膜
|
||
mask_value = 1 # 可以设置为不同的值来区分不同的掩膜
|
||
for y, x in selected_centers:
|
||
y_start, y_end = y - half_size, y + half_size + 1
|
||
x_start, x_end = x - half_size, x + half_size + 1
|
||
new_mask_array[y_start:y_end, x_start:x_end] = mask_value
|
||
mask_value += 1
|
||
print(f"Generated {len(selected_centers)} masks of size {mask_size}x{mask_size}")
|
||
|
||
# 保存滤纸掩膜,塑料掩膜,以及新生成的掩膜为不同颜色,保存至同一个图片上,保存至bil_path的masks文件夹下
|
||
if bil_path is not None:
|
||
# 确保输出目录存在
|
||
os.makedirs(os.path.join(os.path.dirname(bil_path), 'masks'), exist_ok=True)
|
||
|
||
# 创建RGB图像用于可视化(黑色背景)
|
||
height, width = filter_mask_original.shape
|
||
combined_visualization = np.zeros((height, width, 3), dtype=np.uint8)
|
||
|
||
# 滤纸掩膜用蓝色表示
|
||
combined_visualization[:, :, 2] = filter_mask_binary * 100 # B通道
|
||
|
||
# 塑料掩膜用红色表示
|
||
combined_visualization[:, :, 0] = mask_binary * 255 # R通道
|
||
|
||
# 新生成的掩膜用绿色表示
|
||
new_mask_binary = (new_mask_array > 0).astype(np.uint8)
|
||
combined_visualization[:, :, 1] = np.maximum(combined_visualization[:, :, 1], new_mask_binary * 255) # G通道
|
||
|
||
# 获取文件名(不含扩展名)
|
||
filename = os.path.splitext(os.path.basename(bil_path))[0]
|
||
output_path = os.path.join(os.path.join(os.path.dirname(bil_path), 'masks'),
|
||
f"{filename}_mask_visualization.png")
|
||
|
||
# 保存图像
|
||
cv2.imwrite(output_path, combined_visualization)
|
||
print(f"Saved mask visualization to: {output_path}")
|
||
|
||
# 合并掩膜:将新生成的掩膜和原塑料掩膜合并
|
||
# 新掩膜使用不同的标签值,避免与原掩膜冲突
|
||
combined_mask = mask.copy().astype(np.uint16)
|
||
# 将新掩膜添加到合并掩膜中(使用较大的标签值,如1000+)
|
||
combined_mask[new_mask_array > 0] = new_mask_array[new_mask_array > 0] + 1000
|
||
|
||
return new_mask_array
|
||
|
||
|
||
def process_single_bil(bil_path, num_masks=50, rng_seed=None):
|
||
"""
|
||
处理单个BIL文件,生成滤纸背景样本的光谱特征行,并返回DataFrame
|
||
"""
|
||
try:
|
||
print(f"\n{'=' * 60}")
|
||
print(f"Processing: {os.path.basename(bil_path)}")
|
||
print(f"{'=' * 60}")
|
||
|
||
# 处理BIL文件生成RGB图像
|
||
print("Processing BIL file to generate RGB image...")
|
||
rgb_img = process_bil_files(bil_path)
|
||
|
||
# HDR:仅在缺失 wavelength 时补齐;并尽量对齐训练相机波长(与 main.py 一致)
|
||
change_hdr_file(bil_path, TRAIN_WAVELENGTHS)
|
||
|
||
# 生成掩膜:返回塑料掩膜 + 滤纸掩膜
|
||
print("Generating mask...")
|
||
mask, filter_mask_original = detect_microplastic_mask_from_array(
|
||
image=rgb_img,
|
||
filter_method='threshold',
|
||
diameter=None,
|
||
flow_threshold=0.4,
|
||
cellprob_threshold=-1,
|
||
model_path=None,
|
||
detect_filter=True
|
||
)
|
||
|
||
# 生成新的随机背景小块掩膜(在滤纸内且不与塑料重叠)
|
||
if rng_seed is not None:
|
||
np.random.seed(int(rng_seed))
|
||
new_mask_array = generate_new_mask(filter_mask_original, mask, num_masks=num_masks, bil_path=bil_path)
|
||
|
||
# 提取光谱与形状特征(仅限新背景小块)
|
||
print("Extracting features from BIL file...")
|
||
pcv.observations = {} # 清理plantcv状态
|
||
df = process_images(bil_path, new_mask_array)
|
||
|
||
# 背景校正(用整图的滤纸背景光谱作为除数)+ 可选重采样到训练相机波长
|
||
print("Applying background correction (+ optional resample)...")
|
||
bg_spectrum = process_images_background(bil_path, mask)
|
||
df = apply_background_and_optional_resample_for_samples(df, bg_spectrum, bil_path)
|
||
|
||
# 数据清理:去NA、轮廓点数不足、面积过小过滤(与 main.py 对齐)
|
||
print("Cleaning data...")
|
||
df = df.dropna()
|
||
df = df[df['contour'].apply(lambda x: len(x) > 1 if isinstance(x, list) else True)]
|
||
df = df[df['area'] >= 500]
|
||
|
||
# 添加文件名列(不含扩展名)
|
||
filename = os.path.splitext(os.path.basename(bil_path))[0]
|
||
df.insert(0, 'filename', filename)
|
||
|
||
print(f"Extracted {len(df)} background objects from {os.path.basename(bil_path)}")
|
||
return df
|
||
|
||
except Exception as e:
|
||
print(f"Error processing {bil_path}: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def main():
|
||
# 支持文件或目录;新增可调样本数量与随机种子,便于复现
|
||
bil_path_or_folder = r"D:\Data\Traindata-11"
|
||
output_csv_path = r"E:\plastic\plastic\output\滤纸样本光谱\11.csv"
|
||
num_masks = 50
|
||
rng_seed = 42
|
||
|
||
os.makedirs(os.path.dirname(output_csv_path), exist_ok=True)
|
||
|
||
if os.path.isfile(bil_path_or_folder):
|
||
bil_files = [bil_path_or_folder]
|
||
elif os.path.isdir(bil_path_or_folder):
|
||
bil_files = [os.path.join(bil_path_or_folder, f) for f in os.listdir(bil_path_or_folder) if f.endswith('.bil')]
|
||
print(f"Found {len(bil_files)} BIL files to process")
|
||
else:
|
||
print(f"Error: {bil_path_or_folder} is not a valid file or directory")
|
||
return
|
||
|
||
is_first_row = True
|
||
total_objects = 0
|
||
|
||
for i, bil_path in enumerate(bil_files, 1):
|
||
print(f"\n[{i}/{len(bil_files)}] Processing file...")
|
||
df = process_single_bil(bil_path, num_masks=num_masks, rng_seed=rng_seed)
|
||
|
||
if df is not None and len(df) > 0:
|
||
df.to_csv(
|
||
output_csv_path,
|
||
mode='a' if not is_first_row else 'w',
|
||
index=False,
|
||
header=is_first_row
|
||
)
|
||
total_objects += len(df)
|
||
is_first_row = False
|
||
print(f" -> {len(df)} objects appended to CSV file")
|
||
|
||
if total_objects > 0:
|
||
print(f"\nSummary:")
|
||
print(f" Total files processed: {len(bil_files)}")
|
||
print(f" Total background objects collected: {total_objects}")
|
||
print(f" Output file: {output_csv_path}")
|
||
else:
|
||
print("\nNo results to save.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |