Skip to content

Commit 8e64e60

Browse files
committed
migrated attack code to basenode
1 parent 8970b3b commit 8e64e60

File tree

3 files changed

+91
-110
lines changed

3 files changed

+91
-110
lines changed

src/algos/base_class.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from utils.types import ConfigType
3939
from utils.dropout_utils import NodeDropout
40+
from utils.gias import gia_main
4041

4142
import torchvision.transforms as T # type: ignore
4243
import os
@@ -382,6 +383,17 @@ def __init__(
382383
super().__init__(config, comm_utils)
383384
self.server_node = 0
384385
self.set_parameters(config)
386+
if "gia" in config:
387+
if int(self.node_id) in self.config["gia_attackers"]:
388+
self.gia_attacker = True
389+
self.params_s = dict()
390+
self.params_t = dict()
391+
# Track neighbor updates with a dictionary mapping neighbor_id to their updates
392+
self.neighbor_updates = defaultdict(list)
393+
# Track which neighbors we've already attacked
394+
self.attacked_neighbors = set()
395+
396+
self.base_params = [key for key, _ in self.model.named_parameters()]
385397

386398
def set_parameters(self, config: Dict[str, Any]) -> None:
387399
"""
@@ -672,6 +684,68 @@ def receive_and_aggregate(self):
672684
assert "model" in repr, "Model not found in the received message"
673685
self.set_model_weights(repr["model"])
674686

687+
def receive_attack_and_aggregate(self, neighbors: List[int], round: int, num_neighbors: int) -> None:
688+
"""
689+
Receives updates, launches GIA attack when second update is seen from a neighbor
690+
"""
691+
print("CLIENT RECEIVING ATTACK AND AGGREGATING")
692+
if self.is_working:
693+
# Receive the model updates from the neighbors
694+
model_updates = self.comm_utils.receive(node_ids=neighbors)
695+
assert len(model_updates) == num_neighbors
696+
697+
for neighbor_info in model_updates:
698+
neighbor_id = neighbor_info["sender"]
699+
neighbor_model = neighbor_info["model"]
700+
neighbor_model = OrderedDict(
701+
(key, value) for key, value in neighbor_model.items()
702+
if key in self.base_params
703+
)
704+
705+
neighbor_images = neighbor_info["images"]
706+
neighbor_labels = neighbor_info["labels"]
707+
708+
# Store this update
709+
self.neighbor_updates[neighbor_id].append({
710+
"model": neighbor_model,
711+
"images": neighbor_images,
712+
"labels": neighbor_labels
713+
})
714+
715+
# Check if we have 2 updates from this neighbor and haven't attacked them yet
716+
if len(self.neighbor_updates[neighbor_id]) == 2 and neighbor_id not in self.attacked_neighbors:
717+
print(f"Client {self.node_id} attacking {neighbor_id}!")
718+
719+
# Get the two parameter sets for the attack
720+
p_s = self.neighbor_updates[neighbor_id][0]["model"]
721+
p_t = self.neighbor_updates[neighbor_id][1]["model"]
722+
723+
# Launch the attack
724+
if result := gia_main(p_s,
725+
p_t,
726+
self.base_params,
727+
self.model,
728+
neighbor_labels,
729+
neighbor_images,
730+
self.node_id):
731+
output, stats = result
732+
733+
# log output and stats as image
734+
self.log_utils.log_gia_image(output, neighbor_labels, neighbor_id, label=f"round_{round}_reconstruction")
735+
self.log_utils.log_summary(f"round {round} gia targeting {neighbor_id} stats: {stats}")
736+
else:
737+
self.log_utils.log_summary(f"Client {self.node_id} failed to attack {neighbor_id} in round {round}!")
738+
print(f"Client {self.node_id} failed to attack {neighbor_id}!")
739+
continue
740+
741+
# Mark this neighbor as attacked
742+
self.attacked_neighbors.add(neighbor_id)
743+
744+
# Optionally, clear the stored updates to save memory
745+
del self.neighbor_updates[neighbor_id]
746+
747+
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)
748+
675749
def run_protocol(self) -> None:
676750
raise NotImplementedError
677751

@@ -762,7 +836,6 @@ def get_model(self, **kwargs: Any) -> Any:
762836
def run_protocol(self) -> None:
763837
raise NotImplementedError
764838

765-
766839
class CommProtocol(object):
767840
"""
768841
Communication protocol tags for the server and users
@@ -800,6 +873,7 @@ def __init__(
800873
keys = self.model_utils.get_last_layer_keys(self.get_model_weights())
801874
self.model_keys_to_ignore.extend(keys)
802875

876+
803877
def local_test(self, **kwargs: Any) -> Tuple[float, float]:
804878
"""
805879
Test the model locally, not to be used in the traditional FedAvg
@@ -864,13 +938,19 @@ def aggregate(
864938
self.set_model_weights(agg_wts)
865939
return None
866940

867-
def receive_and_aggregate(self, neighbors: List[int]) -> None:
868-
if self.is_working:
869-
# Receive the model updates from the neighbors
870-
model_updates = self.comm_utils.receive(node_ids=neighbors)
871-
# Aggregate the representations
872-
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)
873-
941+
def receive_and_aggregate(self, neighbors: List[int], it:int=0) -> None:
942+
"""
943+
Receive the model weights from the collaborators and aggregate
944+
launches GIA attack if self is a GIA attacker
945+
"""
946+
if hasattr(self, "gia_attacker"):
947+
self.receive_attack_and_aggregate(neighbors, it, len(neighbors))
948+
else:
949+
if self.is_working:
950+
# Receive the model updates from the neighbors
951+
model_updates = self.comm_utils.receive(node_ids=neighbors)
952+
# Aggregate the representations
953+
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)
874954

875955
def get_collaborator_weights(
876956
self, reprs_dict: Dict[int, OrderedDict[int, Tensor]]

src/algos/fl.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,6 @@ def __init__(
113113
self.model_save_path = "{}/saved_models/node_{}.pt".format(
114114
self.config["results_path"], self.node_id
115115
)
116-
if "gia" in self.config:
117-
# to store param differences for GIA attack
118-
self.params_s = [None for i in range(4)]
119-
self.params_t = [None for i in range(4)]
120-
121-
# save randomly initialized parameters
122-
self.random_params = self.model.state_dict()
123116

124117
def fed_avg(self, model_wts: List[OrderedDict[str, Tensor]]):
125118
num_users = len(model_wts)
@@ -174,7 +167,7 @@ def test(self, **kwargs: Any) -> List[float]:
174167
self.model_utils.save_model(self.model, self.model_save_path)
175168
return [test_loss, test_acc, time_taken]
176169

177-
def receive_and_aggregate_gia(self, round: int, attack_start_round: int, attack_end_round: int, dump_file_name: str = ""):
170+
def receive_attack_and_aggregate(self, round: int, attack_start_round: int, attack_end_round: int, dump_file_name: str = ""):
178171
reprs = self.comm_utils.all_gather()
179172

180173
with open(dump_file_name, "wb") as f:
@@ -211,13 +204,6 @@ def receive_and_aggregate_gia(self, round: int, attack_start_round: int, attack_
211204
images = rep["images"]
212205
labels = rep["labels"]
213206

214-
# with open(f"params_t_{client_id}.pkl", "wb") as f:
215-
# pickle.dump(model_params, f)
216-
# with open(f"params_s_{client_id}.pkl", "wb") as f:
217-
# pickle.dump(self.params_s[client_id - 1], f)
218-
# with open(f"random_params_{client_id}.pkl", "wb") as f:
219-
# pickle.dump(random_params, f)
220-
221207
# Launch GIA attack
222208
p_s, p_t = self.params_s[client_id - 1], self.params_t[client_id - 1]
223209
gia_main(p_s, p_t, base_params, self.model, labels, images, client_id)
@@ -245,15 +231,7 @@ def single_round(self, round: int, attack_start_round: int = 0, attack_end_round
245231
if round < attack_start_round or round > attack_end_round:
246232
self.receive_and_aggregate()
247233
else:
248-
# Set file name based on start or end of attack range
249-
dump_file_name = ""
250-
if round == attack_start_round:
251-
dump_file_name = "/u/yshi23/sonar/src/start_reprs"
252-
elif round == attack_end_round:
253-
dump_file_name = "/u/yshi23/sonar/src/end_reprs"
254-
255-
print(f"In round {round}, preparing for GIA with file: {dump_file_name}")
256-
self.receive_and_aggregate_gia(round, attack_start_round, attack_end_round, dump_file_name)
234+
self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round, dump_file_name)
257235

258236

259237
def run_protocol(self):

src/algos/fl_static.py

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -24,86 +24,13 @@ def __init__(
2424
super().__init__(config, comm_utils)
2525
self.topology = select_topology(config, self.node_id)
2626
self.topology.initialize()
27-
if "gia" in config:
28-
if int(self.node_id) in self.config["gia_attackers"]:
29-
self.gia_attacker = True
30-
self.params_s = dict()
31-
self.params_t = dict()
32-
# Track neighbor updates with a dictionary mapping neighbor_id to their updates
33-
self.neighbor_updates = defaultdict(list)
34-
# Track which neighbors we've already attacked
35-
self.attacked_neighbors = set()
36-
37-
self.base_params = [key for key, _ in self.model.named_parameters()]
3827

3928
def get_representation(self, **kwargs: Any) -> OrderedDict[str, torch.Tensor]:
4029
"""
4130
Returns the model weights as representation.
4231
"""
4332
return self.get_model_weights()
4433

45-
def receive_attack_and_aggregate(self, neighbors: List[int], round: int, num_neighbors: int) -> None:
46-
"""
47-
Receives updates, launches GIA attack when second update is seen from a neighbor
48-
"""
49-
print("CLIENT RECEIVING ATTACK AND AGGREGATING")
50-
if self.is_working:
51-
# Receive the model updates from the neighbors
52-
model_updates = self.comm_utils.receive(node_ids=neighbors)
53-
assert len(model_updates) == num_neighbors
54-
55-
for neighbor_info in model_updates:
56-
neighbor_id = neighbor_info["sender"]
57-
neighbor_model = neighbor_info["model"]
58-
neighbor_model = OrderedDict(
59-
(key, value) for key, value in neighbor_model.items()
60-
if key in self.base_params
61-
)
62-
63-
neighbor_images = neighbor_info["images"]
64-
neighbor_labels = neighbor_info["labels"]
65-
66-
# Store this update
67-
self.neighbor_updates[neighbor_id].append({
68-
"model": neighbor_model,
69-
"images": neighbor_images,
70-
"labels": neighbor_labels
71-
})
72-
73-
# Check if we have 2 updates from this neighbor and haven't attacked them yet
74-
if len(self.neighbor_updates[neighbor_id]) == 2 and neighbor_id not in self.attacked_neighbors:
75-
print(f"Client {self.node_id} attacking {neighbor_id}!")
76-
77-
# Get the two parameter sets for the attack
78-
p_s = self.neighbor_updates[neighbor_id][0]["model"]
79-
p_t = self.neighbor_updates[neighbor_id][1]["model"]
80-
81-
# Launch the attack
82-
if result := gia_main(p_s,
83-
p_t,
84-
self.base_params,
85-
self.model,
86-
neighbor_labels,
87-
neighbor_images,
88-
self.node_id):
89-
output, stats = result
90-
91-
# log output and stats as image
92-
self.log_utils.log_gia_image(output, neighbor_labels, neighbor_id, label=f"round_{round}_reconstruction")
93-
self.log_utils.log_summary(f"round {round} gia targeting {neighbor_id} stats: {stats}")
94-
else:
95-
self.log_utils.log_summary(f"Client {self.node_id} failed to attack {neighbor_id} in round {round}!")
96-
print(f"Client {self.node_id} failed to attack {neighbor_id}!")
97-
continue
98-
99-
# Mark this neighbor as attacked
100-
self.attacked_neighbors.add(neighbor_id)
101-
102-
# Optionally, clear the stored updates to save memory
103-
del self.neighbor_updates[neighbor_id]
104-
105-
self.aggregate(model_updates, keys_to_ignore=self.model_keys_to_ignore)
106-
10734
def run_protocol(self) -> None:
10835
"""
10936
Runs the federated learning protocol for the client.
@@ -129,11 +56,7 @@ def run_protocol(self) -> None:
12956
neighbors = self.topology.sample_neighbours(self.num_collaborators)
13057
stats["neighbors"] = neighbors
13158

132-
if hasattr(self, "gia_attacker"):
133-
print(f"Client {self.node_id} is a GIA attacker!")
134-
self.receive_attack_and_aggregate(neighbors, it, len(neighbors))
135-
else:
136-
self.receive_and_aggregate(neighbors)
59+
self.receive_and_aggregate(neighbors)
13760

13861
stats["bytes_received"], stats["bytes_sent"] = self.comm_utils.get_comm_cost()
13962
stats["test_loss"], stats["test_acc"] = self.local_test()

0 commit comments

Comments
 (0)