|
37 | 37 | ) |
38 | 38 | from utils.types import ConfigType |
39 | 39 | from utils.dropout_utils import NodeDropout |
| 40 | +from utils.gias import gia_main |
40 | 41 |
|
41 | 42 | import torchvision.transforms as T # type: ignore |
42 | 43 | import os |
@@ -382,6 +383,17 @@ def __init__( |
382 | 383 | super().__init__(config, comm_utils) |
383 | 384 | self.server_node = 0 |
384 | 385 | self.set_parameters(config) |
| 386 | + if "gia" in config: |
| 387 | + if int(self.node_id) in self.config["gia_attackers"]: |
| 388 | + self.gia_attacker = True |
| 389 | + self.params_s = dict() |
| 390 | + self.params_t = dict() |
| 391 | + # Track neighbor updates with a dictionary mapping neighbor_id to their updates |
| 392 | + self.neighbor_updates = defaultdict(list) |
| 393 | + # Track which neighbors we've already attacked |
| 394 | + self.attacked_neighbors = set() |
| 395 | + |
| 396 | + self.base_params = [key for key, _ in self.model.named_parameters()] |
385 | 397 |
|
386 | 398 | def set_parameters(self, config: Dict[str, Any]) -> None: |
387 | 399 | """ |
@@ -672,6 +684,68 @@ def receive_and_aggregate(self): |
672 | 684 | assert "model" in repr, "Model not found in the received message" |
673 | 685 | self.set_model_weights(repr["model"]) |
674 | 686 |
|
| 687 | + def receive_attack_and_aggregate(self, neighbors: List[int], round: int, num_neighbors: int) -> None: |
| 688 | + """ |
| 689 | + Receives updates, launches GIA attack when second update is seen from a neighbor |
| 690 | + """ |
| 691 | + print("CLIENT RECEIVING ATTACK AND AGGREGATING") |
| 692 | + if self.is_working: |
| 693 | + # Receive the model updates from the neighbors |
| 694 | + model_updates = self.comm_utils.receive(node_ids=neighbors) |
| 695 | + assert len(model_updates) == num_neighbors |
| 696 | + |
| 697 | + for neighbor_info in model_updates: |
| 698 | + neighbor_id = neighbor_info["sender"] |
| 699 | + neighbor_model = neighbor_info["model"] |
| 700 | + neighbor_model = OrderedDict( |
| 701 | + (key, value) for key, value in neighbor_model.items() |
| 702 | + if key in self.base_params |
| 703 | + ) |
| 704 | + |
| 705 | + neighbor_images = neighbor_info["images"] |
| 706 | + neighbor_labels = neighbor_info["labels"] |
| 707 | + |
| 708 | + # Store this update |
| 709 | + self.neighbor_updates[neighbor_id].append({ |
| 710 | + "model": neighbor_model, |
| 711 | + "images": neighbor_images, |
| 712 | + "labels": neighbor_labels |
| 713 | + }) |
| 714 | + |
| 715 | + # Check if we have 2 updates from this neighbor and haven't attacked them yet |
| 716 | + if len(self.neighbor_updates[neighbor_id]) == 2 and neighbor_id not in self.attacked_neighbors: |
| 717 | + print(f"Client {self.node_id} attacking {neighbor_id}!") |
| 718 | + |
| 719 | + # Get the two parameter sets for the attack |
| 720 | + p_s = self.neighbor_updates[neighbor_id][0]["model"] |
| 721 | + p_t = self.neighbor_updates[neighbor_id][1]["model"] |
| 722 | + |
| 723 | + # Launch the attack |
| 724 | + if result := gia_main(p_s, |
| 725 | + p_t, |
| 726 | + self.base_params, |
| 727 | + self.model, |
| 728 | + neighbor_labels, |
| 729 | + neighbor_images, |
| 730 | + self.node_id): |
| 731 | + output, stats = result |
| 732 | + |
| 733 | + # log output and stats as image |
| 734 | + self.log_utils.log_gia_image(output, neighbor_labels, neighbor_id, label=f"round_{round}_reconstruction") |
| 735 | + self.log_utils.log_summary(f"round {round} gia targeting {neighbor_id} stats: {stats}") |
| 736 | + else: |
| 737 | + self.log_utils.log_summary(f"Client {self.node_id} failed to attack {neighbor_id} in round {round}!") |
| 738 | + print(f"Client {self.node_id} failed to attack {neighbor_id}!") |
| 739 | + continue |
| 740 | + |
| 741 | + # Mark this neighbor as attacked |
| 742 | + self.attacked_neighbors.add(neighbor_id) |
| 743 | + |
| 744 | + # Optionally, clear the stored updates to save memory |
| 745 | + del self.neighbor_updates[neighbor_id] |
| 746 | + |
| 747 | + self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) |
| 748 | + |
675 | 749 | def run_protocol(self) -> None: |
676 | 750 | raise NotImplementedError |
677 | 751 |
|
@@ -762,7 +836,6 @@ def get_model(self, **kwargs: Any) -> Any: |
762 | 836 | def run_protocol(self) -> None: |
763 | 837 | raise NotImplementedError |
764 | 838 |
|
765 | | - |
766 | 839 | class CommProtocol(object): |
767 | 840 | """ |
768 | 841 | Communication protocol tags for the server and users |
@@ -800,6 +873,7 @@ def __init__( |
800 | 873 | keys = self.model_utils.get_last_layer_keys(self.get_model_weights()) |
801 | 874 | self.model_keys_to_ignore.extend(keys) |
802 | 875 |
|
| 876 | + |
803 | 877 | def local_test(self, **kwargs: Any) -> Tuple[float, float]: |
804 | 878 | """ |
805 | 879 | Test the model locally, not to be used in the traditional FedAvg |
@@ -864,13 +938,19 @@ def aggregate( |
864 | 938 | self.set_model_weights(agg_wts) |
865 | 939 | return None |
866 | 940 |
|
867 | | - def receive_and_aggregate(self, neighbors: List[int]) -> None: |
868 | | - if self.is_working: |
869 | | - # Receive the model updates from the neighbors |
870 | | - model_updates = self.comm_utils.receive(node_ids=neighbors) |
871 | | - # Aggregate the representations |
872 | | - self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) |
873 | | - |
| 941 | + def receive_and_aggregate(self, neighbors: List[int], it:int=0) -> None: |
| 942 | + """ |
| 943 | + Receive the model weights from the collaborators and aggregate |
| 944 | + launches GIA attack if self is a GIA attacker |
| 945 | + """ |
| 946 | + if hasattr(self, "gia_attacker"): |
| 947 | + self.receive_attack_and_aggregate(neighbors, it, len(neighbors)) |
| 948 | + else: |
| 949 | + if self.is_working: |
| 950 | + # Receive the model updates from the neighbors |
| 951 | + model_updates = self.comm_utils.receive(node_ids=neighbors) |
| 952 | + # Aggregate the representations |
| 953 | + self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore) |
874 | 954 |
|
875 | 955 | def get_collaborator_weights( |
876 | 956 | self, reprs_dict: Dict[int, OrderedDict[int, Tensor]] |
|
0 commit comments