1728 lines
64 KiB
Python
1728 lines
64 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Delta E* Color Difference Calculation Tool - Complete Refactoring
|
||
|
||
Features:
|
||
- Calculate color differences (Delta E*) in CIE LAB color space
|
||
- Support multiple color difference calculation methods: CIE76, CIE94, CIEDE2000
|
||
- Input: LAB three-band image + standard color CSV file
|
||
- Output: Color difference results as multi-band images or CSV files
|
||
- Support arbitrary range selection for standard colors and target colors
|
||
- Optional heatmap visualization output
|
||
|
||
依赖库:
|
||
- numpy, pandas: 数据处理
|
||
- colour-science: Color difference calculation
|
||
- matplotlib: 可视化
|
||
- spectral: ENVI文件处理
|
||
- tqdm: 进度条(可选)
|
||
|
||
安装依赖:
|
||
pip install numpy pandas colour-science matplotlib spectral tqdm
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import os
|
||
import sys
|
||
import warnings
|
||
import argparse
|
||
from pathlib import Path
|
||
from typing import Optional, Tuple, List, Union, Dict, Any
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
|
||
@dataclass
|
||
class ColorDifferenceConfig:
|
||
"""色差计算配置类"""
|
||
lab_image_path: Optional[str] = None
|
||
standard_colors_path: Optional[str] = None
|
||
method: str = 'CIEDE2000'
|
||
output_dir: Optional[str] = None
|
||
|
||
def __post_init__(self):
|
||
"""参数校验和默认值设置"""
|
||
if not self.lab_image_path:
|
||
raise ValueError("必须指定LAB图像文件路径(lab_image_path)")
|
||
|
||
if not self.standard_colors_path:
|
||
raise ValueError("必须指定标准颜色文件路径(standard_colors_path)")
|
||
|
||
if self.method not in ['CIE76', 'CIE94', 'CIEDE2000']:
|
||
raise ValueError(f"不支持的计算方法: {self.method}")
|
||
|
||
if not self.output_dir:
|
||
self.output_dir = './results'
|
||
|
||
|
||
# 尝试导入可选依赖
|
||
try:
|
||
from colour import delta_E
|
||
COLOUR_AVAILABLE = True
|
||
except ImportError:
|
||
COLOUR_AVAILABLE = False
|
||
print("警告: colour-science库不可用,请安装: pip install colour-science")
|
||
|
||
try:
|
||
import spectral
|
||
SPECTRAL_AVAILABLE = True
|
||
except ImportError:
|
||
SPECTRAL_AVAILABLE = False
|
||
print("警告: spectral库不可用,将无法处理ENVI格式文件")
|
||
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.colors as mcolors
|
||
from matplotlib import cm
|
||
MATPLOTLIB_AVAILABLE = True
|
||
except ImportError:
|
||
MATPLOTLIB_AVAILABLE = False
|
||
print("Warning: matplotlib not available, heatmap generation will be disabled")
|
||
|
||
try:
|
||
from tqdm import tqdm
|
||
TQDM_AVAILABLE = True
|
||
except ImportError:
|
||
TQDM_AVAILABLE = False
|
||
print("提示: tqdm库不可用,进度条将不可用。安装: pip install tqdm")
|
||
|
||
|
||
class DeltaEMethod(Enum):
|
||
"""Color difference calculation method enumeration"""
|
||
CIE76 = "CIE76"
|
||
CIE94 = "CIE94"
|
||
CIEDE2000 = "CIEDE2000"
|
||
|
||
@staticmethod
|
||
def to_colour_method(method: str) -> str:
|
||
"""将内部方法标识符映射到colour库的标准方法字符串"""
|
||
mapping = {
|
||
"CIE76": "CIE 1976",
|
||
"CIE94": "CIE 1994",
|
||
"CIEDE2000": "CIE 2000"
|
||
}
|
||
return mapping.get(method, "CIE 2000")
|
||
|
||
|
||
@dataclass
|
||
class ColorData:
|
||
"""颜色数据容器"""
|
||
L: np.ndarray
|
||
a: np.ndarray
|
||
b: np.ndarray
|
||
categories: Optional[List[str]] = None
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
|
||
@property
|
||
def lab_array(self) -> np.ndarray:
|
||
"""返回LAB数组 (..., 3)"""
|
||
return np.stack([self.L, self.a, self.b], axis=-1)
|
||
|
||
@property
|
||
def shape(self) -> tuple:
|
||
"""返回数据形状"""
|
||
return self.L.shape
|
||
|
||
@classmethod
|
||
def from_array(cls, array: np.ndarray, categories: Optional[List[str]] = None) -> 'ColorData':
|
||
"""从数组创建ColorData对象"""
|
||
if array.shape[-1] != 3:
|
||
raise ValueError(f"数组的最后一个维度必须为3 (LAB), 但得到 {array.shape}")
|
||
|
||
if array.ndim == 1:
|
||
array = array.reshape(1, 3)
|
||
|
||
return cls(
|
||
L=array[..., 0],
|
||
a=array[..., 1],
|
||
b=array[..., 2],
|
||
categories=categories
|
||
)
|
||
|
||
|
||
class LabDataValidator:
|
||
"""LAB数据验证器"""
|
||
|
||
@staticmethod
|
||
def validate_lab_range(lab_data: Union[np.ndarray, ColorData],
|
||
data_name: str = "LAB数据",
|
||
warn_only: bool = True) -> bool:
|
||
"""
|
||
验证LAB数据是否在合理范围内
|
||
|
||
参数:
|
||
lab_data: LAB数据数组或ColorData对象
|
||
data_name: 数据名称(用于错误信息)
|
||
warn_only: 是否仅警告而不抛出异常
|
||
|
||
返回:
|
||
bool: 数据是否有效
|
||
"""
|
||
if isinstance(lab_data, ColorData):
|
||
L, a, b = lab_data.L, lab_data.a, lab_data.b
|
||
else:
|
||
L, a, b = lab_data[..., 0], lab_data[..., 1], lab_data[..., 2]
|
||
|
||
L_min, L_max = np.nanmin(L), np.nanmax(L)
|
||
a_min, a_max = np.nanmin(a), np.nanmax(a)
|
||
b_min, b_max = np.nanmin(b), np.nanmax(b)
|
||
|
||
issues = []
|
||
|
||
# 检查L*范围(理论[0, 100],允许微小误差)
|
||
if L_max > 100.5 or L_min < -0.5:
|
||
issues.append(f"L*值超出理论范围[0, 100],实际为[{L_min:.2f}, {L_max:.2f}]")
|
||
|
||
# 检查a*, b*范围(常见[-128, 127])
|
||
if abs(a_max) > 150 or abs(a_min) > 150:
|
||
issues.append(f"a*值超出常见范围[-128, 127],实际为[{a_min:.2f}, {a_max:.2f}]")
|
||
|
||
if abs(b_max) > 150 or abs(b_min) > 150:
|
||
issues.append(f"b*值超出常见范围[-128, 127],实际为[{b_min:.2f}, {b_max:.2f}]")
|
||
|
||
# 检查NaN值
|
||
nan_count = np.sum(np.isnan(L)) + np.sum(np.isnan(a)) + np.sum(np.isnan(b))
|
||
if nan_count > 0:
|
||
issues.append(f"发现{nan_count}个NaN值")
|
||
|
||
if issues:
|
||
message = f"{data_name}验证问题: " + "; ".join(issues)
|
||
if warn_only:
|
||
warnings.warn(message)
|
||
return True # 即使有问题也继续处理
|
||
else:
|
||
raise ValueError(message)
|
||
|
||
return True
|
||
|
||
@staticmethod
|
||
def validate_data_shape(data1: np.ndarray, data2: np.ndarray, operation: str) -> bool:
|
||
"""验证两个数据数组的形状是否兼容"""
|
||
# For color difference calculation, the last dimension must be 3
|
||
if data1.shape[-1] != 3 or data2.shape[-1] != 3:
|
||
raise ValueError(f"{operation}: 两个数组的最后一个维度都必须为3 (LAB)")
|
||
|
||
# 检查广播兼容性
|
||
try:
|
||
np.broadcast_shapes(data1.shape[:-1], data2.shape[:-1])
|
||
return True
|
||
except ValueError as e:
|
||
raise ValueError(f"{operation}: 数组形状不兼容 - {str(e)}")
|
||
|
||
|
||
class ColorStandardManager:
|
||
"""Standard color manager"""
|
||
|
||
def __init__(self):
|
||
self.standards = pd.DataFrame()
|
||
self._color_cache = {}
|
||
|
||
def load_from_csv(self, csv_path: Union[str, Path]) -> pd.DataFrame:
|
||
"""
|
||
Load standard colors from CSV file
|
||
|
||
Parameters:
|
||
csv_path: CSV file path
|
||
|
||
Returns:
|
||
Standard colors DataFrame
|
||
"""
|
||
print(f"Loading standard colors file: {csv_path}")
|
||
|
||
try:
|
||
df = pd.read_csv(csv_path)
|
||
except Exception as e:
|
||
raise IOError(f"无法读取CSV文件 {csv_path}: {str(e)}")
|
||
|
||
# 验证数据列
|
||
if len(df.columns) < 4:
|
||
raise ValueError(f"CSV文件必须至少有4列(类别, L, a, b),但只有 {len(df.columns)} 列")
|
||
|
||
# 识别列
|
||
if 'category' not in df.columns:
|
||
# 如果没有category列,设置默认列名
|
||
df.columns = ['category', 'L', 'a', 'b'] + list(df.columns[4:])
|
||
|
||
# 处理可能的列名变体
|
||
column_mapping = {
|
||
'l': 'L',
|
||
'l*': 'L',
|
||
'L*': 'L',
|
||
'a*': 'a',
|
||
'b*': 'b'
|
||
}
|
||
|
||
for old_col, new_col in column_mapping.items():
|
||
if old_col in df.columns and new_col not in df.columns:
|
||
df.rename(columns={old_col: new_col}, inplace=True)
|
||
print(f"将列 '{old_col}' 重命名为 '{new_col}'")
|
||
|
||
# 确保必要的列存在
|
||
required_cols = ['category', 'L', 'a', 'b']
|
||
for col in required_cols:
|
||
if col not in df.columns:
|
||
raise ValueError(f"CSV文件缺少必需的列: {col}. 可用列: {list(df.columns)}")
|
||
|
||
# 验证数据类型
|
||
for col in ['L', 'a', 'b']:
|
||
if not pd.api.types.is_numeric_dtype(df[col]):
|
||
try:
|
||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||
except:
|
||
raise ValueError(f"列 {col} 包含非数值数据")
|
||
|
||
# 删除NaN值
|
||
initial_count = len(df)
|
||
df = df.dropna(subset=['L', 'a', 'b'])
|
||
if len(df) < initial_count:
|
||
print(f"警告: 删除了 {initial_count - len(df)} 行包含NaN值的记录")
|
||
|
||
self.standards = df.reset_index(drop=True)
|
||
|
||
print(f"Successfully loaded {len(self.standards)} standard colors")
|
||
print(f"Standard color categories: {', '.join(df['category'].unique().astype(str))}")
|
||
|
||
# Preview data
|
||
if len(self.standards) > 0:
|
||
print("Standard colors preview:")
|
||
print(self.standards.head().to_string())
|
||
|
||
return self.standards
|
||
|
||
def get_standard_colors(self, indices: Optional[List[int]] = None) -> ColorData:
|
||
"""
|
||
获取指定索引的标准色
|
||
|
||
参数:
|
||
indices: 标准色索引列表,None表示全部
|
||
|
||
返回:
|
||
ColorData对象
|
||
"""
|
||
if self.standards.empty:
|
||
raise ValueError("尚未加载标准色数据")
|
||
|
||
if indices is None:
|
||
df = self.standards
|
||
else:
|
||
df = self.standards.iloc[indices]
|
||
|
||
categories = df['category'].tolist()
|
||
lab_array = df[['L', 'a', 'b']].values
|
||
|
||
return ColorData.from_array(lab_array, categories)
|
||
|
||
def get_standard_by_category(self, categories: List[str]) -> ColorData:
|
||
"""
|
||
按类别获取标准色
|
||
|
||
参数:
|
||
categories: 类别名称列表
|
||
|
||
返回:
|
||
ColorData对象
|
||
"""
|
||
if self.standards.empty:
|
||
raise ValueError("尚未加载标准色数据")
|
||
|
||
mask = self.standards['category'].isin(categories)
|
||
df = self.standards[mask]
|
||
|
||
if len(df) == 0:
|
||
raise ValueError(f"未找到类别为 {categories} 的标准色")
|
||
|
||
categories_list = df['category'].tolist()
|
||
lab_array = df[['L', 'a', 'b']].values
|
||
|
||
return ColorData.from_array(lab_array, categories_list)
|
||
|
||
def save_to_csv(self, output_path: Union[str, Path]) -> None:
|
||
"""保存标准色到CSV文件"""
|
||
if self.standards.empty:
|
||
print("警告: 没有标准色数据可保存")
|
||
return
|
||
|
||
self.standards.to_csv(output_path, index=False)
|
||
print(f"标准色已保存到: {output_path}")
|
||
|
||
|
||
class DeltaECalculator:
|
||
"""
|
||
Delta E* Color Difference Calculator
|
||
|
||
Supports calculating color differences in CIE LAB color space
|
||
"""
|
||
|
||
def __init__(self, method: Union[str, DeltaEMethod] = DeltaEMethod.CIEDE2000):
|
||
"""
|
||
Initialize color difference calculator
|
||
|
||
Parameters:
|
||
method: Color difference calculation method ('CIE76', 'CIE94', 'CIEDE2000')
|
||
"""
|
||
if not COLOUR_AVAILABLE:
|
||
raise ImportError("需要安装colour-science库: pip install colour-science")
|
||
|
||
# 处理方法参数
|
||
if isinstance(method, DeltaEMethod):
|
||
self.method = method
|
||
else:
|
||
try:
|
||
self.method = DeltaEMethod(method.upper())
|
||
except ValueError:
|
||
raise ValueError(f"不支持的色差计算方法: {method}。可选: {[e.value for e in DeltaEMethod]}")
|
||
|
||
print(f"初始化Delta E计算器 - 方法: {self.method.value}")
|
||
|
||
def calculate(self, lab1: Union[np.ndarray, ColorData],
|
||
lab2: Union[np.ndarray, ColorData]) -> np.ndarray:
|
||
"""
|
||
计算两个LAB颜色或颜色集之间的Delta E
|
||
|
||
参数:
|
||
lab1: 第一个LAB颜色或颜色集
|
||
lab2: 第二个LAB颜色或颜色集
|
||
|
||
返回:
|
||
delta_e: Delta E值或值数组
|
||
"""
|
||
# 提取数组
|
||
if isinstance(lab1, ColorData):
|
||
arr1 = lab1.lab_array
|
||
else:
|
||
arr1 = lab1
|
||
|
||
if isinstance(lab2, ColorData):
|
||
arr2 = lab2.lab_array
|
||
else:
|
||
arr2 = lab2
|
||
|
||
# 验证数据
|
||
LabDataValidator.validate_data_shape(arr1, arr2, "色差计算")
|
||
|
||
# 转换为colour库的方法名称
|
||
colour_method = DeltaEMethod.to_colour_method(self.method.value)
|
||
|
||
try:
|
||
# 使用colour库的统一API
|
||
result = delta_E(arr1, arr2, method=colour_method)
|
||
return result
|
||
except Exception as e:
|
||
raise RuntimeError(f"色差计算失败(方法: {self.method.value}): {str(e)}")
|
||
|
||
def calculate_image_vs_standards(self, image_data: ColorData,
|
||
standards: ColorData,
|
||
use_progress_bar: bool = True) -> np.ndarray:
|
||
"""
|
||
计算图像中每个像素与每个标准色的Delta E
|
||
|
||
参数:
|
||
image_data: 图像颜色数据
|
||
standards: 标准色数据
|
||
use_progress_bar: 是否显示进度条
|
||
|
||
返回:
|
||
delta_e_image: Delta E图像 (height, width, n_standards)
|
||
"""
|
||
print(f"开始计算图像色差...")
|
||
print(f" 图像大小: {image_data.shape}")
|
||
print(f" 标准色数量: {len(standards.categories) if standards.categories else standards.shape[0]}")
|
||
|
||
# 重塑图像为 (pixels, 3)
|
||
height, width = image_data.shape[:2]
|
||
lab_pixels = image_data.lab_array.reshape(-1, 3)
|
||
n_pixels = lab_pixels.shape[0]
|
||
|
||
# 获取标准色数组
|
||
std_array = standards.lab_array
|
||
n_standards = std_array.shape[0]
|
||
|
||
# 准备结果数组
|
||
delta_e_result = np.zeros((n_pixels, n_standards), dtype=np.float32)
|
||
|
||
# 计算进度条
|
||
if use_progress_bar and TQDM_AVAILABLE:
|
||
iterator = tqdm(range(n_standards), desc="计算标准色", unit="色")
|
||
else:
|
||
iterator = range(n_standards)
|
||
print(f"处理 0/{n_standards} 个标准色", end="")
|
||
|
||
# 对每个标准色计算Delta E
|
||
for i in iterator:
|
||
# 广播计算所有像素与这个标准色的Delta E
|
||
delta_e_result[:, i] = self.calculate(lab_pixels, std_array[i])
|
||
|
||
if not TQDM_AVAILABLE and not use_progress_bar:
|
||
if (i + 1) % max(1, n_standards // 10) == 0 or i + 1 == n_standards:
|
||
print(f"\r处理 {i + 1}/{n_standards} 个标准色", end="")
|
||
|
||
if not TQDM_AVAILABLE and not use_progress_bar:
|
||
print() # 换行
|
||
|
||
# 重塑回图像格式 (height, width, n_standards)
|
||
delta_e_image = delta_e_result.reshape(height, width, n_standards)
|
||
|
||
# 输出统计信息
|
||
print(f"色差计算完成!")
|
||
print(f" ΔE范围: [{delta_e_result.min():.3f}, {delta_e_result.max():.3f}]")
|
||
print(f" ΔE均值: {delta_e_result.mean():.3f} ± {delta_e_result.std():.3f}")
|
||
|
||
# 添加色差解读
|
||
self._interpret_delta_e(delta_e_result.mean())
|
||
|
||
return delta_e_image
|
||
|
||
def calculate_pairwise(self, colors1: ColorData,
|
||
colors2: Optional[ColorData] = None) -> pd.DataFrame:
|
||
"""
|
||
计算两组颜色之间的两两Delta E
|
||
|
||
参数:
|
||
colors1: 第一组颜色数据
|
||
colors2: 第二组颜色数据,如果为None则计算colors1内部的两两比较
|
||
|
||
返回:
|
||
包含Delta E结果的DataFrame
|
||
"""
|
||
if colors2 is None:
|
||
colors2 = colors1
|
||
print(f"计算内部两两Delta E - 颜色数量: {colors1.shape[0]}")
|
||
else:
|
||
print(f"计算两组颜色间Delta E - 组1: {colors1.shape[0]}, 组2: {colors2.shape[0]}")
|
||
|
||
results = []
|
||
|
||
# 获取类别信息
|
||
cats1 = colors1.categories if colors1.categories else [f"Color_{i}" for i in range(colors1.shape[0])]
|
||
cats2 = colors2.categories if colors2.categories else [f"Color_{i}" for i in range(colors2.shape[0])]
|
||
|
||
# 获取LAB数组
|
||
lab1 = colors1.lab_array
|
||
lab2 = colors2.lab_array
|
||
|
||
# 计算所有组合
|
||
for i in range(lab1.shape[0]):
|
||
for j in range(lab2.shape[0]):
|
||
# 如果是内部比较且i>=j,跳过以避免重复(对角线保留)
|
||
if colors2 is colors1 and i > j:
|
||
continue
|
||
|
||
delta_e_value = self.calculate(lab1[i], lab2[j])
|
||
|
||
results.append({
|
||
'reference_index': i,
|
||
'reference_category': cats1[i],
|
||
'reference_L': lab1[i, 0],
|
||
'reference_a': lab1[i, 1],
|
||
'reference_b': lab1[i, 2],
|
||
'target_index': j,
|
||
'target_category': cats2[j],
|
||
'target_L': lab2[j, 0],
|
||
'target_a': lab2[j, 1],
|
||
'target_b': lab2[j, 2],
|
||
'delta_E': delta_e_value,
|
||
'method': self.method.value
|
||
})
|
||
|
||
result_df = pd.DataFrame(results)
|
||
print(f"两两Delta E计算完成,共 {len(result_df)} 对颜色组合")
|
||
|
||
return result_df
|
||
|
||
def _interpret_delta_e(self, mean_delta_e: float) -> None:
|
||
"""根据ΔE平均值提供解读"""
|
||
print(f" ΔE结果解读 ({self.method.value}):")
|
||
|
||
if self.method == DeltaEMethod.CIEDE2000:
|
||
if mean_delta_e < 0.5:
|
||
print(" ΔE<0.5 - 差异极微小,通常不可感知")
|
||
elif mean_delta_e < 1.0:
|
||
print(" ΔE<1.0 - 微小差异,经验丰富的观察者可能感知")
|
||
elif mean_delta_e < 2.0:
|
||
print(" ΔE<2.0 - 可感知的微小差异")
|
||
elif mean_delta_e < 3.5:
|
||
print(" ΔE<3.5 - 中等差异,明显可感知")
|
||
elif mean_delta_e < 5.0:
|
||
print(" ΔE<5.0 - 较大差异")
|
||
else:
|
||
print(" ΔE≥5.0 - 非常大差异,可能被认为是不同颜色")
|
||
else:
|
||
# 对于CIE76和CIE94的通用解读
|
||
if mean_delta_e < 1.0:
|
||
print(" ΔE<1.0 - 差异通常不可感知")
|
||
elif mean_delta_e < 2.3:
|
||
print(" ΔE<2.3 - 微小差异(近距离观察可感知)")
|
||
elif mean_delta_e < 4.0:
|
||
print(" ΔE<4.0 - 可感知的差异")
|
||
elif mean_delta_e < 6.0:
|
||
print(" ΔE<6.0 - 明显差异")
|
||
else:
|
||
print(" ΔE≥6.0 - 非常大差异")
|
||
|
||
|
||
class EnviFileHandler:
|
||
"""ENVI文件处理器"""
|
||
|
||
def __init__(self):
|
||
if not SPECTRAL_AVAILABLE:
|
||
raise ImportError("需要安装spectral库才能处理ENVI格式文件")
|
||
|
||
def load_image(self, image_path: Union[str, Path]) -> Tuple[ColorData, Dict[str, Any]]:
|
||
"""
|
||
使用spectral库加载ENVI格式的高光谱/LAB图像
|
||
|
||
参数:
|
||
image_path: 图像文件路径
|
||
支持格式: .hdr (头文件), .dat/.bip/.bsq/.bil (数据文件)
|
||
spectral库会自动找到对应的文件
|
||
|
||
返回:
|
||
image_data: ColorData对象,包含L, a, b数据
|
||
metadata: ENVI文件的元数据字典
|
||
"""
|
||
image_path = Path(image_path)
|
||
print(f"正在加载ENVI图像: {image_path}")
|
||
|
||
try:
|
||
# 使用spectral库打开图像,自动处理各种格式
|
||
# spectral.open_image 可以接受:
|
||
# - 数据文件路径 (.dat, .bip, .bsq, .bil)
|
||
# - HDR文件路径 (.hdr)
|
||
# - 甚至basename(会自动查找对应的文件)
|
||
img = spectral.open_image(str(image_path))
|
||
|
||
print(f"图像形状: {img.shape}")
|
||
print(f"数据类型: {img.dtype}")
|
||
print(f"波段数: {img.shape[2] if len(img.shape) > 2 else 1}")
|
||
|
||
except Exception as e:
|
||
# 提供详细的错误信息和建议
|
||
error_msg = f"无法打开ENVI图像 {image_path}: {str(e)}"
|
||
|
||
# 检查文件是否存在
|
||
if not image_path.exists():
|
||
error_msg += f"\n文件不存在: {image_path}"
|
||
|
||
# 建议可能的替代文件
|
||
parent_dir = image_path.parent
|
||
if parent_dir.exists():
|
||
# 查找可能的ENVI文件
|
||
envi_files = []
|
||
for ext in ['*.hdr', '*.dat', '*.bip', '*.bsq', '*.bil']:
|
||
envi_files.extend(list(parent_dir.glob(ext)))
|
||
|
||
if envi_files:
|
||
error_msg += f"\n\n目录中的ENVI文件:"
|
||
for f in envi_files[:10]: # 最多显示10个
|
||
error_msg += f"\n {f.name}"
|
||
else:
|
||
error_msg += f"\n\n目录中没有找到ENVI文件"
|
||
|
||
error_msg += f"\n\n使用建议:"
|
||
error_msg += f"\n1. 确保文件路径正确"
|
||
error_msg += f"\n2. 尝试使用数据文件 (.bip, .dat) 而不是头文件 (.hdr)"
|
||
error_msg += f"\n3. 或使用不带扩展名的basename,让spectral自动查找"
|
||
|
||
raise IOError(error_msg)
|
||
|
||
# 检查波段数
|
||
if img.shape[2] != 3:
|
||
raise ValueError(f"LAB图像必须有3个波段,当前波段数: {img.shape[2]}")
|
||
|
||
# 加载数据
|
||
image_array = img.load().astype(np.float32)
|
||
|
||
print(f"图像形状: {image_array.shape}")
|
||
print(f"数据类型: {image_array.dtype}")
|
||
|
||
# 创建ColorData对象
|
||
image_data = ColorData(
|
||
L=image_array[..., 0],
|
||
a=image_array[..., 1],
|
||
b=image_array[..., 2]
|
||
)
|
||
|
||
# 检查数据范围
|
||
self._print_lab_range(image_data, "图像")
|
||
|
||
# 提取元数据
|
||
metadata = {}
|
||
if hasattr(img, 'metadata') and img.metadata:
|
||
metadata = dict(img.metadata)
|
||
|
||
return image_data, metadata
|
||
|
||
def save_color_data(self, color_data: np.ndarray, output_path: Union[str, Path], output_format: str = 'csv',
|
||
input_hdr_path: Optional[Union[str, Path]] = None, method: str = 'DeltaE'):
|
||
"""
|
||
保存颜色数据到文件
|
||
|
||
参数:
|
||
color_data: 颜色数据数组
|
||
output_path: 输出文件路径
|
||
output_format: 输出格式 ('csv' 或 'dat')
|
||
input_hdr_path: 输入ENVI文件的HDR路径,用于复制元数据
|
||
method: 方法名称,用于波段命名
|
||
"""
|
||
output_path = Path(output_path)
|
||
|
||
print(f"保存{method}数据到: {output_path}, 格式: {output_format}")
|
||
|
||
# 根据颜色空间确定列名(这里是Delta E,所以使用通用命名)
|
||
if len(color_data.shape) >= 3 and color_data.shape[-1] <= 10:
|
||
# 如果是多波段数据,假设是Delta E结果
|
||
n_bands = color_data.shape[-1] if len(color_data.shape) == 3 else 1
|
||
column_names = [f'DeltaE_{i+1}' for i in range(n_bands)]
|
||
else:
|
||
column_names = [f'Band_{i+1}' for i in range(color_data.shape[-1])]
|
||
|
||
if output_format == 'csv':
|
||
# 保存为CSV格式
|
||
if len(color_data.shape) == 3: # 图像数据 (height, width, channels)
|
||
# 为图像数据创建CSV,每行是一个像素的颜色值
|
||
color_flat = color_data.reshape(-1, color_data.shape[-1])
|
||
df = pd.DataFrame(color_flat, columns=column_names)
|
||
df.to_csv(output_path, index=False)
|
||
print(f"保存为CSV: {color_flat.shape[0]} 行数据")
|
||
else: # 一维数据 (samples, channels)
|
||
df = pd.DataFrame(color_data, columns=column_names)
|
||
df.to_csv(output_path, index=False)
|
||
print(f"保存为CSV: {color_data.shape[0]} 行数据")
|
||
|
||
elif output_format == 'dat':
|
||
# 保存为ENVI格式的.dat文件,参考classification.py的方式
|
||
try:
|
||
from osgeo import gdal
|
||
GDAL_AVAILABLE = True
|
||
except ImportError:
|
||
GDAL_AVAILABLE = False
|
||
|
||
# 为ENVI格式创建必要的元数据
|
||
if len(color_data.shape) == 3: # 图像数据 (height, width, channels)
|
||
height, width, channels = color_data.shape
|
||
|
||
# 确保目录存在
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"保存ENVI文件 - {method}形状: {color_data.shape}")
|
||
|
||
if GDAL_AVAILABLE and channels <= 4: # 支持最多4个通道
|
||
# 使用GDAL保存,参考classification.py的方式
|
||
driver = gdal.GetDriverByName('ENVI')
|
||
|
||
# 创建多波段ENVI文件
|
||
out_dataset = driver.Create(
|
||
str(output_path), width, height, channels,
|
||
gdal.GDT_Float32, options=['INTERLEAVE=BIL']
|
||
)
|
||
|
||
if out_dataset is None:
|
||
raise ValueError(f"无法创建输出文件: {output_path}")
|
||
|
||
# 写入每个波段的数据
|
||
for band_idx in range(channels):
|
||
band_data = color_data[:, :, band_idx] # (height, width)
|
||
out_band = out_dataset.GetRasterBand(band_idx + 1)
|
||
out_band.WriteArray(band_data.astype(np.float32))
|
||
|
||
# 设置波段名称
|
||
out_band.SetDescription(column_names[band_idx])
|
||
|
||
# 设置地理信息 (如果有的话,这里没有地理信息)
|
||
# out_dataset.SetGeoTransform(geotransform)
|
||
# out_dataset.SetProjection(projection)
|
||
|
||
out_dataset.FlushCache()
|
||
out_dataset = None
|
||
|
||
# 创建HDR文件,保持输入文件的元数据
|
||
self._create_envi_hdr_file(output_path, height, width, channels, 'float32', input_hdr_path, method)
|
||
|
||
print(f"使用GDAL保存为ENVI文件: {output_path}")
|
||
print(f"数据类型: float32, 波段数: {channels}, 大小: {width}x{height}")
|
||
|
||
else:
|
||
# GDAL不可用时,直接保存为二进制文件
|
||
print("GDAL不可用,使用numpy直接保存二进制数据")
|
||
|
||
# 确保目录存在
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 重新排列数据为BIL格式 (channels, height, width) 并保存
|
||
color_bil = np.transpose(color_data, (2, 0, 1)) # (channels, height, width)
|
||
color_bil.astype(np.float32).tofile(str(output_path))
|
||
|
||
# 创建HDR文件
|
||
self._create_envi_hdr_file(output_path, height, width, channels, 'float32', input_hdr_path, method)
|
||
|
||
print(f"使用numpy保存为ENVI格式二进制文件: {output_path}")
|
||
print(f"数据类型: float32, 波段数: {channels}, 大小: {width}x{height}")
|
||
|
||
else: # 一维数据 (samples, channels)
|
||
# 一维数据保存为CSV
|
||
csv_path = output_path.with_suffix('.csv')
|
||
np.savetxt(csv_path, color_data, delimiter=',',
|
||
header=','.join(column_names), comments='')
|
||
print(f"一维数据保存为CSV: {csv_path}")
|
||
|
||
else:
|
||
raise ValueError(f"不支持的输出格式: {output_format}")
|
||
|
||
print(f"结果已保存到: {output_path}")
|
||
|
||
def _create_envi_hdr_file(self, bil_path: Union[str, Path], height: int, width: int,
|
||
bands: int, data_type: str, input_hdr_path: Optional[Union[str, Path]] = None,
|
||
method: str = 'DeltaE') -> None:
|
||
"""
|
||
创建ENVI头文件,参考classification.py的方式,并保持输入文件的元数据
|
||
|
||
Args:
|
||
bil_path: BIL文件路径
|
||
height: 图像高度
|
||
width: 图像宽度
|
||
bands: 波段数
|
||
data_type: 数据类型 ('float32', 'uint8', 'int16', 等)
|
||
input_hdr_path: 输入ENVI文件的HDR路径,用于复制元数据
|
||
method: 方法名称,用于描述
|
||
"""
|
||
hdr_path = Path(bil_path).with_suffix('.hdr')
|
||
|
||
# 数据类型映射 (ENVI格式)
|
||
dtype_map = {
|
||
'uint8': '1',
|
||
'int16': '2',
|
||
'int32': '3',
|
||
'float32': '4',
|
||
'float64': '5',
|
||
'complex64': '6',
|
||
'complex128': '9',
|
||
'uint16': '12',
|
||
'uint32': '13',
|
||
'int64': '14',
|
||
'uint64': '15'
|
||
}
|
||
|
||
envi_dtype = dtype_map.get(data_type, '4') # 默认为float32
|
||
|
||
# 为Delta E数据确定波段名称
|
||
band_names = [f"DeltaE_Band_{i+1}" for i in range(bands)]
|
||
|
||
with open(hdr_path, 'w') as f:
|
||
f.write("ENVI\n")
|
||
f.write("description = {\n")
|
||
f.write(f" Delta E Color Difference Data - Generated by DeltaECalculator\n")
|
||
f.write(f" Method: {method}\n")
|
||
f.write("}\n")
|
||
f.write(f"samples = {width}\n")
|
||
f.write(f"lines = {height}\n")
|
||
f.write(f"bands = {bands}\n")
|
||
f.write("header offset = 0\n")
|
||
f.write("file type = ENVI Standard\n")
|
||
f.write(f"data type = {envi_dtype}\n")
|
||
f.write("interleave = bil\n")
|
||
f.write("sensor type = Unknown\n")
|
||
f.write("byte order = 0\n")
|
||
|
||
# 波段名称
|
||
f.write("band names = {\n")
|
||
for i, name in enumerate(band_names):
|
||
f.write(f' "{name}"')
|
||
if i < len(band_names) - 1:
|
||
f.write(",")
|
||
f.write("\n")
|
||
f.write("}\n")
|
||
|
||
# 如果有输入HDR文件,尝试复制相关的元数据
|
||
if input_hdr_path and Path(input_hdr_path).exists():
|
||
try:
|
||
self._copy_hdr_metadata(input_hdr_path, f)
|
||
except Exception as e:
|
||
print(f"复制HDR元数据失败: {e}")
|
||
|
||
print(f"ENVI头文件创建完成: {hdr_path}")
|
||
|
||
def _copy_hdr_metadata(self, input_hdr_path: Union[str, Path], output_file) -> None:
|
||
"""
|
||
从输入HDR文件复制元数据到输出HDR文件
|
||
|
||
Args:
|
||
input_hdr_path: 输入HDR文件路径
|
||
output_file: 输出文件对象
|
||
"""
|
||
try:
|
||
with open(input_hdr_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
|
||
# 解析输入HDR文件,提取有用的元数据
|
||
lines = content.split('\n')
|
||
metadata_to_copy = []
|
||
|
||
# 需要复制的元数据字段
|
||
fields_to_copy = [
|
||
'wavelength units', 'wavelength', 'fwhm', 'bbl', 'map info',
|
||
'coordinate system string', 'projection info', 'pixel size',
|
||
'acquisition time', 'sensor type', 'radiance scale factor',
|
||
'reflectance scale factor', 'data gain values', 'data offset values'
|
||
]
|
||
|
||
i = 0
|
||
while i < len(lines):
|
||
line = lines[i].strip()
|
||
if '=' in line:
|
||
key = line.split('=')[0].strip().lower()
|
||
if any(field in key for field in fields_to_copy):
|
||
# 复制这个字段及其可能的后续行
|
||
metadata_to_copy.append(lines[i])
|
||
i += 1
|
||
# 如果是多行值,继续读取
|
||
while i < len(lines) and not ('=' in lines[i] and not lines[i].strip().endswith(',')):
|
||
if lines[i].strip():
|
||
metadata_to_copy.append(lines[i])
|
||
i += 1
|
||
if i >= len(lines):
|
||
break
|
||
continue
|
||
i += 1
|
||
|
||
# 将复制的元数据写入输出文件
|
||
if metadata_to_copy:
|
||
output_file.write("\n")
|
||
for line in metadata_to_copy:
|
||
if line.strip():
|
||
output_file.write(line + "\n")
|
||
|
||
print(f"已从输入HDR文件复制 {len(metadata_to_copy)} 行元数据")
|
||
|
||
except Exception as e:
|
||
print(f"读取输入HDR文件失败: {e}")
|
||
|
||
def save_delta_e_image(self, delta_e_image: np.ndarray,
|
||
output_path: Union[str, Path],
|
||
standard_categories: List[str],
|
||
method: str,
|
||
metadata: Optional[Dict[str, Any]] = None) -> None:
|
||
"""
|
||
保存Delta E结果为ENVI格式(兼容性方法)
|
||
|
||
参数:
|
||
delta_e_image: Delta E图像 (height, width, n_standards)
|
||
output_path: 输出路径
|
||
standard_categories: 标准色类别列表
|
||
method: 色差计算方法
|
||
metadata: 可选的元数据
|
||
"""
|
||
# 根据标准色类别确定列名
|
||
column_names = []
|
||
for i, cat in enumerate(standard_categories):
|
||
column_names.append(f'DeltaE_{cat}_{method}')
|
||
|
||
# 转换数据格式以匹配save_color_data的期望
|
||
# delta_e_image 是 (height, width, n_standards),需要转换为 (height, width, channels)
|
||
color_data = delta_e_image
|
||
|
||
# 调用统一的保存方法
|
||
self.save_color_data(color_data, output_path, 'dat', None, method)
|
||
|
||
def _print_lab_range(self, color_data: ColorData, data_name: str) -> None:
|
||
"""打印LAB值范围"""
|
||
L_min, L_max = color_data.L.min(), color_data.L.max()
|
||
a_min, a_max = color_data.a.min(), color_data.a.max()
|
||
b_min, b_max = color_data.b.min(), color_data.b.max()
|
||
|
||
print(f"{data_name}LAB值范围:")
|
||
print(f" L: [{L_min:.2f}, {L_max:.2f}]")
|
||
print(f" a: [{a_min:.2f}, {a_max:.2f}]")
|
||
print(f" b: [{b_min:.2f}, {b_max:.2f}]")
|
||
|
||
|
||
class DeltaEVisualizer:
|
||
"""Delta E可视化器"""
|
||
|
||
def __init__(self, colormap: str = 'RdYlBu_r'):
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
raise ImportError("需要安装matplotlib库才能生成可视化")
|
||
|
||
self.colormap = colormap
|
||
plt.style.use('default')
|
||
|
||
def create_heatmap(self, delta_e_image: np.ndarray,
|
||
standard_categories: List[str],
|
||
output_path: Union[str, Path],
|
||
title: str = "Delta E Heatmap",
|
||
dpi: int = 150) -> None:
|
||
"""
|
||
Create Delta E heatmap
|
||
|
||
Parameters:
|
||
delta_e_image: Delta E图像 (height, width, n_standards)
|
||
standard_categories: 标准色类别列表
|
||
output_path: 输出路径
|
||
title: 图表标题
|
||
dpi: 图像分辨率
|
||
"""
|
||
print(f"Generating heatmap: {output_path}")
|
||
|
||
n_standards = delta_e_image.shape[2]
|
||
|
||
# 计算子图布局
|
||
if n_standards <= 4:
|
||
nrows, ncols = 1, n_standards
|
||
figsize = (5 * ncols, 5)
|
||
elif n_standards <= 9:
|
||
nrows = 2
|
||
ncols = (n_standards + 1) // 2
|
||
figsize = (5 * ncols, 4 * nrows)
|
||
else:
|
||
nrows = int(np.ceil(np.sqrt(n_standards)))
|
||
ncols = int(np.ceil(n_standards / nrows))
|
||
figsize = (5 * ncols, 4 * nrows)
|
||
|
||
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
|
||
|
||
# 如果只有一个子图,确保axes是数组
|
||
if n_standards == 1:
|
||
axes = np.array([axes])
|
||
|
||
# 扁平化axes数组以便迭代
|
||
if axes.ndim > 1:
|
||
axes_flat = axes.flatten()
|
||
else:
|
||
axes_flat = axes
|
||
|
||
# 计算全局的最小值和最大值,用于统一颜色条范围
|
||
global_min = np.nanmin(delta_e_image)
|
||
global_max = np.nanmax(delta_e_image)
|
||
|
||
# 为每个标准色创建热图
|
||
for i in range(len(axes_flat)):
|
||
ax = axes_flat[i]
|
||
|
||
if i < n_standards:
|
||
# 提取当前标准色的Delta E数据
|
||
de_data = delta_e_image[:, :, i]
|
||
|
||
# 计算统计数据
|
||
mean_de = np.nanmean(de_data)
|
||
std_de = np.nanstd(de_data)
|
||
|
||
# 显示热图
|
||
im = ax.imshow(de_data,
|
||
cmap=self.colormap,
|
||
vmin=global_min,
|
||
vmax=global_max,
|
||
aspect='auto')
|
||
|
||
# 设置标题
|
||
if i < len(standard_categories):
|
||
category = standard_categories[i]
|
||
else:
|
||
category = f"Standard_{i}"
|
||
|
||
ax.set_title(f"{category}\nΔE={mean_de:.2f}±{std_de:.2f}", fontsize=10)
|
||
ax.axis('off')
|
||
|
||
# 添加颜色条
|
||
plt.colorbar(im, ax=ax, shrink=0.8, pad=0.02)
|
||
else:
|
||
# 隐藏多余的子图
|
||
ax.axis('off')
|
||
|
||
# Add main title
|
||
plt.suptitle(title, fontsize=14, y=1.02)
|
||
|
||
# Adjust layout
|
||
plt.tight_layout()
|
||
|
||
# Save image
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
plt.savefig(output_path, dpi=dpi, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"Heatmap saved: {output_path}")
|
||
|
||
def create_delta_e_histogram(self, delta_e_values: np.ndarray,
|
||
output_path: Union[str, Path],
|
||
title: str = "Delta E Distribution Histogram",
|
||
bins: int = 30) -> None:
|
||
"""
|
||
Create Delta E distribution histogram
|
||
|
||
Parameters:
|
||
delta_e_values: Delta E值数组
|
||
output_path: 输出路径
|
||
title: 图表标题
|
||
bins: 直方图箱数
|
||
"""
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
|
||
# 绘制直方图
|
||
n, bins, patches = ax.hist(delta_e_values.flatten(), bins=bins,
|
||
edgecolor='black', alpha=0.7)
|
||
|
||
# 添加统计信息
|
||
mean_val = np.mean(delta_e_values)
|
||
std_val = np.std(delta_e_values)
|
||
median_val = np.median(delta_e_values)
|
||
|
||
# Add vertical lines marking statistics
|
||
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.2f}')
|
||
ax.axvline(median_val, color='green', linestyle='--', linewidth=2, label=f'Median: {median_val:.2f}')
|
||
|
||
# Set title and labels
|
||
ax.set_title(title, fontsize=14)
|
||
ax.set_xlabel('Delta E', fontsize=12)
|
||
ax.set_ylabel('Frequency', fontsize=12)
|
||
ax.legend()
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
# Add statistical information text box
|
||
textstr = '\n'.join([
|
||
f'Data points: {delta_e_values.size:,}',
|
||
f'Minimum: {delta_e_values.min():.2f}',
|
||
f'Maximum: {delta_e_values.max():.2f}',
|
||
f'Mean: {mean_val:.2f}',
|
||
f'Standard deviation: {std_val:.2f}',
|
||
f'Median: {median_val:.2f}'
|
||
])
|
||
|
||
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
|
||
ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
|
||
verticalalignment='top', bbox=props)
|
||
|
||
# 保存图像
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"直方图已保存: {output_path}")
|
||
|
||
|
||
class DeltaEApp:
|
||
"""主应用程序类"""
|
||
|
||
def __init__(self):
|
||
self.calculator = None
|
||
self.standard_manager = ColorStandardManager()
|
||
self.envi_handler = None if not SPECTRAL_AVAILABLE else EnviFileHandler()
|
||
self.visualizer = None if not MATPLOTLIB_AVAILABLE else DeltaEVisualizer()
|
||
|
||
def run(self, args):
|
||
"""Run application"""
|
||
print("=" * 60)
|
||
print("Delta E* Color Difference Calculation Tool - Starting")
|
||
print("=" * 60)
|
||
|
||
# Initialize calculator
|
||
self.calculator = DeltaECalculator(method=args.method)
|
||
|
||
if args.mode == 'image':
|
||
self._run_image_mode(args)
|
||
elif args.mode == 'pairwise':
|
||
self._run_pairwise_mode(args)
|
||
else:
|
||
raise ValueError(f"Unknown mode: {args.mode}")
|
||
|
||
print("=" * 60)
|
||
print("Processing completed!")
|
||
print("=" * 60)
|
||
|
||
def _run_image_mode(self, args):
|
||
"""运行图像模式"""
|
||
# 根据输入文件类型加载数据
|
||
input_path = Path(args.input)
|
||
|
||
if input_path.suffix.lower() == '.csv':
|
||
# CSV 文件输入
|
||
print(f"从CSV文件加载LAB数据: {input_path}")
|
||
import pandas as pd
|
||
|
||
df = pd.read_csv(input_path)
|
||
print(f"CSV文件列名: {list(df.columns)}")
|
||
|
||
# 查找LAB列(支持多种命名方式)
|
||
lab_columns = {}
|
||
for col in df.columns:
|
||
col_lower = col.lower()
|
||
if col_lower in ['l', 'l*', 'l_star'] or col_lower.startswith('l '):
|
||
lab_columns['L'] = col
|
||
elif col_lower in ['a', 'a*', 'a_star'] or col_lower.startswith('a '):
|
||
lab_columns['a'] = col
|
||
elif col_lower in ['b', 'b*', 'b_star'] or col_lower.startswith('b '):
|
||
lab_columns['b'] = col
|
||
|
||
if len(lab_columns) < 3:
|
||
raise ValueError(f"CSV文件必须包含L、a、b列。找到的列: {list(lab_columns.keys())}")
|
||
|
||
# 提取LAB数据
|
||
L = df[lab_columns['L']].values
|
||
a = df[lab_columns['a']].values
|
||
b = df[lab_columns['b']].values
|
||
|
||
# 创建ColorData对象
|
||
image_data = ColorData(L=L, a=a, b=b)
|
||
metadata = {}
|
||
|
||
print(f"从CSV加载了 {len(L)} 个LAB样本")
|
||
|
||
else:
|
||
# ENVI 文件输入
|
||
if not SPECTRAL_AVAILABLE:
|
||
raise ImportError("ENVI文件输入需要spectral库,但未安装")
|
||
|
||
image_data, metadata = self.envi_handler.load_image(str(input_path))
|
||
|
||
# 加载标准色
|
||
standards_df = self.standard_manager.load_from_csv(args.standards)
|
||
|
||
# 验证数据
|
||
LabDataValidator.validate_lab_range(image_data, "输入图像", warn_only=True)
|
||
|
||
# 解析标准色范围
|
||
standard_indices = self._parse_range_param(args.standards_range, len(standards_df))
|
||
standard_colors = self.standard_manager.get_standard_colors(standard_indices)
|
||
|
||
# 计算Delta E
|
||
delta_e_image = self.calculator.calculate_image_vs_standards(
|
||
image_data, standard_colors, use_progress_bar=not args.no_progress
|
||
)
|
||
|
||
# 准备标准色类别列表
|
||
if standard_colors.categories:
|
||
standard_categories = standard_colors.categories
|
||
else:
|
||
standard_categories = [f"Standard_{i}" for i in range(len(standard_indices))]
|
||
|
||
# 保存结果
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 生成输出文件名
|
||
if hasattr(args, 'output_file') and args.output_file:
|
||
# 如果指定了输出文件名
|
||
if Path(args.output_file).is_absolute():
|
||
output_path = Path(args.output_file)
|
||
else:
|
||
output_path = output_dir / args.output_file
|
||
else:
|
||
# 使用默认文件名
|
||
input_name = Path(args.input).stem
|
||
output_filename = f"{input_name}_delta_e_{args.method.lower()}.dat"
|
||
output_path = output_dir / output_filename
|
||
|
||
# 保存ENVI图像
|
||
self.envi_handler.save_delta_e_image(
|
||
delta_e_image, output_path,
|
||
standard_categories, args.method,
|
||
metadata
|
||
)
|
||
|
||
# 保存CSV摘要
|
||
self._save_image_summary_csv(delta_e_image, standard_categories, output_path)
|
||
|
||
# 生成可视化
|
||
if args.create_heatmap and self.visualizer:
|
||
heatmap_path = output_path.with_suffix('.png')
|
||
self.visualizer.create_heatmap(
|
||
delta_e_image, standard_categories, heatmap_path,
|
||
title=f"Delta E ({args.method}) 热图"
|
||
)
|
||
|
||
if args.create_histogram and self.visualizer:
|
||
histogram_path = output_path.with_name(f"{output_path.stem}_histogram.png")
|
||
self.visualizer.create_delta_e_histogram(
|
||
delta_e_image, histogram_path,
|
||
title=f"Delta E ({args.method}) 分布"
|
||
)
|
||
|
||
def _run_pairwise_mode(self, args):
|
||
"""运行两两比较模式"""
|
||
# 加载标准色
|
||
standards_df = self.standard_manager.load_from_csv(args.standards)
|
||
|
||
# 解析范围参数
|
||
all_indices = list(range(len(standards_df)))
|
||
|
||
if args.reference_range and args.target_range:
|
||
ref_indices = self._parse_range_param(args.reference_range, len(standards_df))
|
||
tgt_indices = self._parse_range_param(args.target_range, len(standards_df))
|
||
|
||
ref_colors = self.standard_manager.get_standard_colors(ref_indices)
|
||
tgt_colors = self.standard_manager.get_standard_colors(tgt_indices)
|
||
|
||
# 计算两组之间的Delta E
|
||
result_df = self.calculator.calculate_pairwise(ref_colors, tgt_colors)
|
||
else:
|
||
# 计算所有标准色内部的两两Delta E
|
||
all_colors = self.standard_manager.get_standard_colors(all_indices)
|
||
result_df = self.calculator.calculate_pairwise(all_colors)
|
||
|
||
# 保存结果
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 生成输出文件名
|
||
if hasattr(args, 'output_file') and args.output_file:
|
||
# 如果指定了输出文件名
|
||
if Path(args.output_file).is_absolute():
|
||
output_path = Path(args.output_file)
|
||
else:
|
||
output_path = output_dir / args.output_file
|
||
else:
|
||
# 使用默认文件名
|
||
output_filename = f"delta_e_pairwise_{args.method.lower()}.csv"
|
||
output_path = output_dir / output_filename
|
||
|
||
result_df.to_csv(output_path, index=False)
|
||
print(f"两两比较结果已保存到: {output_path}")
|
||
|
||
# 生成矩阵热图(如果颜色数量合理)
|
||
if self.visualizer and len(standards_df) <= 20 and args.create_heatmap:
|
||
self._create_pairwise_heatmap(result_df, output_path)
|
||
|
||
def _parse_range_param(self, range_str: Optional[str], max_index: int) -> List[int]:
|
||
"""解析范围参数"""
|
||
if not range_str or range_str.lower() == 'all':
|
||
return list(range(max_index))
|
||
|
||
indices = []
|
||
parts = range_str.split(',')
|
||
|
||
for part in parts:
|
||
part = part.strip()
|
||
if '-' in part:
|
||
# 处理范围
|
||
try:
|
||
start, end = map(int, part.split('-'))
|
||
if start < 0 or end >= max_index or start > end:
|
||
raise ValueError(f"范围 {part} 超出有效范围 [0, {max_index-1}]")
|
||
indices.extend(range(start, end + 1))
|
||
except ValueError as e:
|
||
raise ValueError(f"无效的范围格式: {part}。错误: {str(e)}")
|
||
else:
|
||
# 处理单个索引
|
||
try:
|
||
idx = int(part)
|
||
if idx < 0 or idx >= max_index:
|
||
raise ValueError(f"索引 {idx} 超出有效范围 [0, {max_index-1}]")
|
||
indices.append(idx)
|
||
except ValueError:
|
||
raise ValueError(f"无效的索引: {part}")
|
||
|
||
# 去重并排序
|
||
indices = sorted(set(indices))
|
||
|
||
if not indices:
|
||
raise ValueError("未指定有效的索引")
|
||
|
||
return indices
|
||
|
||
def _save_image_summary_csv(self, delta_e_image: np.ndarray,
|
||
standard_categories: List[str],
|
||
output_path: Path) -> None:
|
||
"""保存图像结果的统计摘要"""
|
||
summary_data = []
|
||
|
||
for i, category in enumerate(standard_categories):
|
||
de_data = delta_e_image[:, :, i]
|
||
de_flat = de_data.flatten()
|
||
|
||
summary_data.append({
|
||
'category': category,
|
||
'min': np.nanmin(de_flat),
|
||
'max': np.nanmax(de_flat),
|
||
'mean': np.nanmean(de_flat),
|
||
'median': np.nanmedian(de_flat),
|
||
'std': np.nanstd(de_flat),
|
||
'q25': np.nanpercentile(de_flat, 25),
|
||
'q75': np.nanpercentile(de_flat, 75),
|
||
'pixels': np.sum(~np.isnan(de_flat))
|
||
})
|
||
|
||
summary_df = pd.DataFrame(summary_data)
|
||
summary_path = output_path.with_name(f"{output_path.stem}_summary.csv")
|
||
summary_df.to_csv(summary_path, index=False, float_format='%.4f')
|
||
print(f"Statistical summary saved to: {summary_path}")
|
||
|
||
def _create_pairwise_heatmap(self, result_df: pd.DataFrame, output_path: Path) -> None:
|
||
"""Create heatmap for pairwise comparison results"""
|
||
try:
|
||
# 提取唯一的类别
|
||
ref_categories = result_df['reference_category'].unique()
|
||
tgt_categories = result_df['target_category'].unique()
|
||
|
||
# 创建矩阵
|
||
matrix = np.zeros((len(ref_categories), len(tgt_categories)))
|
||
|
||
# 填充矩阵
|
||
cat_to_idx = {cat: i for i, cat in enumerate(ref_categories)}
|
||
for _, row in result_df.iterrows():
|
||
i = cat_to_idx[row['reference_category']]
|
||
j = np.where(tgt_categories == row['target_category'])[0][0]
|
||
matrix[i, j] = row['delta_E']
|
||
|
||
# 创建热图
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
im = ax.imshow(matrix, cmap='RdYlBu_r')
|
||
|
||
# 设置刻度
|
||
ax.set_xticks(np.arange(len(tgt_categories)))
|
||
ax.set_yticks(np.arange(len(ref_categories)))
|
||
ax.set_xticklabels(tgt_categories, rotation=45, ha='right')
|
||
ax.set_yticklabels(ref_categories)
|
||
|
||
# 添加颜色条
|
||
plt.colorbar(im, ax=ax, label='Delta E')
|
||
|
||
# Add numerical labels
|
||
for i in range(len(ref_categories)):
|
||
for j in range(len(tgt_categories)):
|
||
text = ax.text(j, i, f'{matrix[i, j]:.2f}',
|
||
ha="center", va="center", color="black", fontsize=8)
|
||
|
||
ax.set_title("Pairwise Delta E Comparison Matrix")
|
||
plt.tight_layout()
|
||
|
||
# Save heatmap
|
||
heatmap_path = output_path.with_suffix('.png')
|
||
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
print(f"Pairwise comparison heatmap saved: {heatmap_path}")
|
||
except Exception as e:
|
||
print(f"Error generating pairwise comparison heatmap: {str(e)}")
|
||
|
||
|
||
def parse_arguments():
|
||
"""Parse command line arguments"""
|
||
parser = argparse.ArgumentParser(
|
||
description='Delta E* 色差计算工具 - 完整版',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
使用示例:
|
||
1. Calculate color differences between image and standard colors:
|
||
python delta_e_tool.py --mode image --image input.lab --standards colors.csv --output result.dat
|
||
|
||
2. Calculate color differences for specified standard color range:
|
||
python delta_e_tool.py --mode image --image input.lab --standards colors.csv
|
||
--standards-range "0,1,2,5-10" --output result.dat --method CIE94
|
||
|
||
3. Generate heatmap and histogram:
|
||
python delta_e_tool.py --mode image --image input.lab --standards colors.csv
|
||
--output result.dat --create-heatmap --create-histogram
|
||
|
||
4. Calculate pairwise color differences between standard colors:
|
||
python delta_e_tool.py --mode pairwise --standards colors.csv --output pairwise.csv
|
||
|
||
5. Calculate color differences between two groups of standard colors:
|
||
python delta_e_tool.py --mode pairwise --standards colors.csv
|
||
--reference-range "0-5" --target-range "6-10" --output pairwise.csv
|
||
"""
|
||
)
|
||
|
||
# 基本参数
|
||
parser.add_argument('--mode', required=True, choices=['image', 'pairwise'],
|
||
help='Calculation mode: image (image vs standard colors) or pairwise (pairwise comparison)')
|
||
|
||
# Input parameters
|
||
parser.add_argument('--input', required=True, help='Input file path (ENVI format LAB image or CSV file)')
|
||
parser.add_argument('--standards', required=True, help='Standard colors CSV file path')
|
||
|
||
# Output parameters
|
||
parser.add_argument('--output-dir', default='./results', help='Output directory path')
|
||
parser.add_argument('--create-heatmap', action='store_true', help='Generate heatmap')
|
||
parser.add_argument('--create-histogram', action='store_true', help='Generate histogram')
|
||
|
||
# Calculation parameters
|
||
parser.add_argument('--method', default='CIEDE2000',
|
||
choices=['CIE76', 'CIE94', 'CIEDE2000'],
|
||
help='Color difference calculation method (default: CIEDE2000)')
|
||
|
||
# Range selection parameters
|
||
parser.add_argument('--standards-range', help='Standard color range (e.g.: "0,1,2" or "0-5" or "all")')
|
||
parser.add_argument('--reference-range', help='Reference color range (pairwise mode)')
|
||
parser.add_argument('--target-range', help='Target color range (pairwise mode)')
|
||
|
||
# Other parameters
|
||
parser.add_argument('--no-progress', action='store_true',
|
||
help='Disable progress bar display')
|
||
|
||
return parser.parse_args()
|
||
|
||
|
||
# =============================================================================
|
||
# 直接函数调用接口
|
||
# =============================================================================
|
||
|
||
def calculate_delta_e_image(image_path: Union[str, Path],
|
||
standards_path: Union[str, Path],
|
||
output_path: Union[str, Path],
|
||
method: str = 'CIEDE2000',
|
||
standards_range: Optional[str] = None,
|
||
create_heatmap: bool = False,
|
||
create_histogram: bool = False,
|
||
no_progress: bool = False) -> np.ndarray:
|
||
"""
|
||
Directly calculate Delta E color differences between image and standard colors
|
||
|
||
Parameters:
|
||
image_path: LAB图像文件路径 (ENVI格式)
|
||
standards_path: 标准色CSV文件路径
|
||
output_path: 输出文件路径
|
||
method: 色差计算方法 ('CIE76', 'CIE94', 'CIEDE2000')
|
||
standards_range: 标准色范围 ("0,1,2" 或 "0-5" 或 "all")
|
||
create_heatmap: 是否生成热图
|
||
create_histogram: 是否生成直方图
|
||
no_progress: 是否禁用进度条
|
||
|
||
返回:
|
||
delta_e_image: Delta E图像数组 (height, width, n_standards)
|
||
|
||
示例:
|
||
>>> delta_e = calculate_delta_e_image(
|
||
... image_path="input.lab",
|
||
... standards_path="standards.csv",
|
||
... output_path="result.dat",
|
||
... method="CIEDE2000",
|
||
... create_heatmap=True
|
||
... )
|
||
"""
|
||
# 验证依赖
|
||
if not SPECTRAL_AVAILABLE:
|
||
raise ImportError("需要安装spectral库: pip install spectral")
|
||
|
||
if not COLOUR_AVAILABLE:
|
||
raise ImportError("需要安装colour-science库: pip install colour-science")
|
||
|
||
# 创建应用实例
|
||
app = DeltaEApp()
|
||
|
||
try:
|
||
# 初始化计算器
|
||
app.calculator = DeltaECalculator(method=method)
|
||
|
||
# 加载图像
|
||
image_data, metadata = app.envi_handler.load_image(image_path)
|
||
|
||
# 加载标准色
|
||
standards_df = app.standard_manager.load_from_csv(standards_path)
|
||
|
||
# 验证数据
|
||
LabDataValidator.validate_lab_range(image_data, "输入图像", warn_only=True)
|
||
|
||
# 解析标准色范围
|
||
standard_indices = app._parse_range_param(standards_range, len(standards_df))
|
||
standard_colors = app.standard_manager.get_standard_colors(standard_indices)
|
||
|
||
# 计算Delta E
|
||
delta_e_image = app.calculator.calculate_image_vs_standards(
|
||
image_data, standard_colors, use_progress_bar=not no_progress
|
||
)
|
||
|
||
# 保存结果
|
||
app.envi_handler.save_delta_e_image(
|
||
delta_e_image, output_path,
|
||
standard_colors.categories or [f"Standard_{i}" for i in range(len(standard_indices))],
|
||
method, metadata
|
||
)
|
||
|
||
# 生成可视化
|
||
if create_heatmap and app.visualizer:
|
||
heatmap_path = Path(output_path).with_suffix('.png')
|
||
app.visualizer.create_heatmap(
|
||
delta_e_image,
|
||
standard_colors.categories or [f"Standard_{i}" for i in range(len(standard_indices))],
|
||
heatmap_path,
|
||
title=f"Delta E ({method}) 热图"
|
||
)
|
||
|
||
if create_histogram and app.visualizer:
|
||
histogram_path = Path(output_path).with_name(f"{Path(output_path).stem}_histogram.png")
|
||
app.visualizer.create_delta_e_histogram(
|
||
delta_e_image, histogram_path,
|
||
title=f"Delta E ({method}) 分布"
|
||
)
|
||
|
||
print(f"色差计算完成! 结果已保存到: {output_path}")
|
||
return delta_e_image
|
||
|
||
except Exception as e:
|
||
print(f"计算失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
def calculate_delta_e_pairwise(standards_path: Union[str, Path],
|
||
output_path: Union[str, Path],
|
||
method: str = 'CIEDE2000',
|
||
reference_range: Optional[str] = None,
|
||
target_range: Optional[str] = None,
|
||
create_heatmap: bool = False) -> pd.DataFrame:
|
||
"""
|
||
直接计算标准色之间的两两Delta E
|
||
|
||
参数:
|
||
standards_path: 标准色CSV文件路径
|
||
output_path: 输出CSV文件路径
|
||
method: 色差计算方法 ('CIE76', 'CIE94', 'CIEDE2000')
|
||
reference_range: 参考色范围 ("0,1,2" 或 "0-5" 或 "all")
|
||
target_range: 目标色范围 ("0,1,2" 或 "0-5" 或 "all")
|
||
create_heatmap: 是否生成热图
|
||
|
||
返回:
|
||
result_df: 包含Delta E结果的DataFrame
|
||
|
||
示例:
|
||
>>> df = calculate_delta_e_pairwise(
|
||
... standards_path="standards.csv",
|
||
... output_path="pairwise.csv",
|
||
... method="CIEDE2000"
|
||
... )
|
||
"""
|
||
if not COLOUR_AVAILABLE:
|
||
raise ImportError("需要安装colour-science库: pip install colour-science")
|
||
|
||
# 创建应用实例
|
||
app = DeltaEApp()
|
||
|
||
try:
|
||
# 初始化计算器
|
||
app.calculator = DeltaECalculator(method=method)
|
||
|
||
# 加载标准色
|
||
standards_df = app.standard_manager.load_from_csv(standards_path)
|
||
|
||
# 解析范围参数
|
||
all_indices = list(range(len(standards_df)))
|
||
|
||
if reference_range and target_range:
|
||
ref_indices = app._parse_range_param(reference_range, len(standards_df))
|
||
tgt_indices = app._parse_range_param(target_range, len(standards_df))
|
||
|
||
ref_colors = app.standard_manager.get_standard_colors(ref_indices)
|
||
tgt_colors = app.standard_manager.get_standard_colors(tgt_indices)
|
||
|
||
# 计算两组之间的Delta E
|
||
result_df = app.calculator.calculate_pairwise(ref_colors, tgt_colors)
|
||
else:
|
||
# 计算所有标准色内部的两两Delta E
|
||
all_colors = app.standard_manager.get_standard_colors(all_indices)
|
||
result_df = app.calculator.calculate_pairwise(all_colors)
|
||
|
||
# 保存结果
|
||
result_df.to_csv(output_path, index=False)
|
||
print(f"两两比较结果已保存到: {output_path}")
|
||
|
||
# 生成热图(如果颜色数量合理且启用)
|
||
if create_heatmap and app.visualizer and len(standards_df) <= 20:
|
||
app._create_pairwise_heatmap(result_df, Path(output_path))
|
||
|
||
return result_df
|
||
|
||
except Exception as e:
|
||
print(f"计算失败: {str(e)}")
|
||
raise
|
||
|
||
|
||
def calculate_delta_e_from_arrays(lab_array: np.ndarray,
|
||
standard_lab: np.ndarray,
|
||
method: str = 'CIEDE2000',
|
||
standard_categories: Optional[List[str]] = None) -> Union[np.ndarray, pd.DataFrame]:
|
||
"""
|
||
直接从LAB数组计算Delta E
|
||
|
||
参数:
|
||
lab_array: LAB数组 (..., 3) 或 ColorData对象
|
||
standard_lab: 标准色LAB数组 (n_standards, 3)
|
||
method: 色差计算方法 ('CIE76', 'CIE94', 'CIEDE2000')
|
||
standard_categories: 标准色类别列表
|
||
|
||
返回:
|
||
如果lab_array是图像数据: Delta E数组 (height, width, n_standards)
|
||
如果lab_array是一维数据: 包含Delta E结果的DataFrame
|
||
|
||
示例:
|
||
>>> # 计算图像与标准色的色差
|
||
>>> image_lab = np.random.rand(100, 100, 3) * 100 # 模拟图像数据
|
||
>>> standards = np.array([[50, 0, 0], [70, 20, 10]]) # 两个标准色
|
||
>>> delta_e = calculate_delta_e_from_arrays(image_lab, standards)
|
||
|
||
>>> # 计算样本与标准色的色差
|
||
>>> samples = np.array([[60, 5, 5], [80, -10, 15]])
|
||
>>> df = calculate_delta_e_from_arrays(samples, standards)
|
||
"""
|
||
if not COLOUR_AVAILABLE:
|
||
raise ImportError("需要安装colour-science库: pip install colour-science")
|
||
|
||
# 创建计算器
|
||
calculator = DeltaECalculator(method=method)
|
||
|
||
# 处理输入数据
|
||
if isinstance(lab_array, ColorData):
|
||
lab_data = lab_array
|
||
else:
|
||
lab_data = ColorData.from_array(lab_array)
|
||
|
||
if isinstance(standard_lab, ColorData):
|
||
std_data = standard_lab
|
||
else:
|
||
std_data = ColorData.from_array(standard_lab, standard_categories)
|
||
|
||
# 验证数据
|
||
LabDataValidator.validate_lab_range(lab_data, "输入数据", warn_only=True)
|
||
LabDataValidator.validate_lab_range(std_data, "标准色数据", warn_only=True)
|
||
|
||
# 根据数据维度选择计算方式
|
||
if lab_data.shape[0] > 1 and len(lab_data.shape) == 1:
|
||
# 一维数据(多个样本),返回DataFrame
|
||
return calculator.calculate_pairwise(lab_data, std_data)
|
||
else:
|
||
# 二维或三维数据(图像),返回数组
|
||
return calculator.calculate_image_vs_standards(lab_data, std_data)
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
try:
|
||
# 解析命令行参数
|
||
args = parse_arguments()
|
||
|
||
# 验证必要参数
|
||
if args.mode == 'image' and not args.input:
|
||
print("错误: 图像模式需要指定 --image 参数")
|
||
sys.exit(1)
|
||
|
||
# 创建并运行应用程序
|
||
app = DeltaEApp()
|
||
app.run(args)
|
||
|
||
except ImportError as e:
|
||
print(f"导入错误: {str(e)}")
|
||
print("请确保已安装所有必需的库:")
|
||
print(" pip install numpy pandas colour-science matplotlib spectral")
|
||
if TQDM_AVAILABLE:
|
||
print(" pip install tqdm # 可选,用于进度条")
|
||
sys.exit(1)
|
||
|
||
except FileNotFoundError as e:
|
||
print(f"文件错误: 找不到文件 - {str(e)}")
|
||
sys.exit(1)
|
||
|
||
except ValueError as e:
|
||
print(f"数据错误: {str(e)}")
|
||
sys.exit(1)
|
||
|
||
except Exception as e:
|
||
print(f"未预期的错误: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
|
||
# =============================================================================
|
||
# 直接运行示例(可修改参数后直接运行)
|
||
# =============================================================================
|
||
|
||
if __name__ == '__main__':
|
||
exit(main())
|