Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ Condition may compare attributes of `target` and also call one of these methods:

If `target` is a Dataflow, remember you can access `target.source` and/or `target.sink` along with other attributes.

Conditions on assets can analyze all incoming and outgoing Dataflows by inspecting
the `target.input` and `target.output` attributes. For example, to match a threat only against
servers with incoming traffic, use `any(target.inputs)`. A more advanced example,
matching elements connecting to SQL datastores, would be `any(f.sink.oneOf(Datastore) and f.sink.isSQL for f in target.outputs)`.

## Currently supported threats

```text
Expand Down
66 changes: 50 additions & 16 deletions pytm/pytm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,19 @@ def __set__(self, instance, value):
super().__set__(instance, value)


class varElements(var):

def __set__(self, instance, value):
for i, e in enumerate(value):
if not isinstance(e, Element):
raise ValueError(
"expecting a list of Elements, item number {} is a {}".format(
i, type(value)
)
)
super().__set__(instance, list(value))


class varFindings(var):

def __set__(self, instance, value):
Expand Down Expand Up @@ -200,17 +213,35 @@ def _match_responses(flows):


def _apply_defaults(flows):
inputs = defaultdict(list)
outputs = defaultdict(list)
for e in flows:
e._safeset("data", e.source.data)

if e.isResponse:
e._safeset("protocol", e.source.protocol)
e._safeset("srcPort", e.source.port)
e._safeset("isEncrypted", e.source.isEncrypted)
else:
e._safeset("protocol", e.sink.protocol)
e._safeset("dstPort", e.sink.port)
if hasattr(e.sink, "isEncrypted"):
e._safeset("isEncrypted", e.sink.isEncrypted)
continue

e._safeset("protocol", e.sink.protocol)
e._safeset("dstPort", e.sink.port)
if hasattr(e.sink, "isEncrypted"):
e._safeset("isEncrypted", e.sink.isEncrypted)

outputs[e.source].append(e)
inputs[e.sink].append(e)

for e, flows in inputs.items():
try:
e.inputs = flows
except (AttributeError, ValueError):
pass
for e, flows in outputs.items():
try:
e.outputs = flows
except (AttributeError, ValueError):
pass


def _describe_classes(classes):
Expand Down Expand Up @@ -396,11 +427,14 @@ def check(self):
_apply_defaults(TM._BagOfFlows)
if self.ignoreUnused:
TM._BagOfElements, TM._BagOfBoundaries = _get_elements_and_boundaries(TM._BagOfFlows)
result = True
for e in (TM._BagOfElements):
e.check()
if not e.check():
result = False
if self.ignoreUnused:
# cannot rely on user defined order if assets are re-used in multiple models
TM._BagOfElements = _sort_elem(TM._BagOfElements)
return result

def _check_duplicates(self, flows):
if self.onDuplicates == Action.NO_ACTION:
Expand Down Expand Up @@ -538,10 +572,6 @@ def _uniq_name(self):

def check(self):
return True
''' makes sure it is good to go '''
# all minimum annotations are in place
if self.description == "" or self.name == "":
raise ValueError("Element {} need a description and a name.".format(self.name))

def dfd(self, **kwargs):
self._is_drawn = True
Expand Down Expand Up @@ -638,6 +668,8 @@ class Lambda(Element):
environment = varString("")
implementsAPI = varBool(False)
authorizesSource = varBool(False)
inputs = varElements([], doc="incoming Dataflows")
outputs = varElements([], doc="outgoing Dataflows")

def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand All @@ -659,6 +691,8 @@ class Server(Element):
isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted")
protocol = varString("", doc="Default network protocol for incoming data flows")
data = varString("", doc="Default type of data in incoming data flows")
inputs = varElements([], doc="incoming Dataflows")
outputs = varElements([], doc="outgoing Dataflows")
providesConfidentiality = varBool(False)
providesIntegrity = varBool(False)
authenticatesSource = varBool(False)
Expand Down Expand Up @@ -721,6 +755,8 @@ class Datastore(Element):
isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted")
protocol = varString("", doc="Default network protocol for incoming data flows")
data = varString("", doc="Default type of data in incoming data flows")
inputs = varElements([], doc="incoming Dataflows")
outputs = varElements([], doc="outgoing Dataflows")
onRDS = varBool(False)
storesLogData = varBool(False)
storesPII = varBool(False, doc="""Personally Identifiable Information
Expand Down Expand Up @@ -766,6 +802,8 @@ class Actor(Element):
port = varInt(-1, doc="Default TCP port for outgoing data flows")
protocol = varString("", doc="Default network protocol for outgoing data flows")
data = varString("", doc="Default type of data in outgoing data flows")
inputs = varElements([], doc="incoming Dataflows")
outputs = varElements([], doc="outgoing Dataflows")

def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand All @@ -785,6 +823,8 @@ class Process(Element):
isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted")
protocol = varString("", doc="Default network protocol for incoming data flows")
data = varString("", doc="Default type of data in incoming data flows")
inputs = varElements([], doc="incoming Dataflows")
outputs = varElements([], doc="outgoing Dataflows")
codeType = varString("Unmanaged")
implementsCommunicationProtocol = varBool(False)
providesConfidentiality = varBool(False)
Expand Down Expand Up @@ -884,12 +924,6 @@ def __init__(self, source, sink, name, **kwargs):
def __set__(self, instance, value):
print("Should not have gotten here.")

def check(self):
''' makes sure it is good to go '''
# all minimum annotations are in place
# then add itself to _BagOfFlows
pass

def dfd(self, mergeResponses=False, **kwargs):
self._is_drawn = True
color = _setColor(self)
Expand Down
46 changes: 34 additions & 12 deletions tests/test_private_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
import random

from pytm.pytm import Actor, Boundary, Dataflow, Datastore, Server, TM, Threat
from pytm.pytm import Actor, Boundary, Dataflow, Datastore, Process, Server, TM, Threat


class TestUniqueNames(unittest.TestCase):
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_responses(self):
http_resp = Dataflow(web, user, "http resp")
http_resp.responseTo = http_req

tm.check()
self.assertTrue(tm.check())

self.assertEqual(http_req.response, http_resp)
self.assertIs(http_resp.isResponse, True)
Expand All @@ -83,28 +83,32 @@ def test_defaults(self):
isEncrypted=False,
data="SQL resp",
)
worker = Process("Task queue worker")

req_get = Dataflow(user, server, "HTTP GET")
query = Dataflow(server, db, "Query", data="SQL")
server_query = Dataflow(server, db, "Query", data="SQL")
result = Dataflow(db, server, "Results", isResponse=True)
resp_get = Dataflow(server, user, "HTTP Response", isResponse=True)

req_post = Dataflow(user, server, "HTTP POST", data="JSON")
resp_post = Dataflow(server, user, "HTTP Response", isResponse=True)

tm.check()
worker_query = Dataflow(worker, db, "Query", data="SQL")
Dataflow(db, worker, "Results", isResponse=True)

self.assertTrue(tm.check())

self.assertEqual(req_get.srcPort, -1)
self.assertEqual(req_get.dstPort, server.port)
self.assertEqual(req_get.isEncrypted, server.isEncrypted)
self.assertEqual(req_get.protocol, server.protocol)
self.assertEqual(req_get.data, user.data)

self.assertEqual(query.srcPort, -1)
self.assertEqual(query.dstPort, db.port)
self.assertEqual(query.isEncrypted, db.isEncrypted)
self.assertEqual(query.protocol, db.protocol)
self.assertNotEqual(query.data, server.data)
self.assertEqual(server_query.srcPort, -1)
self.assertEqual(server_query.dstPort, db.port)
self.assertEqual(server_query.isEncrypted, db.isEncrypted)
self.assertEqual(server_query.protocol, db.protocol)
self.assertNotEqual(server_query.data, server.data)

self.assertEqual(result.srcPort, db.port)
self.assertEqual(result.dstPort, -1)
Expand All @@ -130,18 +134,27 @@ def test_defaults(self):
self.assertEqual(resp_post.protocol, server.protocol)
self.assertEqual(resp_post.data, server.data)

self.assertListEqual(server.inputs, [req_get, req_post])
self.assertListEqual(server.outputs, [server_query])
self.assertListEqual(worker.inputs, [])
self.assertListEqual(worker.outputs, [worker_query])


class TestMethod(unittest.TestCase):

def test_defaults(self):
tm = TM("my test tm", description="aa", isOrdered=True)

internet = Boundary("Internet")
cloud = Boundary("Cloud")

user = Actor("User", inBoundary=internet)
server = Server("Server")
db = Datastore("DB", inBoundary=cloud)
db = Datastore("DB", inBoundary=cloud, isSQL=True)
func = Datastore("Lambda function", inBoundary=cloud)

request = Dataflow(user, server, "request")
response = Dataflow(server, user, "response")
response = Dataflow(server, user, "response", isResponse=True)
user_query = Dataflow(user, db, "user query")
server_query = Dataflow(server, db, "server query")
func_query = Dataflow(func, db, "func query")
Expand All @@ -161,12 +174,21 @@ def test_defaults(self):
{"target": response, "condition": "target.enters(Boundary)"},
{"target": response, "condition": "not target.exits(Boundary)"},
{"target": user, "condition": "target.inside(Boundary)"},
{"target": func, "condition": "not any(target.inputs)"},
{
"target": server,
"condition": "any(f.sink.oneOf(Datastore) and f.sink.isSQL "
"for f in target.outputs)",
},
]

self.assertTrue(tm.check())

for case in testCases:
t = Threat(SID="", target=default_target, condition=case["condition"])
self.assertTrue(
t.apply(case["target"]),
"Failed to match {} against {}".format(
case["target"], case["condition"]
case["target"], case["condition"],
),
)