1096 lines
42 KiB
Python
1096 lines
42 KiB
Python
"""
|
||
回归预测工具包
|
||
支持使用训练好的回归模型对高光谱图像进行预测
|
||
包含图像读取、遮罩处理、批量预测、结果可视化等功能
|
||
"""
|
||
|
||
import os
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from pathlib import Path
|
||
import glob
|
||
import joblib
|
||
import json
|
||
import warnings
|
||
import argparse
|
||
from typing import Optional, List, Dict, Any, Union, Tuple
|
||
from dataclasses import dataclass, field
|
||
import time
|
||
|
||
# 导入回归分析器和所有必要的类(为了pickle兼容性)
|
||
try:
|
||
# 当作为模块导入时使用相对导入
|
||
from .regression import (
|
||
RegressionAnalyzer,
|
||
ExtremeLearningMachine,
|
||
GeneralizedAdditiveModel,
|
||
LSTMRegressor,
|
||
GRURegressor
|
||
)
|
||
except ImportError:
|
||
# 当直接运行脚本时使用绝对导入
|
||
import sys
|
||
import os
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
if current_dir not in sys.path:
|
||
sys.path.insert(0, current_dir)
|
||
|
||
from regression import (
|
||
RegressionAnalyzer,
|
||
ExtremeLearningMachine,
|
||
GeneralizedAdditiveModel,
|
||
LSTMRegressor,
|
||
GRURegressor
|
||
)
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class PredictionConfig:
|
||
"""预测配置类"""
|
||
image_path: str = "" # 高光谱图像路径
|
||
mask_path: Optional[str] = None # 遮罩文件路径
|
||
model_path: Union[str, List[str]] = "" # 模型文件路径(单文件或列表)
|
||
output_dir: str = "prediction_results" # 输出目录
|
||
use_mask: bool = True # 是否使用遮罩
|
||
batch_mode: bool = False # 是否批量处理
|
||
colormap: str = 'viridis' # 颜色映射
|
||
dpi: int = 300 # 输出图像DPI
|
||
save_individual: bool = True # 是否保存单个预测结果
|
||
|
||
|
||
@dataclass
|
||
class ImageReader:
|
||
"""高光谱图像读取器"""
|
||
|
||
def __init__(self):
|
||
pass
|
||
|
||
def read_envi_header(self, header_path: Path) -> Tuple[Optional[int], Optional[int], Optional[int], Optional[int], str, int]:
|
||
"""读取ENVI头文件获取图像尺寸和格式信息"""
|
||
rows, cols, bands = None, None, None
|
||
data_type = None
|
||
interleave = 'bsq' # 默认值
|
||
byte_order = 0 # 默认小端字节序
|
||
|
||
try:
|
||
with open(header_path, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
# 解析关键参数
|
||
lines = content.split('\n')
|
||
for line in lines:
|
||
line = line.strip()
|
||
if '=' in line:
|
||
key, value = line.split('=', 1)
|
||
key = key.strip()
|
||
value = value.strip()
|
||
|
||
if key.lower() == 'samples':
|
||
cols = int(value)
|
||
elif key.lower() == 'lines':
|
||
rows = int(value)
|
||
elif key.lower() == 'bands':
|
||
bands = int(value)
|
||
elif key.lower() == 'data type':
|
||
data_type = int(value)
|
||
elif key.lower() == 'interleave':
|
||
interleave = value.lower() # bsq, bil, bip
|
||
elif key.lower() == 'byte order':
|
||
byte_order = int(value)
|
||
|
||
if rows and cols and bands:
|
||
print(f"从头文件读取: {rows}x{cols}x{bands}, interleave={interleave}, data_type={data_type}")
|
||
|
||
except Exception as e:
|
||
print(f"读取头文件时出错: {e}")
|
||
|
||
return rows, cols, bands, data_type, interleave, byte_order
|
||
|
||
def get_numpy_dtype(self, envi_data_type: int) -> np.dtype:
|
||
"""将ENVI数据类型转换为numpy数据类型"""
|
||
dtype_map = {
|
||
1: np.uint8, # 8-bit byte
|
||
2: np.int16, # 16-bit signed integer
|
||
3: np.int32, # 32-bit signed long integer
|
||
4: np.float32, # 32-bit floating point
|
||
5: np.float64, # 64-bit double precision floating point
|
||
12: np.uint16, # 16-bit unsigned integer
|
||
13: np.uint32, # 32-bit unsigned long integer
|
||
14: np.int64, # 64-bit signed long integer
|
||
15: np.uint64, # 64-bit unsigned long integer
|
||
}
|
||
|
||
return dtype_map.get(envi_data_type, np.float32)
|
||
|
||
def read_envi_file(self, data_path: Path, rows: Optional[int] = None, cols: Optional[int] = None,
|
||
bands: Optional[int] = None) -> Optional[np.ndarray]:
|
||
"""
|
||
读取ENVI格式文件,支持多种读取方式
|
||
优先使用spectral库,如果不可用则使用自定义方法
|
||
"""
|
||
"""
|
||
读取ENVI格式文件(支持BSQ, BIL, BIP格式)
|
||
|
||
参数:
|
||
data_path: 数据文件路径
|
||
rows: 图像行数(高度)
|
||
cols: 图像列数(宽度)
|
||
bands: 波段数
|
||
|
||
返回:
|
||
图像数据数组,形状为(rows, cols, bands)
|
||
"""
|
||
# 尝试使用spectral库读取ENVI文件(如果可用)
|
||
try:
|
||
import spectral
|
||
print("使用spectral库读取ENVI文件")
|
||
# spectral库可以自动处理头文件和数据文件的匹配
|
||
img = spectral.open_image(str(data_path))
|
||
data = img.load()
|
||
print(f"spectral库成功读取,数据形状: {data.shape}, 数据类型: {data.dtype}")
|
||
return data
|
||
except ImportError:
|
||
print("spectral库不可用,使用自定义方法读取")
|
||
except Exception as e:
|
||
print(f"spectral库读取失败: {e},使用自定义方法")
|
||
|
||
# 使用自定义方法读取
|
||
# 尝试读取头文件获取尺寸信息
|
||
# 首先检查输入路径是否已经是头文件
|
||
if data_path.suffix.lower() == '.hdr':
|
||
header_path = data_path
|
||
else:
|
||
# 输入是数据文件,尝试寻找对应的头文件
|
||
# 首先尝试直接替换扩展名为.hdr
|
||
header_path = data_path.with_suffix('.hdr')
|
||
|
||
# 如果不存在,尝试其他可能的命名方式
|
||
if not header_path.exists():
|
||
# 对于.bil文件,头文件可能是.bil.hdr
|
||
alt_header_path = data_path.parent / f"{data_path.name}.hdr"
|
||
if alt_header_path.exists():
|
||
header_path = alt_header_path
|
||
else:
|
||
# 或者文件名.hdr(去掉扩展名)
|
||
base_header_path = data_path.parent / f"{data_path.stem}.hdr"
|
||
if base_header_path.exists():
|
||
header_path = base_header_path
|
||
|
||
if header_path.exists():
|
||
print(f"正在读取头文件: {header_path}")
|
||
rows, cols, bands, data_type, interleave, byte_order = self.read_envi_header(header_path)
|
||
|
||
|
||
if data_type is None:
|
||
data_type = 4 # 默认float32
|
||
|
||
dtype = self.get_numpy_dtype(data_type)
|
||
|
||
# 确定数据文件路径
|
||
if data_path.suffix.lower() == '.hdr':
|
||
# 输入是头文件,对应的数据文件是去掉.hdr扩展名的文件
|
||
actual_data_path = data_path.with_suffix('')
|
||
print(f"输入是头文件,计算出的数据文件路径: {actual_data_path}")
|
||
# 检查常见的数据文件扩展名
|
||
if not actual_data_path.exists():
|
||
# 尝试添加.bil扩展名
|
||
bil_path = actual_data_path.with_suffix('.bil')
|
||
if bil_path.exists():
|
||
actual_data_path = bil_path
|
||
print(f"找到BIL数据文件: {actual_data_path}")
|
||
else:
|
||
print(f"未找到数据文件,尝试的文件: {actual_data_path}, {bil_path}")
|
||
else:
|
||
# 输入是数据文件
|
||
actual_data_path = data_path
|
||
|
||
else:
|
||
# 如果没有头文件,使用默认值或参数值
|
||
if rows is None or cols is None or bands is None:
|
||
print(f"找不到头文件 {header_path},无法读取图像")
|
||
return None
|
||
print(f"未找到头文件,使用手动指定的尺寸: {rows}x{cols}x{bands}")
|
||
dtype = np.float32
|
||
interleave = 'bsq' # 默认
|
||
actual_data_path = data_path
|
||
|
||
try:
|
||
# 读取二进制数据
|
||
print(f"正在读取数据文件: {actual_data_path}, 尺寸: {rows}x{cols}x{bands}, 格式: {interleave}, 数据类型: {dtype}")
|
||
|
||
# 使用更安全的方式读取数据
|
||
file_size = actual_data_path.stat().st_size
|
||
expected_bytes = rows * cols * bands * dtype(0).itemsize
|
||
print(f"文件大小: {file_size} 字节, 期望: {expected_bytes} 字节")
|
||
|
||
if file_size != expected_bytes:
|
||
print(f"文件大小不匹配!文件: {file_size} 字节, 计算: {expected_bytes} 字节")
|
||
# 尝试继续读取,可能头文件有错误
|
||
print("尝试继续读取,但结果可能不正确...")
|
||
|
||
# 读取数据
|
||
with open(actual_data_path, 'rb') as f:
|
||
raw_data = np.frombuffer(f.read(), dtype=dtype)
|
||
|
||
# 检查数据大小是否匹配
|
||
expected_size = rows * cols * bands
|
||
if len(raw_data) != expected_size:
|
||
print(f"数据大小不匹配。期望: {expected_size}, 实际: {len(raw_data)}")
|
||
if len(raw_data) < expected_size:
|
||
print("数据不完整,无法继续处理")
|
||
return None
|
||
else:
|
||
print("数据超出期望大小,将截取所需部分")
|
||
raw_data = raw_data[:expected_size]
|
||
|
||
# 根据数据交织方式重塑数据
|
||
if interleave == 'bsq':
|
||
# BSQ: Band Sequential (波段顺序)
|
||
# 存储顺序: 波段1的所有行所有列 -> 波段2的所有行所有列 -> ...
|
||
data = raw_data.reshape(bands, rows, cols)
|
||
# 转换为标准格式: (rows, cols, bands)
|
||
data = np.transpose(data, (1, 2, 0))
|
||
|
||
elif interleave == 'bil':
|
||
# BIL: Band Interleaved by Line (按行交织)
|
||
# 存储顺序: 行1的所有波段 -> 行2的所有波段 -> ...
|
||
data = raw_data.reshape(rows, bands, cols)
|
||
# 转换为标准格式: (rows, cols, bands)
|
||
data = np.transpose(data, (0, 2, 1))
|
||
|
||
elif interleave == 'bip':
|
||
# BIP: Band Interleaved by Pixel (按像素交织)
|
||
# 存储顺序: 像素1的所有波段 -> 像素2的所有波段 -> ...
|
||
data = raw_data.reshape(rows, cols, bands)
|
||
# 已经是标准格式,无需转置
|
||
else:
|
||
print(f"未知的交织格式: {interleave},尝试BSQ格式")
|
||
data = raw_data.reshape(bands, rows, cols)
|
||
data = np.transpose(data, (1, 2, 0))
|
||
|
||
print(f"成功读取,数据形状: {data.shape}, 数据类型: {data.dtype}")
|
||
return data
|
||
|
||
except Exception as e:
|
||
print(f"读取数据文件时出错: {e}")
|
||
return None
|
||
|
||
def read_mask_file(self, mask_path: Path, reference_image_path: Optional[Path] = None) -> Optional[np.ndarray]:
|
||
"""读取遮罩文件"""
|
||
try:
|
||
# 支持多种遮罩文件格式
|
||
if mask_path.suffix.lower() in ['.bil', '.bsq', '.bip']:
|
||
# ENVI格式遮罩
|
||
mask_data = self.read_envi_file(mask_path)
|
||
if mask_data is not None:
|
||
# 转换为二值遮罩
|
||
mask_data = mask_data.squeeze()
|
||
mask_data = (mask_data > 0).astype(np.uint8)
|
||
elif mask_path.suffix.lower() in ['.tif', '.tiff']:
|
||
# TIFF格式(需要安装rasterio或PIL)
|
||
try:
|
||
from PIL import Image
|
||
mask_img = Image.open(mask_path)
|
||
mask_data = np.array(mask_img)
|
||
# 转换为二值
|
||
mask_data = (mask_data > 0).astype(np.uint8)
|
||
except ImportError:
|
||
print("需要安装PIL(Pillow)来读取TIFF文件")
|
||
return None
|
||
elif mask_path.suffix.lower() in ['.png', '.jpg', '.jpeg']:
|
||
# 图像格式遮罩
|
||
try:
|
||
from PIL import Image
|
||
mask_img = Image.open(mask_path)
|
||
mask_data = np.array(mask_img)
|
||
# 如果是彩色图像,转换为灰度
|
||
if len(mask_data.shape) == 3:
|
||
mask_data = np.mean(mask_data, axis=2)
|
||
# 转换为二值
|
||
mask_data = (mask_data > 0).astype(np.uint8)
|
||
except ImportError:
|
||
print("需要安装PIL(Pillow)来读取图像文件")
|
||
return None
|
||
elif mask_path.suffix.lower() == '.shp':
|
||
# Shapefile格式,需要栅格化
|
||
mask_data = self._rasterize_shapefile(mask_path, reference_image_path)
|
||
elif mask_path.suffix.lower() == '.dat':
|
||
# DAT格式文件
|
||
mask_data = self._read_dat_file(mask_path)
|
||
else:
|
||
print(f"不支持的遮罩文件格式: {mask_path.suffix}")
|
||
return None
|
||
|
||
if mask_data is not None:
|
||
print(f"成功读取遮罩文件: {mask_path.name}, 形状: {mask_data.shape}")
|
||
return mask_data
|
||
|
||
except Exception as e:
|
||
print(f"读取遮罩文件时出错: {e}")
|
||
return None
|
||
|
||
def _rasterize_shapefile(self, shp_path: Path, reference_image_path: Optional[Path] = None) -> Optional[np.ndarray]:
|
||
"""将Shapefile栅格化"""
|
||
try:
|
||
import geopandas as gpd
|
||
import rasterio
|
||
from rasterio import features
|
||
from rasterio.transform import from_bounds
|
||
import numpy as np
|
||
|
||
# 读取shapefile
|
||
gdf = gpd.read_file(shp_path)
|
||
|
||
if gdf.empty:
|
||
print("Shapefile为空")
|
||
return None
|
||
|
||
# 获取参考图像的地理信息
|
||
if reference_image_path and reference_image_path.exists():
|
||
# 处理ENVI文件的路径问题
|
||
rasterio_path = str(reference_image_path)
|
||
if reference_image_path.suffix.lower() == '.hdr':
|
||
# 如果是头文件,尝试找到对应的数据文件
|
||
data_path = reference_image_path.with_suffix('')
|
||
if data_path.exists():
|
||
rasterio_path = str(data_path)
|
||
print(f"使用ENVI数据文件作为参考: {rasterio_path}")
|
||
else:
|
||
bil_path = reference_image_path.with_suffix('.bil')
|
||
if bil_path.exists():
|
||
rasterio_path = str(bil_path)
|
||
print(f"使用ENVI .bil文件作为参考: {rasterio_path}")
|
||
else:
|
||
print(f"找不到对应的ENVI数据文件,跳过参考图像")
|
||
reference_image_path = None
|
||
|
||
if reference_image_path: # 检查是否仍然有效
|
||
try:
|
||
# 从参考图像获取尺寸和地理变换
|
||
with rasterio.open(rasterio_path) as ref_src:
|
||
height, width = ref_src.height, ref_src.width
|
||
transform = ref_src.transform
|
||
crs = ref_src.crs
|
||
|
||
print(f"使用参考图像的地理信息: {width}x{height}")
|
||
except Exception as e:
|
||
print(f"读取参考图像失败: {e},将使用Shapefile边界")
|
||
reference_image_path = None
|
||
else:
|
||
# 如果没有参考图像,使用shapefile的边界
|
||
bounds = gdf.total_bounds # [minx, miny, maxx, maxy]
|
||
width, height = 1000, 1000 # 默认尺寸
|
||
transform = from_bounds(bounds[0], bounds[1], bounds[2], bounds[3], width, height)
|
||
crs = gdf.crs
|
||
|
||
print(f"使用Shapefile边界创建栅格: {width}x{height}")
|
||
|
||
# 创建栅格化掩码
|
||
mask_data = np.zeros((height, width), dtype=np.uint8)
|
||
|
||
# 将几何图形栅格化
|
||
shapes = ((geom, 1) for geom in gdf.geometry)
|
||
burned = features.rasterize(
|
||
shapes=shapes,
|
||
out=mask_data,
|
||
transform=transform,
|
||
all_touched=True,
|
||
dtype=np.uint8
|
||
)
|
||
|
||
return burned
|
||
|
||
except ImportError as e:
|
||
print(f"栅格化Shapefile需要安装必要的库: {e}")
|
||
print("请安装: pip install geopandas rasterio")
|
||
return None
|
||
except Exception as e:
|
||
print(f"栅格化Shapefile时出错: {e}")
|
||
return None
|
||
|
||
def _read_dat_file(self, dat_path: Path) -> Optional[np.ndarray]:
|
||
"""读取DAT格式文件"""
|
||
try:
|
||
# DAT文件可能是二进制数据文件,尝试不同的读取方式
|
||
|
||
# 首先尝试作为文本文件读取
|
||
try:
|
||
# 尝试读取为CSV格式
|
||
data = np.loadtxt(dat_path, delimiter=',')
|
||
print(f"作为CSV格式读取DAT文件,形状: {data.shape}")
|
||
|
||
# 如果是多列数据,取第一列作为掩码
|
||
if len(data.shape) > 1:
|
||
mask_data = data[:, 0] if data.shape[1] > 1 else data.flatten()
|
||
else:
|
||
mask_data = data.flatten()
|
||
|
||
except ValueError:
|
||
# 如果不是文本格式,尝试作为二进制文件读取
|
||
print("尝试作为二进制文件读取DAT文件")
|
||
|
||
# 读取二进制数据
|
||
with open(dat_path, 'rb') as f:
|
||
raw_data = f.read()
|
||
|
||
# 尝试推断数据类型和尺寸(这可能需要调整)
|
||
# 假设是float32类型,需要根据实际情况调整
|
||
try:
|
||
data = np.frombuffer(raw_data, dtype=np.float32)
|
||
print(f"作为float32二进制读取,元素数量: {len(data)}")
|
||
|
||
# 尝试推断二维结构(这只是一个假设)
|
||
# 可能需要手动指定尺寸或从其他地方获取
|
||
n = int(np.sqrt(len(data)))
|
||
if n * n == len(data):
|
||
mask_data = data.reshape(n, n)
|
||
print(f"重塑为方阵: {n}x{n}")
|
||
else:
|
||
# 如果无法形成方阵,保持一维
|
||
mask_data = data
|
||
print("保持一维数组")
|
||
|
||
except Exception as e:
|
||
print(f"二进制读取失败: {e}")
|
||
return None
|
||
|
||
# 转换为二值掩码
|
||
mask_data = (mask_data > 0).astype(np.uint8)
|
||
|
||
return mask_data
|
||
|
||
except Exception as e:
|
||
print(f"读取DAT文件时出错: {e}")
|
||
return None
|
||
|
||
|
||
@dataclass
|
||
class RegressionPredictor:
|
||
"""回归预测器类"""
|
||
|
||
def __init__(self, config: PredictionConfig):
|
||
"""
|
||
初始化回归预测器
|
||
|
||
参数:
|
||
config: 预测配置对象
|
||
"""
|
||
self.config = config
|
||
self.image_reader = ImageReader()
|
||
self.models = {} # 存储加载的模型
|
||
self.scalers = {} # 存储加载的标准化器
|
||
self.model_info = {} # 存储模型信息
|
||
|
||
# 创建输出目录
|
||
self.output_dir = Path(config.output_dir)
|
||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
def load_models(self) -> bool:
|
||
"""加载模型文件"""
|
||
try:
|
||
if isinstance(self.config.model_path, str):
|
||
# 单模型文件或文件夹
|
||
model_path = Path(self.config.model_path)
|
||
if model_path.is_file():
|
||
# 单文件
|
||
return self._load_single_model(model_path)
|
||
elif model_path.is_dir():
|
||
# 文件夹中的多个模型
|
||
return self._load_models_from_directory(model_path)
|
||
else:
|
||
print(f"模型路径不存在: {model_path}")
|
||
return False
|
||
elif isinstance(self.config.model_path, list):
|
||
# 模型文件列表
|
||
for model_path in self.config.model_path:
|
||
if not self._load_single_model(Path(model_path)):
|
||
print(f"加载模型失败: {model_path}")
|
||
return False
|
||
return True
|
||
else:
|
||
print("无效的模型路径格式")
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"加载模型时出错: {e}")
|
||
return False
|
||
|
||
def _load_single_model(self, model_path: Path) -> bool:
|
||
"""加载单个模型"""
|
||
try:
|
||
# 构建相关文件路径
|
||
model_file = None
|
||
scaler_file = None
|
||
info_file = None
|
||
|
||
if model_path.suffix == '.pkl':
|
||
# 直接是模型文件
|
||
model_file = model_path
|
||
# 尝试寻找对应的标准化器和信息文件
|
||
scaler_file = model_path.parent / f"scaler_{model_path.stem.split('_')[-1]}.pkl"
|
||
info_file = model_path.parent / f"info_{model_path.stem}.json"
|
||
else:
|
||
# 可能是文件夹或错误路径
|
||
print(f"不支持的模型文件格式: {model_path}")
|
||
return False
|
||
|
||
# 加载模型
|
||
if model_file.exists():
|
||
model = joblib.load(model_file)
|
||
model_name = model_file.stem
|
||
self.models[model_name] = model
|
||
print(f"加载模型: {model_name}")
|
||
else:
|
||
print(f"模型文件不存在: {model_file}")
|
||
return False
|
||
|
||
# 加载标准化器
|
||
if scaler_file and scaler_file.exists():
|
||
scaler = joblib.load(scaler_file)
|
||
self.scalers[model_name] = scaler
|
||
print(f"加载标准化器: {scaler_file.name}")
|
||
else:
|
||
print(f"标准化器文件不存在,将不使用标准化: {scaler_file}")
|
||
# 不使用标准化器,设置为None
|
||
self.scalers[model_name] = None
|
||
|
||
# 加载模型信息
|
||
if info_file and info_file.exists():
|
||
with open(info_file, 'r', encoding='utf-8') as f:
|
||
info = json.load(f)
|
||
self.model_info[model_name] = info
|
||
print(f"加载模型信息: {info_file.name}")
|
||
else:
|
||
print(f"模型信息文件不存在: {info_file}")
|
||
# 创建基本信息
|
||
self.model_info[model_name] = {
|
||
'model_name': model_name,
|
||
'timestamp': 'unknown'
|
||
}
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"加载单个模型时出错: {e}")
|
||
return False
|
||
|
||
def _load_models_from_directory(self, model_dir: Path) -> bool:
|
||
"""从目录加载多个模型"""
|
||
try:
|
||
# 查找所有.pkl文件
|
||
model_files = list(model_dir.glob('*.pkl'))
|
||
model_files = [f for f in model_files if not f.name.startswith('scaler_')]
|
||
|
||
if not model_files:
|
||
print(f"在目录 {model_dir} 中未找到模型文件")
|
||
return False
|
||
|
||
print(f"在目录中找到 {len(model_files)} 个模型文件")
|
||
|
||
# 加载每个模型
|
||
for model_file in model_files:
|
||
if not self._load_single_model(model_file):
|
||
print(f"加载模型失败: {model_file}")
|
||
return False
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"从目录加载模型时出错: {e}")
|
||
return False
|
||
|
||
def load_image_and_mask(self) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||
"""加载图像和遮罩"""
|
||
# 加载高光谱图像
|
||
image_path = Path(self.config.image_path)
|
||
if not image_path.exists():
|
||
print(f"图像文件不存在: {image_path}")
|
||
return None, None
|
||
|
||
image_data = self.image_reader.read_envi_file(image_path)
|
||
if image_data is None:
|
||
print("无法读取图像文件")
|
||
return None, None
|
||
|
||
# 加载遮罩(如果指定)
|
||
mask_data = None
|
||
if self.config.use_mask and self.config.mask_path:
|
||
mask_path = Path(self.config.mask_path)
|
||
if mask_path.exists():
|
||
# 为Shapefile栅格化传递参考图像路径
|
||
reference_path = image_path if mask_path.suffix.lower() == '.shp' else None
|
||
mask_data = self.image_reader.read_mask_file(mask_path, reference_path)
|
||
if mask_data is None:
|
||
print("无法读取遮罩文件,将不使用遮罩")
|
||
else:
|
||
print(f"遮罩文件不存在: {mask_path},将不使用遮罩")
|
||
|
||
return image_data, mask_data
|
||
|
||
def preprocess_image(self, image_data: np.ndarray, mask_data: Optional[np.ndarray] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
||
"""预处理图像数据"""
|
||
# 获取图像尺寸
|
||
rows, cols, bands = image_data.shape
|
||
|
||
# 重塑为二维数组 (pixels, bands) 用于预测
|
||
X_image = image_data.reshape(-1, bands)
|
||
|
||
# 处理遮罩
|
||
valid_mask = None
|
||
if mask_data is not None and self.config.use_mask:
|
||
# 确保遮罩尺寸匹配
|
||
if mask_data.shape[:2] != (rows, cols):
|
||
print(f"遮罩尺寸不匹配: 图像{rows}x{cols}, 遮罩{mask_data.shape[:2]}")
|
||
# 尝试调整遮罩尺寸
|
||
try:
|
||
from skimage.transform import resize
|
||
mask_data_resized = resize(mask_data.astype(float), (rows, cols), order=0, preserve_range=True)
|
||
mask_data = mask_data_resized.astype(np.uint8)
|
||
print("已调整遮罩尺寸以匹配图像")
|
||
except ImportError:
|
||
print("需要安装scikit-image来调整遮罩尺寸")
|
||
mask_data = None
|
||
|
||
if mask_data is not None:
|
||
# 创建有效像素掩码
|
||
valid_mask = mask_data.reshape(-1) > 0
|
||
print(f"使用遮罩,有效像素: {valid_mask.sum()}/{len(valid_mask)}")
|
||
|
||
# 只保留有效像素
|
||
X_image = X_image[valid_mask]
|
||
else:
|
||
print("不使用遮罩,处理所有像素")
|
||
|
||
return X_image, valid_mask
|
||
|
||
def predict_image(self, image_data: np.ndarray, mask_data: Optional[np.ndarray] = None) -> Dict[str, np.ndarray]:
|
||
"""对图像进行预测"""
|
||
# 预处理图像
|
||
X_image, valid_mask = self.preprocess_image(image_data, mask_data)
|
||
|
||
if len(X_image) == 0:
|
||
print("没有有效像素进行预测")
|
||
return {}
|
||
|
||
predictions = {}
|
||
|
||
# 对每个加载的模型进行预测
|
||
for model_name, model in self.models.items():
|
||
try:
|
||
print(f"使用模型 {model_name} 进行预测...")
|
||
|
||
# 获取标准化器
|
||
scaler = self.scalers.get(model_name)
|
||
if scaler:
|
||
# 标准化数据
|
||
X_scaled = scaler.transform(X_image)
|
||
else:
|
||
X_scaled = X_image
|
||
|
||
# 进行预测
|
||
y_pred = model.predict(X_scaled)
|
||
|
||
# 重建预测图像
|
||
rows, cols, _ = image_data.shape
|
||
pred_image = np.full((rows * cols,), np.nan, dtype=np.float32)
|
||
|
||
if valid_mask is not None:
|
||
# 只填充有效像素
|
||
pred_image[valid_mask] = y_pred
|
||
else:
|
||
# 填充所有像素
|
||
pred_image[:] = y_pred
|
||
|
||
# 重塑为图像格式
|
||
pred_image = pred_image.reshape(rows, cols)
|
||
|
||
predictions[model_name] = pred_image
|
||
print(f"模型 {model_name} 预测完成")
|
||
|
||
except Exception as e:
|
||
print(f"使用模型 {model_name} 预测时出错: {e}")
|
||
continue
|
||
|
||
return predictions
|
||
|
||
def save_prediction_results(self, predictions: Dict[str, np.ndarray], image_data: np.ndarray,
|
||
mask_data: Optional[np.ndarray] = None) -> List[Path]:
|
||
"""保存预测结果"""
|
||
saved_files = []
|
||
image_name = Path(self.config.image_path).stem
|
||
|
||
for model_name, pred_image in predictions.items():
|
||
try:
|
||
# 保存PNG可视化结果
|
||
if self.config.save_individual:
|
||
output_name = f"{image_name}_{model_name}_prediction.png"
|
||
else:
|
||
output_name = f"{image_name}_prediction.png"
|
||
|
||
output_path = self.output_dir / output_name
|
||
|
||
# 可视化预测结果
|
||
self._visualize_prediction(pred_image, model_name, output_path, mask_data)
|
||
|
||
saved_files.append(output_path)
|
||
print(f"保存预测结果PNG: {output_path}")
|
||
|
||
# 保存ENVI格式的单波段文件
|
||
envi_files = self._save_envi_prediction(pred_image, model_name, image_name, image_data.shape)
|
||
saved_files.extend(envi_files)
|
||
|
||
except Exception as e:
|
||
print(f"保存模型 {model_name} 的预测结果时出错: {e}")
|
||
continue
|
||
|
||
return saved_files
|
||
|
||
def _save_envi_prediction(self, pred_image: np.ndarray, model_name: str, image_name: str,
|
||
original_shape: tuple) -> List[Path]:
|
||
"""保存预测结果为ENVI格式的单波段文件"""
|
||
saved_files = []
|
||
|
||
try:
|
||
# 创建输出文件名
|
||
if self.config.save_individual:
|
||
base_name = f"{image_name}_{model_name}_prediction"
|
||
else:
|
||
base_name = f"{image_name}_prediction"
|
||
|
||
bil_path = self.output_dir / f"{base_name}.bil"
|
||
hdr_path = self.output_dir / f"{base_name}.hdr"
|
||
|
||
# 处理NaN值
|
||
pred_image_clean = pred_image.copy()
|
||
if np.any(np.isnan(pred_image_clean)):
|
||
# 用最小值填充NaN
|
||
min_val = np.nanmin(pred_image_clean)
|
||
pred_image_clean = np.nan_to_num(pred_image_clean, nan=min_val)
|
||
|
||
# 确保数据类型适合保存
|
||
pred_image_clean = pred_image_clean.astype(np.float32)
|
||
|
||
# 保存为BIL格式 (Band Interleaved by Line)
|
||
rows, cols = pred_image_clean.shape
|
||
|
||
# BIL格式: 按行存储,每个像素的所有波段连续存储
|
||
# 对于单波段数据,BIL格式就是简单的行优先存储
|
||
pred_image_clean.tofile(bil_path)
|
||
print(f"保存ENVI数据文件: {bil_path}")
|
||
|
||
# 生成对应的头文件
|
||
self._create_envi_header(hdr_path, rows, cols, 1, bil_path, pred_image_clean.dtype)
|
||
|
||
saved_files.extend([bil_path, hdr_path])
|
||
print(f"保存ENVI头文件: {hdr_path}")
|
||
|
||
except Exception as e:
|
||
print(f"保存ENVI格式预测结果时出错: {e}")
|
||
|
||
return saved_files
|
||
|
||
def _create_envi_header(self, hdr_path: Path, rows: int, cols: int, bands: int,
|
||
data_path: Path, dtype: np.dtype):
|
||
"""创建ENVI格式的头文件"""
|
||
try:
|
||
# 获取数据类型映射
|
||
envi_data_type = self._get_envi_data_type(dtype)
|
||
|
||
header_content = f"""ENVI
|
||
description = {{
|
||
Regression prediction result saved as ENVI format}}
|
||
samples = {cols}
|
||
lines = {rows}
|
||
bands = {bands}
|
||
header offset = 0
|
||
file type = ENVI Standard
|
||
data type = {envi_data_type}
|
||
interleave = bil
|
||
sensor type = Unknown
|
||
byte order = 0
|
||
"""
|
||
|
||
with open(hdr_path, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
except Exception as e:
|
||
print(f"创建ENVI头文件时出错: {e}")
|
||
|
||
def _get_envi_data_type(self, dtype: np.dtype) -> int:
|
||
"""将numpy数据类型转换为ENVI数据类型"""
|
||
dtype_map = {
|
||
np.uint8: 1, # 8-bit byte
|
||
np.int16: 2, # 16-bit signed integer
|
||
np.int32: 3, # 32-bit signed long integer
|
||
np.float32: 4, # 32-bit floating point
|
||
np.float64: 5, # 64-bit double precision floating point
|
||
np.uint16: 12, # 16-bit unsigned integer
|
||
np.uint32: 13, # 32-bit unsigned long integer
|
||
np.int64: 14, # 64-bit signed long integer
|
||
np.uint64: 15, # 64-bit unsigned long integer
|
||
}
|
||
|
||
return dtype_map.get(dtype.type, 4) # 默认float32
|
||
|
||
def _visualize_prediction(self, pred_image: np.ndarray, model_name: str, output_path: Path,
|
||
mask_data: Optional[np.ndarray] = None):
|
||
"""可视化预测结果"""
|
||
# 处理NaN值
|
||
pred_image_clean = pred_image.copy()
|
||
if np.any(np.isnan(pred_image_clean)):
|
||
# 用最小值填充NaN
|
||
min_val = np.nanmin(pred_image_clean)
|
||
pred_image_clean = np.nan_to_num(pred_image_clean, nan=min_val)
|
||
|
||
# 创建图形
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
||
|
||
# 原始预测结果
|
||
im1 = ax1.imshow(pred_image_clean, cmap=self.config.colormap)
|
||
ax1.set_title(f'{model_name} - Prediction Results')
|
||
ax1.axis('off')
|
||
plt.colorbar(im1, ax=ax1, shrink=0.8)
|
||
|
||
# 如果有遮罩,显示遮罩区域
|
||
if mask_data is not None:
|
||
# 创建RGB图像叠加遮罩
|
||
rgb_image = np.zeros((*pred_image_clean.shape, 3))
|
||
|
||
# 归一化预测结果到0-1
|
||
pred_norm = (pred_image_clean - np.nanmin(pred_image_clean)) / (np.nanmax(pred_image_clean) - np.nanmin(pred_image_clean))
|
||
pred_norm = np.nan_to_num(pred_norm, nan=0)
|
||
|
||
# 应用颜色映射
|
||
cmap = plt.cm.get_cmap(self.config.colormap)
|
||
colored_pred = cmap(pred_norm)
|
||
rgb_image = colored_pred[:, :, :3]
|
||
|
||
# 叠加遮罩边界
|
||
from skimage import measure
|
||
try:
|
||
contours = measure.find_contours(mask_data, 0.5)
|
||
for contour in contours:
|
||
# 绘制遮罩边界
|
||
ax2.plot(contour[:, 1], contour[:, 0], color='red', linewidth=1, alpha=0.7)
|
||
except ImportError:
|
||
pass
|
||
|
||
ax2.imshow(rgb_image)
|
||
ax2.set_title(f'{model_name} - With Mask Overlay')
|
||
else:
|
||
# 显示预测结果的直方图
|
||
ax2.hist(pred_image_clean.flatten(), bins=50, alpha=0.7, color='blue', edgecolor='black')
|
||
ax2.set_title(f'{model_name} - Prediction Distribution')
|
||
ax2.set_xlabel('Predicted Value')
|
||
ax2.set_ylabel('Frequency')
|
||
|
||
ax2.axis('off') if mask_data is not None else None
|
||
|
||
# 添加统计信息
|
||
stats_text = f"""
|
||
Statistics:
|
||
Min: {np.nanmin(pred_image):.4f}
|
||
Max: {np.nanmax(pred_image):.4f}
|
||
Mean: {np.nanmean(pred_image):.4f}
|
||
Std: {np.nanstd(pred_image):.4f}
|
||
Valid Pixels: {np.sum(~np.isnan(pred_image))}/{pred_image.size}
|
||
"""
|
||
|
||
fig.suptitle(f'Regression Prediction Results - {model_name}', fontsize=14, fontweight='bold')
|
||
plt.figtext(0.02, 0.02, stats_text, fontsize=10, verticalalignment='bottom',
|
||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
plt.savefig(output_path, dpi=self.config.dpi, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
def predict_single_image(self) -> bool:
|
||
"""预测单个图像"""
|
||
print(f"开始预测图像: {self.config.image_path}")
|
||
|
||
# 加载图像和遮罩
|
||
image_data, mask_data = self.load_image_and_mask()
|
||
if image_data is None:
|
||
return False
|
||
|
||
# 加载模型
|
||
if not self.load_models():
|
||
return False
|
||
|
||
# 进行预测
|
||
predictions = self.predict_image(image_data, mask_data)
|
||
if not predictions:
|
||
return False
|
||
|
||
# 保存结果
|
||
saved_files = self.save_prediction_results(predictions, image_data, mask_data)
|
||
|
||
print(f"预测完成!生成了 {len(saved_files)} 个结果文件")
|
||
print(f"输出目录: {self.output_dir}")
|
||
|
||
return True
|
||
|
||
def predict_batch_images(self) -> bool:
|
||
"""批量预测多个图像"""
|
||
print("开始批量预测...")
|
||
|
||
# 获取图像列表
|
||
image_path = Path(self.config.image_path)
|
||
if image_path.is_file():
|
||
image_files = [image_path]
|
||
elif image_path.is_dir():
|
||
# 查找目录中的所有支持的图像文件
|
||
extensions = ['.bil', '.bsq', '.bip']
|
||
image_files = []
|
||
for ext in extensions:
|
||
image_files.extend(list(image_path.glob(f'*{ext}')))
|
||
image_files = sorted(list(set(image_files))) # 去重并排序
|
||
else:
|
||
print(f"无效的图像路径: {image_path}")
|
||
return False
|
||
|
||
if not image_files:
|
||
print("未找到图像文件")
|
||
return False
|
||
|
||
print(f"找到 {len(image_files)} 个图像文件")
|
||
|
||
# 加载模型(只加载一次)
|
||
if not self.load_models():
|
||
return False
|
||
|
||
total_success = 0
|
||
|
||
for i, image_file in enumerate(image_files, 1):
|
||
print(f"\n处理图像 {i}/{len(image_files)}: {image_file.name}")
|
||
print("-" * 50)
|
||
|
||
# 更新配置中的图像路径
|
||
self.config.image_path = str(image_file)
|
||
|
||
# 尝试预测单个图像
|
||
if self.predict_single_image():
|
||
total_success += 1
|
||
else:
|
||
print(f"预测失败: {image_file.name}")
|
||
|
||
print(f"\n批量预测完成!")
|
||
print(f"成功预测: {total_success}/{len(image_files)} 个图像")
|
||
|
||
return total_success > 0
|
||
|
||
def run_prediction(self) -> bool:
|
||
"""运行预测流程"""
|
||
start_time = time.time()
|
||
|
||
try:
|
||
if self.config.batch_mode:
|
||
success = self.predict_batch_images()
|
||
else:
|
||
success = self.predict_single_image()
|
||
|
||
elapsed_time = time.time() - start_time
|
||
|
||
return success
|
||
|
||
except Exception as e:
|
||
elapsed_time = time.time() - start_time
|
||
print(f"错误详情: {e}")
|
||
return False
|
||
|
||
|
||
def create_default_config(image_path: str, model_path: str, mask_path: Optional[str] = None,
|
||
output_dir: str = "prediction_results", use_mask: bool = True) -> PredictionConfig:
|
||
"""创建默认预测配置"""
|
||
return PredictionConfig(
|
||
image_path=image_path,
|
||
mask_path=mask_path,
|
||
model_path=model_path,
|
||
output_dir=output_dir,
|
||
use_mask=use_mask,
|
||
batch_mode=False,
|
||
colormap='viridis',
|
||
dpi=300,
|
||
save_individual=True
|
||
)
|
||
|
||
|
||
# def main():
|
||
# """主函数"""
|
||
# parser = argparse.ArgumentParser(description='回归预测工具 - 对高光谱图像进行含量预测')
|
||
|
||
# # 必需参数
|
||
# parser.add_argument('image_path', help='高光谱图像文件路径或目录')
|
||
# parser.add_argument('model_path', help='模型文件路径(单个文件或目录)')
|
||
|
||
# # 可选参数
|
||
# parser.add_argument('--mask_path', help='遮罩文件路径')
|
||
# parser.add_argument('--output_dir', default='prediction_results', help='输出目录')
|
||
# parser.add_argument('--no_mask', action='store_true', help='不使用遮罩')
|
||
# parser.add_argument('--batch', action='store_true', help='批量处理模式')
|
||
# parser.add_argument('--colormap', default='viridis', help='颜色映射')
|
||
# parser.add_argument('--dpi', type=int, default=300, help='输出图像DPI')
|
||
# parser.add_argument('--no_individual', action='store_true', help='不保存单个模型结果')
|
||
|
||
# args = parser.parse_args()
|
||
|
||
# # 创建配置
|
||
# config = PredictionConfig(
|
||
# image_path=args.image_path,
|
||
# mask_path=args.mask_path,
|
||
# model_path=args.model_path,
|
||
# output_dir=args.output_dir,
|
||
# use_mask=not args.no_mask,
|
||
# batch_mode=args.batch,
|
||
# colormap=args.colormap,
|
||
# dpi=args.dpi,
|
||
# save_individual=not args.no_individual
|
||
# )
|
||
|
||
# # 创建预测器并运行
|
||
# predictor = RegressionPredictor(config)
|
||
# success = predictor.run_prediction()
|
||
|
||
# return 0 if success else 1
|
||
def main():
|
||
"""主函数"""
|
||
|
||
|
||
# 必需参数
|
||
image_path = r"D:\resonon\RegressionTutorial\7_minute.bil"
|
||
model_path = r"E:\code\spectronon\single_classsfication\rgression_method\models"
|
||
|
||
# 可选参数
|
||
mask_path = r"E:\code\spectronon\single_classsfication\rgression_method\data\roi.shp"
|
||
output_dir = r"E:\code\spectronon\single_classsfication\output"
|
||
mask = True
|
||
batch = True
|
||
colormap = 'viridis'
|
||
dpi = 300
|
||
no_individual = False
|
||
|
||
# 创建配置
|
||
config = PredictionConfig(
|
||
image_path=image_path,
|
||
mask_path=mask_path,
|
||
model_path=model_path,
|
||
output_dir=output_dir,
|
||
use_mask=mask,
|
||
batch_mode=batch,
|
||
colormap=colormap,
|
||
dpi=dpi,
|
||
save_individual=not no_individual
|
||
)
|
||
|
||
# 创建预测器并运行
|
||
predictor = RegressionPredictor(config)
|
||
success = predictor.run_prediction()
|
||
|
||
return 0 if success else 1
|
||
|
||
if __name__ == "__main__":
|
||
exit(main()) |