import os
import cv2
import numpy as np
import random
import shutil
from pathlib import Path

def rotate_image_and_labels(image, labels, angle):
    """
    旋转图像并调整标签坐标
    :param image: 输入图像
    :param labels: YOLO格式标签 [[class_id, x_center, y_center, width, height], ...]
    :param angle: 旋转角度(90, 180, 270)
    :return: 旋转后的图像和标签
    """
    h, w = image.shape[:2]
    center = (w // 2, h // 2)
    
    # 计算旋转矩阵
    if angle == 90:
        M = cv2.getRotationMatrix2D(center, 90, 1.0)
        new_w, new_h = h, w
    elif angle == 180:
        M = cv2.getRotationMatrix2D(center, 180, 1.0)
        new_w, new_h = w, h
    elif angle == 270:
        M = cv2.getRotationMatrix2D(center, 270, 1.0)
        new_w, new_h = h, w
    else:
        raise ValueError("只支持90, 180, 270度旋转")
    
    # 应用旋转
    rotated = cv2.warpAffine(image, M, (new_w, new_h))
    
    # 调整标签坐标
    rotated_labels = []
    for label in labels:
        class_id, x, y, bw, bh = label
        
        # 转换为绝对坐标
        x_abs = x * w
        y_abs = y * h
        bw_abs = bw * w
        bh_abs = bh * h
        
        # 计算边界框四个角点
        x1 = x_abs - bw_abs/2
        y1 = y_abs - bh_abs/2
        x2 = x_abs + bw_abs/2
        y2 = y_abs + bh_abs/2
        points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
        
        # 旋转角点
        ones = np.ones(shape=(len(points), 1))
        points_ones = np.hstack([points, ones])
        transformed_points = M.dot(points_ones.T).T
        
        # 获取新边界框
        min_x, min_y = np.min(transformed_points, axis=0)
        max_x, max_y = np.max(transformed_points, axis=0)
        
        # 转换为归一化坐标
        new_x = (min_x + max_x) / (2 * new_w)
        new_y = (min_y + max_y) / (2 * new_h)
        new_bw = (max_x - min_x) / new_w
        new_bh = (max_y - min_y) / new_h
        
        # 确保坐标在[0,1]范围内
        new_x = np.clip(new_x, 0.0, 1.0)
        new_y = np.clip(new_y, 0.0, 1.0)
        new_bw = np.clip(new_bw, 0.0, 1.0)
        new_bh = np.clip(new_bh, 0.0, 1.0)
        
        # 只保留有效的边界框
        if new_bw > 0.01 and new_bh > 0.01:
            rotated_labels.append([class_id, new_x, new_y, new_bw, new_bh])
    
    return rotated, rotated_labels

def crop_image_and_labels(image, labels, grid_size=(4, 4)):
    """
    裁剪图像为网格并调整标签
    :param image: 输入图像
    :param labels: YOLO格式标签
    :param grid_size: 网格尺寸 (rows, cols)
    :return: 裁剪后的图像和标签列表
    """
    h, w = image.shape[:2]
    crops = []
    rows, cols = grid_size
    
    # 计算每个裁剪区域的大小
    crop_h = h // rows
    crop_w = w // cols
    
    for i in range(rows):
        for j in range(cols):
            # 计算裁剪区域
            x1 = j * crop_w
            y1 = i * crop_h
            x2 = x1 + crop_w
            y2 = y1 + crop_h
            
            # 裁剪图像
            crop = image[y1:y2, x1:x2]
            
            # 调整标签
            crop_labels = []
            for label in labels:
                class_id, x, y, bw, bh = label
                
                # 转换为绝对坐标
                x_abs = x * w
                y_abs = y * h
                bw_abs = bw * w
                bh_abs = bh * h
                
                # 计算边界框
                box_x1 = x_abs - bw_abs/2
                box_y1 = y_abs - bh_abs/2
                box_x2 = x_abs + bw_abs/2
                box_y2 = y_abs + bh_abs/2
                
                # 计算与裁剪区域的交集
                inter_x1 = max(box_x1, x1)
                inter_y1 = max(box_y1, y1)
                inter_x2 = min(box_x2, x2)
                inter_y2 = min(box_y2, y2)
                
                inter_w = inter_x2 - inter_x1
                inter_h = inter_y2 - inter_y1
                
                # 如果交集有效
                if inter_w > 0 and inter_h > 0:
                    # 计算新的中心点
                    new_x = (inter_x1 + inter_w/2 - x1) / crop_w
                    new_y = (inter_y1 + inter_h/2 - y1) / crop_h
                    new_bw = inter_w / crop_w
                    new_bh = inter_h / crop_h
                    
                    # 确保坐标在[0,1]范围内
                    new_x = np.clip(new_x, 0.0, 1.0)
                    new_y = np.clip(new_y, 0.0, 1.0)
                    new_bw = np.clip(new_bw, 0.0, 1.0)
                    new_bh = np.clip(new_bh, 0.0, 1.0)
                    
                    # 只保留有效的边界框
                    if new_bw > 0.01 and new_bh > 0.01:
                        crop_labels.append([class_id, new_x, new_y, new_bw, new_bh])
            
            crops.append((crop, crop_labels, (i, j)))
    
    return crops

def adjust_rgb(image, r_factor, g_factor, b_factor):
    """调整RGB通道参数"""
    # 分离通道
    b, g, r = cv2.split(image)
    
    # 应用增益
    r = np.clip(r * r_factor, 0, 255).astype(np.uint8)
    g = np.clip(g * g_factor, 0, 255).astype(np.uint8)
    b = np.clip(b * b_factor, 0, 255).astype(np.uint8)
    
    # 合并通道
    return cv2.merge([b, g, r])

def adjust_contrast(image, factor):
    """调整对比度"""
    mean = np.mean(image, axis=(0, 1))
    adjusted = np.clip((image - mean) * factor + mean, 0, 255).astype(np.uint8)
    return adjusted

def adjust_grayscale(image, grayscale_factor=1.0):
    """调整灰度参数"""
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
    
    # 混合原始图像和灰度图像
    return cv2.addWeighted(image, 1 - grayscale_factor, gray, grayscale_factor, 0)

def parse_yolo_label(label_path):
    """解析YOLO标签文件"""
    labels = []
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            for line in f.readlines():
                parts = line.strip().split()
                if len(parts) == 5:
                    labels.append(list(map(float, parts)))
    return labels

def save_yolo_label(label_path, labels):
    """保存YOLO格式标签"""
    with open(label_path, 'w') as f:
        for label in labels:
            class_id, x, y, w, h = label
            f.write(f"{int(class_id)} {x:.6f} {y:.6f} {w:.6f} {h:.6f}\n")

def main():
    # 配置路径
    input_dir = Path(r"F:\Cells_Identification\Selected_Data")
    output_dir = Path(r"F:\Cells_Identification\Selected_Data_Enhance")
    
    # 创建输出目录
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 支持的图像格式
    image_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
    
    # 收集所有图像文件
    image_files = []
    for ext in image_exts:
        image_files.extend(input_dir.glob(f'*{ext}'))
    
    print(f"找到 {len(image_files)} 张原始图像")
    
    # 复制classes.txt
    classes_file = input_dir / "classes.txt"
    if classes_file.exists():
        shutil.copy(classes_file, output_dir / "classes.txt")
    
    # 数据增强计数器
    augment_counter = 1
    
    # 处理每张图像
    for img_path in image_files:
        # 读取图像
        image = cv2.imread(str(img_path))
        if image is None:
            continue
        
        # 获取对应的标签文件
        label_path = img_path.with_suffix('.txt')
        labels = parse_yolo_label(label_path)
        
        base_name = img_path.stem
        ext = img_path.suffix
        
        # 1. 保存原始图像和标签
        orig_img_path = output_dir / f"{base_name}{ext}"
        orig_label_path = output_dir / f"{base_name}.txt"
        
        cv2.imwrite(str(orig_img_path), image)
        save_yolo_label(orig_label_path, labels)
        
        # 2. 旋转增强 (90, 180, 270度)
        for angle in [90, 180, 270]:
            rotated_img, rotated_labels = rotate_image_and_labels(image.copy(), labels, angle)
            
            # 保存旋转后的图像和标签
            rot_img_path = output_dir / f"{base_name}_rot{angle}_{augment_counter}{ext}"
            rot_label_path = output_dir / f"{base_name}_rot{angle}_{augment_counter}.txt"
            
            cv2.imwrite(str(rot_img_path), rotated_img)
            save_yolo_label(rot_label_path, rotated_labels)
            augment_counter += 1
        
        # 3. 裁剪增强 (4x4网格)
        crops = crop_image_and_labels(image.copy(), labels, grid_size=(4, 4))
        for crop_img, crop_labels, (i, j) in crops:
            # 保存裁剪后的图像和标签
            crop_img_path = output_dir / f"{base_name}_crop{i}_{j}_{augment_counter}{ext}"
            crop_label_path = output_dir / f"{base_name}_crop{i}_{j}_{augment_counter}.txt"
            
            cv2.imwrite(str(crop_img_path), crop_img)
            save_yolo_label(crop_label_path, crop_labels)
            augment_counter += 1
        
        # 4. 颜色增强
        # RGB调整
        rgb_img = adjust_rgb(image.copy(), 
                            r_factor=random.uniform(0.7, 1.3),
                            g_factor=random.uniform(0.7, 1.3),
                            b_factor=random.uniform(0.7, 1.3))
        
        # 对比度调整
        contrast_img = adjust_contrast(image.copy(), factor=random.uniform(0.5, 1.5))
        
        # 灰度调整
        gray_img = adjust_grayscale(image.copy(), grayscale_factor=random.uniform(0.3, 0.8))
        
        # 保存颜色增强后的图像（标签不变）
        color_types = [
            (rgb_img, "rgb"),
            (contrast_img, "contrast"),
            (gray_img, "gray")
        ]
        
        for color_img, color_type in color_types:
            color_img_path = output_dir / f"{base_name}_{color_type}_{augment_counter}{ext}"
            color_label_path = output_dir / f"{base_name}_{color_type}_{augment_counter}.txt"
            
            cv2.imwrite(str(color_img_path), color_img)
            save_yolo_label(color_label_path, labels)
            augment_counter += 1
    
    print(f"数据增强完成! 共生成 {augment_counter-1} 个增强样本")
    print(f"结果保存在: {output_dir}")

if __name__ == "__main__":
    main()