1300 lines
53 KiB
Python
1300 lines
53 KiB
Python
"""
|
||
光谱到颜色转换工具
|
||
|
||
功能:
|
||
- 将高光谱反射率数据转换为 CIE XYZ 三刺激值
|
||
- 支持多种标准光源(D65、D50、A、F系列等)
|
||
- 支持不同观察者角度(CIE 1931 2° 和 CIE 1964 10°)
|
||
- 输入格式:高光谱图片(ENVI格式)或 CSV 文件
|
||
- 输出格式:dat文件图像或CSV文件
|
||
|
||
依赖库:
|
||
- colour-science: 颜色科学计算
|
||
- spectral: ENVI文件处理
|
||
- numpy, pandas: 数据处理
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import os
|
||
from pathlib import Path
|
||
from typing import Optional, Tuple, Union, List
|
||
import warnings
|
||
import traceback
|
||
from scipy.interpolate import interp1d
|
||
|
||
# 尝试导入 colour 库
|
||
try:
|
||
import colour
|
||
COLOUR_AVAILABLE = True
|
||
print(f"Colour 库版本: {colour.__version__}")
|
||
|
||
# 新版本的 colour 库使用这些导入方式
|
||
from colour import MSDS_CMFS, SDS_ILLUMINANTS, CCS_ILLUMINANTS
|
||
from colour import SpectralDistribution, SpectralShape
|
||
from colour.colorimetry import sd_to_XYZ
|
||
from colour.models import XYZ_to_sRGB, XYZ_to_Lab
|
||
|
||
# 调试:查看可用的 CMFS
|
||
print("可用的 CMFS 键:", list(MSDS_CMFS.keys())[:5], "...")
|
||
print("可用的光源键:", list(SDS_ILLUMINANTS.keys())[:5], "...")
|
||
|
||
except ImportError as e:
|
||
COLOUR_AVAILABLE = False
|
||
print(f"导入 colour 库时出错: {e}")
|
||
print("警告: colour-science库不可用,请安装: pip install colour-science")
|
||
# 创建空对象以避免后续错误
|
||
MSDS_CMFS = None
|
||
SDS_ILLUMINANTS = None
|
||
SpectralDistribution = None
|
||
sd_to_XYZ = None
|
||
XYZ_to_sRGB = None
|
||
SpectralShape = None
|
||
|
||
# 尝试导入 spectral 库
|
||
try:
|
||
import spectral
|
||
SPECTRAL_AVAILABLE = True
|
||
except ImportError:
|
||
SPECTRAL_AVAILABLE = False
|
||
print("警告: spectral库不可用,将无法处理ENVI格式文件")
|
||
|
||
class SpectralToColorConverter:
|
||
"""
|
||
光谱到颜色转换器
|
||
|
||
支持将高光谱反射率数据转换为多种颜色空间:CIE XYZ、xyY、L*a*b*、L*C*h*
|
||
"""
|
||
|
||
# 标准光源定义 - 映射到 colour 库中的键
|
||
ILLUMINANTS = {
|
||
'D65': 'D65', # 日光(6500K)
|
||
'D50': 'D50', # 日光(5000K)
|
||
'A': 'A', # 白炽灯(2856K)
|
||
'F2': 'F2', # 荧光灯(冷白)
|
||
'F7': 'F7', # 荧光灯(宽带日光)
|
||
'F11': 'F11', # 荧光灯(三色荧光)
|
||
}
|
||
|
||
# 观察者定义 - 映射到 colour 库中的键
|
||
OBSERVERS = {
|
||
'2°': 'CIE 1931 2 Degree Standard Observer',
|
||
'10°': 'CIE 1964 10 Degree Standard Observer'
|
||
}
|
||
|
||
# 输出颜色空间定义
|
||
COLOR_SPACES = {
|
||
'XYZ': 'CIE XYZ',
|
||
'xyY': 'CIE xyY',
|
||
'Lab': 'CIE L*a*b*',
|
||
'LCH': 'CIE L*C*h*'
|
||
}
|
||
|
||
def __init__(self, illuminant: str = 'D65', observer: str = '2°',
|
||
color_space: str = 'XYZ', normalize_xyz: bool = True, use_parallel: bool = True, n_jobs: int = -1):
|
||
"""
|
||
初始化转换器
|
||
|
||
参数:
|
||
illuminant: 光源类型 ('D65', 'D50', 'A', 'F2', 'F7', 'F11')
|
||
observer: 观察者角度 ('2°', '10°')
|
||
color_space: 输出颜色空间 ('XYZ', 'xyY', 'Lab', 'LCH')
|
||
normalize_xyz: 当color_space='XYZ'时,是否归一化到0-1范围 (默认: True)
|
||
use_parallel: 是否使用并行处理(对于大图像)
|
||
n_jobs: 并行作业数 (-1表示使用所有可用CPU核心)
|
||
"""
|
||
if not COLOUR_AVAILABLE:
|
||
raise ImportError("需要安装colour-science库: pip install colour-science")
|
||
|
||
if illuminant not in self.ILLUMINANTS:
|
||
raise ValueError(f"不支持的光源类型: {illuminant}. 支持的类型: {list(self.ILLUMINANTS.keys())}")
|
||
|
||
if observer not in self.OBSERVERS:
|
||
raise ValueError(f"不支持的观察者角度: {observer}. 支持的角度: {list(self.OBSERVERS.keys())}")
|
||
|
||
if color_space not in self.COLOR_SPACES:
|
||
raise ValueError(f"不支持的颜色空间: {color_space}. 支持的类型: {list(self.COLOR_SPACES.keys())}")
|
||
|
||
self.illuminant = illuminant
|
||
self.observer = observer
|
||
self.color_space = color_space
|
||
self.normalize_xyz = normalize_xyz
|
||
self.use_parallel = use_parallel
|
||
self.n_jobs = n_jobs
|
||
|
||
# 初始化colour库的颜色空间
|
||
self._setup_color_space()
|
||
|
||
print(f"初始化完成 - 光源: {illuminant}, 观察者: {observer}, 颜色空间: {color_space}, XYZ归一化: {normalize_xyz}, 并行处理: {use_parallel}")
|
||
|
||
def _setup_color_space(self):
|
||
"""设置颜色空间参数"""
|
||
try:
|
||
# 获取标准观察者的颜色匹配函数 (CMFS)
|
||
observer_key = self.OBSERVERS[self.observer]
|
||
if observer_key not in MSDS_CMFS:
|
||
# 尝试其他可能的键
|
||
if self.observer == '2°':
|
||
observer_key = 'CIE 1931 2 Degree Standard Observer'
|
||
else:
|
||
observer_key = 'CIE 1964 10 Degree Standard Observer'
|
||
|
||
self.cmfs = MSDS_CMFS[observer_key]
|
||
print(f"加载观察者: {observer_key}")
|
||
|
||
# 获取光源的光谱功率分布 (SPD)
|
||
illuminant_key = self.ILLUMINANTS[self.illuminant]
|
||
self.illuminant_spd = SDS_ILLUMINANTS[illuminant_key]
|
||
print(f"加载光源: {illuminant_key}")
|
||
|
||
# 调试信息
|
||
print(f"CMFS 波长范围: {self.cmfs.wavelengths[0]} - {self.cmfs.wavelengths[-1]} nm")
|
||
print(f"光源波长范围: {self.illuminant_spd.wavelengths[0]} - {self.illuminant_spd.wavelengths[-1]} nm")
|
||
|
||
except KeyError as e:
|
||
print(f"可用的 CMFS 键: {list(MSDS_CMFS.keys())}")
|
||
print(f"可用的光源键: {list(SDS_ILLUMINANTS.keys())}")
|
||
raise ValueError(f"无法加载颜色空间参数: {e}。请检查 colour 库版本。")
|
||
|
||
def load_csv_data(self, file_path: Union[str, Path], wavelength_col: str = 'wavelength',
|
||
reflectance_start_col: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
加载CSV格式的光谱数据
|
||
|
||
参数:
|
||
file_path: CSV文件路径
|
||
wavelength_col: 波长列名(默认第一列)
|
||
reflectance_start_col: 反射率数据起始列名(如果为None,则使用wavelength_col后的所有列)
|
||
|
||
返回:
|
||
wavelengths: 波长数组 (nm)
|
||
reflectances: 反射率数据数组 (samples x wavelengths)
|
||
"""
|
||
try:
|
||
df = pd.read_csv(file_path)
|
||
print(f"CSV 文件列名: {list(df.columns)}")
|
||
|
||
# 获取波长数据
|
||
if wavelength_col not in df.columns:
|
||
# 如果没有指定波长列,使用第一列
|
||
wavelengths = df.iloc[:, 0].values.astype(float)
|
||
print(f"使用第一列作为波长数据,列名: {df.columns[0]}")
|
||
else:
|
||
wavelengths = df[wavelength_col].values.astype(float)
|
||
print(f"使用指定列作为波长数据,列名: {wavelength_col}")
|
||
|
||
# 获取反射率数据
|
||
if reflectance_start_col is None:
|
||
# 使用除波长列外的所有列
|
||
if wavelength_col in df.columns:
|
||
reflectance_cols = [col for col in df.columns if col != wavelength_col]
|
||
else:
|
||
reflectance_cols = df.columns[1:] # 跳过第一列(波长)
|
||
else:
|
||
# 从指定列开始的所有列
|
||
col_names = list(df.columns)
|
||
start_idx = col_names.index(reflectance_start_col)
|
||
reflectance_cols = col_names[start_idx:]
|
||
|
||
reflectances = df[reflectance_cols].values.astype(float)
|
||
print(f"反射率数据形状: {reflectances.shape}")
|
||
print(f"波长数据形状: {wavelengths.shape}")
|
||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||
|
||
# 确保波长和反射率数据维度匹配
|
||
if reflectances.shape[1] != len(wavelengths):
|
||
print(f"警告: 反射率数据列数 ({reflectances.shape[1]}) 与波长数量 ({len(wavelengths)}) 不匹配")
|
||
# 尝试转置
|
||
if reflectances.shape[0] == len(wavelengths):
|
||
reflectances = reflectances.T
|
||
print(f"已转置反射率数据,新形状: {reflectances.shape}")
|
||
|
||
except Exception as e:
|
||
print(f"加载 CSV 数据时出错: {e}")
|
||
traceback.print_exc()
|
||
raise
|
||
|
||
print(f"从CSV文件加载数据完成: {reflectances.shape[0]} 个样本, {len(wavelengths)} 个波长点")
|
||
return wavelengths, reflectances
|
||
|
||
def load_envi_data(self, file_path: Union[str, Path]) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""
|
||
加载ENVI格式的高光谱数据
|
||
|
||
参数:
|
||
file_path: ENVI文件路径(.dat, .bil, .bsq等)
|
||
|
||
返回:
|
||
wavelengths: 波长数组 (nm)
|
||
reflectances: 反射率数据数组 (height x width x wavelengths)
|
||
"""
|
||
if not SPECTRAL_AVAILABLE:
|
||
raise ImportError("需要安装spectral库才能处理ENVI格式文件")
|
||
|
||
try:
|
||
# 读取ENVI文件
|
||
img = spectral.open_image(str(file_path))
|
||
print(f"ENVI 文件信息: 形状={img.shape}, 数据类型={img.dtype}")
|
||
|
||
# 获取波长信息
|
||
if hasattr(img, 'bands') and hasattr(img.bands, 'centers'):
|
||
wavelengths = np.array(img.bands.centers)
|
||
print(f"从文件头读取波长信息: {len(wavelengths)} 个波段")
|
||
else:
|
||
# 如果没有波长信息,使用默认的可见光范围
|
||
wavelengths = np.linspace(400, 700, img.shape[2])
|
||
warnings.warn("ENVI文件中未找到波长信息,使用默认400-700nm范围")
|
||
|
||
# 获取反射率数据
|
||
reflectances = img.load().astype(float)
|
||
|
||
# 如果数据是整数类型,假设是0-255或0-65535范围,需要归一化到0-1
|
||
if reflectances.dtype in [np.uint8, np.uint16, np.uint32]:
|
||
max_val = np.iinfo(reflectances.dtype).max
|
||
reflectances = reflectances / max_val
|
||
print(f"数据从 {reflectances.dtype} 归一化到 [0, 1]")
|
||
elif reflectances.dtype in [np.uint16]:
|
||
# 特殊处理:如果数据看起来像是0-10000范围(常见的高光谱格式),归一化到0-1
|
||
if reflectances.max() > 1.0 and reflectances.max() <= 10000:
|
||
reflectances = reflectances / 10000.0
|
||
print(f"数据从 {reflectances.dtype} 按10000归一化到 [0, 1]")
|
||
elif reflectances.dtype in [np.float32, np.float64]:
|
||
# 对于浮点数据,如果范围超过1.0,可能是0-10000范围
|
||
if reflectances.max() > 1.0 and reflectances.max() <= 10000:
|
||
reflectances = reflectances
|
||
print(f"浮点数据按10000归一化到 [0, 1]")
|
||
|
||
# 检查数据范围
|
||
print(f"反射率数据范围: [{reflectances.min():.6f}, {reflectances.max():.6f}]")
|
||
print(f"波长范围: {wavelengths[0]:.1f} - {wavelengths[-1]:.1f} nm")
|
||
|
||
except Exception as e:
|
||
print(f"加载 ENVI 数据时出错: {e}")
|
||
traceback.print_exc()
|
||
raise
|
||
|
||
print(f"从ENVI文件加载数据完成: {reflectances.shape}, 波长范围: {wavelengths[0]:.1f}-{wavelengths[-1]:.1f}nm")
|
||
return wavelengths, reflectances
|
||
|
||
def spectral_to_xyz(self, wavelengths: np.ndarray, reflectances: np.ndarray) -> np.ndarray:
|
||
"""
|
||
使用colour库将光谱反射率转换为CIE XYZ三刺激值
|
||
|
||
参数:
|
||
wavelengths: 波长数组 (nm)
|
||
reflectances: 反射率数组 (samples x wavelengths) 或 (height x width x wavelengths)
|
||
|
||
返回:
|
||
xyz: XYZ三刺激值数组 (samples x 3) 或 (height x width x 3)
|
||
"""
|
||
print(f"开始光谱到XYZ转换...")
|
||
print(f"输入反射率形状: {reflectances.shape}")
|
||
print(f"输入波长范围: {wavelengths.min():.1f} - {wavelengths.max():.1f} nm")
|
||
|
||
original_shape = reflectances.shape
|
||
if len(original_shape) == 3: # (height, width, wavelengths)
|
||
original_spatial_shape = original_shape[:2]
|
||
reflectances_flat = reflectances.reshape(-1, original_shape[2])
|
||
print(f"三维数据展平为: {reflectances_flat.shape} (像素数 x 波长数)")
|
||
else: # (samples, wavelengths)
|
||
original_spatial_shape = (original_shape[0],)
|
||
reflectances_flat = reflectances
|
||
|
||
# 确保波长是递增的
|
||
if np.any(np.diff(wavelengths) < 0):
|
||
sort_idx = np.argsort(wavelengths)
|
||
wavelengths = wavelengths[sort_idx]
|
||
reflectances_flat = reflectances_flat[:, sort_idx]
|
||
print("已对波长进行排序")
|
||
|
||
# 获取CMF和光源的波长范围
|
||
cmf_wavelengths = self.cmfs.wavelengths
|
||
illuminant_wavelengths = self.illuminant_spd.wavelengths
|
||
|
||
print(f"CMF波长范围: {cmf_wavelengths.min():.1f} - {cmf_wavelengths.max():.1f} nm")
|
||
print(f"光源波长范围: {illuminant_wavelengths.min():.1f} - {illuminant_wavelengths.max():.1f} nm")
|
||
|
||
# ============================================
|
||
# 确定计算波长范围
|
||
# ============================================
|
||
# 由于CMF定义颜色感知,以CMF范围为准
|
||
# 但需要考虑光源的有效范围
|
||
target_min_wl = cmf_wavelengths.min() # 360 nm
|
||
target_max_wl = min(cmf_wavelengths.max(), illuminant_wavelengths.max()) # 取CMF和光源的最小上限
|
||
|
||
print(f"目标计算波长范围: {target_min_wl:.1f} - {target_max_wl:.1f} nm")
|
||
|
||
# 创建标准波长网格(基于CMF,但不超过光源最大波长)
|
||
# 使用1nm间隔以获得更精确的结果
|
||
standard_wavelengths = np.arange(
|
||
np.ceil(target_min_wl),
|
||
np.floor(target_max_wl) + 1,
|
||
1
|
||
)
|
||
print(f"标准波长网格: {len(standard_wavelengths)} 个点, {standard_wavelengths[0]:.1f}-{standard_wavelengths[-1]:.1f} nm")
|
||
|
||
# 检查波长覆盖情况
|
||
if wavelengths.min() > target_min_wl:
|
||
print(f"警告: 图像数据从 {wavelengths.min():.1f}nm 开始,需要外推到 {target_min_wl:.1f}nm")
|
||
if wavelengths.max() < target_max_wl:
|
||
print(f"警告: 图像数据到 {wavelengths.max():.1f}nm 结束,需要外推到 {target_max_wl:.1f}nm")
|
||
|
||
# 确保波长是唯一的和排序的
|
||
wavelengths_unique, unique_indices = np.unique(wavelengths, return_index=True)
|
||
reflectances_unique = reflectances_flat[:, unique_indices]
|
||
|
||
print(f"去重后波长形状: {wavelengths_unique.shape}, 反射率形状: {reflectances_unique.shape}")
|
||
|
||
# 插值图像数据到标准波长网格,并进行外推
|
||
print(f"开始插值并外推 {reflectances_flat.shape[0]} 个光谱到标准波长范围...")
|
||
|
||
# 创建插值函数 - 使用线性外推
|
||
interp_func = interp1d(
|
||
wavelengths_unique,
|
||
reflectances_unique.T,
|
||
kind='linear',
|
||
bounds_error=False, # 允许外推
|
||
fill_value='extrapolate', # 使用线性外推
|
||
axis=0
|
||
)
|
||
|
||
# 一次性插值所有光谱
|
||
reflectances_interp = interp_func(standard_wavelengths).T
|
||
|
||
# 对反射率进行裁剪,确保在合理范围内 [0, 1.2]
|
||
reflectances_interp = np.clip(reflectances_interp, 0, 1.2)
|
||
|
||
print(f"插值并外推后的反射率形状: {reflectances_interp.shape}")
|
||
print(f"反射率范围: [{reflectances_interp.min():.3f}, {reflectances_interp.max():.3f}]")
|
||
|
||
# ============================================
|
||
# 获取CMF在标准波长上的值
|
||
# ============================================
|
||
print(f"获取CMF在标准波长上的值...")
|
||
|
||
# 首先获取原始的CMF数据
|
||
cmf_values = self.cmfs.values
|
||
|
||
print(f"原始CMF形状: {cmf_values.shape}")
|
||
|
||
# 处理不同格式的CMF数据
|
||
if cmf_values.ndim == 2:
|
||
if cmf_values.shape[1] == 3:
|
||
# 形状为 (wavelengths, 3) - 这是您遇到的情况
|
||
# 我们需要转置为 (3, wavelengths) 以便后续计算
|
||
print(f"CMF格式: (wavelengths, 3) - 进行转置")
|
||
cmf_x = cmf_values[:, 0]
|
||
cmf_y = cmf_values[:, 1]
|
||
cmf_z = cmf_values[:, 2]
|
||
# 将CMF插值到标准波长网格
|
||
cmf_interp_func_x = interp1d(
|
||
cmf_wavelengths, cmf_x,
|
||
kind='linear', bounds_error=False, fill_value=0
|
||
)
|
||
cmf_interp_func_y = interp1d(
|
||
cmf_wavelengths, cmf_y,
|
||
kind='linear', bounds_error=False, fill_value=0
|
||
)
|
||
cmf_interp_func_z = interp1d(
|
||
cmf_wavelengths, cmf_z,
|
||
kind='linear', bounds_error=False, fill_value=0
|
||
)
|
||
cmf_x_interp = cmf_interp_func_x(standard_wavelengths)
|
||
cmf_y_interp = cmf_interp_func_y(standard_wavelengths)
|
||
cmf_z_interp = cmf_interp_func_z(standard_wavelengths)
|
||
|
||
cmf_values = np.array([cmf_x_interp, cmf_y_interp, cmf_z_interp])
|
||
elif cmf_values.shape[0] == 3:
|
||
# 形状为 (3, wavelengths) - 理想的格式
|
||
print(f"CMF格式: (3, wavelengths) - 直接使用")
|
||
# 将CMF插值到标准波长网格
|
||
cmf_interp_func = interp1d(
|
||
cmf_wavelengths, cmf_values,
|
||
kind='linear', bounds_error=False, fill_value=0,
|
||
axis=1
|
||
)
|
||
cmf_values = cmf_interp_func(standard_wavelengths)
|
||
else:
|
||
raise ValueError(f"不支持的CMF二维形状: {cmf_values.shape}")
|
||
elif cmf_values.ndim == 1:
|
||
# 单通道CMF,需要扩展到三通道
|
||
print(f"CMF格式: 一维数组 - 扩展到三通道")
|
||
cmf_interp_func = interp1d(
|
||
cmf_wavelengths, cmf_values,
|
||
kind='linear', bounds_error=False, fill_value=0
|
||
)
|
||
cmf_single = cmf_interp_func(standard_wavelengths)
|
||
cmf_values = np.array([cmf_single, cmf_single, cmf_single])
|
||
else:
|
||
raise ValueError(f"不支持的CMF维度: {cmf_values.ndim}")
|
||
|
||
print(f"处理后CMF形状: {cmf_values.shape}")
|
||
|
||
# ============================================
|
||
# 获取光源SPD在标准波长上的值
|
||
# ============================================
|
||
print(f"将光源SPD插值到标准波长网格...")
|
||
|
||
# 获取光源值
|
||
illuminant_values_raw = self.illuminant_spd.values
|
||
illuminant_wavelengths_raw = self.illuminant_spd.wavelengths
|
||
|
||
# 将光源插值到标准波长网格
|
||
illuminant_interp_func = interp1d(
|
||
illuminant_wavelengths_raw,
|
||
illuminant_values_raw,
|
||
kind='linear',
|
||
bounds_error=False,
|
||
fill_value=0 # 对于超出范围的部分,填充0
|
||
)
|
||
illuminant_values = illuminant_interp_func(standard_wavelengths)
|
||
|
||
# 确保光源值非负
|
||
illuminant_values = np.maximum(illuminant_values, 0)
|
||
|
||
print(f"插值后的光源形状: {illuminant_values.shape}")
|
||
|
||
# ============================================
|
||
# 计算XYZ值 - 向量化计算
|
||
# ============================================
|
||
print(f"开始计算 {reflectances_interp.shape[0]} 个像素的XYZ值...")
|
||
|
||
# XYZ计算公式: X = ∫ R(λ) * S(λ) * x̄(λ) dλ / ∫ S(λ) * x̄(λ) dλ
|
||
# 其中 R(λ) 是反射率,S(λ) 是光源SPD,x̄(λ) 是颜色匹配函数
|
||
|
||
# 计算积分(使用矩形积分)
|
||
delta_lambda = standard_wavelengths[1] - standard_wavelengths[0] # 假设均匀间隔
|
||
|
||
# 扩展维度以便广播运算
|
||
cmf_expanded = cmf_values[np.newaxis, :, :] # (1, 3, wavelengths)
|
||
illuminant_expanded = illuminant_values[np.newaxis, np.newaxis, :] # (1, 1, wavelengths)
|
||
|
||
# 计算分母: S(λ) * CMF(λ) - 这个在所有像素间是相同的
|
||
denominator = illuminant_expanded * cmf_expanded
|
||
xyz_denominator = np.sum(denominator * delta_lambda, axis=2) # (1, 3)
|
||
|
||
print(f"归一化分母: {xyz_denominator[0]}")
|
||
|
||
# 批处理计算以提高性能
|
||
batch_size = min(100000, reflectances_interp.shape[0]) # 每批处理最多100000个像素
|
||
total_pixels = reflectances_interp.shape[0]
|
||
xyz_array = np.zeros((total_pixels, 3))
|
||
|
||
print(f"使用批处理模式,每批 {batch_size} 个像素")
|
||
|
||
for start_idx in range(0, total_pixels, batch_size):
|
||
end_idx = min(start_idx + batch_size, total_pixels)
|
||
batch_reflectances = reflectances_interp[start_idx:end_idx] # (batch_size, wavelengths)
|
||
|
||
if start_idx == 0 or (start_idx // batch_size) % 10 == 0:
|
||
print(f"处理批次: {start_idx}-{end_idx}/{total_pixels} ({start_idx/total_pixels*100:.1f}%)")
|
||
|
||
# 扩展维度以便广播运算
|
||
reflectances_expanded = batch_reflectances[:, np.newaxis, :] # (batch, 1, wavelengths)
|
||
|
||
# 计算分子: R(λ) * S(λ) * CMF(λ)
|
||
numerator = reflectances_expanded * illuminant_expanded * cmf_expanded
|
||
|
||
# 沿波长维度积分
|
||
xyz_numerator = np.sum(numerator * delta_lambda, axis=2) # (batch, 3)
|
||
|
||
# 归一化
|
||
batch_xyz = xyz_numerator / (xyz_denominator + 1e-10) # 避免除零
|
||
|
||
xyz_array[start_idx:end_idx] = batch_xyz
|
||
|
||
print(f"批处理计算完成,XYZ形状: {xyz_array.shape}")
|
||
|
||
# 验证:计算完美漫反射体的XYZ(反射率全为1)
|
||
perfect_reflector = np.ones_like(standard_wavelengths)
|
||
perfect_numerator = np.sum(perfect_reflector * illuminant_values * cmf_values * delta_lambda, axis=1)
|
||
perfect_xyz = perfect_numerator / (xyz_denominator[0] + 1e-10)
|
||
print(f"完美漫反射体(反射率=1)的理论XYZ: {perfect_xyz}")
|
||
|
||
# 恢复原始空间维度
|
||
if len(original_shape) == 3: # 图像数据
|
||
xyz_array = xyz_array.reshape(original_spatial_shape[0], original_spatial_shape[1], 3)
|
||
print(f"恢复三维形状: {xyz_array.shape}")
|
||
elif len(original_spatial_shape) == 1: # 一维数据
|
||
xyz_array = xyz_array.reshape(original_spatial_shape[0], 3)
|
||
|
||
# 检查XYZ值范围
|
||
print(f"XYZ值范围:")
|
||
print(f" X: [{xyz_array[..., 0].min():.3f}, {xyz_array[..., 0].max():.3f}]")
|
||
print(f" Y: [{xyz_array[..., 1].min():.3f}, {xyz_array[..., 1].max():.3f}]")
|
||
print(f" Z: [{xyz_array[..., 2].min():.3f}, {xyz_array[..., 2].max():.3f}]")
|
||
|
||
return xyz_array
|
||
|
||
def spectral_to_xyz_colour(self, wavelengths: np.ndarray, reflectances: np.ndarray) -> np.ndarray:
|
||
"""
|
||
使用colour库的sd_to_XYZ函数将光谱反射率转换为CIE XYZ三刺激值
|
||
这是替代方法,使用colour库的内置函数
|
||
|
||
参数:
|
||
wavelengths: 波长数组 (nm)
|
||
reflectances: 反射率数组 (samples x wavelengths) 或 (height x width x wavelengths)
|
||
|
||
返回:
|
||
xyz: XYZ三刺激值数组 (samples x 3) 或 (height x width x 3)
|
||
"""
|
||
print(f"开始光谱到XYZ转换(使用colour库)...")
|
||
print(f"输入反射率形状: {reflectances.shape}")
|
||
print(f"输入波长范围: {wavelengths.min():.1f} - {wavelengths.max():.1f} nm")
|
||
|
||
original_shape = reflectances.shape
|
||
if len(original_shape) == 3: # (height, width, wavelengths)
|
||
original_spatial_shape = original_shape[:2]
|
||
reflectances_flat = reflectances.reshape(-1, original_shape[2])
|
||
print(f"三维数据展平为: {reflectances_flat.shape} (像素数 x 波长数)")
|
||
else: # (samples, wavelengths)
|
||
original_spatial_shape = (original_shape[0],)
|
||
reflectances_flat = reflectances
|
||
|
||
# 确保波长是递增的
|
||
if np.any(np.diff(wavelengths) < 0):
|
||
sort_idx = np.argsort(wavelengths)
|
||
wavelengths = wavelengths[sort_idx]
|
||
reflectances_flat = reflectances_flat[:, sort_idx]
|
||
print("已对波长进行排序")
|
||
|
||
# 获取CMF和光源的波长范围
|
||
cmf_wavelengths = self.cmfs.wavelengths
|
||
illuminant_wavelengths = self.illuminant_spd.wavelengths
|
||
|
||
print(f"CMF波长范围: {cmf_wavelengths.min():.1f} - {cmf_wavelengths.max():.1f} nm")
|
||
print(f"光源波长范围: {illuminant_wavelengths.min():.1f} - {illuminant_wavelengths.max():.1f} nm")
|
||
|
||
# ============================================
|
||
# 确定计算波长范围 - 使用colour推荐的范围
|
||
# ============================================
|
||
# 使用colour库推荐的范围:通常360-780nm(当use_practice_range=True时)
|
||
target_min_wl = max(360.0, wavelengths.min())
|
||
target_max_wl = min(780.0, wavelengths.max(), illuminant_wavelengths.max())
|
||
|
||
# 确保目标范围有效
|
||
target_min_wl = max(target_min_wl, 360.0)
|
||
target_max_wl = min(target_max_wl, 780.0)
|
||
|
||
print(f"目标计算波长范围: {target_min_wl:.1f} - {target_max_wl:.1f} nm")
|
||
|
||
# 创建标准波长网格(使用5nm间隔,这是ASTM E308标准常用的间隔)
|
||
standard_wavelengths = np.arange(
|
||
np.ceil(target_min_wl),
|
||
np.floor(target_max_wl) + 1,
|
||
5
|
||
)
|
||
print(f"标准波长网格(5nm间隔): {len(standard_wavelengths)} 个点, {standard_wavelengths[0]:.1f}-{standard_wavelengths[-1]:.1f} nm")
|
||
|
||
# 检查波长覆盖情况
|
||
if wavelengths.min() > target_min_wl:
|
||
print(f"警告: 图像数据从 {wavelengths.min():.1f}nm 开始,需要外推到 {target_min_wl:.1f}nm")
|
||
if wavelengths.max() < target_max_wl:
|
||
print(f"警告: 图像数据到 {wavelengths.max():.1f}nm 结束,需要外推到 {target_max_wl:.1f}nm")
|
||
|
||
# 确保波长是唯一的和排序的
|
||
wavelengths_unique, unique_indices = np.unique(wavelengths, return_index=True)
|
||
reflectances_unique = reflectances_flat[:, unique_indices]
|
||
|
||
print(f"去重后波长形状: {wavelengths_unique.shape}, 反射率形状: {reflectances_unique.shape}")
|
||
|
||
# 插值图像数据到标准波长网格
|
||
print(f"开始插值 {reflectances_flat.shape[0]} 个光谱到标准波长范围...")
|
||
|
||
# 创建插值函数 - 使用线性外推
|
||
interp_func = interp1d(
|
||
wavelengths_unique,
|
||
reflectances_unique.T,
|
||
kind='linear',
|
||
bounds_error=False,
|
||
fill_value='extrapolate', # 使用线性外推
|
||
axis=0
|
||
)
|
||
|
||
# 一次性插值所有光谱
|
||
reflectances_interp = interp_func(standard_wavelengths).T
|
||
|
||
# 对反射率进行裁剪,确保在合理范围内 [0, 1.0]
|
||
# colour库期望反射率在0-1范围内
|
||
reflectances_interp = np.clip(reflectances_interp, 0, 1.0)
|
||
|
||
print(f"插值后的反射率形状: {reflectances_interp.shape}")
|
||
print(f"反射率范围: [{reflectances_interp.min():.3f}, {reflectances_interp.max():.3f}]")
|
||
|
||
# ============================================
|
||
# 使用colour库的sd_to_XYZ函数计算XYZ
|
||
# ============================================
|
||
print(f"开始使用colour库计算 {reflectances_interp.shape[0]} 个像素的XYZ值...")
|
||
|
||
# 初始化XYZ数组
|
||
total_pixels = reflectances_interp.shape[0]
|
||
xyz_array = np.zeros((total_pixels, 3))
|
||
|
||
# 批处理计算以提高性能
|
||
batch_size = min(5000, total_pixels) # 每批处理5000个像素
|
||
print(f"使用批处理模式,每批 {batch_size} 个像素")
|
||
|
||
# 创建一个形状对象,用于光谱插值
|
||
shape = SpectralShape(standard_wavelengths[0], standard_wavelengths[-1], 5)
|
||
|
||
# 将CMFS和光源插值到标准波长网格
|
||
cmfs_interp = self.cmfs.interpolate(shape)
|
||
illuminant_interp = self.illuminant_spd.interpolate(shape)
|
||
|
||
# 批处理计算
|
||
for start_idx in range(0, total_pixels, batch_size):
|
||
end_idx = min(start_idx + batch_size, total_pixels)
|
||
|
||
if start_idx % (batch_size * 10) == 0:
|
||
print(f"处理批次: {start_idx}-{end_idx}/{total_pixels} ({start_idx/total_pixels*100:.1f}%)")
|
||
|
||
# 获取当前批次的光谱
|
||
batch_spectra = reflectances_interp[start_idx:end_idx]
|
||
|
||
# 为每个光谱创建SpectralDistribution对象
|
||
batch_xyz = []
|
||
for i in range(batch_spectra.shape[0]):
|
||
try:
|
||
# 创建光谱分布对象
|
||
sd = SpectralDistribution(
|
||
batch_spectra[i],
|
||
standard_wavelengths,
|
||
name=f'Sample_{start_idx + i}'
|
||
)
|
||
|
||
# 使用colour库的sd_to_XYZ函数
|
||
xyz = sd_to_XYZ(
|
||
sd,
|
||
cmfs=cmfs_interp,
|
||
illuminant=illuminant_interp,
|
||
method="Integration" # 或者使用 "ASTM E308"
|
||
)
|
||
|
||
batch_xyz.append(xyz)
|
||
except Exception as e:
|
||
print(f"像素 {start_idx + i} 计算失败: {e}")
|
||
# 使用零值作为后备
|
||
batch_xyz.append([0, 0, 0])
|
||
|
||
# 保存结果
|
||
xyz_array[start_idx:end_idx] = np.array(batch_xyz)
|
||
|
||
print(f"批处理计算完成,XYZ形状: {xyz_array.shape}")
|
||
|
||
# ============================================
|
||
# 验证结果
|
||
# ============================================
|
||
print(f"\n验证计算:")
|
||
|
||
# 验证:计算完美漫反射体(反射率全为1)的XYZ
|
||
perfect_reflector_sd = SpectralDistribution(
|
||
np.ones_like(standard_wavelengths),
|
||
standard_wavelengths,
|
||
name='Perfect Reflector'
|
||
)
|
||
|
||
# 计算完美漫反射体的XYZ
|
||
perfect_xyz = sd_to_XYZ(
|
||
perfect_reflector_sd,
|
||
cmfs=cmfs_interp,
|
||
illuminant=illuminant_interp
|
||
)
|
||
|
||
print(f"完美漫反射体(反射率=1)的XYZ: {perfect_xyz}")
|
||
print(f"注意: 完美漫反射体的Y值应为100(归一化后),当前Y值: {perfect_xyz[1]}")
|
||
|
||
# 恢复原始空间维度
|
||
if len(original_shape) == 3: # 图像数据
|
||
xyz_array = xyz_array.reshape(original_spatial_shape[0], original_spatial_shape[1], 3)
|
||
print(f"恢复三维形状: {xyz_array.shape}")
|
||
elif len(original_spatial_shape) == 1: # 一维数据
|
||
xyz_array = xyz_array.reshape(original_spatial_shape[0], 3)
|
||
|
||
# 检查XYZ值范围
|
||
print(f"\nXYZ值范围:")
|
||
print(f" X: [{xyz_array[..., 0].min():.3f}, {xyz_array[..., 0].max():.3f}]")
|
||
print(f" Y: [{xyz_array[..., 1].min():.3f}, {xyz_array[..., 1].max():.3f}]")
|
||
print(f" Z: [{xyz_array[..., 2].min():.3f}, {xyz_array[..., 2].max():.3f}]")
|
||
|
||
return xyz_array
|
||
|
||
def convert_color_space(self, xyz: np.ndarray) -> np.ndarray:
|
||
"""
|
||
将XYZ转换为指定的颜色空间
|
||
|
||
参数:
|
||
xyz: XYZ数组 (..., 3)
|
||
|
||
返回:
|
||
转换后的颜色数组 (..., 3 或 4)
|
||
"""
|
||
if self.color_space == 'XYZ':
|
||
if self.normalize_xyz:
|
||
# 使用标准白点进行归一化(得到相对XYZ值)
|
||
whitepoint = self._get_whitepoint_xyz()
|
||
return xyz / whitepoint
|
||
else:
|
||
# 返回绝对XYZ值(0-100范围,这是CIE标准)
|
||
return xyz
|
||
elif self.color_space == 'xyY':
|
||
return self.xyz_to_xyy(xyz)
|
||
elif self.color_space == 'Lab':
|
||
return self.xyz_to_lab(xyz)
|
||
elif self.color_space == 'LCH':
|
||
return self.xyz_to_lch(xyz)
|
||
else:
|
||
raise ValueError(f"不支持的颜色空间: {self.color_space}")
|
||
|
||
def _get_whitepoint_xyz(self) -> np.ndarray:
|
||
"""
|
||
获取当前光源的XYZ白点值
|
||
|
||
使用colour库计算标准白点值,方法:
|
||
1. 获取光源的光谱功率分布 (SPD)
|
||
2. 获取标准观察者的颜色匹配函数 (CMFs)
|
||
3. 计算sd_to_XYZ得到相对值
|
||
4. 归一化到Y=100
|
||
|
||
返回:
|
||
白点XYZ值数组 [X, Y, Z],范围0-100
|
||
"""
|
||
try:
|
||
# 1. 获取光源的光谱数据 (SpectralDistribution)
|
||
illuminant_spd = self.illuminant_spd
|
||
|
||
# 2. 获取标准观察者的颜色匹配函数 (CMFs)
|
||
if self.observer == '2°':
|
||
cmfs = MSDS_CMFS['CIE 1931 2 Degree Standard Observer']
|
||
else:
|
||
cmfs = MSDS_CMFS['CIE 1964 10 Degree Standard Observer']
|
||
|
||
# 3. 将光谱转换为XYZ三刺激值 (得到相对值)
|
||
xyz_relative = sd_to_XYZ(illuminant_spd, cmfs)
|
||
|
||
# 4. 归一化到Y=100
|
||
xyz_normalized = (xyz_relative / xyz_relative[1]) * 100
|
||
|
||
print(f"计算得{self.illuminant}光源白点 (X, Y, Z): {xyz_normalized}")
|
||
|
||
return np.array([xyz_normalized[0], xyz_normalized[1], xyz_normalized[2]])
|
||
|
||
except Exception as e:
|
||
print(f"计算白点值失败: {e}, 使用标准值")
|
||
|
||
# 回退到标准白点值
|
||
if self.illuminant == 'D65':
|
||
return np.array([95.047, 100.0, 108.883])
|
||
elif self.illuminant == 'D50':
|
||
return np.array([96.422, 100.0, 82.521])
|
||
elif self.illuminant == 'A':
|
||
return np.array([109.850, 100.0, 35.585])
|
||
else:
|
||
# 默认使用D65
|
||
return np.array([95.047, 100.0, 108.883])
|
||
|
||
def xyz_to_xyy(self, xyz: np.ndarray) -> np.ndarray:
|
||
"""
|
||
将XYZ转换为xyY颜色空间
|
||
|
||
参数:
|
||
xyz: XYZ数组 (..., 3)
|
||
|
||
返回:
|
||
xyy: xyY数组 (..., 3)
|
||
"""
|
||
# 获取白点值用于归一化
|
||
whitepoint = self._get_whitepoint_xyz()
|
||
|
||
# 对XYZ进行白点归一化
|
||
xyz_norm = xyz / whitepoint
|
||
|
||
# XYZ to xyY conversion (使用归一化后的值)
|
||
x = xyz_norm[..., 0] / (xyz_norm[..., 0] + xyz_norm[..., 1] + xyz_norm[..., 2] + 1e-10)
|
||
y = xyz_norm[..., 1] / (xyz_norm[..., 0] + xyz_norm[..., 1] + xyz_norm[..., 2] + 1e-10)
|
||
Y = xyz_norm[..., 1] # 亮度分量(归一化后的Y)
|
||
|
||
# 处理可能的NaN值
|
||
x = np.nan_to_num(x, nan=0.0)
|
||
y = np.nan_to_num(y, nan=0.0)
|
||
Y = np.nan_to_num(Y, nan=0.0)
|
||
|
||
return np.stack([x, y, Y], axis=-1)
|
||
|
||
def xyz_to_lab(self, xyz: np.ndarray) -> np.ndarray:
|
||
"""
|
||
将XYZ转换为L*a*b*颜色空间
|
||
|
||
参数:
|
||
xyz: XYZ数组 (..., 3),绝对XYZ值(范围0-100)
|
||
|
||
返回:
|
||
lab: L*a*b*数组 (..., 3)
|
||
"""
|
||
# 获取白点XYZ值
|
||
whitepoint_xyz = self._get_whitepoint_xyz()
|
||
|
||
# 对XYZ进行白点归一化(得到相对XYZ值,范围[0, 1])
|
||
xyz_normalized = xyz / whitepoint_xyz
|
||
|
||
# 使用 CCS_ILLUMINANTS 获取参考光源的色度坐标 (x, y)
|
||
observer_key = self.OBSERVERS[self.observer]
|
||
illuminant_xy = CCS_ILLUMINANTS[observer_key][self.illuminant]
|
||
|
||
# 使用归一化后的XYZ值和色度坐标进行转换
|
||
lab = XYZ_to_Lab(xyz_normalized, illuminant_xy)
|
||
return lab
|
||
|
||
def xyz_to_lch(self, xyz: np.ndarray) -> np.ndarray:
|
||
"""
|
||
将XYZ转换为L*C*h*颜色空间
|
||
|
||
参数:
|
||
xyz: XYZ数组 (..., 3)
|
||
|
||
返回:
|
||
lch: L*C*h*数组 (..., 3)
|
||
"""
|
||
# 先转换为Lab,然后转换为LCH
|
||
lab = self.xyz_to_lab(xyz)
|
||
|
||
L = lab[..., 0]
|
||
a = lab[..., 1]
|
||
b = lab[..., 2]
|
||
|
||
# 计算色度C和色调h
|
||
C = np.sqrt(a**2 + b**2)
|
||
h = np.arctan2(b, a) * 180 / np.pi # 转换为度
|
||
h = np.where(h < 0, h + 360, h) # 确保h在0-360度范围内
|
||
|
||
# 处理可能的NaN值
|
||
C = np.nan_to_num(C, nan=0.0)
|
||
h = np.nan_to_num(h, nan=0.0)
|
||
|
||
return np.stack([L, C, h], axis=-1)
|
||
|
||
def process_file(self, input_path: Union[str, Path], output_path: Optional[Union[str, Path]] = None,
|
||
output_format: str = 'csv', input_type: Optional[str] = None,
|
||
wavelength_col: str = 'wavelength', reflectance_start_col: Optional[str] = None,
|
||
use_colour_library: bool = False) -> np.ndarray:
|
||
"""
|
||
处理单个文件并转换为颜色
|
||
|
||
参数:
|
||
input_path: 输入文件路径
|
||
output_path: 输出文件路径(如果为None,则不保存文件)
|
||
output_format: 输出格式 ('csv' 或 'dat')
|
||
input_type: 输入类型 ('csv' 或 'envi',如果为None则自动检测)
|
||
wavelength_col: CSV文件的波长列名
|
||
reflectance_start_col: CSV文件的反射率起始列名
|
||
use_colour_library: 是否使用colour库的sd_to_XYZ函数(如果为False则使用自定义实现)
|
||
|
||
返回:
|
||
xyz: XYZ三刺激值数组
|
||
"""
|
||
input_path = Path(input_path)
|
||
|
||
if not input_path.exists():
|
||
raise FileNotFoundError(f"输入文件不存在: {input_path}")
|
||
|
||
# 自动检测输入类型
|
||
if input_type is None:
|
||
if input_path.suffix.lower() == '.csv':
|
||
input_type = 'csv'
|
||
else:
|
||
input_type = 'envi'
|
||
|
||
print(f"输入文件类型: {input_type}")
|
||
|
||
# 加载数据
|
||
if input_type == 'csv':
|
||
wavelengths, reflectances = self.load_csv_data(
|
||
input_path, wavelength_col, reflectance_start_col
|
||
)
|
||
elif input_type == 'envi':
|
||
wavelengths, reflectances = self.load_envi_data(input_path)
|
||
else:
|
||
raise ValueError(f"不支持的输入类型: {input_type}")
|
||
|
||
# 转换为XYZ
|
||
if use_colour_library:
|
||
print("\n使用colour库的sd_to_XYZ函数进行转换...")
|
||
xyz = self.spectral_to_xyz_colour(wavelengths, reflectances)
|
||
else:
|
||
print("\n使用自定义实现进行转换...")
|
||
xyz = self.spectral_to_xyz(wavelengths, reflectances)
|
||
|
||
# 转换为指定的颜色空间
|
||
color_data = self.convert_color_space(xyz)
|
||
print(f"转换为{self.color_space}颜色空间,数据形状: {color_data.shape}")
|
||
|
||
# 保存结果
|
||
if output_path is not None:
|
||
# 如果输入是ENVI格式,传递HDR文件路径以复制元数据
|
||
input_hdr_path = None
|
||
if input_type == 'envi':
|
||
input_hdr_path = Path(input_path).with_suffix('.hdr')
|
||
if not input_hdr_path.exists():
|
||
input_hdr_path = None
|
||
|
||
self.save_color_data(color_data, output_path, output_format, input_hdr_path)
|
||
|
||
return xyz
|
||
|
||
def save_color_data(self, color_data: np.ndarray, output_path: Union[str, Path], output_format: str = 'csv',
|
||
input_hdr_path: Optional[Union[str, Path]] = None):
|
||
"""
|
||
保存颜色数据到文件
|
||
|
||
参数:
|
||
color_data: 颜色数据数组
|
||
output_path: 输出文件路径
|
||
output_format: 输出格式 ('csv' 或 'dat')
|
||
"""
|
||
output_path = Path(output_path)
|
||
|
||
print(f"保存{self.color_space}数据到: {output_path}, 格式: {output_format}")
|
||
|
||
# 根据颜色空间确定列名
|
||
if self.color_space == 'XYZ':
|
||
column_names = ['X', 'Y', 'Z']
|
||
elif self.color_space == 'xyY':
|
||
column_names = ['x', 'y', 'Y']
|
||
elif self.color_space == 'Lab':
|
||
column_names = ['L*', 'a*', 'b*']
|
||
elif self.color_space == 'LCH':
|
||
column_names = ['L*', 'C*', 'h*']
|
||
else:
|
||
column_names = [f'Band_{i+1}' for i in range(color_data.shape[-1])]
|
||
|
||
if output_format == 'csv':
|
||
# 保存为CSV格式
|
||
if len(color_data.shape) == 3: # 图像数据 (height, width, channels)
|
||
# 为图像数据创建CSV,每行是一个像素的颜色值
|
||
color_flat = color_data.reshape(-1, color_data.shape[-1])
|
||
df = pd.DataFrame(color_flat, columns=column_names)
|
||
df.to_csv(output_path, index=False)
|
||
print(f"保存为CSV: {color_flat.shape[0]} 行数据")
|
||
else: # 一维数据 (samples, channels)
|
||
df = pd.DataFrame(color_data, columns=column_names)
|
||
df.to_csv(output_path, index=False)
|
||
print(f"保存为CSV: {color_data.shape[0]} 行数据")
|
||
|
||
elif output_format == 'dat':
|
||
# 保存为ENVI格式的.dat文件,参考classification.py的方式
|
||
try:
|
||
from osgeo import gdal
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
|
||
# 为ENVI格式创建必要的元数据
|
||
if len(color_data.shape) == 3: # 图像数据 (height, width, channels)
|
||
height, width, channels = color_data.shape
|
||
|
||
# 确保目录存在
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"保存ENVI文件 - 原始{self.color_space}形状: {color_data.shape}")
|
||
|
||
if GDAL_AVAILABLE and channels <= 4: # 支持最多4个通道
|
||
# 使用GDAL保存,参考classification.py的方式
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
|
||
# 创建多波段ENVI文件
|
||
out_dataset = driver.Create(
|
||
str(output_path), width, height, channels,
|
||
gdal.GDT_Float32, options=['INTERLEAVE=BIL']
|
||
)
|
||
|
||
if out_dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||
|
||
# 写入每个波段的数据
|
||
for band_idx in range(channels):
|
||
band_data = color_data[:, :, band_idx] # (height, width)
|
||
out_band = out_dataset.GetRasterBand(band_idx + 1)
|
||
out_band.WriteArray(band_data.astype(np.float32))
|
||
|
||
# 设置波段名称
|
||
out_band.SetDescription(column_names[band_idx])
|
||
|
||
# 设置地理信息 (如果有的话,这里没有地理信息)
|
||
# out_dataset.SetGeoTransform(geotransform)
|
||
# out_dataset.SetProjection(projection)
|
||
|
||
out_dataset.FlushCache()
|
||
out_dataset = None
|
||
|
||
# 创建HDR文件,保持输入文件的元数据
|
||
self._create_envi_hdr_file(output_path, height, width, channels, 'float32', input_hdr_path)
|
||
|
||
print(f"使用GDAL保存为ENVI文件: {output_path}")
|
||
print(f"数据类型: float32, 波段数: {channels}, 大小: {width}x{height}")
|
||
|
||
elif SPECTRAL_AVAILABLE:
|
||
# 回退到spectral库保存
|
||
color_transposed = np.transpose(color_data, (2, 0, 1)) # (channels, height, width)
|
||
|
||
# spectral.envi.save_image 需要 HDR 文件路径,而不是 DAT 文件路径
|
||
hdr_path = output_path.with_suffix('.hdr')
|
||
|
||
# 保存为ENVI文件
|
||
spectral.envi.save_image(str(hdr_path), color_transposed, dtype=np.float32,
|
||
interleave='bil')
|
||
|
||
print(f"使用spectral库保存为ENVI文件: HDR={hdr_path}, DAT={output_path}")
|
||
|
||
else:
|
||
raise ImportError("需要安装GDAL或spectral库才能保存为ENVI格式")
|
||
|
||
else: # 一维数据 (samples, channels)
|
||
# 一维数据保存为CSV
|
||
csv_path = output_path.with_suffix('.csv')
|
||
np.savetxt(csv_path, color_data, delimiter=',',
|
||
header=','.join(column_names), comments='')
|
||
print(f"一维数据保存为CSV: {csv_path}")
|
||
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {output_format}")
|
||
|
||
print(f"结果已保存到: {output_path}")
|
||
|
||
def _create_envi_hdr_file(self, bil_path: Union[str, Path], height: int, width: int,
|
||
bands: int, data_type: str, input_hdr_path: Optional[Union[str, Path]] = None) -> None:
|
||
"""
|
||
创建ENVI头文件,参考classification.py的方式,并保持输入文件的元数据
|
||
|
||
Args:
|
||
bil_path: BIL文件路径
|
||
height: 图像高度
|
||
width: 图像宽度
|
||
bands: 波段数
|
||
data_type: 数据类型 ('float32', 'uint8', 'int16', 等)
|
||
input_hdr_path: 输入ENVI文件的HDR路径,用于复制元数据
|
||
"""
|
||
hdr_path = Path(bil_path).with_suffix('.hdr')
|
||
|
||
# 数据类型映射 (ENVI格式)
|
||
dtype_map = {
|
||
'uint8': '1',
|
||
'int16': '2',
|
||
'int32': '3',
|
||
'float32': '4',
|
||
'float64': '5',
|
||
'complex64': '6',
|
||
'complex128': '9',
|
||
'uint16': '12',
|
||
'uint32': '13',
|
||
'int64': '14',
|
||
'uint64': '15'
|
||
}
|
||
|
||
envi_dtype = dtype_map.get(data_type, '4') # 默认为float32
|
||
|
||
# 根据颜色空间确定波段名称
|
||
if self.color_space == 'XYZ':
|
||
band_names = ["X/Xn (CIE 1931)", "Y/Yn (CIE 1931)", "Z/Zn (CIE 1931)"]
|
||
elif self.color_space == 'xyY':
|
||
band_names = ["x (CIE 1931)", "y (CIE 1931)", "Y/Yn (CIE 1931)"]
|
||
elif self.color_space == 'Lab':
|
||
band_names = ["L* (CIE L*a*b*)", "a* (CIE L*a*b*)", "b* (CIE L*a*b*)"]
|
||
elif self.color_space == 'LCH':
|
||
band_names = ["L* (CIE L*C*h*)", "C* (CIE L*C*h*)", "h* (CIE L*C*h*)"]
|
||
else:
|
||
band_names = [f"Band {i+1}" for i in range(bands)]
|
||
|
||
with open(hdr_path, 'w') as f:
|
||
f.write("ENVI\n")
|
||
f.write("description = {\n")
|
||
f.write(f" {self.COLOR_SPACES[self.color_space]} Color Data - Generated by SpectralToColorConverter\n")
|
||
f.write(f" Illuminant: {self.illuminant}, Observer: {self.observer}\n")
|
||
f.write("}\n")
|
||
f.write(f"samples = {width}\n")
|
||
f.write(f"lines = {height}\n")
|
||
f.write(f"bands = {bands}\n")
|
||
f.write("header offset = 0\n")
|
||
f.write("file type = ENVI Standard\n")
|
||
f.write(f"data type = {envi_dtype}\n")
|
||
f.write("interleave = bil\n")
|
||
f.write("sensor type = Unknown\n")
|
||
f.write("byte order = 0\n")
|
||
|
||
# 波段名称
|
||
f.write("band names = {\n")
|
||
for i, name in enumerate(band_names):
|
||
f.write(f' "{name}"')
|
||
if i < len(band_names) - 1:
|
||
f.write(",")
|
||
f.write("\n")
|
||
f.write("}\n")
|
||
|
||
print(f"ENVI头文件创建完成: {hdr_path}")
|
||
|
||
def main():
|
||
"""
|
||
主函数:命令行使用示例
|
||
"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='光谱到颜色转换工具')
|
||
parser.add_argument('--input', required=True, help='输入文件路径 (CSV或ENVI格式)')
|
||
parser.add_argument('-o', '--output', help='输出文件路径')
|
||
parser.add_argument('--illuminant', default='D65', choices=['D65', 'D50', 'A', 'F2', 'F7', 'F11'],
|
||
help='光源类型 (默认: D65)')
|
||
parser.add_argument('--observer', default='2°', choices=['2°', '10°'],
|
||
help='观察者角度 (默认: 2°)')
|
||
parser.add_argument('--color-space', default='XYZ', choices=['XYZ', 'xyY', 'Lab', 'LCH'],
|
||
help='输出颜色空间 (默认: XYZ)')
|
||
parser.add_argument('--normalize-xyz', action='store_true', default=True,
|
||
help='当输出XYZ时,归一化到0-1范围 (默认: True)')
|
||
parser.add_argument('--no-normalize-xyz', action='store_true',
|
||
help='当输出XYZ时,不进行归一化,保持0-100范围')
|
||
parser.add_argument('--format', default='csv', choices=['csv', 'dat'],
|
||
help='输出格式 (默认: csv)')
|
||
parser.add_argument('--wavelength-col', default='wavelength',
|
||
help='CSV文件的波长列名 (默认: wavelength)')
|
||
parser.add_argument('--reflectance-start', default=None,
|
||
help='CSV文件的反射率起始列名')
|
||
parser.add_argument('--use-colour', action='store_true',
|
||
help='使用colour库的sd_to_XYZ函数进行转换')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 处理XYZ归一化参数
|
||
normalize_xyz = args.normalize_xyz and not args.no_normalize_xyz
|
||
|
||
# 创建转换器
|
||
converter = SpectralToColorConverter(
|
||
illuminant=args.illuminant,
|
||
observer=args.observer,
|
||
color_space=args.color_space,
|
||
normalize_xyz=normalize_xyz
|
||
)
|
||
|
||
# 处理文件
|
||
try:
|
||
xyz = converter.process_file(
|
||
args.input,
|
||
args.output,
|
||
args.format,
|
||
wavelength_col=args.wavelength_col,
|
||
reflectance_start_col=args.reflectance_start,
|
||
use_colour_library=args.use_colour
|
||
)
|
||
|
||
print(f"处理完成! XYZ数据形状: {xyz.shape}")
|
||
print(f"XYZ值范围: X[{xyz[..., 0].min():.2f}, {xyz[..., 0].max():.2f}], "
|
||
f"Y[{xyz[..., 1].min():.2f}, {xyz[..., 1].max():.2f}], "
|
||
f"Z[{xyz[..., 2].min():.2f}, {xyz[..., 2].max():.2f}]")
|
||
|
||
except Exception as e:
|
||
print(f"处理文件时出错: {e}")
|
||
traceback.print_exc()
|
||
|
||
|
||
# 测试代码
|
||
def main():
|
||
"""主函数:命令行接口"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(
|
||
description='光谱到颜色空间转换工具',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
1. 将高光谱数据转换为Lab颜色空间:
|
||
python spectral2cie2.py input.bip.hdr -c Lab -o output_lab.dat
|
||
|
||
2. 将CSV光谱数据转换为XYZ颜色空间:
|
||
python spectral2cie2.py input.csv -c XYZ -o output_xyz.csv -w wavelength
|
||
|
||
3. 使用不同光源和观察者:
|
||
python spectral2cie2.py input.hdr -c Lab -i D65 -b "10°" -o output.dat
|
||
|
||
支持的颜色空间:
|
||
XYZ: CIE XYZ色度空间
|
||
xyY: CIE xyY色度空间
|
||
Lab: CIE L*a*b*色度空间
|
||
LCH: CIE L*C*h*色度空间
|
||
|
||
支持的光源:
|
||
D65, D50, A, F2, F7, F11
|
||
|
||
支持的观察者:
|
||
2°, 10°
|
||
"""
|
||
)
|
||
|
||
parser.add_argument('--input', required=True, help='输入文件路径 (.hdr高光谱图像 或 .csv文件)')
|
||
parser.add_argument('-c', '--color_space', default='Lab',
|
||
choices=['XYZ', 'xyY', 'Lab', 'LCH'],
|
||
help='输出颜色空间 (默认: Lab)')
|
||
parser.add_argument('-o', '--output', required=True,
|
||
help='输出文件路径')
|
||
parser.add_argument('-f', '--format', default='dat', choices=['csv', 'dat'],
|
||
help='输出格式 (默认: dat)')
|
||
parser.add_argument('-i', '--illuminant', default='D65',
|
||
choices=['D65', 'D50', 'A', 'F2', 'F7', 'F11'],
|
||
help='标准光源 (默认: D65)')
|
||
parser.add_argument('-b', '--observer', default='2°',
|
||
choices=['2°', '10°'],
|
||
help='观察者 (默认: 2°)')
|
||
parser.add_argument('-w', '--wavelength_col',
|
||
help='CSV文件的波长列名 (默认: wavelength)')
|
||
parser.add_argument('--no_normalize_xyz', action='store_true',
|
||
help='不归一化XYZ值')
|
||
parser.add_argument('--use_numpy', action='store_true',
|
||
help='使用NumPy实现而不是colour库')
|
||
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
print("=" * 60)
|
||
print("光谱到颜色空间转换工具")
|
||
print("=" * 60)
|
||
print(f"输入文件: {args.input}")
|
||
print(f"颜色空间: {args.color_space}")
|
||
print(f"光源: {args.illuminant}")
|
||
print(f"观察者: {args.observer}")
|
||
print(f"输出文件: {args.output}")
|
||
print(f"输出格式: {args.format}")
|
||
print(f"使用colour库: {not args.use_numpy}")
|
||
print()
|
||
|
||
# 创建转换器
|
||
converter = SpectralToColorConverter(
|
||
illuminant=args.illuminant,
|
||
observer=args.observer,
|
||
color_space=args.color_space,
|
||
normalize_xyz=not args.no_normalize_xyz,
|
||
use_parallel=True
|
||
)
|
||
|
||
# 处理文件
|
||
color_data = converter.process_file(
|
||
args.input,
|
||
args.output,
|
||
args.format,
|
||
input_type=None,
|
||
wavelength_col=args.wavelength_col,
|
||
use_colour_library=not args.use_numpy
|
||
)
|
||
|
||
print("\n" + "=" * 60)
|
||
print("转换完成!")
|
||
print(f"数据形状: {color_data.shape}")
|
||
print(f"数据范围: [{color_data.min():.3f}, {color_data.max():.3f}]")
|
||
print("=" * 60)
|
||
|
||
except Exception as e:
|
||
print(f"✗ 处理失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|