Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion examples/example_contextvars.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,52 @@
"""Пример перенаправления заголовков запроса в GigaChat"""
# pip install python-dotenv
import asyncio

import gigachat.context
from gigachat import GigaChat

ACCESS_TOKEN = ...

headers = {
"Authorization": f"Bearer {ACCESS_TOKEN}",
# for logging
"X-Session-ID": "8324244b-7133-4d30-a328-31d8466e5502",
"X-Request-ID": "8324244b-7133-4d30-a328-31d8466e5502",
"X-Service-ID": "my_custom_service",
"X-Operation-ID": "my_custom_qna",
"X-Trace-ID": "trace-id-1234567890",
"X-Agent-ID": "agent-id-1234567890",
}

# Установка переменных для клиента
with GigaChat(verify_ssl_certs=False) as giga:
gigachat.context.authorization_cvar.set(headers.get("Authorization"))
gigachat.context.session_id_cvar.set(headers.get("X-Session-ID"))
gigachat.context.request_id_cvar.set(headers.get("X-Request-ID"))
gigachat.context.service_id_cvar.set(headers.get("X-Service-ID"))
gigachat.context.operation_id_cvar.set(headers.get("X-Operation-ID"))
gigachat.context.trace_id_cvar.set(headers.get("X-Trace-ID"))
gigachat.context.agent_id_cvar.set(headers.get("X-Agent-ID"))
gigachat.context.custom_headers_cvar.set({"X-Custom-Header": "CustomValue"})

response = giga.chat("Какие факторы влияют на стоимость страховки на дом?")
print(response.choices[0].message.content)


# Установка переменных для каждого вызова (только async)
async def ask_with_headers(giga, some_custom_header, prompt):
gigachat.context.custom_headers_cvar.set({"X-Custom-Header": some_custom_header})

response = await giga.achat(prompt)
print(response.choices[0].message.content)


async def async_main():
async with GigaChat(verify_ssl_certs=False) as giga:
await asyncio.gather(
ask_with_headers(giga, "CustomValue 1", "Кто тебя сделал?"),
ask_with_headers(giga, "CustomValue 2", "Как тебя зовут?"),
)


if __name__ == "__main__":
asyncio.run(async_main())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gigachat"
version = "0.1.39post2" # Includes fix related to h11 upgrade
version = "0.1.40"
description = "GigaChat. Python-library for GigaChain and LangChain"
authors = ["Konstantin Krestnikov <[email protected]>", "Sergey Malyshev <[email protected]>"]
license = "MIT"
Expand Down
41 changes: 22 additions & 19 deletions src/gigachat/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import httpx

from gigachat.context import (
agent_id_cvar,
authorization_cvar,
client_id_cvar,
custom_headers_cvar,
operation_id_cvar,
request_id_cvar,
service_id_cvar,
session_id_cvar,
trace_id_cvar,
)
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.pydantic_v1 import BaseModel
Expand All @@ -28,25 +31,25 @@ def build_headers(access_token: Optional[str] = None) -> Dict[str, str]:

headers["User-Agent"] = USER_AGENT

authorization = authorization_cvar.get()
session_id = session_id_cvar.get()
request_id = request_id_cvar.get()
service_id = service_id_cvar.get()
operation_id = operation_id_cvar.get()
client_id = client_id_cvar.get()

if authorization:
headers["Authorization"] = authorization
if session_id:
headers["X-Session-ID"] = session_id
if request_id:
headers["X-Request-ID"] = request_id
if service_id:
headers["X-Service-ID"] = service_id
if operation_id:
headers["X-Operation-ID"] = operation_id
if client_id:
headers["X-Client-ID"] = client_id
context_vars = {
"Authorization": authorization_cvar,
"X-Session-ID": session_id_cvar,
"X-Request-ID": request_id_cvar,
"X-Service-ID": service_id_cvar,
"X-Operation-ID": operation_id_cvar,
"X-Client-ID": client_id_cvar,
"X-Trace-ID": trace_id_cvar,
"X-Agent-ID": agent_id_cvar,
}

for header, cvar in context_vars.items():
value = cvar.get()
if value:
headers[header] = value

custom_headers = custom_headers_cvar.get()
if custom_headers:
headers.update(custom_headers)
return headers


Expand Down
10 changes: 9 additions & 1 deletion src/gigachat/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextvars import ContextVar
from typing import Optional
from typing import Dict, Optional

authorization_cvar: ContextVar[Optional[str]] = ContextVar("authorization_cvar", default=None)
"""Информация об авторизации с помощью JWE"""
Expand All @@ -10,4 +10,12 @@
session_id_cvar: ContextVar[Optional[str]] = ContextVar("session_id_cvar", default=None)
"""Уникальный ID сессии"""
service_id_cvar: ContextVar[Optional[str]] = ContextVar("service_id_cvar", default=None)
"""Уникальный ID сервиса"""
operation_id_cvar: ContextVar[Optional[str]] = ContextVar("operation_id_cvar", default=None)
"""Информация об авторизации с помощью JWE"""
trace_id_cvar: ContextVar[Optional[str]] = ContextVar("trace_id_cvar", default=None)
"""Уникальный ID экземпляра процесса (основной операции)"""
agent_id_cvar: ContextVar[Optional[str]] = ContextVar("agent_id_cvar", default=None)
"""Уникальный ID агента"""
custom_headers_cvar: ContextVar[Optional[Dict[str, str]]] = ContextVar("custom_headers_cvar", default=None)
"""Дополнительные HTTP-заголовки, которые будут добавлены к запросу"""
17 changes: 16 additions & 1 deletion tests/unit_tests/gigachat/api/test_get_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from pytest_httpx import HTTPXMock

from gigachat.api import get_model
from gigachat.context import authorization_cvar, operation_id_cvar, request_id_cvar, service_id_cvar, session_id_cvar
from gigachat.context import (
agent_id_cvar,
authorization_cvar,
custom_headers_cvar,
operation_id_cvar,
request_id_cvar,
service_id_cvar,
session_id_cvar,
trace_id_cvar,
)
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.models import Model

Expand All @@ -21,6 +30,9 @@ def test__kwargs_context_vars() -> None:
token_session_id_cvar = session_id_cvar.set("session_id_cvar")
token_service_id_cvar = service_id_cvar.set("service_id_cvar")
token_operation_id_cvar = operation_id_cvar.set("operation_id_cvar")
token_trace_id_cvar = trace_id_cvar.set("trace_id_cvar")
token_agent_id_cvar = agent_id_cvar.set("agent_id_cvar")
token_custom_headers_cvar = custom_headers_cvar.set({"custom_headers_cvar": "val"})

assert get_model._get_kwargs(model="model")

Expand All @@ -29,6 +41,9 @@ def test__kwargs_context_vars() -> None:
session_id_cvar.reset(token_session_id_cvar)
service_id_cvar.reset(token_service_id_cvar)
operation_id_cvar.reset(token_operation_id_cvar)
trace_id_cvar.reset(token_trace_id_cvar)
agent_id_cvar.reset(token_agent_id_cvar)
custom_headers_cvar.reset(token_custom_headers_cvar)


def test_sync(httpx_mock: HTTPXMock) -> None:
Expand Down
17 changes: 16 additions & 1 deletion tests/unit_tests/gigachat/api/test_get_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
from pytest_httpx import HTTPXMock

from gigachat.api import get_models
from gigachat.context import authorization_cvar, operation_id_cvar, request_id_cvar, service_id_cvar, session_id_cvar
from gigachat.context import (
agent_id_cvar,
authorization_cvar,
custom_headers_cvar,
operation_id_cvar,
request_id_cvar,
service_id_cvar,
session_id_cvar,
trace_id_cvar,
)
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.models import Models

Expand All @@ -21,6 +30,9 @@ def test__kwargs_context_vars() -> None:
token_session_id_cvar = session_id_cvar.set("session_id_cvar")
token_service_id_cvar = service_id_cvar.set("service_id_cvar")
token_operation_id_cvar = operation_id_cvar.set("operation_id_cvar")
token_trace_id_cvar = trace_id_cvar.set("trace_id_cvar")
token_agent_id_cvar = agent_id_cvar.set("agent_id_cvar")
token_custom_headers_cvar = custom_headers_cvar.set({"custom_headers_cvar": "val"})

assert get_models._get_kwargs()

Expand All @@ -29,6 +41,9 @@ def test__kwargs_context_vars() -> None:
session_id_cvar.reset(token_session_id_cvar)
service_id_cvar.reset(token_service_id_cvar)
operation_id_cvar.reset(token_operation_id_cvar)
trace_id_cvar.reset(token_trace_id_cvar)
agent_id_cvar.reset(token_agent_id_cvar)
custom_headers_cvar.reset(token_custom_headers_cvar)


def test_sync(httpx_mock: HTTPXMock) -> None:
Expand Down
65 changes: 64 additions & 1 deletion tests/unit_tests/gigachat/api/test_post_chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import asyncio

import httpx
import pytest
from pytest_httpx import HTTPXMock

from gigachat.api import post_chat
from gigachat.context import authorization_cvar, operation_id_cvar, request_id_cvar, service_id_cvar, session_id_cvar
from gigachat.api.utils import USER_AGENT
from gigachat.context import (
agent_id_cvar,
authorization_cvar,
custom_headers_cvar,
operation_id_cvar,
request_id_cvar,
service_id_cvar,
session_id_cvar,
trace_id_cvar,
)
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.models import Chat, ChatCompletion

Expand All @@ -15,13 +27,18 @@
CHAT = Chat.parse_obj(get_json("chat.json"))
CHAT_COMPLETION = get_json("chat_completion.json")

X_CUSTOM_HEADER = "X-Custom-Header"


def test__kwargs_context_vars() -> None:
token_authorization_cvar = authorization_cvar.set("authorization_cvar")
token_request_id_cvar = request_id_cvar.set("request_id_cvar")
token_session_id_cvar = session_id_cvar.set("session_id_cvar")
token_service_id_cvar = service_id_cvar.set("service_id_cvar")
token_operation_id_cvar = operation_id_cvar.set("operation_id_cvar")
token_trace_id_cvar = trace_id_cvar.set("trace_id_cvar")
token_agent_id_cvar = agent_id_cvar.set("agent_id_cvar")
token_custom_headers_cvar = custom_headers_cvar.set({"custom_headers_cvar": "val"})

assert post_chat._get_kwargs(chat=Chat(messages=[]))

Expand All @@ -30,6 +47,9 @@ def test__kwargs_context_vars() -> None:
session_id_cvar.reset(token_session_id_cvar)
service_id_cvar.reset(token_service_id_cvar)
operation_id_cvar.reset(token_operation_id_cvar)
trace_id_cvar.reset(token_trace_id_cvar)
agent_id_cvar.reset(token_agent_id_cvar)
custom_headers_cvar.reset(token_custom_headers_cvar)


def test_sync(httpx_mock: HTTPXMock) -> None:
Expand Down Expand Up @@ -86,3 +106,46 @@ async def test_asyncio(httpx_mock: HTTPXMock) -> None:
response = await post_chat.asyncio(client, chat=CHAT)

assert isinstance(response, ChatCompletion)


def test_headers_in_request(httpx_mock: HTTPXMock) -> None:
httpx_mock.add_response(url=MOCK_URL, json=CHAT_COMPLETION)
token_custom_headers_cvar = custom_headers_cvar.set({"X-Custom-Header": "CustomValue"})

with httpx.Client(base_url=BASE_URL) as client:
post_chat.sync(client, chat=CHAT)

headers = httpx_mock.get_requests()[0].headers
assert headers["User-Agent"] == USER_AGENT
assert headers[X_CUSTOM_HEADER] == "CustomValue"

custom_headers_cvar.reset(token_custom_headers_cvar)


@pytest.mark.asyncio()
async def test_headers_in_async_request(httpx_mock: HTTPXMock) -> None:
async def call_with_headers(client: httpx.AsyncClient, headers: dict) -> None:
token_custom_headers_cvar = custom_headers_cvar.set(headers)
await post_chat.asyncio(client, chat=CHAT)
custom_headers_cvar.reset(token_custom_headers_cvar)

httpx_mock.add_response(url=MOCK_URL, json=CHAT_COMPLETION)
httpx_mock.add_response(url=MOCK_URL, json=CHAT_COMPLETION)

async with httpx.AsyncClient(base_url=BASE_URL) as client:
await asyncio.gather(
call_with_headers(client, {X_CUSTOM_HEADER: "CustomValue1"}),
call_with_headers(client, {X_CUSTOM_HEADER: "CustomValue2"}),
call_with_headers(client, {X_CUSTOM_HEADER: "CustomValue3"}),
)

# Verify that headers are not mixed up between concurrent requests
requests = httpx_mock.get_requests()
assert len(requests) == 3

# Extract all header values
header_values = [req.headers[X_CUSTOM_HEADER] for req in requests]
expected_values = ["CustomValue1", "CustomValue2", "CustomValue3"]

# Check that each expected value appears exactly once
assert sorted(header_values) == sorted(expected_values)
16 changes: 15 additions & 1 deletion tests/unit_tests/gigachat/api/test_post_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from pytest_httpx import HTTPXMock

from gigachat.api import post_token
from gigachat.context import operation_id_cvar, request_id_cvar, service_id_cvar, session_id_cvar
from gigachat.context import (
agent_id_cvar,
custom_headers_cvar,
operation_id_cvar,
request_id_cvar,
service_id_cvar,
session_id_cvar,
trace_id_cvar,
)
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.models import Token

Expand All @@ -20,13 +28,19 @@ def test__kwargs_context_vars() -> None:
token_session_id_cvar = session_id_cvar.set("session_id_cvar")
token_service_id_cvar = service_id_cvar.set("service_id_cvar")
token_operation_id_cvar = operation_id_cvar.set("operation_id_cvar")
token_trace_id_cvar = trace_id_cvar.set("trace_id_cvar")
token_agent_id_cvar = agent_id_cvar.set("agent_id_cvar")
token_custom_headers_cvar = custom_headers_cvar.set({"custom_headers_cvar": "val"})

assert post_token._get_kwargs(user="user", password="password")

request_id_cvar.reset(token_request_id_cvar)
session_id_cvar.reset(token_session_id_cvar)
service_id_cvar.reset(token_service_id_cvar)
operation_id_cvar.reset(token_operation_id_cvar)
trace_id_cvar.reset(token_trace_id_cvar)
agent_id_cvar.reset(token_agent_id_cvar)
custom_headers_cvar.reset(token_custom_headers_cvar)


def test_sync(httpx_mock: HTTPXMock) -> None:
Expand Down
17 changes: 16 additions & 1 deletion tests/unit_tests/gigachat/api/test_stream_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from pytest_httpx import HTTPXMock

from gigachat.api import stream_chat
from gigachat.context import authorization_cvar, operation_id_cvar, request_id_cvar, service_id_cvar, session_id_cvar
from gigachat.context import (
agent_id_cvar,
authorization_cvar,
custom_headers_cvar,
operation_id_cvar,
request_id_cvar,
service_id_cvar,
session_id_cvar,
trace_id_cvar,
)
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.models import Chat, ChatCompletionChunk

Expand All @@ -27,6 +36,9 @@ def test__kwargs_context_vars() -> None:
token_session_id_cvar = session_id_cvar.set("session_id_cvar")
token_service_id_cvar = service_id_cvar.set("service_id_cvar")
token_operation_id_cvar = operation_id_cvar.set("operation_id_cvar")
token_trace_id_cvar = trace_id_cvar.set("trace_id_cvar")
token_agent_id_cvar = agent_id_cvar.set("agent_id_cvar")
token_custom_headers_cvar = custom_headers_cvar.set({"custom_headers_cvar": "val"})

assert stream_chat._get_kwargs(chat=Chat(messages=[]))

Expand All @@ -35,6 +47,9 @@ def test__kwargs_context_vars() -> None:
session_id_cvar.reset(token_session_id_cvar)
service_id_cvar.reset(token_service_id_cvar)
operation_id_cvar.reset(token_operation_id_cvar)
trace_id_cvar.reset(token_trace_id_cvar)
agent_id_cvar.reset(token_agent_id_cvar)
custom_headers_cvar.reset(token_custom_headers_cvar)


def test_sync(httpx_mock: HTTPXMock) -> None:
Expand Down