diff --git a/README.md b/README.md index 6e2987e..30a248f 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,11 @@ Condition may compare attributes of `target` and also call one of these methods: If `target` is a Dataflow, remember you can access `target.source` and/or `target.sink` along with other attributes. +Conditions on assets can analyze all incoming and outgoing Dataflows by inspecting +the `target.input` and `target.output` attributes. For example, to match a threat only against +servers with incoming traffic, use `any(target.inputs)`. A more advanced example, +matching elements connecting to SQL datastores, would be `any(f.sink.oneOf(Datastore) and f.sink.isSQL for f in target.outputs)`. + ## Currently supported threats ```text diff --git a/pytm/pytm.py b/pytm/pytm.py index b27eca4..d803bd5 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -100,6 +100,19 @@ def __set__(self, instance, value): super().__set__(instance, value) +class varElements(var): + + def __set__(self, instance, value): + for i, e in enumerate(value): + if not isinstance(e, Element): + raise ValueError( + "expecting a list of Elements, item number {} is a {}".format( + i, type(value) + ) + ) + super().__set__(instance, list(value)) + + class varFindings(var): def __set__(self, instance, value): @@ -200,17 +213,35 @@ def _match_responses(flows): def _apply_defaults(flows): + inputs = defaultdict(list) + outputs = defaultdict(list) for e in flows: e._safeset("data", e.source.data) + if e.isResponse: e._safeset("protocol", e.source.protocol) e._safeset("srcPort", e.source.port) e._safeset("isEncrypted", e.source.isEncrypted) - else: - e._safeset("protocol", e.sink.protocol) - e._safeset("dstPort", e.sink.port) - if hasattr(e.sink, "isEncrypted"): - e._safeset("isEncrypted", e.sink.isEncrypted) + continue + + e._safeset("protocol", e.sink.protocol) + e._safeset("dstPort", e.sink.port) + if hasattr(e.sink, "isEncrypted"): + e._safeset("isEncrypted", e.sink.isEncrypted) + + outputs[e.source].append(e) + inputs[e.sink].append(e) + + for e, flows in inputs.items(): + try: + e.inputs = flows + except (AttributeError, ValueError): + pass + for e, flows in outputs.items(): + try: + e.outputs = flows + except (AttributeError, ValueError): + pass def _describe_classes(classes): @@ -396,11 +427,14 @@ def check(self): _apply_defaults(TM._BagOfFlows) if self.ignoreUnused: TM._BagOfElements, TM._BagOfBoundaries = _get_elements_and_boundaries(TM._BagOfFlows) + result = True for e in (TM._BagOfElements): - e.check() + if not e.check(): + result = False if self.ignoreUnused: # cannot rely on user defined order if assets are re-used in multiple models TM._BagOfElements = _sort_elem(TM._BagOfElements) + return result def _check_duplicates(self, flows): if self.onDuplicates == Action.NO_ACTION: @@ -538,10 +572,6 @@ def _uniq_name(self): def check(self): return True - ''' makes sure it is good to go ''' - # all minimum annotations are in place - if self.description == "" or self.name == "": - raise ValueError("Element {} need a description and a name.".format(self.name)) def dfd(self, **kwargs): self._is_drawn = True @@ -638,6 +668,8 @@ class Lambda(Element): environment = varString("") implementsAPI = varBool(False) authorizesSource = varBool(False) + inputs = varElements([], doc="incoming Dataflows") + outputs = varElements([], doc="outgoing Dataflows") def __init__(self, name, **kwargs): super().__init__(name, **kwargs) @@ -659,6 +691,8 @@ class Server(Element): isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") protocol = varString("", doc="Default network protocol for incoming data flows") data = varString("", doc="Default type of data in incoming data flows") + inputs = varElements([], doc="incoming Dataflows") + outputs = varElements([], doc="outgoing Dataflows") providesConfidentiality = varBool(False) providesIntegrity = varBool(False) authenticatesSource = varBool(False) @@ -721,6 +755,8 @@ class Datastore(Element): isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") protocol = varString("", doc="Default network protocol for incoming data flows") data = varString("", doc="Default type of data in incoming data flows") + inputs = varElements([], doc="incoming Dataflows") + outputs = varElements([], doc="outgoing Dataflows") onRDS = varBool(False) storesLogData = varBool(False) storesPII = varBool(False, doc="""Personally Identifiable Information @@ -766,6 +802,8 @@ class Actor(Element): port = varInt(-1, doc="Default TCP port for outgoing data flows") protocol = varString("", doc="Default network protocol for outgoing data flows") data = varString("", doc="Default type of data in outgoing data flows") + inputs = varElements([], doc="incoming Dataflows") + outputs = varElements([], doc="outgoing Dataflows") def __init__(self, name, **kwargs): super().__init__(name, **kwargs) @@ -785,6 +823,8 @@ class Process(Element): isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") protocol = varString("", doc="Default network protocol for incoming data flows") data = varString("", doc="Default type of data in incoming data flows") + inputs = varElements([], doc="incoming Dataflows") + outputs = varElements([], doc="outgoing Dataflows") codeType = varString("Unmanaged") implementsCommunicationProtocol = varBool(False) providesConfidentiality = varBool(False) @@ -884,12 +924,6 @@ def __init__(self, source, sink, name, **kwargs): def __set__(self, instance, value): print("Should not have gotten here.") - def check(self): - ''' makes sure it is good to go ''' - # all minimum annotations are in place - # then add itself to _BagOfFlows - pass - def dfd(self, mergeResponses=False, **kwargs): self._is_drawn = True color = _setColor(self) diff --git a/tests/test_private_func.py b/tests/test_private_func.py index 1aabfeb..3b7c58e 100644 --- a/tests/test_private_func.py +++ b/tests/test_private_func.py @@ -3,7 +3,7 @@ import unittest import random -from pytm.pytm import Actor, Boundary, Dataflow, Datastore, Server, TM, Threat +from pytm.pytm import Actor, Boundary, Dataflow, Datastore, Process, Server, TM, Threat class TestUniqueNames(unittest.TestCase): @@ -57,7 +57,7 @@ def test_responses(self): http_resp = Dataflow(web, user, "http resp") http_resp.responseTo = http_req - tm.check() + self.assertTrue(tm.check()) self.assertEqual(http_req.response, http_resp) self.assertIs(http_resp.isResponse, True) @@ -83,16 +83,20 @@ def test_defaults(self): isEncrypted=False, data="SQL resp", ) + worker = Process("Task queue worker") req_get = Dataflow(user, server, "HTTP GET") - query = Dataflow(server, db, "Query", data="SQL") + server_query = Dataflow(server, db, "Query", data="SQL") result = Dataflow(db, server, "Results", isResponse=True) resp_get = Dataflow(server, user, "HTTP Response", isResponse=True) req_post = Dataflow(user, server, "HTTP POST", data="JSON") resp_post = Dataflow(server, user, "HTTP Response", isResponse=True) - tm.check() + worker_query = Dataflow(worker, db, "Query", data="SQL") + Dataflow(db, worker, "Results", isResponse=True) + + self.assertTrue(tm.check()) self.assertEqual(req_get.srcPort, -1) self.assertEqual(req_get.dstPort, server.port) @@ -100,11 +104,11 @@ def test_defaults(self): self.assertEqual(req_get.protocol, server.protocol) self.assertEqual(req_get.data, user.data) - self.assertEqual(query.srcPort, -1) - self.assertEqual(query.dstPort, db.port) - self.assertEqual(query.isEncrypted, db.isEncrypted) - self.assertEqual(query.protocol, db.protocol) - self.assertNotEqual(query.data, server.data) + self.assertEqual(server_query.srcPort, -1) + self.assertEqual(server_query.dstPort, db.port) + self.assertEqual(server_query.isEncrypted, db.isEncrypted) + self.assertEqual(server_query.protocol, db.protocol) + self.assertNotEqual(server_query.data, server.data) self.assertEqual(result.srcPort, db.port) self.assertEqual(result.dstPort, -1) @@ -130,18 +134,27 @@ def test_defaults(self): self.assertEqual(resp_post.protocol, server.protocol) self.assertEqual(resp_post.data, server.data) + self.assertListEqual(server.inputs, [req_get, req_post]) + self.assertListEqual(server.outputs, [server_query]) + self.assertListEqual(worker.inputs, []) + self.assertListEqual(worker.outputs, [worker_query]) + class TestMethod(unittest.TestCase): def test_defaults(self): + tm = TM("my test tm", description="aa", isOrdered=True) + internet = Boundary("Internet") cloud = Boundary("Cloud") + user = Actor("User", inBoundary=internet) server = Server("Server") - db = Datastore("DB", inBoundary=cloud) + db = Datastore("DB", inBoundary=cloud, isSQL=True) func = Datastore("Lambda function", inBoundary=cloud) + request = Dataflow(user, server, "request") - response = Dataflow(server, user, "response") + response = Dataflow(server, user, "response", isResponse=True) user_query = Dataflow(user, db, "user query") server_query = Dataflow(server, db, "server query") func_query = Dataflow(func, db, "func query") @@ -161,12 +174,21 @@ def test_defaults(self): {"target": response, "condition": "target.enters(Boundary)"}, {"target": response, "condition": "not target.exits(Boundary)"}, {"target": user, "condition": "target.inside(Boundary)"}, + {"target": func, "condition": "not any(target.inputs)"}, + { + "target": server, + "condition": "any(f.sink.oneOf(Datastore) and f.sink.isSQL " + "for f in target.outputs)", + }, ] + + self.assertTrue(tm.check()) + for case in testCases: t = Threat(SID="", target=default_target, condition=case["condition"]) self.assertTrue( t.apply(case["target"]), "Failed to match {} against {}".format( - case["target"], case["condition"] + case["target"], case["condition"], ), )