/*

Copyright © 2023-25 Sean Holden. All rights reserved.

*/
/*

This file is part of Connect++.

Connect++ is free software: you can redistribute it and/or modify it 
under the terms of the GNU General Public License as published by the 
Free Software Foundation, either version 3 of the License, or (at your 
option) any later version.

Connect++ is distributed in the hope that it will be useful, but WITHOUT 
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for 
more details.

You should have received a copy of the GNU General Public License along 
with Connect++. If not, see <https://www.gnu.org/licenses/>. 

*/

#include "Unifier.hpp"

UnificationOutcome Unifier::operator()(Term* term1, Term* term2) {
    if (params::poly_unification) {
        to_do.clear();
        to_do.push_back(UPair(term1, term2));
        return complete_unification();
    }
    else {
        return unify_terms(term1, term2);
    }
}
//----------------------------------------------------------------------
UnificationOutcome Unifier::unify(Term* t1, Term* t2) {
    UnificationOutcome outcome = unify_terms(t1, t2);
    if (outcome != UnificationOutcome::Succeed) {
        backtrack();
    }
    return outcome;
}
//----------------------------------------------------------------------
// Figure 1, page 454, Handbook of Automated Reasoning Volume 1.
//----------------------------------------------------------------------
UnificationOutcome Unifier::unify_terms(Term* t1, Term* t2) {
    bool t1_is_var;
    Variable* t1v;
    Function* t1f;
    size_t t1a;
    t1 = t1->skip_leading_variables_for_unification(t1_is_var, t1v, t1f, t1a);
    
    bool t2_is_var;
    Variable* t2v;
    Function* t2f;
    size_t t2a;
    t2 = t2->skip_leading_variables_for_unification(t2_is_var, t2v, t2f, t2a);

    if (t1_is_var && t1v == t2v) {}
    else if (!t1_is_var && !t2_is_var) {
        if (t1f == t2f && t1a == t2a) {
            for (size_t i = 0; i < t1a; i++) {
                UnificationOutcome outcome = unify_terms((*t1)[i], (*t2)[i]);
                if (outcome != UnificationOutcome::Succeed) {
                    return outcome;
                }
            }
        }
        else {
            return UnificationOutcome::ConflictFail;
        }
    }
    else if (!t1_is_var) {
        return unify_terms(t2, t1);
    }
    else if (t2->contains_variable(t1v)) {
        return UnificationOutcome::OccursFail;
    }
    else {
        s.push_back(t1v, t2);
        t1v->substitute(t2);
    }
    return UnificationOutcome::Succeed;
}
//----------------------------------------------------------------------
UnificationOutcome Unifier::operator()(const vector<Term*>& t1s,
                                       const vector<Term*>& t2s) {
    if (t1s.size() != t2s.size())
        return UnificationOutcome::ConflictFail;
    if (params::poly_unification) {
        to_do.clear();
        auto i = t2s.begin();
        for (Term* term1 : t1s) {
            to_do.push_back(UPair(term1, *i));
            i++;
        }    
        return complete_unification();
    }
    else {
        auto i = t2s.begin();
        for (Term* term1 : t1s) {
            UnificationOutcome outcome = unify_terms(term1, *i);
            if (outcome != UnificationOutcome::Succeed) {
                backtrack();
                return outcome;
            }
            i++;
        }
        return UnificationOutcome::Succeed;
    }
}
//----------------------------------------------------------------------
UnificationOutcome Unifier::operator()(Literal* l1, Literal* l2) {
    if (!l1->is_compatible_with(l2)) {
        cerr << "ERROR" << ": ";
        cerr << "You're trying to unify non-compatible literals." << endl;
        cerr << "ERROR" << ": " << *l1 << endl;
        cerr << "ERROR" << ": " << *l2 << endl;
    }
    const vector<Term*>& args1 = l1->get_args();
    const vector<Term*>& args2 = l2->get_args();
    return operator()(args1, args2);
}
//----------------------------------------------------------------------
UnificationOutcome Unifier::complete_unification() {
    s.clear();
    while (to_do.size() > 0) {
        // Get the next thing from the queue and split it up.
        UPair upair(to_do.back());
        Term* t1(upair.first);
        Term* t2(upair.second);
        to_do.pop_back();

        bool t1_is_var;
        Variable* t1v;
        Function* t1f;
        size_t t1a;
        t1 = t1->skip_leading_variables_for_unification(t1_is_var, t1v, t1f, t1a);

        bool t2_is_var;
        Variable* t2v;
        Function* t2f;
        size_t t2a;
        t2 = t2->skip_leading_variables_for_unification(t2_is_var, t2v, t2f, t2a);

        // Swap
        if (!t1_is_var && t2_is_var) {
            to_do.push_back(UPair(t2, t1));
            continue;
        }

        // Delete
        if (t1->subbed_equal(t2)) {
            continue;
        }

        // Decompose/Conflict
        if (!t1_is_var && !t2_is_var) {
            // Conflict
            if ((t1f != t2f) ||
                (t1a != t2a)) {
                backtrack();
                return UnificationOutcome::ConflictFail;
            }
            // Decompose
            else {
                size_t n_args = t1->arity();
                for (size_t i = 0; i < n_args; i++) {
                    to_do.push_back(UPair((*t1)[i], (*t2)[i]));
                }
                continue;
            }
        }

        bool contains = t2->contains_variable(t1v);
        
        // Eliminate
        if (t1_is_var && !contains) {
            s.push_back(t1v, t2);
            t1v->substitute(t2);
            continue;
        }

        // Occurs
        if (t1_is_var && !t2_is_var && contains) {
            backtrack();
            return UnificationOutcome::OccursFail;
        }
    }
    return UnificationOutcome::Succeed;
}
//----------------------------------------------------------------------
void Unifier::backtrack() {
    s.backtrack();
    s.clear();
}
//----------------------------------------------------------------------
ostream& operator<<(ostream& out, const UnificationOutcome& o) {
    switch(o) {
        case UnificationOutcome::Succeed:
            out << "Unification: success.";
            break;
        case UnificationOutcome::ConflictFail:
            out << "Unification: failed due to conflict.";
            break;
        case UnificationOutcome::OccursFail:
            out << "Unification: failed due to occurs check.";
            break;
        default:

            break;
    }
    return out;
}
//----------------------------------------------------------------------
ostream& operator<<(ostream& out, const Unifier& u) {
    out << u.to_string();
    return out;
}
