Skip to content

Commit 8fef3f6

Browse files
authored
feat(data-warehouse): New snowflake source (PostHog#29317)
1 parent 93320ba commit 8fef3f6

File tree

6 files changed

+161
-8
lines changed

6 files changed

+161
-8
lines changed

posthog/temporal/data_imports/pipelines/mysql/mysql.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def mysql_source(
186186
is_incremental: bool,
187187
logger: FilteringBoundLogger,
188188
db_incremental_field_last_value: Optional[Any],
189-
team_id: Optional[int] = None,
190189
incremental_field: Optional[str] = None,
191190
incremental_field_type: Optional[IncrementalFieldType] = None,
192191
) -> SourceResponse:
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from typing import Any, Optional
2+
from collections.abc import Iterator
3+
4+
from posthog.temporal.common.logger import FilteringBoundLogger
5+
from posthog.temporal.data_imports.pipelines.helpers import incremental_type_to_initial_value
6+
from posthog.temporal.data_imports.pipelines.pipeline.typings import SourceResponse
7+
from posthog.warehouse.types import IncrementalFieldType
8+
from cryptography.hazmat.backends import default_backend
9+
from cryptography.hazmat.primitives import serialization
10+
from dlt.common.normalizers.naming.snake_case import NamingConvention
11+
import snowflake.connector
12+
from snowflake.connector.cursor import SnowflakeCursor
13+
14+
15+
def _get_connection(
16+
account_id: str,
17+
user: Optional[str],
18+
password: Optional[str],
19+
passphrase: Optional[str],
20+
private_key: Optional[str],
21+
auth_type: str,
22+
database: str,
23+
warehouse: str,
24+
schema: str,
25+
role: Optional[str] = None,
26+
) -> snowflake.connector.SnowflakeConnection:
27+
if auth_type == "password" and user is not None and password is not None:
28+
return snowflake.connector.connect(
29+
account=account_id,
30+
user=user,
31+
password=password,
32+
warehouse=warehouse,
33+
database=database,
34+
schema=schema,
35+
role=role if role else None,
36+
)
37+
38+
if private_key is None:
39+
raise ValueError("Private key is missing for snowflake")
40+
41+
p_key = serialization.load_pem_private_key(
42+
private_key.encode("utf-8"),
43+
password=passphrase.encode() if passphrase is not None else None,
44+
backend=default_backend(),
45+
)
46+
47+
pkb = p_key.private_bytes(
48+
encoding=serialization.Encoding.DER,
49+
format=serialization.PrivateFormat.PKCS8,
50+
encryption_algorithm=serialization.NoEncryption(),
51+
)
52+
53+
return snowflake.connector.connect(
54+
account=account_id,
55+
user=user,
56+
warehouse=warehouse,
57+
database=database,
58+
schema=schema,
59+
role=role if role else None,
60+
private_key=pkb,
61+
)
62+
63+
64+
def _build_query(
65+
database: str,
66+
schema: str,
67+
table_name: str,
68+
is_incremental: bool,
69+
incremental_field: Optional[str],
70+
incremental_field_type: Optional[IncrementalFieldType],
71+
db_incremental_field_last_value: Optional[Any],
72+
) -> tuple[str, tuple[Any, ...]]:
73+
if not is_incremental:
74+
return "SELECT * FROM IDENTIFIER(%s)", (f"{database}.{schema}.{table_name}",)
75+
76+
if incremental_field is None or incremental_field_type is None:
77+
raise ValueError("incremental_field and incremental_field_type can't be None")
78+
79+
if db_incremental_field_last_value is None:
80+
db_incremental_field_last_value = incremental_type_to_initial_value(incremental_field_type)
81+
82+
return "SELECT * FROM IDENTIFIER(%s) WHERE IDENTIFIER(%s) >= %s ORDER BY IDENTIFIER(%s) ASC", (
83+
f"{database}.{schema}.{table_name}",
84+
incremental_field,
85+
db_incremental_field_last_value,
86+
incremental_field,
87+
)
88+
89+
90+
def _get_primary_keys(cursor: SnowflakeCursor, database: str, schema: str, table_name: str) -> list[str] | None:
91+
cursor.execute("SHOW PRIMARY KEYS IN IDENTIFIER(%s)", (f"{database}.{schema}.{table_name}",))
92+
93+
column_index = next((i for i, row in enumerate(cursor.description) if row.name == "column_name"), -1)
94+
95+
if column_index == -1:
96+
raise ValueError("column_name not found in Snowflake cursor description")
97+
98+
keys = [row[column_index] for row in cursor]
99+
100+
return keys if len(keys) > 0 else None
101+
102+
103+
def snowflake_source(
104+
account_id: str,
105+
user: Optional[str],
106+
password: Optional[str],
107+
passphrase: Optional[str],
108+
private_key: Optional[str],
109+
auth_type: str,
110+
database: str,
111+
warehouse: str,
112+
schema: str,
113+
table_names: list[str],
114+
is_incremental: bool,
115+
logger: FilteringBoundLogger,
116+
db_incremental_field_last_value: Optional[Any],
117+
role: Optional[str] = None,
118+
incremental_field: Optional[str] = None,
119+
incremental_field_type: Optional[IncrementalFieldType] = None,
120+
) -> SourceResponse:
121+
table_name = table_names[0]
122+
if not table_name:
123+
raise ValueError("Table name is missing")
124+
125+
with _get_connection(
126+
account_id, user, password, passphrase, private_key, auth_type, database, warehouse, schema, role
127+
) as connection:
128+
with connection.cursor() as cursor:
129+
primary_keys = _get_primary_keys(cursor, database, schema, table_name)
130+
131+
def get_rows() -> Iterator[Any]:
132+
with _get_connection(
133+
account_id, user, password, passphrase, private_key, auth_type, database, warehouse, schema, role
134+
) as connection:
135+
with connection.cursor() as cursor:
136+
query, params = _build_query(
137+
database,
138+
schema,
139+
table_name,
140+
is_incremental,
141+
incremental_field,
142+
incremental_field_type,
143+
db_incremental_field_last_value,
144+
)
145+
logger.debug(f"Snowflake query: {query.format(params)}")
146+
cursor.execute(query, params)
147+
148+
# We cant control the batch size from snowflake when using the arrow function
149+
# https://github.com/snowflakedb/snowflake-connector-python/issues/1712
150+
yield from cursor.fetch_arrow_batches()
151+
152+
name = NamingConvention().normalize_identifier(table_name)
153+
154+
return SourceResponse(name=name, items=get_rows(), primary_keys=primary_keys)

posthog/temporal/data_imports/workflow_activities/import_data_sync.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs):
260260
db_incremental_field_last_value=processed_incremental_last_value
261261
if schema.is_incremental
262262
else None,
263-
team_id=inputs.team_id,
264263
)
265264
else:
266265
source = sql_source_for_type(
@@ -335,7 +334,6 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs):
335334
if schema.is_incremental
336335
else None,
337336
db_incremental_field_last_value=processed_incremental_last_value if schema.is_incremental else None,
338-
team_id=inputs.team_id,
339337
)
340338
else:
341339
source = sql_source_for_type(
@@ -368,7 +366,7 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs):
368366
reset_pipeline=reset_pipeline,
369367
)
370368
elif model.pipeline.source_type == ExternalDataSource.Type.SNOWFLAKE:
371-
from posthog.temporal.data_imports.pipelines.sql_database import (
369+
from posthog.temporal.data_imports.pipelines.snowflake.snowflake import (
372370
snowflake_source,
373371
)
374372

@@ -396,6 +394,8 @@ def import_data_activity_sync(inputs: ImportDataActivityInputs):
396394
warehouse=warehouse,
397395
role=role,
398396
table_names=endpoints,
397+
logger=logger,
398+
is_incremental=schema.is_incremental,
399399
incremental_field=schema.sync_type_config.get("incremental_field") if schema.is_incremental else None,
400400
incremental_field_type=schema.sync_type_config.get("incremental_field_type")
401401
if schema.is_incremental

posthog/temporal/tests/data_imports/test_end_to_end.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,7 @@ async def test_delta_no_merging_on_first_sync_after_reset(team, postgres_config,
13901390
first_call_args, first_call_kwargs = mock_write.call_args_list[0]
13911391
second_call_args, second_call_kwargs = mock_write.call_args_list[1]
13921392

1393-
# The first call should be an append
1393+
# The first call should be an overwrite
13941394
assert first_call_kwargs == {
13951395
"mode": "overwrite",
13961396
"schema_mode": "overwrite",
@@ -1400,7 +1400,7 @@ async def test_delta_no_merging_on_first_sync_after_reset(team, postgres_config,
14001400
"engine": "rust",
14011401
}
14021402

1403-
# The last call should be an append
1403+
# The subsequent call should be an append
14041404
assert second_call_kwargs == {
14051405
"mode": "append",
14061406
"schema_mode": "merge",

requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ semantic_version==2.8.5
9393
simple-salesforce>=1.12.6
9494
scikit-learn==1.5.0
9595
slack_sdk==3.17.1
96-
snowflake-connector-python==3.6.0
96+
snowflake-connector-python==3.13.2
9797
snowflake-sqlalchemy==1.7.3
9898
social-auth-app-django==5.0.0
9999
social-auth-core==4.3.0

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ sniffio==1.3.1
798798
# httpx
799799
# openai
800800
# trio
801-
snowflake-connector-python==3.6.0
801+
snowflake-connector-python==3.13.2
802802
# via
803803
# -r requirements.in
804804
# snowflake-sqlalchemy

0 commit comments

Comments
 (0)