//** 引理划分器模块 - 实现文件
//** 提供多种引理划分策略的具体实现

#include "lemma_dividers.h"

//** 轮询分配划分器实现
std::vector<std::vector<Lemma>> RoundRobinDivider::divide(const std::vector<Lemma>& lemmas, int num_sets) {
    if (num_sets <= 0) num_sets = 1;
    std::vector<std::vector<Lemma>> sets(num_sets);
    if (lemmas.empty()) return sets;

    
    for (size_t i = 0; i < lemmas.size(); ++i) {
        sets[i % num_sets].push_back(lemmas[i]);
    }
    
    
    return sets;
}

std::string RoundRobinDivider::getName() const { 
    return "round-robin"; 
}

std::string RoundRobinDivider::getDescription() const { 
    return "轮询分配 - 依次循环分配引理到各个集合"; 
}

//** 年龄分层划分器实现
std::vector<std::vector<Lemma>> AgeBasedDivider::divide(const std::vector<Lemma>& lemmas, int num_sets) {
    if (num_sets <= 0) num_sets = 1;
    std::vector<std::vector<Lemma>> sets(num_sets);
    if (lemmas.empty()) return sets;

    
    // 首先按子句序号排序（年龄从小到大）
    std::vector<Lemma> sorted_lemmas = lemmas;
    std::sort(sorted_lemmas.begin(), sorted_lemmas.end(), 
              [](const Lemma& a, const Lemma& b) {
                  // 如果clause_number相同，按id排序保证稳定性
                  if (a.clause_number == b.clause_number) {
                      return a.id < b.id;
                  }
                  // 将无效的clause_number(-1)排到最后
                  if (a.clause_number == -1) return false;
                  if (b.clause_number == -1) return true;
                  return a.clause_number < b.clause_number;
              });
    
    // 统计年龄信息
    int min_age = std::numeric_limits<int>::max();
    int max_age = std::numeric_limits<int>::min();
    int valid_age_count = 0;
    
    for (const auto& lemma : sorted_lemmas) {
        if (lemma.clause_number > 0) {
            min_age = std::min(min_age, lemma.clause_number);
            max_age = std::max(max_age, lemma.clause_number);
            valid_age_count++;
        }
    }
    
    // 计算每个年龄段的大小
    size_t lemmas_per_set = sorted_lemmas.size() / num_sets;
    size_t remainder = sorted_lemmas.size() % num_sets;
    
    // 分配引理到各个年龄段
    size_t current_index = 0;
    for (int i = 0; i < num_sets && current_index < sorted_lemmas.size(); ++i) {
        size_t current_set_size = lemmas_per_set + (i < static_cast<int>(remainder) ? 1 : 0);
        
        // 记录当前集合的年龄范围
        int set_min_age = -1, set_max_age = -1;
        
        for (size_t j = 0; j < current_set_size && current_index < sorted_lemmas.size(); ++j) {
            const Lemma& lemma = sorted_lemmas[current_index];
            sets[i].push_back(lemma);
            
            // 更新年龄范围统计
            if (lemma.clause_number > 0) {
                if (set_min_age == -1) set_min_age = lemma.clause_number;
                set_max_age = lemma.clause_number;
            }
            
            current_index++;
        }
        
        // 输出年龄段信息
        std::string age_range = "未知";
        if (set_min_age > 0 && set_max_age > 0) {
            if (set_min_age == set_max_age) {
                age_range = std::to_string(set_min_age);
            } else {
                age_range = std::to_string(set_min_age) + "-" + std::to_string(set_max_age);
            }
        }
        
    }
    
    return sets;
}

std::string AgeBasedDivider::getName() const { 
    return "age-based"; 
}

std::string AgeBasedDivider::getDescription() const { 
    return "年龄分层 - 根据子句序号将引理分为不同年龄段"; 
}



//** 创建指定策略的引理划分器
std::unique_ptr<LemmaDivider> createLemmaDivider(const std::string& strategy_name) {
    if (strategy_name == "round-robin") {
        return make_unique<RoundRobinDivider>();
    } 
    else if (strategy_name == "age-based") {
        return make_unique<AgeBasedDivider>();
    }
    else if (strategy_name == "depth-based") {
        return make_unique<DepthBasedDivider>();
    }
    else if (strategy_name == "depth-balanced") {
        return make_unique<DepthBalancedDivider>();
    }
    else {
        return make_unique<RoundRobinDivider>();
    }
}

//** 推导深度分层划分器实现
std::vector<std::vector<Lemma>> DepthBasedDivider::divide(const std::vector<Lemma>& lemmas, int num_sets) {
    return divide(lemmas, num_sets, false); 
}

std::vector<std::vector<Lemma>> DepthBasedDivider::divide(const std::vector<Lemma>& lemmas, int num_sets, bool unit_clauses_only) {
    // 检查提前停止标志
    if (proof_found.load()) {
        std::vector<std::vector<Lemma>> empty_sets(num_sets);
        return empty_sets;
    }
    
    if (num_sets <= 0) num_sets = 1;
    std::vector<std::vector<Lemma>> sets(num_sets);
    if (lemmas.empty()) return sets;

    
    // 复制引理列表
    std::vector<Lemma> lemmas_with_depths = lemmas;
    
    // 计算推导深度（只在depth-based策略中计算）
    calculateDerivationDepths(lemmas_with_depths);
    
    // 深度计算完成后，对CSE引理进行清理
    for (auto& lemma : lemmas_with_depths) {
        if (lemma.derivation_type == "cse") {
            // 清理CSE引理的证明器特定信息，但保留已计算的深度
            lemma.avatar_id = "";
            // 简化inference_rule
            if (!lemma.inference_rule.empty()) {
                if (lemma.inference_rule.find("resolution") != std::string::npos) {
                    lemma.inference_rule = "resolution";
                } else if (lemma.inference_rule.find("superposition") != std::string::npos) {
                    lemma.inference_rule = "superposition";
                } else {
                    lemma.inference_rule = "other";
                }
            }
        }
    }
    
    // 使用专门的过滤函数过滤引理
    std::vector<Lemma> sorted_lemmas = filter_lemmas_for_depth_based(lemmas_with_depths, unit_clauses_only);
    
    if (sorted_lemmas.empty()) {
        return sets;
    }
    
    // 按推导深度排序（深度从小到大）
    std::sort(sorted_lemmas.begin(), sorted_lemmas.end(), 
              [](const Lemma& a, const Lemma& b) {
                  // 如果深度相同，按子句编号排序保证稳定性
                  if (a.derivation_depth == b.derivation_depth) {
                      return a.clause_number < b.clause_number;
                  }
                  // 将未计算深度的引理(-1)排到最后
                  if (a.derivation_depth == -1) return false;
                  if (b.derivation_depth == -1) return true;
                  return a.derivation_depth < b.derivation_depth;
              });
    
    // 统计深度信息
    std::map<int, int> depth_stats;
    int valid_depth_count = 0;
    int min_depth = std::numeric_limits<int>::max();
    int max_depth = std::numeric_limits<int>::min();
    
    for (const auto& lemma : sorted_lemmas) {
        if (lemma.derivation_depth >= 0) {
            depth_stats[lemma.derivation_depth]++;
            valid_depth_count++;
            min_depth = std::min(min_depth, lemma.derivation_depth);
            max_depth = std::max(max_depth, lemma.derivation_depth);
        }
    }
    
    // 显示深度分布
    
    // depth-based分配策略：均匀数量分配 + 深度排序
    // 每个集合分配相同数量的引理，按深度排序依次分配
    
    size_t total_lemmas = sorted_lemmas.size();
    size_t lemmas_per_set = total_lemmas / num_sets;
    size_t remainder = total_lemmas % num_sets;
    
    
    // 依次分配已排序的引理到各个集合
    size_t current_index = 0;
    for (int i = 0; i < num_sets && current_index < total_lemmas; ++i) {
        // 计算当前集合应该分配的引理数
        size_t current_set_size = lemmas_per_set + (i < static_cast<int>(remainder) ? 1 : 0);
        
        // 分配引理到当前集合
        for (size_t j = 0; j < current_set_size && current_index < total_lemmas; ++j) {
            sets[i].push_back(sorted_lemmas[current_index]);
            current_index++;
        }
        
        // 显示当前集合的深度范围
        if (!sets[i].empty()) {
            int min_depth = std::numeric_limits<int>::max();
            int max_depth = std::numeric_limits<int>::min();
            for (const auto& lemma : sets[i]) {
                if (lemma.derivation_depth >= 0) {
                    min_depth = std::min(min_depth, lemma.derivation_depth);
                    max_depth = std::max(max_depth, lemma.derivation_depth);
                }
            }
            if (min_depth != std::numeric_limits<int>::max()) {
            }
        }
    }
    
    // 输出分配结果统计
    
    return sets;
}

std::string DepthBasedDivider::getName() const { 
    return "depth-based"; 
}

std::string DepthBasedDivider::getDescription() const { 
    return "推导深度分层 - 根据引理的推导深度（从input子句开始的推理步数）分层"; 
}

//** 深度均衡划分器实现
std::vector<std::vector<Lemma>> DepthBalancedDivider::divide(const std::vector<Lemma>& lemmas, int num_sets) {
    return divide(lemmas, num_sets, false); // 默认不使用单元子句过滤
}

std::vector<std::vector<Lemma>> DepthBalancedDivider::divide(const std::vector<Lemma>& lemmas, int num_sets, bool unit_clauses_only) {
    // 检查提前停止标志
    if (proof_found.load()) {
        std::vector<std::vector<Lemma>> empty_sets(num_sets);
        return empty_sets;
    }
    
    if (num_sets <= 0) num_sets = 1;
    std::vector<std::vector<Lemma>> sets(num_sets);
    if (lemmas.empty()) return sets;

    
    // 复制引理列表，因为需要修改derivation_depth字段
    std::vector<Lemma> lemmas_with_depths = lemmas;
    
    // 计算推导深度（只在depth-based策略中计算）
    calculateDerivationDepths(lemmas_with_depths);
    
    // 深度计算完成后，对CSE引理进行清理
    for (auto& lemma : lemmas_with_depths) {
        if (lemma.derivation_type == "cse") {
            // 清理CSE引理的证明器特定信息，但保留已计算的深度
            lemma.avatar_id = "";
            // 简化inference_rule
            if (!lemma.inference_rule.empty()) {
                if (lemma.inference_rule.find("resolution") != std::string::npos) {
                    lemma.inference_rule = "resolution";
                } else if (lemma.inference_rule.find("superposition") != std::string::npos) {
                    lemma.inference_rule = "superposition";
                } else {
                    lemma.inference_rule = "other";
                }
            }
        }
    }
    
    // 使用专门的过滤函数过滤引理
    std::vector<Lemma> filtered_lemmas = filter_lemmas_for_depth_based(lemmas_with_depths, unit_clauses_only);
    
    if (filtered_lemmas.empty()) {
        return sets;
    }
    
    // 按推导深度分组
    std::map<int, std::vector<Lemma>> depth_groups;
    for (const auto& lemma : filtered_lemmas) {
        depth_groups[lemma.derivation_depth].push_back(lemma);
    }
    
    // 统计深度信息
    int min_depth = std::numeric_limits<int>::max();
    int max_depth = std::numeric_limits<int>::min();
    int valid_depth_count = 0;
    
    for (const auto& pair : depth_groups) {
        int depth = pair.first;
        int count = pair.second.size();
        
        if (depth >= 0) {
            min_depth = std::min(min_depth, depth);
            max_depth = std::max(max_depth, depth);
            valid_depth_count += count;
        }
        
    }
    
    
    // 深度均衡分配策略：对每个深度层，使用轮询方式分配到各个集合
    std::vector<int> set_counters(num_sets, 0); // 记录每个集合的当前大小
    
    for (const auto& depth_pair : depth_groups) {
        const std::vector<Lemma>& depth_lemmas = depth_pair.second;
        
        
        // 在当前深度层内使用轮询分配
        for (size_t i = 0; i < depth_lemmas.size(); ++i) {
            int target_set = i % num_sets;
            sets[target_set].push_back(depth_lemmas[i]);
            set_counters[target_set]++;
        }
    }
    
    // 输出分配结果统计
    
    // 分析深度覆盖情况
    
    return sets;
}

std::string DepthBalancedDivider::getName() const { 
    return "depth-balanced"; 
}

std::string DepthBalancedDivider::getDescription() const { 
    return "深度均衡 - 确保每个集合都包含各种推导深度的引理，深度分布均匀"; 
}



//** 显示所有可用的划分策略
void showAvailableDivisionStrategies() {
    // 清空显示函数，不输出策略信息
}

//** 轮询分配策略
std::vector<std::vector<Lemma>> divide_lemmas_round_robin(const std::vector<Lemma>& selected_lemmas, int k_sets) {
    RoundRobinDivider divider;
    return divider.divide(selected_lemmas, k_sets);
} 