|
@@ -0,0 +1,219 @@
|
|
|
+import { FmodeChatCompletion } from 'fmode-ng';
|
|
|
+import { FlowTask, FieldSchema, FlowTaskOptions } from '../../flow.task';
|
|
|
+import { Subject } from 'rxjs';
|
|
|
+import { takeUntil } from 'rxjs/operators';
|
|
|
+
|
|
|
+export interface JsonCompletionOptions extends FlowTaskOptions {
|
|
|
+ promptTemplate: string;
|
|
|
+ modelOptions?: Record<string, any>;
|
|
|
+ strictPromptValidation?: boolean;
|
|
|
+}
|
|
|
+
|
|
|
+export class TaskCompletionJson extends FlowTask {
|
|
|
+ promptTemplate: string;
|
|
|
+ modelOptions: Record<string, any>;
|
|
|
+ strictPromptValidation: boolean;
|
|
|
+ destroy$ = new Subject<void>();
|
|
|
+
|
|
|
+ constructor(options: JsonCompletionOptions) {
|
|
|
+ super({
|
|
|
+ title: options.title || 'JSON Completion Task',
|
|
|
+ output: options.output, // Only output schema is needed
|
|
|
+ initialData: options.initialData
|
|
|
+ });
|
|
|
+
|
|
|
+ this.promptTemplate = options.promptTemplate;
|
|
|
+ this.modelOptions = options.modelOptions || {};
|
|
|
+ this.strictPromptValidation = options.strictPromptValidation ?? true;
|
|
|
+ }
|
|
|
+
|
|
|
+ override async handle(): Promise<void> {
|
|
|
+ // 1. Validate all required prompt variables exist in task.data
|
|
|
+ this.validatePromptVariables();
|
|
|
+
|
|
|
+ // 2. Prepare the prompt with variable substitution
|
|
|
+ const fullPrompt = this.renderPromptTemplate();
|
|
|
+
|
|
|
+ // 3. Call the LLM for completion
|
|
|
+ await this.callModelCompletion(fullPrompt);
|
|
|
+ }
|
|
|
+
|
|
|
+ validatePromptVariables(): void {
|
|
|
+ const requiredVariables = this.extractPromptVariables();
|
|
|
+ const missingVariables: string[] = [];
|
|
|
+ const undefinedVariables: string[] = [];
|
|
|
+
|
|
|
+ requiredVariables.forEach(variable => {
|
|
|
+ if (!(variable in this.data)) {
|
|
|
+ missingVariables.push(variable);
|
|
|
+ } else if (this.data[variable] === undefined) {
|
|
|
+ undefinedVariables.push(variable);
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ const errors: string[] = [];
|
|
|
+
|
|
|
+ if (missingVariables.length > 0) {
|
|
|
+ errors.push(`Missing required variables in task.data: ${missingVariables.join(', ')}`);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (undefinedVariables.length > 0) {
|
|
|
+ errors.push(`Variables with undefined values: ${undefinedVariables.join(', ')}`);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (errors.length > 0 && this.strictPromptValidation) {
|
|
|
+ throw new Error(`Prompt variable validation failed:\n${errors.join('\n')}`);
|
|
|
+ } else if (errors.length > 0) {
|
|
|
+ console.warn(`Prompt variable warnings:\n${errors.join('\n')}`);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ extractPromptVariables(): string[] {
|
|
|
+ const matches = this.promptTemplate.match(/\{\{\w+\}\}/g) || [];
|
|
|
+ const uniqueVariables = new Set<string>();
|
|
|
+
|
|
|
+ matches.forEach(match => {
|
|
|
+ const key = match.replace(/\{\{|\}\}/g, '');
|
|
|
+ uniqueVariables.add(key);
|
|
|
+ });
|
|
|
+
|
|
|
+ return Array.from(uniqueVariables);
|
|
|
+ }
|
|
|
+
|
|
|
+ renderPromptTemplate(): string {
|
|
|
+ let result = this.promptTemplate;
|
|
|
+ const variables = this.extractPromptVariables();
|
|
|
+
|
|
|
+ variables.forEach(variable => {
|
|
|
+ if (this.data[variable] !== undefined) {
|
|
|
+ result = result.replace(new RegExp(`\\{\\{${variable}\\}\\}`, 'g'), this.data[variable]);
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ async callModelCompletion(prompt: string): Promise<void> {
|
|
|
+ return new Promise((resolve, reject) => {
|
|
|
+ const messages = [{
|
|
|
+ role: "user",
|
|
|
+ content: prompt
|
|
|
+ }];
|
|
|
+
|
|
|
+ const completion = new FmodeChatCompletion(messages);
|
|
|
+
|
|
|
+ let accumulatedContent = '';
|
|
|
+
|
|
|
+ completion.sendCompletion({
|
|
|
+ ...this.modelOptions,
|
|
|
+ onComplete: (message: any) => {
|
|
|
+ console.log("onComplete", message);
|
|
|
+ }
|
|
|
+ })
|
|
|
+ .pipe(takeUntil(this.destroy$))
|
|
|
+ .subscribe({
|
|
|
+ next: (message: any) => {
|
|
|
+ if (message.content && typeof message.content === 'string') {
|
|
|
+ accumulatedContent = message.content;
|
|
|
+ this.setProgress(0.3 + (accumulatedContent.length / 1000) * 0.7);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (message.complete) {
|
|
|
+ try {
|
|
|
+ const parsed = this.parseAndValidateResponse(accumulatedContent);
|
|
|
+ this.updateOutputData(parsed);
|
|
|
+ this.setProgress(1);
|
|
|
+ resolve();
|
|
|
+ } catch (error) {
|
|
|
+ this.handleError(error as Error);
|
|
|
+ reject(error);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ },
|
|
|
+ error: (error) => {
|
|
|
+ this.handleError(error);
|
|
|
+ reject(error);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ parseAndValidateResponse(response: string): Record<string, any> {
|
|
|
+ const jsonStart = response.indexOf('{');
|
|
|
+ const jsonEnd = response.lastIndexOf('}') + 1;
|
|
|
+
|
|
|
+ if (jsonStart === -1 || jsonEnd === -1) {
|
|
|
+ throw new Error('Invalid JSON response format');
|
|
|
+ }
|
|
|
+
|
|
|
+ const jsonStr = response.slice(jsonStart, jsonEnd);
|
|
|
+ let parsedData: Record<string, any>;
|
|
|
+
|
|
|
+ try {
|
|
|
+ parsedData = JSON.parse(jsonStr);
|
|
|
+ } catch (e) {
|
|
|
+ throw new Error(`Failed to parse JSON response: ${(e as Error).message}`);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Validate against output schema
|
|
|
+ this.validateOutputData(parsedData);
|
|
|
+
|
|
|
+ return parsedData;
|
|
|
+ }
|
|
|
+
|
|
|
+ validateOutputData(data: Record<string, any>): void {
|
|
|
+ if (!this.outputSchema || this.outputSchema.length === 0) {
|
|
|
+ return; // No validation needed if no output schema defined
|
|
|
+ }
|
|
|
+
|
|
|
+ const errors: string[] = [];
|
|
|
+
|
|
|
+ this.outputSchema.forEach(field => {
|
|
|
+ const value = data[field.name];
|
|
|
+
|
|
|
+ if (field.required && value === undefined) {
|
|
|
+ errors.push(`Missing required field in response: ${field.name}`);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (value !== undefined && !this.checkType(value, field.type)) {
|
|
|
+ errors.push(`${field.name} has wrong type, expected ${field.type}, got ${this.getType(value)}`);
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ if (errors.length > 0) {
|
|
|
+ throw new Error(`Output validation failed:\n${errors.join('\n')}`);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ updateOutputData(parsedData: Record<string, any>): void {
|
|
|
+ Object.entries(parsedData).forEach(([key, value]) => {
|
|
|
+ this.updateData(key, value);
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ handleError(error: Error): void {
|
|
|
+ this.updateData('error', {
|
|
|
+ message: error.message,
|
|
|
+ stack: error.stack,
|
|
|
+ timestamp: new Date().toISOString()
|
|
|
+ });
|
|
|
+ this._status = 'failed';
|
|
|
+ }
|
|
|
+
|
|
|
+ override checkType(value: any, expected: any): boolean {
|
|
|
+ const actualType = this.getType(value);
|
|
|
+ return expected === 'any' || actualType === expected;
|
|
|
+ }
|
|
|
+
|
|
|
+ override getType(value: any): any {
|
|
|
+ if (Array.isArray(value)) return 'array';
|
|
|
+ if (value === null) return 'object';
|
|
|
+ return typeof value as any;
|
|
|
+ }
|
|
|
+
|
|
|
+ onDestroy(): void {
|
|
|
+ this.destroy$.next();
|
|
|
+ this.destroy$.complete();
|
|
|
+ }
|
|
|
+}
|