Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c2a8cfa
add prompt base classes and ElementListPrompt
HenryL27 Jan 17, 2025
21a115a
override .instead in ElementListPrompt to store net-new keys in self.…
HenryL27 Jan 17, 2025
f94da80
add ElementPrompt and StaticPrompt
HenryL27 Jan 17, 2025
b73c162
add unit tests for prompts
HenryL27 Jan 21, 2025
17b2163
forgot to commit this
HenryL27 Jan 21, 2025
5d145d5
address pr comments; flatten properties with flatten_data
HenryL27 Jan 21, 2025
7fa2ff1
support multiple user prompts
HenryL27 Jan 21, 2025
abf9b0b
rename instead to set
HenryL27 Jan 22, 2025
9909c7e
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-llm-unify
HenryL27 Jan 22, 2025
2d1315b
add LLMMap and LLMMapElements transforms
HenryL27 Jan 22, 2025
1853d51
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-llm-unify
HenryL27 Jan 22, 2025
5e86e56
move llm implementations to use RenderedPrompts
HenryL27 Jan 22, 2025
27581ef
also this guy
HenryL27 Jan 22, 2025
739b672
add docset methods
HenryL27 Jan 23, 2025
73d9bdd
docstrings
HenryL27 Jan 23, 2025
ed8785e
add llm_map unit tests
HenryL27 Jan 23, 2025
523d6e3
fix bedrock tests and chaching
HenryL27 Jan 23, 2025
e1b3206
fix anthropic and bedrock ITs
HenryL27 Jan 23, 2025
6500e1c
adjust caching to handle pydantic class response format properly
HenryL27 Jan 23, 2025
f50032d
fix base llm unit tests
HenryL27 Jan 23, 2025
c3c7ea8
adjust all testing mock llms to updated llm interface
HenryL27 Jan 23, 2025
ffaaf0f
deprecate extract entity and implement it with llm_map
HenryL27 Jan 24, 2025
d71cf1a
add context_params decorator to llm_map
HenryL27 Jan 24, 2025
4225e11
revert extract_entity docset method re-implementation
HenryL27 Jan 24, 2025
0d39b27
add initial support for prompts that generate a sequence of rendered …
HenryL27 Jan 25, 2025
0b5ded4
add stuff to EntityExtractor/OpenAIEntityExtractor to convert to LLMMap
HenryL27 Jan 25, 2025
a52f7c2
make docset.extract_entity construct an LLMMap from its entity_extractor
HenryL27 Jan 25, 2025
3a9ac3c
get extract entity working with tokenizer and token limit
HenryL27 Jan 28, 2025
befc3d0
get all extract_entity unit tests passing
HenryL27 Jan 28, 2025
8bf42d5
fix llm_map_elements to deal with postprocess index
HenryL27 Jan 28, 2025
d7ff1eb
add postprocess_fn unit tests for llm_map
HenryL27 Jan 28, 2025
a7a2cc0
ruff complaint
HenryL27 Jan 28, 2025
ebf721e
fix docset unittests
HenryL27 Jan 28, 2025
0bd2a45
move a bunch of stuff back to llm.generate_old. This includes the act…
HenryL27 Jan 28, 2025
95cbaaf
move more stuff back to llm.generate_old
HenryL27 Jan 28, 2025
ea7f0e6
fix the last few mocks
HenryL27 Jan 28, 2025
2e51ee1
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-llm-unify
HenryL27 Jan 28, 2025
57a4e4b
ruff linelength
HenryL27 Jan 28, 2025
a312ba3
mypy!!!
HenryL27 Jan 28, 2025
ebde879
type: ignore + line length is tricky
HenryL27 Jan 28, 2025
ff5efdc
fix generate_old with SimplePrompts
HenryL27 Jan 28, 2025
370e2b7
set openai system role name to system instead of developer like their…
HenryL27 Jan 28, 2025
98ce6a0
address simple pr comments
HenryL27 Jan 29, 2025
1789409
pickle stuff in llm caching path bc not everything is jsonifiable
HenryL27 Jan 30, 2025
8b6f085
rewrite llm_map to deal with iterative prompting better
HenryL27 Jan 30, 2025
763acc5
add a b64encode-to-str to cache bc you can't put bytes in json either
HenryL27 Jan 30, 2025
0331866
fix llm its to mimic the _llm_cache_set/get pickle/unpickle operations
HenryL27 Jan 30, 2025
dfb7540
fix docstrings
HenryL27 Jan 30, 2025
f7c06e7
oops bad type signature
HenryL27 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rewrite llm_map to deal with iterative prompting better
Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 committed Jan 30, 2025
commit 8b6f085eb14e31e7a4bef675d46cf3b15148dab1
51 changes: 24 additions & 27 deletions lib/sycamore/sycamore/llms/prompts/prompts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Union, Optional, Callable, Sequence
from typing import Any, Union, Optional, Callable
import copy

import pydantic
Expand Down Expand Up @@ -42,7 +42,7 @@ class SycamorePrompt:
convert sycamore objects (``Document``s, ``Element``s) into ``RenderedPrompts``
"""

def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]:
def render_document(self, doc: Document) -> RenderedPrompt:
"""Render this prompt, given this document as context.
Used in llm_map

Expand All @@ -54,7 +54,7 @@ def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[Rende
"""
raise NotImplementedError(f"render_document is not implemented for {self.__class__.__name__}")

def render_element(self, elt: Element, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]:
def render_element(self, elt: Element, doc: Document) -> RenderedPrompt:
"""Render this prompt, given this element and its parent document as context.
Used in llm_map_elements

Expand All @@ -66,7 +66,7 @@ def render_element(self, elt: Element, doc: Document) -> Union[RenderedPrompt, S
"""
raise NotImplementedError(f"render_element is not implemented for {self.__class__.__name__}")

def render_multiple_documents(self, docs: list[Document]) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]:
def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt:
"""Render this prompt, given a list of documents as context.
Used in llm_reduce

Expand Down Expand Up @@ -113,19 +113,6 @@ def set(self, **kwargs) -> "SycamorePrompt":
new.__dict__[k] = v
return new

def is_done(self, s: str) -> bool:
"""Decide whether a given response is sufficient. Used when rendering
the prompt generates a sequence of prompts rather than a single prompt.
The default implementation always returns True

Args:
s: a string response from the LLM

Returns:
Whether to continue making LLM calls
"""
return True


def _build_format_str(
system: Optional[str], user: Union[None, str, list[str]], format_args: dict[str, Any]
Expand Down Expand Up @@ -201,7 +188,7 @@ def _render_element_list_to_string(self, doc: Document):
elts = self.element_select(doc.elements)
return self.element_list_constructor(elts)

def render_document(self, doc: Document) -> Union[RenderedPrompt, Sequence[RenderedPrompt]]:
def render_document(self, doc: Document) -> RenderedPrompt:
"""Render this prompt, given this document as context, using python's
``str.format()`` method. The keys passed into ``format()`` are as follows:

Expand Down Expand Up @@ -280,11 +267,18 @@ class ElementListIterPrompt(ElementListPrompt):
# ]
"""

def __init__(self, *, element_batcher: Optional[Callable[[list[Element]], list[list[Element]]]] = None, **kwargs):
def __init__(
self,
*,
element_batcher: Optional[Callable[[list[Element]], list[list[Element]]]] = None,
iteration_var_name: str = "i",
**kwargs,
):
self.element_batcher = element_batcher or (lambda e: [e])
self.iteration_var_name = iteration_var_name
super().__init__(**kwargs)

def render_document(self, doc: Document) -> Sequence[RenderedPrompt]:
def render_document(self, doc: Document) -> RenderedPrompt:
"""Render this prompt, given this document as context, using python's
``str.format()`` method. The keys passed into ``format()`` are as follows:

Expand All @@ -304,19 +298,22 @@ def render_document(self, doc: Document) -> Sequence[RenderedPrompt]:
``self.user.format()`` using the format keys as specified above. Each instance
is rendered from a batch of elements generated by ``self.element_batcher``
"""
i = doc.properties.get(self.iteration_var_name, 0)

format_args = self.kwargs
format_args["doc_text"] = doc.text_representation
flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_")
format_args.update(flat_props)

prompts = []
for elt_batch in self.element_batcher(doc.elements):
elements = self.element_select(elt_batch)
elementstr = self.element_list_constructor(elements)
messages = _build_format_str(self.system, self.user, {"elements": elementstr, **format_args})
prompts.append(RenderedPrompt(messages=messages))
return prompts
for j, elt_batch in enumerate(self.element_batcher(doc.elements)):
if j < i:
continue
else:
elements = self.element_select(elt_batch)
elementstr = self.element_list_constructor(elements)
messages = _build_format_str(self.system, self.user, {"elements": elementstr, **format_args})
return RenderedPrompt(messages=messages)
return RenderedPrompt(messages=[])


class ElementPrompt(SycamorePrompt):
Expand Down
3 changes: 2 additions & 1 deletion lib/sycamore/sycamore/tests/unit/test_docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self):
super().__init__(model_name="mock_model")

def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
print(prompt)
if llm_kwargs is None:
llm_kwargs = {}
if prompt.messages[-1].content.endswith("Element_index: 1\nText: third element\n"):
Expand Down Expand Up @@ -98,6 +97,8 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None)
return "group2"
elif value == "3" or value == "three":
return "group3"
else:
return ""
else:
return prompt.messages[-1].content

Expand Down
14 changes: 7 additions & 7 deletions lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,19 @@ def test_happy_path(self):
assert outdocs[1].text_representation == "booga"
assert outdocs[1].properties["out"] == "booga"

def test_postprocess(self):
def test_validate(self):
prompt = FakeDocPrompt()
llm = FakeLLM()
doc1 = Document({"text_representation": "ooga"})
doc2 = Document({"text_representation": "booga"})
count = 0

def ppfn(d: Document, i: int) -> Document:
def valfn(d: Document) -> bool:
nonlocal count
count += 1
return d
return count > 1

map = LLMMap(None, prompt, "out", llm, postprocess_fn=ppfn)
map = LLMMap(None, prompt, "out", llm, validate=valfn)
_ = map.llm_map([doc1, doc2])

assert count == 2
Expand Down Expand Up @@ -112,12 +112,12 @@ def test_postprocess(self):
doc2 = Document({"doc_id": "2", "elements": [{"text_representation": "booga"}, {}]})
count = 0

def ppfn(e: Element, i: int) -> Element:
def valfn(e: Element) -> bool:
nonlocal count
count += 1
return e
return count > 1

map = LLMMapElements(None, prompt, "out", llm, postprocess_fn=ppfn)
map = LLMMapElements(None, prompt, "out", llm, validate=valfn)
_ = map.llm_map_elements([doc1, doc2])

assert count == 4
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def test_extract_entity_with_tokenizer(self, mocker):
entity_extractor=entity_extractor,
)
taken = entity_docset.take()

assert taken[0].properties[f"{new_field}_source_element_index"] == {0, 1, 2}
assert taken[1].properties[f"{new_field}_source_element_index"] == {2}
assert taken[0].properties[new_field] == "4"
Expand Down
103 changes: 71 additions & 32 deletions lib/sycamore/sycamore/transforms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,15 @@


def _infer_prompts(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confess I'm having a hard time following all of the prompt sequence stuff. I understand the basic motivation for the token stuff, but I can't help but wonder if there is a cleaner way to do it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thought I had which I think is beyond the scope of this is some sort of conditional branching/looping logic in a sycamore pipeline. Like

docset.do()
    .llm_map(<extract entity prompt>)
    .map(if not json set back to 'None')
    .while(lambda d: d.properties.get('entity', 'None') == 'None')
    .continue_processing()

I'm sure that creates all sorts of theoretic problems.

But then prompts can go back to being a single prompt and we can increment a counter on the document object and stuff.

I guess we can do the to-render-object counter thing and keep prompts to single-render in a for loop in llm_map which is probably cleaner. I'll try it and see

prompts: list[Sequence[RenderedPrompt]],
prompts: list[RenderedPrompt],
llm: LLM,
llm_mode: LLMMode,
is_done: Callable[[str], bool] = lambda s: True,
) -> list[tuple[str, int]]:
if llm_mode == LLMMode.SYNC:
res = []
for piter in prompts:
s = ""
i = -1
for p in piter:
i += 1
s = llm.generate(prompt=p)
if is_done(s):
break
res.append((s, i))
for p in prompts:
s = llm.generate(prompt=p)
res.append(s)
return res
elif llm_mode == LLMMode.ASYNC:
raise NotImplementedError("Haven't done async yet")
Expand Down Expand Up @@ -73,27 +66,47 @@ def __init__(
output_field: str,
llm: LLM,
llm_mode: LLMMode = LLMMode.SYNC,
postprocess_fn: Callable[[Document, int], Document] = lambda d, i: d,
iteration_var: Optional[str] = None,
validate: Callable[[Document], bool] = lambda d: True,
max_tries: int = 5,
**kwargs,
):
self._prompt = prompt
self._validate_prompt()
self._output_field = output_field
self._llm = llm
self._llm_mode = llm_mode
self._postprocess_fn = postprocess_fn
self._iteration_var = iteration_var
self._validate = validate
self._max_tries = max_tries
super().__init__(child, f=self.llm_map, **kwargs)

def llm_map(self, documents: list[Document]) -> list[Document]:
rendered_inc = [self._prompt.render_document(d) for d in documents]
rendered = _as_sequences(rendered_inc)
results = _infer_prompts(rendered, self._llm, self._llm_mode, self._prompt.is_done)
postprocessed = []
for d, (r, i) in zip(documents, results):
d.properties[self._output_field] = r
new_d = self._postprocess_fn(d, i)
postprocessed.append(new_d)
return postprocessed
if self._iteration_var is not None:
for d in documents:
d.properties[self._iteration_var] = 0

valid = [False] * len(documents)
tries = 0
while not all(valid) and tries < self._max_tries:
tries += 1
rendered = [self._prompt.render_document(d) for v, d in zip(valid, documents) if not v]
if sum([0, *(len(r.messages) for r in rendered)]) == 0:
break
results = _infer_prompts(rendered, self._llm, self._llm_mode)
ri = 0
for i in range(len(documents)):
if valid[i]:
continue
documents[i].properties[self._output_field] = results[ri]
valid[i] = self._validate(documents[i])
ri += 1
if self._iteration_var is not None and not valid[i]:
documents[i].properties[self._iteration_var] += 1
if self._iteration_var is None:
break

return documents

def _validate_prompt(self):
doc = Document()
Expand Down Expand Up @@ -143,30 +156,56 @@ def __init__(
output_field: str,
llm: LLM,
llm_mode: LLMMode = LLMMode.SYNC,
postprocess_fn: Callable[[Element, int], Element] = lambda e, i: e,
iteration_var: Optional[str] = None,
validate: Callable[[Element], bool] = lambda d: True,
max_tries: int = 5,
**kwargs,
):
self._prompt = prompt
self._validate_prompt()
self._output_field = output_field
self._llm = llm
self._llm_mode = llm_mode
self._postprocess_fn = postprocess_fn
self._iteration_var = iteration_var
self._validate = validate
self._max_tries = max_tries
super().__init__(child, f=self.llm_map_elements, **kwargs)

def llm_map_elements(self, documents: list[Document]) -> list[Document]:
rendered = [(d, e, self._prompt.render_element(e, d)) for d in documents for e in d.elements]
results = _infer_prompts(
_as_sequences([p for _, _, p in rendered]), self._llm, self._llm_mode, self._prompt.is_done
)
new_elts = []
elt_doc_pairs = [(e, d) for d in documents for e in d.elements]
if self._iteration_var is not None:
for e, _ in elt_doc_pairs:
e.properties[self._iteration_var] = 0

valid = [False] * len(elt_doc_pairs)
tries = 0
while not all(valid) and tries < self._max_tries:
tries += 1
rendered = [self._prompt.render_element(e, d) for v, (e, d) in zip(valid, elt_doc_pairs) if not v]
if sum([0, *(len(r.messages) for r in rendered)]) == 0:
break
results = _infer_prompts(rendered, self._llm, self._llm_mode)
ri = 0
for i in range(len(elt_doc_pairs)):
if valid[i]:
continue
print(ri)
elt, doc = elt_doc_pairs[i]
elt.properties[self._output_field] = results[ri]
valid[i] = self._validate(elt)
ri += 1
if self._iteration_var is not None:
elt.properties[self._iteration_var] += 1
if self._iteration_var is None:
break

last_doc = None
for (r, i), (d, e, _) in zip(results, rendered):
new_elts = []
for e, d in elt_doc_pairs:
if last_doc is not None and last_doc.doc_id != d.doc_id:
last_doc.elements = new_elts
new_elts = []
e.properties[self._output_field] = r
new_elts.append(self._postprocess_fn(e, i))
new_elts.append(e)
last_doc = d
if last_doc is not None:
last_doc.elements = new_elts
Expand Down
Loading