#!/usr/bin/env python3
# pylint:disable=missing-class-docstring,no-self-use,wrong-import-order
from __future__ import annotations

__package__ = __package__ or "tests.analyses.cfg"  # pylint:disable=redefined-builtin

import os
import logging
import unittest

import archinfo
import angr
from angr.knowledge_plugins.cfg import CFGNode, CFGModel, MemoryDataSort
from angr.analyses.cfg.indirect_jump_resolvers import mips_elf_fast

from tests.common import bin_location, slow_test

l = logging.getLogger("angr.tests.test_cfgfast")

test_location = os.path.join(bin_location, "tests")


def cstring_to_unicode_string(cstr: bytes) -> bytes:
    return b"".join((bytes([ch]) + b"\x00") for ch in cstr)


class TestCfgfast(unittest.TestCase):
    def cfg_fast_functions_check(self, arch, binary_path, func_addrs, func_features):
        """
        Generate a fast CFG on the given binary, and test if all specified functions are found

        :param str arch: the architecture, will be prepended to `binary_path`
        :param str binary_path: path to the binary under the architecture directory
        :param dict func_addrs: A collection of function addresses that should be recovered
        :param dict func_features: A collection of features for some of the functions
        :return: None
        """

        path = os.path.join(test_location, arch, binary_path)
        proj = angr.Project(path, load_options={"auto_load_libs": False})

        cfg = proj.analyses.CFGFast()
        assert set(cfg.kb.functions.keys()).issuperset(func_addrs)

        for func_addr, feature_dict in func_features.items():
            returning = feature_dict.get("returning", "undefined")
            if returning != "undefined":
                assert cfg.kb.functions.function(addr=func_addr).returning is returning

        # Segment only
        cfg = proj.analyses.CFGFast(force_segment=True)
        assert set(cfg.kb.functions.keys()).issuperset(func_addrs)

        for func_addr, feature_dict in func_features.items():
            returning = feature_dict.get("returning", "undefined")
            if returning != "undefined":
                assert cfg.kb.functions.function(addr=func_addr).returning is returning

        # with normalization enabled
        cfg = proj.analyses.CFGFast(force_segment=True, normalize=True)
        assert set(cfg.kb.functions.keys()).issuperset(func_addrs)

        for func_addr, feature_dict in func_features.items():
            returning = feature_dict.get("returning", "undefined")
            if returning != "undefined":
                assert cfg.kb.functions.function(addr=func_addr).returning is returning

    def cfg_fast_edges_check(self, arch, binary_path, edges):
        """
        Generate a fast CFG on the given binary, and test if all edges are found.

        :param str arch: the architecture, will be prepended to `binary_path`
        :param str binary_path: path to the binary under the architecture directory
        :param list edges: a list of edges
        :return: None
        """

        path = os.path.join(test_location, arch, binary_path)
        proj = angr.Project(path, load_options={"auto_load_libs": False})

        cfg = proj.analyses.CFGFast()

        for src, dst in edges:
            src_node = cfg.model.get_any_node(src)
            dst_node = cfg.model.get_any_node(dst)
            assert src_node is not None, f"CFG node 0x{src:x} is not found."
            assert dst_node is not None, f"CFG node 0x{dst:x} is not found."
            assert dst_node in src_node.successors, f"CFG edge {src_node}-{dst_node} is not found."

    def test_cfg_0(self):
        functions = {
            0x400410,
            0x400420,
            0x400430,
            0x400440,
            0x400470,
            0x40052C,
            0x40053C,
        }

        function_features = {}

        self.cfg_fast_functions_check("x86_64", "cfg_0", functions, function_features)

    def test_cfg_0_pe(self):
        functions = {
            # 0x40150a,  # currently angr identifies 0x40150e due to the way _func_addrs_from_prologues() is
            # implemented. this issue can be resolved with a properly implemented approach like Byte-Weight
            0x4014F0,
        }

        function_features = {}

        self.cfg_fast_functions_check("x86_64", "cfg_0_pe", functions, function_features)

    def test_arm_function_merge(self):
        # function 0x7bb88 is created due to a data hint in another block. this function should be merged with the
        # previous function 0x7ba84

        path = os.path.join(test_location, "armel", "tenda-httpd")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()

        node_7bb88 = cfg.model.get_any_node(0x7BB88)
        assert node_7bb88 is not None
        assert node_7bb88.function_address == 0x7BA84

    @slow_test
    def test_busybox(self):
        edges = {
            (0x4091EC, 0x408DE0),
            # call to putenv. address of putenv may change in the future
            (
                0x449ACC,
                0x5003B8,
            ),
            # call to free. address of free may change in the future
            (
                0x467CFC,
                0x500014,
            ),
        }

        self.cfg_fast_edges_check("mipsel", "busybox", edges)

    @slow_test
    @unittest.skipUnless(
        os.path.isfile("C:\\Windows\\System32\\ntoskrnl.exe"),
        "ntoskrnl.exe does not exist on this system.",
    )
    def test_ntoskrnl(self):
        # we cannot distribute ntoskrnl.exe. as a result, this test case is manual
        path = "C:\\Windows\\System32\\ntoskrnl.exe"
        proj = angr.Project(path, auto_load_libs=False)
        _ = proj.analyses.CFG(data_references=True, normalize=True, show_progressbar=True)

        # nothing should prevent us from finish creating the CFG

    def test_fauxware_function_feauters_x86_64(self):
        functions = {
            0x4004E0,
            0x400510,
            0x400520,
            0x400530,
            0x400540,
            0x400550,
            0x400560,
            0x400570,  # .plt._exit
            0x400580,  # _start
            0x4005AC,
            0x4005D0,
            0x400640,
            0x400664,
            0x4006ED,
            0x4006FD,
            0x40071D,  # main
            0x4007E0,
            0x400870,
            0x400880,
            0x4008B8,
        }

        function_features = {
            0x400570: {"returning": False},  # plt.exit
            0x4006FD: {"returning": False},  # rejected
        }

        return_edges = {
            (0x4006FB, 0x4007C7),
        }  # return from accepted to main

        self.cfg_fast_functions_check("x86_64", "fauxware", functions, function_features)
        self.cfg_fast_edges_check("x86_64", "fauxware", return_edges)

    def test_fauxware_function_features_mips(self):
        functions = {
            0x400534,  # _init
            0x400574,
            0x400598,
            0x4005D0,  # _ftext
            0x4005DC,
            0x400630,  # __do_global_dtors_aux
            0x4006D4,  # frame_dummy
            0x400708,
            0x400710,  # authenticate
            0x400814,  # accepted
            0x400868,  # rejected
            0x4008C0,  # main
            0x400A34,
            0x400A48,  # __libc_csu_init
            0x400AF8,
            0x400B00,  # __do_global_ctors_aux
            0x400B58,
            ### plt entries
            0x400B60,  # strcmp
            0x400B70,  # read
            0x400B80,  # printf
            0x400B90,  # puts
            0x400BA0,  # exit
            0x400BB0,  # open
            0x400BC0,  # __libc_start_main
        }

        function_features = {
            0x400868: {  # rejected
                "returning": False,
            }
        }

        return_edges = {
            (0x40084C, 0x400A04),
        }  # returning edge from accepted to main

        self.cfg_fast_functions_check("mips", "fauxware", functions, function_features)
        self.cfg_fast_edges_check("mips", "fauxware", return_edges)

    def test_mips_elf_fast_indirect_jump_resolver(self):
        bin_path = os.path.join(test_location, "mips", "fauxware")
        proj = angr.Project(bin_path, auto_load_libs=False)
        # enable profiling for MipsElfFast
        # FIXME: The result might be different if other test cases that run in parallel mess with the profiling setting
        mips_elf_fast.enable_profiling()
        _ = proj.analyses.CFG()
        mips_elf_fast.disable_profiling()
        assert mips_elf_fast.HITS_CASE_1 >= 10

    def test_cfg_loop_unrolling(self):
        edges = {
            (0x400658, 0x400636),
            (0x400658, 0x400661),
            (0x400651, 0x400636),
            (0x400651, 0x400661),
        }

        self.cfg_fast_edges_check("x86_64", "cfg_loop_unrolling", edges)

    def test_cfg_switches_x86_64(self):
        edges = {
            # jump table 0 in func_0
            (0x40053A, 0x400547),
            (0x40053A, 0x400552),
            (0x40053A, 0x40055D),
            (0x40053A, 0x400568),
            (0x40053A, 0x400573),
            (0x40053A, 0x400580),
            (0x40053A, 0x40058D),
            # jump table 0 in func_1
            (0x4005BC, 0x4005C9),
            (0x4005BC, 0x4005D8),
            (0x4005BC, 0x4005E7),
            (0x4005BC, 0x4005F6),
            (0x4005BC, 0x400605),
            (0x4005BC, 0x400614),
            (0x4005BC, 0x400623),
            (0x4005BC, 0x400632),
            (0x4005BC, 0x40063E),
            (0x4005BC, 0x40064A),
            (0x4005BC, 0x4006B0),
            # jump table 1 in func_1
            (0x40065A, 0x400667),
            (0x40065A, 0x400673),
            (0x40065A, 0x40067F),
            (0x40065A, 0x40068B),
            (0x40065A, 0x400697),
            (0x40065A, 0x4006A3),
            # jump table 0 in main
            (0x4006E1, 0x4006EE),
            (0x4006E1, 0x4006FA),
            (0x4006E1, 0x40070B),
            (0x4006E1, 0x40071C),
            (0x4006E1, 0x40072D),
            (0x4006E1, 0x40073E),
            (0x4006E1, 0x40074F),
            (0x4006E1, 0x40075B),
        }

        self.cfg_fast_edges_check("x86_64", "cfg_switches", edges)

    def test_cfg_switches_armel(self):
        edges = {
            # jump table 0 in func_0
            (0x10434, 0x10488),
            (0x10434, 0x104E8),
            (0x10434, 0x10498),
            (0x10434, 0x104A8),
            (0x10434, 0x104B8),
            (0x10434, 0x104C8),
            (0x10434, 0x104D8),
            (0x10454, 0x104E8),  # default case
            # jump table 0 in func_1
            (0x10524, 0x105CC),
            (0x10524, 0x106B4),
            (0x10524, 0x105D8),
            (0x10524, 0x105E4),
            (0x10524, 0x105F0),
            (0x10524, 0x105FC),
            (0x10524, 0x10608),
            (0x10524, 0x10614),
            (0x10524, 0x10620),
            (0x10524, 0x1062C),
            (0x10524, 0x10638),
            (0x10534, 0x106B4),  # default case
            # jump table 1 in func_1
            (0x10650, 0x106A4),  # default case
            (0x10640, 0x10668),
            (0x10640, 0x10674),
            (0x10640, 0x10680),
            (0x10640, 0x1068C),
            (0x10640, 0x10698),
            # jump table 0 in main
            (0x10734, 0x107FC),
            (0x10734, 0x10808),
            (0x10734, 0x10818),
            (0x10734, 0x10828),
            (0x10734, 0x10838),
            (0x10734, 0x10848),
            (0x10734, 0x10858),
            (0x10734, 0x10864),
            (0x10744, 0x10864),  # default case
        }

        self.cfg_fast_edges_check("armel", "cfg_switches", edges)

    def test_cfg_switches_s390x(self):
        edges = {
            # jump table 0 in func_0
            (0x4007D4, 0x4007EA),  # case 1
            (0x4007D4, 0x4007F4),  # case 3
            (0x4007D4, 0x4007FE),  # case 5
            (0x4007D4, 0x400808),  # case 7
            (0x4007D4, 0x400812),  # case 9
            (0x4007D4, 0x40081C),  # case 12
            (0x4007C0, 0x4007CA),  # default case
            # jump table 0 in func_1
            (0x400872, 0x4008AE),  # case 2
            (0x400872, 0x4008BE),  # case 10
            (0x400872, 0x4008CE),  # case 12
            (0x400872, 0x4008DE),  # case 14
            (0x400872, 0x4008EE),  # case 15
            (0x400872, 0x4008FE),  # case 16
            (0x400872, 0x40090E),  # case 22
            (0x400872, 0x40091E),  # case 24
            (0x400872, 0x40092E),  # case 28
            (0x400872, 0x400888),  # case 38
            (0x400848, 0x400854),  # default case (1)
            (0x400872, 0x400854),  # default case (2)
            # jump table 1 in func_1
            (0x40093E, 0x400984),  # case 1
            (0x40093E, 0x400974),  # case 2
            (0x40093E, 0x400964),  # case 3
            (0x40093E, 0x400954),  # case 4
            (0x40093E, 0x400994),  # case 5
            (0x400898, 0x40089E),  # default case (1)
            # jump table 0 in main
            # case 1, 3, 5, 7, 9: optimized out
            (0x400638, 0x40064E),  # case 2
            (0x400638, 0x400692),  # case 4
            (0x400638, 0x4006A4),  # case 6
            (0x400638, 0x40066E),  # case 8
            (0x400638, 0x400680),  # case 10
            # case 45: optimized out
            (0x40062C, 0x40065C),  # default case
        }

        self.cfg_fast_edges_check("s390x", "cfg_switches", edges)

    def test_cfg_about_time(self):
        # This is to test the correctness of the PLT stub removal in CFGBase
        proj = angr.Project(os.path.join(test_location, "x86_64", "about_time"), auto_load_libs=False)
        cfg = proj.analyses.CFG()

        # a PLT stub that should be removed
        assert 0x401026 not in cfg.kb.functions
        # a PLT stub that should be removed
        assert 0x4010A6 not in cfg.kb.functions
        # a PLT stub that should be removed
        assert 0x40115E not in cfg.kb.functions
        # the start function that should not be removed
        assert proj.entry in cfg.kb.functions

    def test_cfg_function_stubs_with_single_jumpouts(self):
        proj = angr.Project(os.path.join(test_location, "x86_64", "printenv-rust-stripped"), auto_load_libs=False)
        cfg = proj.analyses.CFG()

        # the function at 0x4864f0 is a function stub that jumps directly to function at 0x486500. ensure that CFGFast
        # discovers both functions correctly instead of merging them together
        assert cfg.kb.functions.contains_addr(0x4864F0)
        assert cfg.kb.functions.contains_addr(0x486500)
        func_jump_stub = cfg.kb.functions.get_by_addr(0x4864F0)
        assert len(func_jump_stub.block_addrs_set) == 1
        assert len(func_jump_stub.jumpout_sites) == 1

    #
    # Serialization
    #

    def test_serialization_cfgnode(self):
        path = os.path.join(test_location, "x86_64", "fauxware")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()
        # the first node
        node = cfg.model.get_any_node(proj.entry)
        assert node is not None

        b = node.serialize()
        assert len(b) > 0
        new_node = CFGNode.parse(b)
        assert new_node.addr == node.addr
        assert new_node.size == node.size
        assert new_node.block_id == node.block_id

    def test_serialization_cfgfast(self):
        path = os.path.join(test_location, "x86_64", "fauxware")
        proj1 = angr.Project(path, auto_load_libs=False)
        proj2 = angr.Project(path, auto_load_libs=False)

        cfg = proj1.analyses.CFGFast()
        # parse the entire graph
        b = cfg.model.serialize()
        assert len(b) > 0

        # simulate importing a cfg from another tool
        cfg_model = CFGModel.parse(b, cfg_manager=proj2.kb.cfgs)

        assert len(cfg_model.graph.nodes) == len(cfg.graph.nodes)
        assert len(cfg_model.graph.edges) == len(cfg.graph.edges)

        n1 = cfg.model.get_any_node(proj1.entry)
        n2 = cfg_model.get_any_node(proj1.entry)
        assert n1 == n2

    #
    # CFG instance copy
    #

    def test_cfg_copy(self):
        path = os.path.join(test_location, "cgc", "CADET_00002")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()
        cfg_copy = cfg.copy()
        for attribute in cfg_copy.__dict__:
            if attribute in ["_graph", "_seg_list", "_model"]:
                continue
            assert getattr(cfg, attribute) == getattr(cfg_copy, attribute)

        assert id(cfg.model) != id(cfg_copy.model)
        assert id(cfg.model.graph) != id(cfg_copy.model.graph)
        assert id(cfg._seg_list) != id(cfg_copy._seg_list)

    #
    # Alignment bytes
    #

    def test_cfg_0_pe_msvc_debug_nocc(self):
        filename = os.path.join("windows", "msvc_cfg_0_debug.exe")
        proj = angr.Project(os.path.join(test_location, "x86_64", filename), auto_load_libs=False)
        cfg = proj.analyses.CFGFast()

        # make sure 0x140015683 is marked as alignments
        sort = cfg._seg_list.occupied_by_sort(0x140016583)
        assert sort == "alignment"

        assert 0x140015683 not in cfg.kb.functions

    #
    # Indirect jump resolvers
    #

    # For test cases for jump table resolver, please refer to test_jumptables.py

    def test_resolve_x86_elf_pic_plt(self):
        path = os.path.join(test_location, "i386", "fauxware_pie")
        proj = angr.Project(path, load_options={"auto_load_libs": False})

        cfg = proj.analyses.CFGFast()

        # puts
        puts_node = cfg.model.get_any_node(0x4005B0)
        assert puts_node is not None

        # there should be only one successor, which jumps to SimProcedure puts
        assert len(puts_node.successors) == 1
        puts_successor = puts_node.successors[0]
        assert puts_successor.addr == proj.loader.find_symbol("puts").rebased_addr

        # the SimProcedure puts should have more than one successors, which are all return targets
        assert len(puts_successor.successors) == 3
        simputs_successor = puts_successor.successors
        return_targets = {a.addr for a in simputs_successor}
        assert return_targets == {0x400800, 0x40087E, 0x4008B6}

    #
    # Function names
    #

    def test_function_names_for_unloaded_libraries(self):
        path = os.path.join(test_location, "i386", "fauxware_pie")
        proj = angr.Project(path, load_options={"auto_load_libs": False})

        cfg = proj.analyses.CFGFast()

        function_names = [f.name if not f.is_plt else "plt_" + f.name for f in cfg.functions.values()]

        assert "plt_puts" in function_names
        assert "plt_read" in function_names
        assert "plt___stack_chk_fail" in function_names
        assert "plt_exit" in function_names
        assert "puts" in function_names
        assert "read" in function_names
        assert "__stack_chk_fail" in function_names
        assert "exit" in function_names

    #
    # Basic blocks
    #

    def test_block_instruction_addresses_armhf(self):
        path = os.path.join(test_location, "armhf", "fauxware")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()

        main_func = cfg.kb.functions["main"]

        # all instruction addresses of the block must be odd
        block = next(b for b in main_func.blocks if b.addr == main_func.addr)

        assert len(block.instruction_addrs) == 12
        for instr_addr in block.instruction_addrs:
            assert instr_addr % 2 == 1

        main_node = cfg.model.get_any_node(main_func.addr)
        assert main_node is not None
        assert len(main_node.instruction_addrs) == 12
        for instr_addr in main_node.instruction_addrs:
            assert instr_addr % 2 == 1

    #
    # Tail-call optimization detection
    #

    def test_tail_call_optimization_detection_armel(self):
        # GitHub issue #1286

        path = os.path.join(test_location, "armel", "Nucleo_read_hyperterminal-stripped.elf")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(
            resolve_indirect_jumps=True,
            force_complete_scan=False,
            normalize=True,
            symbols=False,
            detect_tail_calls=True,
            data_references=True,
        )

        all_func_addrs = set(cfg.functions.keys())
        assert 0x80010B5 not in all_func_addrs
        assert 0x8003EF9 not in all_func_addrs
        assert 0x8008419 not in all_func_addrs

        # Functions that are jumped to from tail-calls
        tail_call_funcs = [
            0x8002BC1,
            0x80046C1,
            0x8000281,
            0x8001BDB,
            0x8002839,
            0x80037AD,
            0x8002C09,
            0x8004165,
            0x8004BE1,
            0x8002EB1,
        ]
        for member in tail_call_funcs:
            assert member in all_func_addrs

        # also test for tailcall return addresses

        # mapping of return blocks to return addrs that are the actual callers of certain tail-calls endpoints
        tail_call_return_addrs = {
            0x8002BD9: [0x800275F],  # 0x8002bc1
            0x80046D7: [0x800275F],  # 0x80046c1
            0x80046ED: [0x800275F],  # 0x80046c1
            0x8001BE7: [0x800068D, 0x8000695],  # 0x8001bdb ??
            0x800284D: [0x800028B, 0x80006E1, 0x80006E7],  # 0x8002839
            0x80037F5: [0x800270B, 0x8002733, 0x8002759, 0x800098F, 0x8000997],  # 0x80037ad
            0x80037EF: [0x800270B, 0x8002733, 0x8002759, 0x800098F, 0x8000997],  # 0x80037ad
            0x8002CC9: [
                0x8002D3B,
                0x8002B99,
                0x8002E9F,
                0x80041AD,
                0x8004C87,
                0x8004D35,
                0x8002EFB,
                0x8002BE9,
                0x80046EB,
                0x800464F,
                0x8002A09,
                0x800325F,
                0x80047C1,
            ],  # 0x8002c09
            0x8004183: [0x8002713],  # 0x8004165
            0x8004C31: [0x8002713],  # 0x8004be1
            0x8004C69: [0x8002713],  # 0x8004be1
            0x8002EF1: [0x800273B],
        }  # 0x8002eb1

        # check all expected return addrs are present
        for returning_block_addr, expected_return_addrs in tail_call_return_addrs.items():
            returning_block = cfg.model.get_any_node(returning_block_addr)
            return_block_addrs = [rb.addr for rb in cfg.model.get_successors(returning_block)]
            msg = (
                f"{returning_block_addr:x}: unequal sizes of expected_addrs "
                f"[{len(expected_return_addrs)}] and return_block_addrs "
                f"[{len(return_block_addrs)}]"
            )
            assert len(return_block_addrs) == len(expected_return_addrs), msg
            for expected_addr in expected_return_addrs:
                msg = f"expected retaddr {expected_addr:x} not found for returning_block {returning_block_addr:x}"
                assert expected_addr in return_block_addrs, msg

    #
    # Incorrect function-leading blocks merging
    #

    def test_function_leading_blocks_merging(self):
        # GitHub issue #1312

        path = os.path.join(test_location, "armel", "Nucleo_read_hyperterminal-stripped.elf")
        proj = angr.Project(path, arch=archinfo.ArchARMCortexM(), auto_load_libs=False)

        cfg = proj.analyses.CFGFast(
            resolve_indirect_jumps=True,
            force_complete_scan=True,
            normalize=True,
            symbols=False,
            detect_tail_calls=True,
        )

        assert 0x8000799 in cfg.kb.functions
        assert 0x800079B not in cfg.kb.functions
        assert 0x800079B not in cfg.kb.functions[0x8000799].block_addrs_set
        assert 0x8000799 in cfg.kb.functions[0x8000799].block_addrs_set
        assert next(iter(b for b in cfg.kb.functions[0x8000799].blocks if b.addr == 0x8000799)).size == 6

    #
    # Blanket
    #

    def test_blanket_fauxware(self):
        path = os.path.join(test_location, "x86_64", "fauxware")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()

        cfb = proj.analyses.CFBlanket(kb=cfg.kb)

        # it should raise a key error when calling floor_addr on address 0 because nothing is mapped there
        # an instruction (or a block) starts at 0x400580
        assert cfb.floor_addr(0x400581) == 0x400580
        # a block ends at 0x4005a9 (exclusive)
        assert cfb.ceiling_addr(0x400581) == 0x4005A9

    #
    # CFG with patches
    #

    def test_unresolvable_targets(self):
        path = os.path.join(test_location, "cgc", "CADET_00002")
        proj = angr.Project(path, auto_load_libs=False)

        proj.analyses.CFGFast(normalize=True)
        func = proj.kb.functions[0x080489E0]

        true_endpoint_addrs = {0x8048BBC, 0x8048AF5, 0x8048B5C, 0x8048A41, 0x8048AA8}
        endpoint_addrs = {node.addr for node in func.endpoints}
        assert len(endpoint_addrs.symmetric_difference(true_endpoint_addrs)) == 0

    def test_indirect_jump_to_outside(self):
        # an indirect jump might be jumping to outside as well
        path = os.path.join(test_location, "mipsel", "libndpi.so.4.0.0")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()

        assert len(list(cfg.functions[0x404EE4].blocks)) == 3
        assert {ep.addr for ep in cfg.functions[0x404EE4].endpoints} == {
            0x404F00,
            0x404F08,
        }

    def test_plt_stub_has_one_jumpout_site(self):
        # each PLT stub must have exactly one jumpout site
        path = os.path.join(test_location, "x86_64", "1after909")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast()

        for func in cfg.kb.functions.values():
            if func.is_plt:
                assert len(func.jumpout_sites) == 1

    def test_generate_special_info(self):
        path = os.path.join(test_location, "mipsel", "fauxware")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()

        assert any(func.info for func in cfg.functions.values())
        assert cfg.functions["main"].info["gp"] == 0x418CA0

    def test_load_from_shellcode(self):
        proj = angr.load_shellcode("loop: dec ecx; jnz loop; ret", "x86")
        cfg = proj.analyses.CFGFast()

        assert len(cfg.model.nodes()) == 2

    def test_starting_point_ordering(self):
        # project entry should always be first
        # so edge/path to unlabeled main function from _start
        # is correctly generated

        path = os.path.join(test_location, "armel", "start_ordering")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast()

        # if ordering is incorrect, edge to function 0x103D4 will not exist
        n = cfg.model.get_any_node(proj.entry)
        assert n is not None
        assert len(n.successors) > 0
        assert len(n.successors[0].successors) > 0
        assert len(n.successors[0].successors[0].successors) == 3

        # now checking if path to the "real main" exists
        assert len(n.successors[0].successors[0].successors[1].successors) > 0
        n = n.successors[0].successors[0].successors[1].successors[0]

        assert len(n.successors) > 0
        assert len(n.successors[0].successors) > 0
        assert len(n.successors[0].successors[0].successors) > 0
        assert n.successors[0].successors[0].successors[0].addr == 0x103D4

    def test_error_returning(self):
        # error() is a great function: its returning depends on the value of the first argument...
        path = os.path.join(test_location, "x86_64", "mv_-O2")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast()

        error_not_returning = [
            0x4030D4,
            0x403100,
            0x40313C,
            0x4031F5,
            0x40348A,
        ]

        error_returning = [0x403179, 0x4031A2, 0x403981, 0x403E30, 0x40403B]

        for error_site in error_not_returning:
            node = cfg.model.get_any_node(error_site)
            assert len(list(cfg.model.get_successors(node, excluding_fakeret=False))) == 1  # only the call successor

        for error_site in error_returning:
            node = cfg.model.get_any_node(error_site)
            assert len(list(cfg.model.get_successors(node, excluding_fakeret=False))) == 2  # both a call and a fakeret

    def test_kepler_server_armhf(self):
        binary_path = os.path.join(test_location, "armhf", "kepler_server")
        proj = angr.Project(binary_path, auto_load_libs=False)
        cfg = proj.analyses.CFG(
            normalize=True,
            indirect_calls_always_return=False,
        )

        func_main = cfg.kb.functions[0x10329]
        assert func_main.returning is False

        func_0 = cfg.kb.functions[0x15EE9]
        assert func_0.returning is False
        assert len(func_0.block_addrs_set) == 1

        func_1 = cfg.kb.functions[0x15D2D]
        assert func_1.returning is False

        func_2 = cfg.kb.functions[0x228C5]
        assert func_2.returning is False

        func_3 = cfg.kb.functions[0x12631]
        assert func_3.returning is True

    def test_func_in_added_segment_by_patcherex_arm(self):
        path = os.path.join(test_location, "armel", "patcherex", "replace_function_patch_with_function_reference")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast(
            normalize=True,
            function_starts={0xA00081},
            regions=[
                (4195232, 4195244),
                (4195244, 4195324),
                (4195324, 4196016),
                (4196016, 4196024),
                (10485888, 10485950),
            ],
        )

        # Check whether the target function is in the functions list
        assert 0xA00081 in cfg.kb.functions
        # Check the number of basic blocks
        assert len(list(cfg.functions[0xA00081].blocks)) == 8

    def test_func_in_added_segment_by_patcherex_x64(self):
        path = os.path.join(test_location, "x86_64", "patchrex", "replace_function_patch_with_function_reference")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast(
            normalize=True,
            function_starts={0xA0013D},
            regions=[
                (4195568, 4195591),
                (4195600, 4195632),
                (4195632, 4195640),
                (4195648, 4196418),
                (4196420, 4196429),
                (10486064, 10486213),
            ],
        )

        # Check whether the target function is in the functions list
        assert 0xA0013D in cfg.kb.functions
        # Check the number of basic blocks
        assert len(list(cfg.functions[0xA0013D].blocks)) == 7

    def test_indirect_calls_always_return_overly_aggressive(self):
        path = os.path.join(test_location, "x86_64", "ls_ubuntu_2004")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast(normalize=True)
        node = cfg.model.get_any_node(0x404DB4)
        assert node is not None
        assert node.function_address == 0x40F770

    def test_removing_lock_edges(self):
        path = os.path.join(
            test_location, "x86_64", "windows", "6f289eb8c8cd826525d79b195b1cf187df509d56120427b10ea3fb1b4db1b7b5.sys"
        )
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast(normalize=True)
        node = cfg.model.get_any_node(0x1400061C2)
        assert {n.addr for n in cfg.model.graph.successors(node)} == {0x1400060DC, 0x1400061D4}

    def test_security_init_cookie_identification(self):
        path = os.path.join(test_location, "x86_64", "windows", "3ware.sys")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast()
        assert cfg.kb.functions[0x1C001A018].name == "_security_init_cookie"
        assert cfg.kb.functions[0x1C0010100].name == "_security_check_cookie"

    def test_security_init_cookie_identification_a(self):
        path = os.path.join(
            test_location, "x86_64", "windows", "1817a5bf9c01035bcf8a975c9f1d94b0ce7f6a200339485d8f93859f8f6d730c.exe"
        )
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast()
        assert cfg.kb.functions[0x21514B5600].name == "_security_init_cookie"

    def test_security_check_cookie_identification_unknown_cookie_location(self):
        path = os.path.join(
            test_location, "x86_64", "windows", "03fb29dab8ab848f15852a37a1c04aa65289c0160d9200dceff64d890b3290dd"
        )
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast()
        assert cfg.kb.functions[0x14710].name == "_security_check_cookie"
        assert cfg.kb.labels[0x17108] == "_security_cookie"

    def test_pe_unmapped_section_data(self):
        path = os.path.join(
            test_location, "i386", "windows", "0b6e56e2325f8e34fc07669414f6b6fdd45b0de37937947c77c7b81c1fed4329"
        )
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast(force_smart_scan=False)
        for block in cfg.kb.functions[0x42CDD0].blocks:
            assert block.addr < 0x42CE00

    def test_windows_x86_driver_entry_hotpatch_points(self):
        # a hot-patch instruction at the beginning of a function of a Windows x86 driver should be considered as part
        # of the function instead of creating more functions.
        path = os.path.join(test_location, "x86", "windows", "CorsairLLAccess32.sys")
        proj = angr.Project(path, auto_load_libs=False)
        cfg = proj.analyses.CFGFast(normalize=True)
        # make sure it is merged properly
        func = cfg.kb.functions["_start"]
        assert len(func.block_addrs_set) == 2
        assert len(func.endpoints) == 1
        assert func.endpoints[0].addr == 0x40400A


class TestCfgfastDataReferences(unittest.TestCase):
    def test_data_references_x86_64(self):
        path = os.path.join(test_location, "x86_64", "fauxware")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(data_references=True)

        memory_data = cfg.memory_data
        # There is no code reference
        code_ref_count = len([d for d in memory_data.values() if d.sort == MemoryDataSort.CodeReference])
        assert code_ref_count >= 0, "There should be no code reference."

        # There are at least 2 pointer arrays
        ptr_array_count = len([d for d in memory_data.values() if d.sort == MemoryDataSort.PointerArray])
        assert ptr_array_count > 2, "Missing some pointer arrays."

        assert 0x4008D0 in memory_data
        sneaky_str = memory_data[0x4008D0]
        assert sneaky_str.sort == "string"
        assert sneaky_str.content == b"SOSNEAKY"

    def test_data_references_mipsel(self):
        path = os.path.join(test_location, "mipsel", "fauxware")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(data_references=True)

        memory_data = cfg.memory_data
        # There is no code reference
        code_ref_count = len([d for d in memory_data.values() if d.sort == MemoryDataSort.CodeReference])
        assert code_ref_count >= 0, "There should be no code reference."

        # There are at least 2 pointer arrays
        ptr_array_count = len([d for d in memory_data.values() if d.sort == MemoryDataSort.PointerArray])
        assert ptr_array_count >= 1, "Missing some pointer arrays."

        assert 0x400C00 in memory_data
        sneaky_str = memory_data[0x400C00]
        assert sneaky_str.sort == "string"
        assert sneaky_str.content == b"SOSNEAKY"

        assert 0x400C0C in memory_data
        str_ = memory_data[0x400C0C]
        assert str_.sort == "string"
        assert str_.content == b"Welcome to the admin console, trusted user!"

        assert 0x400C38 in memory_data
        str_ = memory_data[0x400C38]
        assert str_.sort == "string"
        assert str_.content == b"Go away!"

        assert 0x400C44 in memory_data
        str_ = memory_data[0x400C44]
        assert str_.sort == "string"
        assert str_.content == b"Username: "

        assert 0x400C50 in memory_data
        str_ = memory_data[0x400C50]
        assert str_.sort == "string"
        assert str_.content == b"Password: "

    def test_data_references_mips64(self):
        path = os.path.join(test_location, "mips64", "true")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(data_references=True, cross_references=True)
        memory_data = cfg.memory_data

        assert 0x120007DD8 in memory_data
        assert memory_data[0x120007DD8].sort == "string"
        assert memory_data[0x120007DD8].content == b"coreutils"

        xrefs = proj.kb.xrefs
        refs = list(xrefs.get_xrefs_by_dst(0x120007DD8))
        assert len(refs) == 2
        assert {x.ins_addr for x in refs} == {0x1200020E8, 0x120002108}

    def test_data_references_i386_gcc_pie(self):
        path = os.path.join(test_location, "i386", "nl")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(data_references=True, cross_references=True)
        memory_data = cfg.memory_data

        assert 0x405BB0 in memory_data
        assert memory_data[0x405BB0].sort == "string"
        assert memory_data[0x405BB0].content == b"/usr/local/share/locale"

        xrefs = proj.kb.xrefs
        refs = list(xrefs.get_xrefs_by_dst(0x405BB0))
        assert len(refs) == 1
        assert {x.ins_addr for x in refs} == {0x4011DD}

    def test_data_references_wide_string(self):
        path = os.path.join(test_location, "x86_64", "windows", "fauxware-wide.exe")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(data_references=True)
        recovered_strings = [d.content for d in cfg.memory_data.values() if d.sort == MemoryDataSort.UnicodeString]

        for testme in ("SOSNEAKY", "Welcome to the admin console, trusted user!\n", "Go away!\n", "Username: \n"):
            assert testme.encode("utf-16-le") in recovered_strings

    def test_data_references_lea_string_addr(self):
        path = os.path.join(test_location, "x86_64", "windows", "3ware.sys")
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(data_references=True)
        assert cfg.memory_data[0x1C0010A20].sort == MemoryDataSort.String
        assert cfg.memory_data[0x1C0010A20].content == b"Initialize> %s"

    def test_arm_function_hints_from_data_references(self):
        path = os.path.join(test_location, "armel", "sha224sum")
        proj = angr.Project(path, auto_load_libs=False)

        proj.analyses.CFGFast(data_references=True)
        funcs = proj.kb.functions
        assert funcs.contains_addr(0x129C4)
        func = funcs[0x129C4]
        assert len(list(func.blocks)) == 1
        assert next(iter(func.blocks)).size == 16

    def test_data_references_windows_driver_utf16_strings(self):
        path = os.path.join(
            test_location, "x86_64", "windows", "aaba7db353eb9400e3471eaaa1cf0105f6d1fab0ce63f1a2665c8ba0e8963a05.bin"
        )
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast()

        assert cfg.model.memory_data[0x1DCE0].sort == MemoryDataSort.UnicodeString
        assert cfg.model.memory_data[0x1DCE0].content == cstring_to_unicode_string(
            b"\\Registry\\Machine\\SYSTEM\\CurrentControlSet\\Control\\WinApi"
        )
        assert cfg.model.memory_data[0x1DCE0].size == 116
        assert cfg.model.memory_data[0x1DD90].sort == MemoryDataSort.UnicodeString
        assert cfg.model.memory_data[0x1DD90].content == cstring_to_unicode_string(b"ntdll.dll")
        assert cfg.model.memory_data[0x1DD90].size == 20

    def test_pe_32bit_pointer_array_detection(self):
        path = os.path.join(
            test_location, "i386", "windows", "53575875777863a69a573be858e75ceea834ea54c844bb528128a4ad16879d45"
        )
        proj = angr.Project(path, auto_load_libs=False)

        cfg = proj.analyses.CFGFast(show_progressbar=True)
        cfg_model = cfg.model
        assert cfg._seg_list.is_occupied(0x100018BC) is True
        assert cfg._seg_list.occupied_by_sort(0x100018BC) == "pointer-array"
        assert cfg_model.memory_data[0x100018BC].size == 4
        assert cfg_model.memory_data[0x100018BC].sort == MemoryDataSort.PointerArray
        assert cfg._seg_list.is_occupied(0x10001004) is True
        assert cfg._seg_list.occupied_by_sort(0x10001004) == "pointer-array"
        assert cfg_model.memory_data[0x10001004].size == 228
        assert cfg_model.memory_data[0x10001004].sort == MemoryDataSort.PointerArray

    def test_syscalls_resolved_with_constant_propagation(self):
        for arch in ["x86", "x86_64"]:
            with self.subTest(arch=arch):
                path = os.path.join(test_location, arch, "hello_syscalls")
                proj = angr.Project(path, auto_load_libs=False)
                proj.analyses.CFGFast()
                main = proj.kb.functions["main"]
                write = proj.kb.functions["write"]
                read = proj.kb.functions["read"]
                assert len(set(main.transition_graph.predecessors(write))) == 3
                assert len(set(main.transition_graph.predecessors(read))) == 1


if __name__ == "__main__":
    unittest.main()
