#!/usr/bin/env python3
# pylint: disable=missing-class-docstring,disable=no-self-use
from __future__ import annotations

__package__ = __package__ or "tests.knowledge_plugins.functions"  # pylint:disable=redefined-builtin

import os
import unittest

import angr

from tests.common import bin_location


test_location = os.path.join(bin_location, "tests")


class TestFunction(unittest.TestCase):
    def test_function_serialization(self):
        p = angr.Project(os.path.join(test_location, "x86_64", "fauxware"), auto_load_libs=False)
        cfg = p.analyses.CFG()

        func_main = cfg.kb.functions["main"]
        s = func_main.serialize()

        assert type(s) is bytes
        assert len(s) > 10

        f = angr.knowledge_plugins.Function.parse(s, function_manager=p.kb.functions, project=p)
        assert func_main.addr == f.addr
        assert func_main.name == f.name
        assert func_main.is_prototype_guessed == f.is_prototype_guessed
        assert func_main.prototype == f.prototype
        assert f.prototype is None

    def test_function_serialization_with_prototype(self):
        p = angr.Project(os.path.join(test_location, "x86_64", "fauxware"), auto_load_libs=False)
        cfg = p.analyses.CFG()
        p.analyses.CompleteCallingConventions()

        func_main = cfg.kb.functions["main"]
        assert func_main.calling_convention is not None
        assert func_main.prototype is not None

        s = func_main.serialize()

        assert type(s) is bytes
        assert len(s) > 10

        f = angr.knowledge_plugins.Function.parse(s, function_manager=p.kb.functions, project=p)
        assert func_main.addr == f.addr
        assert func_main.name == f.name
        assert func_main.is_prototype_guessed == f.is_prototype_guessed
        assert func_main.prototype == f.prototype
        assert func_main.calling_convention == f.calling_convention
        assert f.prototype is not None

    def test_function_definition_application(self):
        p = angr.Project(os.path.join(test_location, "x86_64", "fauxware"), auto_load_libs=False)
        cfg = p.analyses.CFG()
        func_main: angr.knowledge_plugins.Function = cfg.kb.functions["main"]

        func_main.apply_definition("int main(int argc, char** argv)")

        # Check prototype of function
        assert func_main.prototype is not None
        assert func_main.prototype.args == (
            angr.types.SimTypeInt().with_arch(p.arch),
            angr.types.SimTypePointer(angr.types.SimTypePointer(angr.types.SimTypeChar()).with_arch(p.arch)).with_arch(
                p.arch
            ),
        )
        # Check that the default calling convention of the architecture was applied
        assert isinstance(func_main.calling_convention, angr.calling_conventions.DefaultCC[p.arch.name]["Linux"])

        func_main.apply_definition("int main(int argc, char** argv)")

    def test_function_instruction_addr_from_any_addr(self):
        p = angr.Project(os.path.join(test_location, "x86_64", "fauxware"), auto_load_libs=False)
        cfg = p.analyses.CFG()

        func_main = cfg.kb.functions["main"]

        assert func_main.addr_to_instruction_addr(0x400739) == 0x400739
        assert func_main.addr_to_instruction_addr(0x40073A) == 0x400739
        assert func_main.addr_to_instruction_addr(0x40073D) == 0x400739
        assert func_main.addr_to_instruction_addr(0x400742) == 0x400742
        assert func_main.addr_to_instruction_addr(0x400743) == 0x400742

    def test_function_instruction_size(self):
        p = angr.Project(os.path.join(test_location, "x86_64", "fauxware"), auto_load_libs=False)
        cfg = p.analyses.CFG()

        func_main = cfg.kb.functions["main"]

        assert func_main.instruction_size(0x40071D) == 1
        assert func_main.instruction_size(0x40071E) == 3
        assert func_main.instruction_size(0x400721) == 4
        assert func_main.instruction_size(0x400725) == 3
        assert func_main.instruction_size(0x400728) == 4
        assert func_main.instruction_size(0x400739) == 5
        assert func_main.instruction_size(0x400742) == 5


if __name__ == "__main__":
    unittest.main()
