Files
HSI/classfication_method/bil2png.py

464 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import glob
class BilToPngConverter:
"""
将BIL/BSQ/BIP文件转换为PNG图像的工具类
"""
def __init__(self, bil_folder, output_folder=None):
"""
初始化转换器
参数:
bil_folder: BIL文件所在的文件夹路径
output_folder: 输出PNG的文件夹路径如果为None则创建在bil_folder下的'png_output'文件夹
"""
self.bil_folder = Path(bil_folder)
if output_folder is None:
self.output_folder = self.bil_folder / 'png_output'
else:
self.output_folder = Path(output_folder)
# 创建输出文件夹
self.output_folder.mkdir(parents=True, exist_ok=True)
def find_bil_files(self):
"""查找文件夹中的所有BIL文件"""
bil_files = []
# 使用glob查找所有BIL文件
for ext in ['.bil', '.BIL']:
bil_files.extend(glob.glob(str(self.bil_folder / f'*{ext}')))
# 去重并排序
bil_files = sorted(list(set(bil_files)))
print(f"找到 {len(bil_files)} 个BIL文件")
return [Path(f) for f in bil_files]
def read_envi_header(self, header_path):
"""读取ENVI头文件获取图像尺寸和格式信息"""
rows, cols, bands = None, None, None
data_type = None
interleave = 'bsq' # 默认值
byte_order = 0 # 默认小端字节序
try:
with open(header_path, 'r') as f:
content = f.read()
# 解析关键参数
lines = content.split('\n')
for line in lines:
line = line.strip()
if '=' in line:
key, value = line.split('=', 1)
key = key.strip()
value = value.strip()
if key.lower() == 'samples':
cols = int(value)
elif key.lower() == 'lines':
rows = int(value)
elif key.lower() == 'bands':
bands = int(value)
elif key.lower() == 'data type':
data_type = int(value)
elif key.lower() == 'interleave':
interleave = value.lower() # bsq, bil, bip
elif key.lower() == 'byte order':
byte_order = int(value)
print(f"从头文件读取: {rows}x{cols}x{bands}, interleave={interleave}, data_type={data_type}")
except Exception as e:
print(f"读取头文件时出错: {e}")
return rows, cols, bands, data_type, interleave, byte_order
def get_numpy_dtype(self, envi_data_type):
"""将ENVI数据类型转换为numpy数据类型"""
dtype_map = {
1: np.uint8, # 8-bit byte
2: np.int16, # 16-bit signed integer
3: np.int32, # 32-bit signed long integer
4: np.float32, # 32-bit floating point
5: np.float64, # 64-bit double precision floating point
12: np.uint16, # 16-bit unsigned integer
13: np.uint32, # 32-bit unsigned long integer
14: np.int64, # 64-bit signed long integer
15: np.uint64, # 64-bit unsigned long integer
}
return dtype_map.get(envi_data_type, np.float32)
def read_envi_file(self, bil_path, rows=None, cols=None, bands=None):
"""
读取ENVI格式文件支持BSQ, BIL, BIP格式
参数:
bil_path: 数据文件路径
rows: 图像行数(高度)
cols: 图像列数(宽度)
bands: 波段数
返回:
图像数据数组,形状为(rows, cols, bands)
"""
# 尝试读取同名的头文件(.hdr)获取尺寸信息
header_path = bil_path.with_suffix('.hdr')
if header_path.exists():
print(f"正在读取头文件: {header_path}")
rows, cols, bands, data_type, interleave, byte_order = self.read_envi_header(header_path)
if data_type is None:
data_type = 4 # 默认float32
dtype = self.get_numpy_dtype(data_type)
else:
# 如果没有头文件,使用默认值或参数值
if rows is None or cols is None or bands is None:
raise ValueError(f"找不到头文件 {header_path}请手动指定rows, cols, bands参数")
print(f"未找到头文件,使用手动指定的尺寸: {rows}x{cols}x{bands}")
dtype = np.float32
interleave = 'bsq' # 默认
# 计算文件大小并验证
expected_size = rows * cols * bands * np.dtype(dtype).itemsize
actual_size = bil_path.stat().st_size
if expected_size != actual_size:
print(f"警告: 文件大小不匹配。预期: {expected_size}字节,实际: {actual_size}字节")
print(f"这可能是由于数据压缩或头部信息导致")
# 读取二进制数据
print(f"正在读取数据文件: {bil_path.name}, 尺寸: {rows}x{cols}x{bands}, 格式: {interleave}")
try:
# 读取原始数据
raw_data = np.fromfile(bil_path, dtype=dtype)
# 根据数据交织方式重塑数据
if interleave == 'bsq':
# BSQ: Band Sequential (波段顺序)
# 存储顺序: 波段1的所有行所有列 -> 波段2的所有行所有列 -> ...
data = raw_data.reshape(bands, rows, cols)
# 转换为标准格式: (rows, cols, bands)
data = np.transpose(data, (1, 2, 0))
elif interleave == 'bil':
# BIL: Band Interleaved by Line (按行交织)
# 存储顺序: 行1的所有波段 -> 行2的所有波段 -> ...
data = raw_data.reshape(rows, bands, cols)
# 转换为标准格式: (rows, cols, bands)
data = np.transpose(data, (0, 2, 1))
elif interleave == 'bip':
# BIP: Band Interleaved by Pixel (按像素交织)
# 存储顺序: 像素1的所有波段 -> 像素2的所有波段 -> ...
data = raw_data.reshape(rows, cols, bands)
# 已经是标准格式,无需转置
else:
print(f"未知的交织格式: {interleave}尝试BSQ格式")
data = raw_data.reshape(bands, rows, cols)
data = np.transpose(data, (1, 2, 0))
print(f"成功读取,数据形状: {data.shape}, 数据类型: {data.dtype}")
return data
except Exception as e:
print(f"读取数据文件时出错: {e}")
print(f"原始数据大小: {len(raw_data)}, 期望大小: {rows * cols * bands}")
raise
def normalize_image(self, image_data, percentile_low=2, percentile_high=98):
"""
归一化图像数据
参数:
image_data: 输入图像数据
percentile_low: 低百分位数
percentile_high: 高百分位数
返回:
归一化后的图像数据 (0-1范围)
"""
# 如果有多波段,分别归一化每个波段
if len(image_data.shape) == 3 and image_data.shape[2] > 1:
normalized = np.zeros_like(image_data, dtype=np.float32)
for band in range(image_data.shape[2]):
band_data = image_data[:, :, band]
# 去除异常值
low = np.percentile(band_data, percentile_low)
high = np.percentile(band_data, percentile_high)
# 截断
band_clipped = np.clip(band_data, low, high)
# 归一化到0-1
if high > low:
normalized[:, :, band] = (band_clipped - low) / (high - low)
else:
normalized[:, :, band] = band_data - band_data.min()
if normalized[:, :, band].max() > 0:
normalized[:, :, band] /= normalized[:, :, band].max()
else:
# 单波段图像
if len(image_data.shape) == 3:
image_data = image_data.squeeze()
# 去除异常值
low = np.percentile(image_data, percentile_low)
high = np.percentile(image_data, percentile_high)
# 截断
clipped = np.clip(image_data, low, high)
# 归一化到0-1
if high > low:
normalized = (clipped - low) / (high - low)
else:
normalized = image_data - image_data.min()
if normalized.max() > 0:
normalized /= normalized.max()
return np.clip(normalized, 0, 1)
def apply_colormap(self, image_data, colormap='viridis'):
"""
应用颜色映射
参数:
image_data: 归一化的图像数据 (0-1范围)
colormap: 颜色映射名称
返回:
RGB图像数据
"""
# 获取颜色映射
cmap = plt.cm.get_cmap(colormap)
# 应用颜色映射
colored = cmap(image_data)
# 转换为RGB (移除alpha通道)
if colored.shape[2] == 4:
colored = colored[:, :, :3]
return colored
def extract_title_from_filename(self, filename):
"""
从文件名提取标题(以下划线分割的第二个元素)
参数:
filename: 文件名
返回:
提取的标题
"""
# 获取文件名(不带扩展名)
stem = Path(filename).stem
# 以下划线分割
parts = stem.split('_')
if len(parts) >= 2:
return parts[1]
else:
# 如果没有足够的下划线分割,返回文件名本身
return stem
def envi_to_png(self, bil_path, output_name=None,
rows=None, cols=None, bands=None,
colormap='viridis', dpi=150):
"""
将ENVI格式文件转换为PNG
参数:
bil_path: 数据文件路径
output_name: 输出文件名如果为None则使用BIL文件名
rows: 图像行数
cols: 图像列数
bands: 波段数
colormap: 颜色映射名称
dpi: 输出图像DPI
返回:
输出文件路径
"""
try:
# 读取ENVI文件
image_data = self.read_envi_file(bil_path, rows, cols, bands)
# 提取标题
title = self.extract_title_from_filename(bil_path.name)
# 如果有多波段选择前三个波段用于RGB显示
if len(image_data.shape) == 3 and image_data.shape[2] >= 3:
# 取前三个波段
rgb_data = image_data[:, :, :3]
# 分别归一化每个波段
normalized_r = self.normalize_image(rgb_data[:, :, 0])
normalized_g = self.normalize_image(rgb_data[:, :, 1])
normalized_b = self.normalize_image(rgb_data[:, :, 2])
# 组合为RGB
normalized_rgb = np.stack([normalized_r, normalized_g, normalized_b], axis=2)
# 用于显示
display_image = normalized_rgb
print(f"使用RGB合成波段1,2,3")
elif len(image_data.shape) == 3 and image_data.shape[2] == 1:
# 单波段,使用颜色映射
single_band = image_data.squeeze()
normalized = self.normalize_image(single_band)
display_image = self.apply_colormap(normalized, colormap)
print(f"使用单波段 + 颜色映射")
else:
# 单波段(二维数组)
normalized = self.normalize_image(image_data)
display_image = self.apply_colormap(normalized, colormap)
print(f"使用单波段 + 颜色映射")
# 创建输出文件名
if output_name is None:
output_name = bil_path.stem + '.png'
else:
if not output_name.endswith('.png'):
output_name += '.png'
output_path = self.output_folder / output_name
# 创建图形 - 修复关键问题:确保每次创建新图形
plt.figure(figsize=(10, 8), dpi=100)
# 显示图像
plt.imshow(display_image)
plt.title(f'{title}', fontsize=16, fontweight='bold')
plt.axis('off')
# 调整布局
plt.tight_layout()
# 保存图像
plt.savefig(output_path, dpi=dpi, bbox_inches='tight', pad_inches=0.1)
plt.close() # 关闭当前图形
print(f"已保存: {output_path}")
return output_path
except Exception as e:
print(f"转换文件 {bil_path.name} 时出错: {e}")
import traceback
traceback.print_exc()
return None
def batch_convert(self, rows=None, cols=None, bands=None, colormap='viridis'):
"""
批量转换所有ENVI格式文件
参数:
rows: 图像行数
cols: 图像列数
bands: 波段数
colormap: 颜色映射名称
返回:
成功转换的文件列表
"""
# 查找所有数据文件
bil_files = self.find_bil_files()
if not bil_files:
print("未找到BIL文件")
return []
successful_conversions = []
print(f"\n开始批量转换 {len(bil_files)} 个文件...")
print("=" * 50)
for i, bil_path in enumerate(bil_files, 1):
print(f"\n处理文件 {i}/{len(bil_files)}: {bil_path.name}")
print("-" * 30)
# 生成输出文件名
output_name = f"{bil_path.stem}.png"
# 转换文件
output_path = self.envi_to_png(
bil_path,
output_name=output_name,
rows=rows,
cols=cols,
bands=bands,
colormap=colormap
)
if output_path:
successful_conversions.append(output_path)
print("\n" + "=" * 50)
print(f"批量转换完成!")
print(f"成功转换: {len(successful_conversions)}/{len(bil_files)} 个文件")
print(f"输出文件夹: {self.output_folder}")
return successful_conversions
def main():
"""主函数:使用示例"""
# 设置BIL文件夹路径
bil_folder = r"E:\code\spectronon\single_classsfication\output"
# 检查文件夹是否存在
if not os.path.exists(bil_folder):
print(f"错误: 文件夹 '{bil_folder}' 不存在")
return
# 创建转换器
converter = BilToPngConverter(bil_folder)
# 批量转换
# 注意如果BIL文件没有头文件需要手动指定rows, cols, bands参数
# 例如converter.batch_convert(rows=512, cols=512, bands=1)
# 自动尝试转换(如果有头文件)
successful = converter.batch_convert()
if not successful:
print("\n自动转换失败,可能需要手动指定图像尺寸。")
# 尝试手动指定
try:
rows = int(input("请输入图像行数(高度): "))
cols = int(input("请输入图像列数(宽度): "))
bands = int(input("请输入波段数单波段图像请输入1: "))
# 重新尝试转换
converter.batch_convert(rows=rows, cols=cols, bands=bands)
except ValueError:
print("请输入有效的数字")
# 如果直接运行此脚本
if __name__ == "__main__":
main()