/*++
Copyright (c) 2012 Microsoft Corporation

Module Name:

    dl_mk_unfold.cpp

Abstract:

    Unfold rules once, return the unfolded set of rules.

Author:

    Nikolaj Bjorner (nbjorner) 2012-10-15

Revision History:

--*/
#include "dl_mk_unfold.h"

namespace datalog {

    mk_unfold::mk_unfold(context& ctx):
        rule_transformer::plugin(100, false),
        m_ctx(ctx),
        m(ctx.get_manager()),
        rm(ctx.get_rule_manager()),
        m_unify(ctx)
    {}

    void mk_unfold::expand_tail(rule& r, unsigned tail_idx, rule_set const& src, rule_set& dst) {
        SASSERT(tail_idx <= r.get_uninterpreted_tail_size());
        if (tail_idx == r.get_uninterpreted_tail_size()) {
            dst.add_rule(&r);
        }
        else {
            func_decl* p = r.get_decl(tail_idx);
            rule_vector const& p_rules = src.get_predicate_rules(p);
            rule_ref new_rule(rm);
            for (unsigned i = 0; i < p_rules.size(); ++i) {
                rule const& r2 = *p_rules[i];
                if (m_unify.unify_rules(r, tail_idx, r2) &&
                    m_unify.apply(r, tail_idx, r2, new_rule)) {
                    expr_ref_vector s1 = m_unify.get_rule_subst(r, true);
                    expr_ref_vector s2 = m_unify.get_rule_subst(r2, false);
                    resolve_rule(m_pc, r, r2, tail_idx, s1, s2, *new_rule.get());
                    expand_tail(*new_rule.get(), tail_idx+r2.get_uninterpreted_tail_size(), src, dst);
                }
            }
        }
    }
        
    rule_set * mk_unfold::operator()(rule_set const & source, model_converter_ref& mc, proof_converter_ref& pc) {
        m_pc = 0;
        ref<replace_proof_converter> rpc;
        if (pc) {
            rpc = alloc(replace_proof_converter, m);
            m_pc = rpc.get();
        }
        rule_set* rules = alloc(rule_set, m_ctx);
        rule_set::iterator it = source.begin(), end = source.end();
        for (; it != end; ++it) {
            expand_tail(**it, 0, source, *rules);
        }
        if (pc) {
            pc = concat(pc.get(), rpc.get());
        }
        return rules;
    }

};

