初始提交
This commit is contained in:
162
onlyspectral_background.py
Normal file
162
onlyspectral_background.py
Normal file
@ -0,0 +1,162 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
def process_single_bil_background(full_bil_path):
|
||||
"""
|
||||
处理单个BIL文件,提取滤纸背景的平均光谱(不需要mask)
|
||||
|
||||
参数:
|
||||
full_bil_path: BIL文件路径
|
||||
|
||||
返回:
|
||||
mean_spectrum: 归一化的平均光谱数组
|
||||
"""
|
||||
try:
|
||||
from plantcv import plantcv as pcv
|
||||
import cv2
|
||||
except ImportError:
|
||||
raise ImportError("需要导入plantcv和cv2库:from plantcv import plantcv as pcv; import cv2")
|
||||
|
||||
# 设置 PlantCV 的全局参数
|
||||
pcv.params.debug = "None"
|
||||
pcv.params.dpi = 100
|
||||
pcv.params.text_size = 1
|
||||
pcv.params.text_thickness = 1
|
||||
|
||||
# 读取光谱图像(ENVI 格式)
|
||||
spectral_array = pcv.readimage(filename=full_bil_path, mode='envi')
|
||||
bath_100 = spectral_array.array_data[:, :, 100]
|
||||
|
||||
# 使用阈值提取滤纸区域(背景区域,阈值>5500)
|
||||
bath_over = pcv.threshold.binary(gray_img=bath_100, threshold=5500, object_type='light')
|
||||
|
||||
# 对滤纸掩膜进行形态学操作
|
||||
bath_erode = pcv.erode(gray_img=bath_over, ksize=3, i=4)
|
||||
dilate_img = pcv.dilate(gray_img=bath_erode, ksize=3, i=4)
|
||||
|
||||
# 填充大的孔洞
|
||||
dilate_img = pcv.fill(bin_img=dilate_img, size=10000)
|
||||
|
||||
# 使用cv2.connectedComponents找到连通区域
|
||||
num1, labeled_spectral_mask = cv2.connectedComponents(dilate_img.astype(np.uint8))
|
||||
|
||||
print(f" 检测到 {num1 - 1} 个滤纸区域")
|
||||
|
||||
# 找到最大面积的对象
|
||||
max_area = 0
|
||||
max_label = 0
|
||||
|
||||
for label in range(1, num1):
|
||||
obj_mask = labeled_spectral_mask == label
|
||||
area = np.sum(obj_mask)
|
||||
if area > max_area:
|
||||
max_area = area
|
||||
max_label = label
|
||||
|
||||
# 提取最大面积对象的光谱
|
||||
if max_label > 0:
|
||||
max_obj_mask = labeled_spectral_mask == max_label
|
||||
spectral_values = spectral_array.array_data[max_obj_mask, :]
|
||||
mean_spectrum = np.mean(spectral_values, axis=0) / 10000 # 归一化
|
||||
return mean_spectrum
|
||||
else:
|
||||
print(f" 警告: 未找到有效的滤纸区域")
|
||||
n_bands = spectral_array.array_data.shape[-1]
|
||||
return np.zeros(n_bands)
|
||||
|
||||
|
||||
def batch_process_bil_background(bil_folder, output_csv_path):
|
||||
"""
|
||||
批量处理文件夹中的BIL文件,提取滤纸背景光谱并保存到CSV
|
||||
|
||||
参数:
|
||||
bil_folder: 包含.bil文件的文件夹路径
|
||||
output_csv_path: 输出CSV文件路径(宽格式:每行一个文件,每列一个波段)
|
||||
|
||||
返回:
|
||||
df: 包含所有文件光谱的DataFrame
|
||||
"""
|
||||
bil_folder = Path(bil_folder)
|
||||
if not bil_folder.exists():
|
||||
raise ValueError(f"文件夹不存在: {bil_folder}")
|
||||
|
||||
# 查找所有.bil文件
|
||||
bil_files = []
|
||||
for ext in ['.bil', '.BIL']:
|
||||
bil_files.extend(list(bil_folder.glob(f'*{ext}')))
|
||||
|
||||
if len(bil_files) == 0:
|
||||
print(f"警告: 在 {bil_folder} 中未找到任何.bil文件")
|
||||
return None
|
||||
|
||||
print(f"找到 {len(bil_files)} 个BIL文件")
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_csv_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 存储所有结果
|
||||
all_spectra = []
|
||||
filenames = []
|
||||
failed_files = []
|
||||
|
||||
# 处理每个文件
|
||||
for i, bil_path in enumerate(bil_files, 1):
|
||||
print(f"\n[{i}/{len(bil_files)}] 处理文件: {bil_path.name}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
try:
|
||||
spectrum = process_single_bil_background(str(bil_path))
|
||||
|
||||
if spectrum is not None and len(spectrum) > 0 and np.sum(spectrum) > 0:
|
||||
all_spectra.append(spectrum)
|
||||
filenames.append(bil_path.stem)
|
||||
print(f" ✓ 成功提取光谱 ({len(spectrum)} 个波段)")
|
||||
else:
|
||||
failed_files.append(bil_path.name)
|
||||
print(f" ✗ 未能提取有效光谱")
|
||||
|
||||
except Exception as e:
|
||||
failed_files.append(bil_path.name)
|
||||
print(f" ✗ 处理失败: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# 合并结果
|
||||
if len(all_spectra) == 0:
|
||||
print("\n没有有效的结果可以保存")
|
||||
return None
|
||||
|
||||
# 创建宽格式DataFrame
|
||||
n_bands = len(all_spectra[0])
|
||||
column_names = [f'band_{i}' for i in range(n_bands)]
|
||||
|
||||
df = pd.DataFrame(all_spectra, columns=column_names, index=filenames)
|
||||
df.index.name = 'filename'
|
||||
df = df.reset_index()
|
||||
|
||||
# 保存到CSV
|
||||
df.to_csv(output_csv_path, index=False)
|
||||
|
||||
# 打印统计信息
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"批量处理完成!")
|
||||
print(f" 总共文件数: {len(bil_files)}")
|
||||
print(f" 成功处理: {len(all_spectra)}")
|
||||
print(f" 失败文件: {len(failed_files)}")
|
||||
if failed_files:
|
||||
print(f" 失败列表: {', '.join(failed_files)}")
|
||||
print(f" 光谱波段数: {n_bands}")
|
||||
print(f" 输出文件: {output_csv_path}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
|
||||
bil_folder = r"D:\Data\Traindata-09\whitebord"
|
||||
output_csv = r"D:\Data\Traindata-09\whitebord\filter_paper_spectra.csv"
|
||||
|
||||
batch_process_bil_background(bil_folder, output_csv)
|
||||
Reference in New Issue
Block a user