Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add unit tests for prompts
Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 committed Jan 21, 2025
commit b73c1624951f39799d20548df332d37cbeb915b1
238 changes: 238 additions & 0 deletions lib/sycamore/sycamore/tests/unit/llms/prompts/test_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from sycamore.data.element import Element
from sycamore.llms.prompts.prompts import (
RenderedPrompt,
RenderedMessage,
StaticPrompt,
SycamorePrompt,
ElementPrompt,
ElementListPrompt,
)
from sycamore.data import Document
from sycamore.tests.config import TEST_DIR
from pyarrow.fs import LocalFileSystem
import pytest


@pytest.fixture(scope="module")
def dummy_document():
docpath = TEST_DIR / "resources/data/pdfs/ntsb-report.pdf"
local = LocalFileSystem()
path = str(docpath)
input_stream = local.open_input_stream(path)
document = Document()
document.binary_representation = input_stream.readall()
document.type = "pdf"
document.properties["path"] = path
document.properties["pages"] = 6
document.elements = [
Element(
text_representation="Element 1",
type="Text",
element_id="e1",
properties={"page_number": 1},
bbox=(0.1, 0.1, 0.4, 0.4),
),
Element(
text_representation="Element 2",
type="Text",
element_id="e2",
properties={"page_number": 2},
bbox=(0.1, 0.1, 0.4, 0.4),
),
Element(
text_representation="Element 3",
type="Text",
element_id="e3",
properties={"page_number": 3},
bbox=(0.1, 0.1, 0.4, 0.4),
),
Element(
text_representation="Element 4",
type="Text",
element_id="e4",
properties={"page_number": 3},
bbox=(0.4, 0.1, 0.8, 0.4),
),
Element(
text_representation="Element 5",
type="Text",
element_id="e5",
properties={"page_number": 3},
bbox=(0.1, 0.4, 0.8, 0.8),
),
Element(
text_representation="Element 6",
type="Text",
element_id="e6",
properties={"page_number": 4},
bbox=(0.1, 0.1, 0.4, 0.4),
),
]
return document


class TestRenderedPrompt:
"""RenderedPrompt and RenderedMessage are dataclasses,
no need to test them. Nothing to test :)"""

pass


class TestSycamorePrompt:
def test_instead_is_cow(self):
sp = SycamorePrompt()
sp.__dict__["key"] = "value"
sp2 = sp.instead(key="other value")
assert sp.key == "value"
assert sp2.key == "other value"


class TestStaticPrompt:
def test_static_rd(self, dummy_document):
prompt = StaticPrompt(system="system {x}", user="computers")
with pytest.raises(KeyError):
prompt.render_document(dummy_document)

prompt = prompt.instead(x=76)
expected = RenderedPrompt(
messages=[
RenderedMessage(role="system", content="system 76"),
RenderedMessage(role="user", content="computers"),
]
)
assert prompt.render_document(dummy_document) == expected
assert prompt.render_element(dummy_document.elements[0], dummy_document) == expected
assert prompt.render_multiple_documents([dummy_document]) == expected


class TestElementPrompt:
def test_basic(self, dummy_document):
prompt = ElementPrompt(
system="You know everything there is to know about jazz, {name}",
user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}",
name="Frank Sinatra",
)
expected = RenderedPrompt(
messages=[
RenderedMessage(
role="system", content="You know everything there is to know about jazz, Frank Sinatra"
),
RenderedMessage(role="user", content="Summarize the information on page 3.\nTEXT: Element 4"),
]
)
assert prompt.render_element(dummy_document.elements[3], dummy_document) == expected
with pytest.raises(NotImplementedError):
prompt.render_document(dummy_document)
with pytest.raises(NotImplementedError):
prompt.render_multiple_documents([dummy_document])

def test_get_parent_context(self, dummy_document):
prompt = ElementPrompt(
system="You know everything there is to know about {custom_property}, {name}",
user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}",
name="Frank Sinatra",
capture_parent_context=lambda doc, elt: {"custom_property": doc.properties["pages"]},
)
expected = RenderedPrompt(
messages=[
RenderedMessage(role="system", content="You know everything there is to know about 6, Frank Sinatra"),
RenderedMessage(role="user", content="Summarize the information on page 3.\nTEXT: Element 4"),
]
)
assert prompt.render_element(dummy_document.elements[3], dummy_document) == expected

def test_include_image(self, dummy_document):
prompt = ElementPrompt(
system="You know everything there is to know about {custom_property}, {name}",
user="Summarize the information on page {elt_property_page_number}.\nTEXT: {elt_text}",
name="Frank Sinatra",
capture_parent_context=lambda doc, elt: {"custom_property": doc.properties["pages"]},
include_element_image=True,
)
rp = prompt.render_element(dummy_document.elements[3], dummy_document)
assert rp.messages[1].images is not None and len(rp.messages[1].images) == 1
assert rp.messages[1].role == "user"
assert rp.messages[0].images is None

prompt = prompt.instead(user=None)
rp2 = prompt.render_element(dummy_document.elements[1], dummy_document)
assert len(rp2.messages) == 1
assert rp2.messages[0].role == "system"
assert rp2.messages[0].images is not None
assert len(rp2.messages[0].images) == 1


class TestElementListPrompt:
def test_basic(self, dummy_document):
prompt = ElementListPrompt(system="sys", user="usr: {elements}")
expected = RenderedPrompt(
messages=[
RenderedMessage(role="system", content="sys"),
RenderedMessage(
role="user",
content="usr: ELEMENT 0: Element 1\nELEMENT 1: Element 2\n"
"ELEMENT 2: Element 3\nELEMENT 3: Element 4\nELEMENT 4: Element 5\nELEMENT 5: Element 6",
),
]
)
assert prompt.render_document(dummy_document) == expected

def test_limit_elements(self, dummy_document):
prompt = ElementListPrompt(system="sys", user="usr: {elements}", num_elements=3)
expected = RenderedPrompt(
messages=[
RenderedMessage(role="system", content="sys"),
RenderedMessage(
role="user",
content="usr: ELEMENT 0: Element 1\nELEMENT 1: Element 2\nELEMENT 2: Element 3",
),
]
)
assert prompt.render_document(dummy_document) == expected

def test_select_odd_elements(self, dummy_document):
prompt = ElementListPrompt(
system="sys",
user="usr: {elements}",
element_select=lambda elts: [elts[i] for i in range(len(elts)) if i % 2 == 1],
)
expected = RenderedPrompt(
messages=[
RenderedMessage(role="system", content="sys"),
RenderedMessage(
role="user",
content="usr: ELEMENT 0: Element 2\nELEMENT 1: Element 4\nELEMENT 2: Element 6",
),
]
)
assert prompt.render_document(dummy_document) == expected

def test_order_elements(self, dummy_document):
prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_order=lambda e: list(reversed(e)))
expected = RenderedPrompt(
messages=[
RenderedMessage(role="system", content="sys"),
RenderedMessage(
role="user",
content="usr: ELEMENT 0: Element 6\nELEMENT 1: Element 5\n"
"ELEMENT 2: Element 4\nELEMENT 3: Element 3\nELEMENT 4: Element 2\nELEMENT 5: Element 1",
),
]
)
assert prompt.render_document(dummy_document) == expected

def test_construct_element_list(self, dummy_document):
def list_constructor(elts: list[Element]) -> str:
return "<>" + "</><>".join(f"{i}-{e.type}" for i, e in enumerate(elts)) + "</>"

prompt = ElementListPrompt(system="sys", user="usr: {elements}", element_list_constructor=list_constructor)
expected = RenderedPrompt(
messages=[
RenderedMessage(role="system", content="sys"),
RenderedMessage(
role="user",
content="usr: <>0-Text</><>1-Text</><>2-Text</><>3-Text</><>4-Text</><>5-Text</>",
),
]
)
assert prompt.render_document(dummy_document) == expected
18 changes: 17 additions & 1 deletion lib/sycamore/sycamore/utils/pdf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from PIL import Image

from pypdf import PdfReader, PdfWriter
import pdf2image

from sycamore import DocSet
from sycamore.functions.document import DrawBoxes, split_and_convert_to_image
from sycamore.utils.image_utils import show_images
from sycamore.utils.image_utils import show_images, crop_to_bbox
from sycamore.data import Document, Element
import json

Expand Down Expand Up @@ -180,3 +181,18 @@ def promote_title(elements: list[Element], title_candidate_elements=["Section-he
if section_header:
section_header.type = "Title"
return elements


def get_element_image(element: Element, document: Document) -> Image.Image:
assert document.type == "pdf", "Cannot get picture of element from non-pdf"
assert document.binary_representation is not None, "Cannot get image since there is not binary representation"
assert element.bbox is not None, "Cannot get picture of element if it has no BBox"
assert element.properties.get("page_number") is not None and isinstance(
element.properties["page_number"], int
), "Cannot get picture of element without known page number"
bits = BytesIO(document.binary_representation)
pagebits = BytesIO()
select_pdf_pages(bits, pagebits, [element.properties["page_number"]])
images = pdf2image.convert_from_bytes(pagebits.getvalue())
im = crop_to_bbox(images[0], element.bbox)
return im
Loading