Skip to content

Commit e650025

Browse files
authored
Handle all truncation cases (#637)
* populate is truncated if max_tokens or max_model_len is hit * handle overlong prompt error gracefully (do not propagate as error) * parse is_truncated from response * show summary in vf-eval * also set is_truncated on bad request error * fix tests * fix ty * show type repr * do not set prompt too long from req id * also remove in docs
1 parent ba0f2ba commit e650025

File tree

7 files changed

+58
-6
lines changed

7 files changed

+58
-6
lines changed

notes/TRAJECTORIES.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,6 @@ async def add_model_response(
350350
response: ModelResponse,
351351
):
352352
"""Add a model response as a trajectory step."""
353-
if response is not None and response.id == "overlong-prompt":
354-
state["prompt_too_long"] = True
355-
return
356353
completion_messages = await parse_response_messages(response, self.message_type)
357354
tokens = await parse_response_tokens(response, self.message_type)
358355
trajectory_step = TrajectoryStep(

tests/test_eval_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def test_print_results_rollout_indexing(capsys):
6464
example_id=example_ids,
6565
reward=rewards,
6666
metrics={"test_metric": metric_values},
67+
is_truncated=[False] * 6,
68+
stop_conditions=[None] * 6,
6769
metadata=_make_metadata(num_examples, rollouts_per_example),
6870
)
6971

@@ -102,6 +104,8 @@ def test_print_results_single_rollout(capsys):
102104
example_id=example_ids,
103105
reward=rewards,
104106
metrics={},
107+
is_truncated=[False] * 3,
108+
stop_conditions=[None] * 3,
105109
metadata=_make_metadata(num_examples, rollouts_per_example),
106110
)
107111

@@ -134,6 +138,8 @@ def test_print_results_three_rollouts(capsys):
134138
example_id=example_ids,
135139
reward=rewards,
136140
metrics={},
141+
is_truncated=[False] * 6,
142+
stop_conditions=[None] * 6,
137143
metadata=_make_metadata(num_examples, rollouts_per_example),
138144
)
139145

verifiers/envs/environment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ async def init_state(
565565
state["model"] = model
566566
state["sampling_args"] = sampling_args
567567
state["is_completed"] = False
568+
state["is_truncated"] = False
568569
state["oai_tools"] = None
569570
if "info" in state and hasattr(state["info"], "oai_tools"):
570571
state["oai_tools"] = state["info"]["oai_tools"]
@@ -621,6 +622,9 @@ async def _teardown(self):
621622
async def _render_stop(self, state: State, condition) -> bool:
622623
if await condition(state):
623624
state["is_completed"] = True
625+
state["is_truncated"] = state.get("is_truncated", False) or any(
626+
step.get("is_truncated", False) for step in state.get("trajectory", [])
627+
)
624628
state["stop_condition"] = condition.__name__
625629
if state.get("stop_condition") == "has_error":
626630
self.logger.error(
@@ -724,6 +728,8 @@ def _prepare_rollout_results(
724728
infos = [state.get("info", {}) for state in all_states]
725729
example_ids = [state.get("example_id", 0) for state in all_states]
726730
rewards = [state.get("reward", 0.0) for state in all_states]
731+
stop_conditions = [state.get("stop_condition", None) for state in all_states]
732+
is_truncated = [state.get("is_truncated", False) for state in all_states]
727733

728734
metrics: dict[str, list[float]] = {}
729735
for state in all_states:
@@ -767,6 +773,8 @@ def _prepare_rollout_results(
767773
example_id=example_ids,
768774
reward=rewards,
769775
metrics=metrics,
776+
stop_conditions=stop_conditions,
777+
is_truncated=is_truncated,
770778
metadata=metadata,
771779
)
772780

verifiers/envs/multiturn_env.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from verifiers.utils.message_utils import concat_messages
1616
from verifiers.utils.response_utils import (
17+
parse_is_truncated,
1718
parse_response_messages,
1819
parse_response_tokens,
1920
)
@@ -68,19 +69,22 @@ async def add_model_response(
6869
prompt_messages: Messages,
6970
response: ModelResponse,
7071
):
71-
if response is not None and response.id == "overlong-prompt":
72-
state["prompt_too_long"] = True
7372
completion_messages = await parse_response_messages(response, self.message_type)
73+
response_is_truncated = await parse_is_truncated(response, self.message_type)
7474
tokens = await parse_response_tokens(
7575
response, self.message_type, self.max_seq_len
7676
)
77+
is_truncated = response_is_truncated or (
78+
tokens is not None and bool(tokens.get("is_truncated"))
79+
)
7780
trajectory_step = TrajectoryStep(
7881
prompt=prompt_messages,
7982
completion=completion_messages,
8083
response=response,
8184
tokens=tokens,
8285
reward=None,
8386
advantage=None,
87+
is_truncated=is_truncated,
8488
extras={},
8589
)
8690
trajectory_step["completion"] = completion_messages
@@ -107,5 +111,9 @@ async def rollout(
107111
response = await self.get_model_response(state, prompt_messages)
108112
await self.add_model_response(state, prompt_messages, response)
109113
except vf.Error as e:
110-
state["error"] = e
114+
if isinstance(e, vf.OverlongPromptError):
115+
state["prompt_too_long"] = True
116+
state["is_truncated"] = True
117+
else:
118+
state["error"] = e
111119
return state

verifiers/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class TrajectoryStep(TypedDict):
6767
tokens: TrajectoryStepTokens | None
6868
reward: float | None
6969
advantage: float | None
70+
is_truncated: bool
7071
extras: dict[str, Any]
7172

7273

@@ -99,6 +100,7 @@ class State(dict):
99100
sampling_args: SamplingArgs | None
100101
# created during rollout
101102
is_completed: bool
103+
is_truncated: bool
102104
stop_condition: str | None
103105
oai_tools: list[ChatCompletionToolParam]
104106
trajectory: list[TrajectoryStep]
@@ -167,6 +169,8 @@ class GenerateOutputs(TypedDict):
167169
example_id: list[int]
168170
reward: list[float]
169171
metrics: dict[str, list[float]]
172+
stop_conditions: list[str | None]
173+
is_truncated: list[bool]
170174
metadata: GenerateMetadata
171175

172176

verifiers/utils/eval_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import time
5+
from collections import Counter
56
from contextlib import contextmanager
67
from pathlib import Path
78
from typing import cast
@@ -97,6 +98,19 @@ def print_results(results: GenerateOutputs, num_samples: int = 1):
9798
out = f"r{i + 1}: {trials}"
9899
print(out)
99100

101+
print("Info:")
102+
print(
103+
f"is_truncated: avg - {np.mean(results['is_truncated']):.3f}, std - {np.std(results['is_truncated']):.3f}"
104+
)
105+
print(
106+
f"stop_conditions: {', '.join([f'{k}={v}' for k, v in Counter(results['stop_conditions']).items()])}"
107+
)
108+
errors = [e for e in errors if e is not None]
109+
if errors:
110+
print(
111+
f"errors: {', '.join([f'{k}: {v / len(errors):.3f}' for k, v in Counter([type(e).__name__ for e in errors]).items()])}"
112+
)
113+
100114

101115
async def run_evaluation(config: EvalConfig) -> GenerateOutputs:
102116
# set up AsyncOpenAI client with high limits to prevent timeouts

verifiers/utils/response_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,18 @@ async def parse_response_messages(
125125
response_text = response.choices[0].text or ""
126126
completion_messages = str(response_text)
127127
return completion_messages
128+
129+
130+
async def parse_is_truncated(
131+
response: ModelResponse, message_type: MessageType
132+
) -> bool:
133+
if message_type == "chat":
134+
assert isinstance(response, ChatCompletion)
135+
assert len(response.choices) == 1, "Response should always have one choice"
136+
return response.choices[0].finish_reason == "length"
137+
elif message_type == "completion":
138+
assert isinstance(response, Completion)
139+
assert len(response.choices) == 1, "Response should always have one choice"
140+
return response.choices[0].finish_reason == "length"
141+
else:
142+
return False

0 commit comments

Comments
 (0)