import { IModelParams, IModels } from '@ai21/studio-store';
import { get } from 'lodash';
import { CustomModelType } from '../../../../../web/src/data-types/Model';

export const getResponseDataFromJamba = (
  model: string,
  response: JSON,
  showAlternative: boolean
) => {
  const { finish_reason } = response?.data?.choices[0]?.finish_reason || {
    finish_reason: '',
  };
  const { content, role } = response?.data?.choices[0]?.message || {
    content: '',
    role: '',
  };
  const usage = response?.data?.usage || { prompt_tokens: 0, total_tokens: 0 };
  const { prompt_tokens, total_tokens } = usage;

  const completionId = response?.data?.id;

  return {
    completion: content,
    completionId,
    promptTokens: prompt_tokens,
    tokensCount: total_tokens,
    finishReason: finish_reason,
    reason: !!finish_reason,
  };
};

export const getResponseData = (
  model: string,
  response: JSON,
  showAlternative: boolean
) => {
  const newMessage = get(response, 'data.choices[0]');
  const isJambaModel =
    model === CustomModelType.JAMBA_INSTRUCT ||
    model === CustomModelType.JAMBA_1_5_LARGE ||
    model === CustomModelType.JAMBA_1_5_MINI;

  if (isJambaModel && newMessage) {
    return getResponseDataFromJamba(model, response, showAlternative);
  }

  const completion = showAlternative
    ? response?.data?.completions[0]?.data
    : response?.data?.completions[0]?.data?.text;
  const completionId = response?.data?.id;
  const promptTokens = response?.data?.prompt?.tokens;
  const finishReason = response?.data?.completions[0]?.finishReason;
  const reason = !!finishReason?.length;
  const tokensCount = response?.data?.completions[0]?.data?.tokens?.length || 0;

  return {
    completion,
    completionId,
    promptTokens,
    tokensCount,
    finishReason,
    reason,
  };
};

export const prepareRequestParams = (
  prompt: string,
  params: Partial<IModelParams & { modelId: string }>,
  models: IModels,
  showAlternativeTokens?: boolean
) => {
  let { modelId = '', stopSequences = [], epoch } = params;

  let customType: string = '';

  const isFoundationModel = [
    'j2-light',
    'j2-mid',
    'j2-ultra',
    'jamba-instruct-preview',
    'jamba-1.5-large',
    'jamba-1.5-mini',
  ].includes(modelId);

  if (isFoundationModel) {
    epoch = undefined;
  } else {
    const model = models[modelId];
    customType = model.customModelType ?? '';
    modelId = model.name;
  }

  if (epoch == 0) {
    delete params['epoch'];
  }

  const parsedParams = {
    ...params,
  };

  parsedParams.stopSequences = stopSequences.map(
    (seq: string) => mapStopSequences[seq] || seq
  );

  const isJambaModel =
    modelId === CustomModelType.JAMBA_INSTRUCT ||
    modelId === CustomModelType.JAMBA_1_5_LARGE ||
    modelId === CustomModelType.JAMBA_1_5_MINI;

  if (isJambaModel) {
    return {
      model: modelId,
      customType,
      numStopSequences: parsedParams.stopSequences.length,
      promptLength: prompt.length,
      topKReturn: showAlternativeTokens ? 10 : 0,
      n: 1,
      ...parsedParams,
      prompt,
      messages: [{ content: prompt, role: 'user' }],
      max_tokens: parsedParams.maxTokens || 4096,
      top_p: parsedParams.topP || 1.0,
      stop: parsedParams.stopSequences,
    };
  }

  return {
    model: modelId,
    customType,
    numStopSequences: parsedParams.stopSequences.length,
    promptLength: prompt.length,
    ...parsedParams,
    topKReturn: showAlternativeTokens ? 10 : 0,
    numResults: 1,
    prompt,
  };
};

const mapStopSequences: any = {
  Enter: '↵',
};
