import { IModelParams } from '@ai21/studio-store';
import { api } from '@ai21/studio-api';
import { get } from 'lodash';

export class CompletionQueue {
  private queue: any[] = [];
  private runningRequests = 0;
  private completedRequests = 0;
  private total = 0;
  private resolve: any;
  private onProgress?: OnProgress;
  private onResponse: OnResponse = () => {};

  constructor(
    private maxRequests: number,
    private setId: string,
    private customModelsMap: Json
  ) {}

  addRequest(id: string, prompt: string, params: Partial<IModelParams>) {
    this.queue.push({ id, prompt, params });
  }

  async fireRequest(queueItem: any) {
    const { id, prompt, params } = queueItem;
    this.runningRequests++;
    const parsedParams = prepareRequestParams(
      prompt,
      params,
      this.customModelsMap
    );

    const response = await api.completion.complete(parsedParams, {
      appFeature: 'generationSet',
      setId: this.setId,
    });

    this.completedRequests++;
    this.runningRequests--;

    const data = get(
      response,
      'data.completions[0].data.text'.replace('\n', ''),
      ''
    );

    this.onResponse(id, data);

    if (this.onProgress) {
      this.onProgress(this.completedRequests);
    }
  }

  done() {
    if (this.onProgress) {
      this.onProgress(this.total);
    }

    if (this.resolve) {
      this.resolve();
    }
  }

  process() {
    const availableSlots = this.maxRequests - this.runningRequests;

    if (availableSlots <= 0) {
      setTimeout(() => {
        this.process();
      }, 1000);
      return;
    }

    const queueItems = this.queue.splice(0, availableSlots);

    queueItems.forEach((queueItem) => {
      this.fireRequest(queueItem);
    });

    if (this.queue.length === 0 && this.runningRequests === 0) {
      this.done();
      return;
    }

    setTimeout(() => {
      this.process();
    }, 1000);
  }

  start(callbacks: Callbacks) {
    this.onResponse = callbacks.onResponse;
    this.onProgress = callbacks.onProgress;

    return new Promise<Json>((resolve) => {
      this.total = this.queue.length;
      this.resolve = resolve;
      setTimeout(() => {
        this.process();
      }, 100);
    });
  }
}

type OnProgress = (completed: number) => void;
type OnResponse = (id: string, completion: string) => void;

type Callbacks = {
  onResponse: OnResponse;
  onProgress?: OnProgress;
};

export const prepareRequestParams = (
  prompt: string,
  params: IModelParams,
  customModelsMap: Json
) => {
  const {
    modelName = '',
    temperature = 0,
    topP = 0,
    maxTokens = 1,
    frequencyPenalty = 0,
    countPenalty = 0,
    presencePenalty = 0,
    stopSequences = [],
  } = params;

  const customType = customModelsMap[modelName];

  return {
    prompt,
    model: modelName,
    temperature,
    topP,
    maxTokens,
    stopSequences,
    customType,
    numStopSequences: stopSequences.length,
    promptLength: prompt.length,
    numResults: 1,
    topKReturn: 0,
    frequencyPenalty,
    presencePenalty,
    countPenalty: countPenalty,
  };
};
