Skip to content

Commit b5277b3

Browse files
authored
[BUGFIX] Custom endpoint and form state (huggingface#398)
* store endpoint usage flag in DB * read endpoint usage from DB info * show selected model properly * create process with endpoint-related info * remove extra condition
1 parent e45d52a commit b5277b3

File tree

7 files changed

+40
-27
lines changed

7 files changed

+40
-27
lines changed

src/features/add-column/form/execution-form.tsx

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ export const ExecutionForm = component$<SidebarProps>(
7878

7979
const allModels = useContext<Model[]>(modelsContext);
8080

81-
let {
81+
const {
8282
DEFAULT_MODEL,
8383
DEFAULT_MODEL_PROVIDER,
8484
modelEndpointEnabled,
@@ -103,7 +103,8 @@ export const ExecutionForm = component$<SidebarProps>(
103103
const selectedModelId = useSignal<string>('');
104104
const selectedProvider = useSignal<string>('');
105105

106-
const endpointURLSelected = useSignal(modelEndpointEnabled);
106+
const enableCustomEndpoint = useSignal(modelEndpointEnabled);
107+
const endpointURLSelected = useSignal(false);
107108

108109
const onSelectedVariables = $((variables: { id: string }[]) => {
109110
columnsReferences.value = variables.map((v) => v.id);
@@ -115,11 +116,6 @@ export const ExecutionForm = component$<SidebarProps>(
115116
return column.type === 'image';
116117
});
117118

118-
if (isImageColumn.value) {
119-
// Currently, we custom endpoint only for text models
120-
modelEndpointEnabled = false;
121-
}
122-
123119
const modelProviders = useComputed$(() => {
124120
const model = models.value.find(
125121
(m: Model) =>
@@ -166,19 +162,28 @@ export const ExecutionForm = component$<SidebarProps>(
166162
}
167163
});
168164

169-
useTask$(() => {
165+
useTask$(({ track }) => {
166+
track(column);
167+
170168
variables.value = columns.value
171169
.filter((c) => c.id !== column.id && !hasBlobContent(c))
172170
.map((c) => ({
173171
id: c.id,
174172
name: c.name,
175173
}));
176174

175+
if (isImageColumn.value) {
176+
// Currently, we custom endpoint only for text models
177+
enableCustomEndpoint.value = false;
178+
}
179+
177180
const { process } = column;
178181
if (!process) return;
179182

180183
prompt.value = process.prompt;
181184
searchOnWeb.value = process.searchEnabled || false;
185+
endpointURLSelected.value =
186+
(enableCustomEndpoint.value && process.useEndpointURL) || false;
182187

183188
if (process.modelName) {
184189
// If there's a previously selected model, use that
@@ -385,7 +390,7 @@ export const ExecutionForm = component$<SidebarProps>(
385390
<p class="text-neutral-500 underline">
386391
{selectedModelId.value}
387392
</p>
388-
{modelEndpointEnabled && !endpointURLSelected.value && (
393+
{!endpointURLSelected.value && (
389394
<Tooltip text="Reset default model">
390395
<LuUndo2
391396
class="w-4 h-4 rounded-full gap-2 text-neutral-500 cursor-pointer hover:bg-neutral-200"

src/features/table/components/header/cell-generation.tsx

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import { component$, useContext } from '@builder.io/qwik';
1+
import { component$ } from '@builder.io/qwik';
22
import { LuEgg } from '@qwikest/icons/lucide';
33
import { Tooltip } from '~/components/ui/tooltip/tooltip';
44
import { useGenerateColumn } from '~/features/execution';
5-
import { configContext } from '~/routes/home/layout';
65
import { type Column, TEMPORAL_ID } from '~/state';
76

87
export const CellGeneration = component$<{ column: Column }>(({ column }) => {
@@ -11,10 +10,6 @@ export const CellGeneration = component$<{ column: Column }>(({ column }) => {
1110
if (column.id === TEMPORAL_ID || column.kind !== 'dynamic') return null;
1211
if (!column.process) return null;
1312

14-
const { modelEndpointEnabled } = useContext(configContext);
15-
16-
column.process.useEndpointURL = modelEndpointEnabled;
17-
1813
return (
1914
<Tooltip text="Regenerate">
2015
<div

src/features/table/table-body.tsx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ export const TableBody = component$(() => {
219219
...column,
220220
process: {
221221
...column.process!,
222-
useEndpointURL: modelEndpointEnabled,
223222
offset,
224223
limit,
225224
},

src/services/db/models/process.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ export class ProcessModel extends Model<
1919
declare id: CreationOptional<string>;
2020
declare prompt: string;
2121
declare modelName: string;
22-
declare searchEnabled: boolean;
22+
2323
declare modelProvider: string;
24+
declare useCustomEndpoint: boolean;
25+
26+
declare searchEnabled: boolean;
2427
declare columnId: ForeignKey<ColumnModel['id']>;
2528

2629
declare referredColumns: NonAttribute<ColumnModel[]>; // This is a virtual attribute
@@ -40,18 +43,22 @@ ProcessModel.init(
4043
defaultValue: DataTypes.UUIDV4,
4144
primaryKey: true,
4245
},
43-
modelName: {
46+
prompt: {
4447
type: DataTypes.STRING,
4548
allowNull: false,
4649
},
47-
modelProvider: {
50+
modelName: {
4851
type: DataTypes.STRING,
4952
allowNull: false,
5053
},
51-
prompt: {
54+
modelProvider: {
5255
type: DataTypes.STRING,
5356
allowNull: false,
5457
},
58+
useCustomEndpoint: {
59+
type: DataTypes.BOOLEAN,
60+
defaultValue: false,
61+
},
5562
searchEnabled: {
5663
type: DataTypes.BOOLEAN,
5764
defaultValue: false,

src/services/repository/columns.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ export const modelToColumn = (model: ColumnModel): Column => {
2626
columnsReferences: (model.process?.referredColumns ?? []).map(
2727
(columnRef) => columnRef.id,
2828
),
29+
prompt: model.process?.prompt ?? '',
2930
modelName: model.process?.modelName ?? '',
3031
modelProvider: model.process?.modelProvider ?? '',
31-
prompt: model.process?.prompt ?? '',
32+
useEndpointURL: model.process?.useCustomEndpoint ?? false,
3233
searchEnabled: model.process?.searchEnabled,
3334
updatedAt: model.process?.updatedAt,
3435
},

src/services/repository/processes.ts

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ import { ProcessColumnModel, ProcessModel } from '~/services/db/models';
22
import type { Process } from '~/state';
33

44
export interface CreateProcess {
5+
prompt: string;
56
modelName: string;
67
modelProvider: string;
7-
prompt: string;
8+
useEndpointURL?: boolean;
89
searchEnabled: boolean;
910
columnsReferences?: string[];
1011
}
@@ -19,9 +20,10 @@ export const createProcess = async ({
1920
};
2021
}): Promise<Process> => {
2122
const model = await ProcessModel.create({
23+
prompt: process.prompt,
2224
modelName: process.modelName,
2325
modelProvider: process.modelProvider,
24-
prompt: process.prompt,
26+
useCustomEndpoint: process.useEndpointURL ?? false,
2527
searchEnabled: process.searchEnabled,
2628
columnId: column.id,
2729
});
@@ -37,9 +39,10 @@ export const createProcess = async ({
3739

3840
return {
3941
id: model.id,
42+
prompt: model.prompt,
4043
modelName: model.modelName,
4144
modelProvider: model.modelProvider,
42-
prompt: model.prompt,
45+
useEndpointURL: model.useCustomEndpoint,
4346
searchEnabled: model.searchEnabled,
4447
columnsReferences: process?.columnsReferences || [],
4548
updatedAt: model.updatedAt,
@@ -55,9 +58,10 @@ export const updateProcess = async (process: Process): Promise<Process> => {
5558

5659
model.changed('updatedAt', true);
5760
model.set({
61+
prompt: process.prompt,
5862
modelName: process.modelName,
5963
modelProvider: process.modelProvider,
60-
prompt: process.prompt,
64+
useCustomEndpoint: process.useEndpointURL ?? false,
6165
searchEnabled: process.searchEnabled,
6266
});
6367

@@ -74,9 +78,10 @@ export const updateProcess = async (process: Process): Promise<Process> => {
7478

7579
return {
7680
id: model.id,
81+
prompt: model.prompt,
7782
modelName: model.modelName,
7883
modelProvider: model.modelProvider,
79-
prompt: model.prompt,
84+
useEndpointURL: model.useCustomEndpoint,
8085
searchEnabled: model.searchEnabled,
8186
columnsReferences: process.columnsReferences,
8287
updatedAt: model.updatedAt,

src/usecases/run-autodataset.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,10 @@ async function createDatasetWithColumns(
385385

386386
const process = await createProcess({
387387
process: {
388+
prompt: column.prompt,
388389
modelName: processModelName,
389390
modelProvider: processModelProvider,
390-
prompt: column.prompt,
391+
useEndpointURL: textGeneration.endpointUrl !== undefined && !isImage,
391392
searchEnabled,
392393
columnsReferences: columnReferences.map((ref) => {
393394
const refIndex = columnNames.indexOf(ref);

0 commit comments

Comments
 (0)