import {
  Array,
  Boolean,
  Literal,
  Null,
  Number,
  Optional,
  Record as RRecord,
  Static,
  String,
  Union,
} from "runtypes";

import { getNormalEnum } from "../runtypeEnums";

export const OpenAiChatModelLiteral = Union(
  Literal("GPT_4_OMNI"),
  Literal("GPT_4_OMNI_MINI"),
);
export type OpenAiChatModel = Static<typeof OpenAiChatModelLiteral>;
export const OpenAiChatModel = getNormalEnum(OpenAiChatModelLiteral);

export const OpenAiChatModelNameMap: Record<OpenAiChatModel, string> = {
  [OpenAiChatModel.GPT_4_OMNI]: "gpt-4o-2024-11-20",
  [OpenAiChatModel.GPT_4_OMNI_MINI]: "gpt-4o-mini-2024-07-18",
};

// OpenAI documentation of how they count tokens seems to be kinda off, so
// for now we'll just subtract 10 from the max tokens to be safe.
export const OPENAI_CHAT_MODEL_DETAILS: Record<
  OpenAiChatModel,
  | { maxTokens: number; tokensPerMessage: number }
  | {
      maxInputTokens: number;
      maxOutputTokens: number;
      tokensPerMessage: number;
    }
> = {
  [OpenAiChatModel.GPT_4_OMNI]: {
    maxInputTokens: 128000 - 4096 - 10,
    maxOutputTokens: 4096,
    tokensPerMessage: 3,
  },
  [OpenAiChatModel.GPT_4_OMNI_MINI]: {
    maxInputTokens: 128000 - 16384 - 10,
    maxOutputTokens: 16384,
    tokensPerMessage: 3,
  },
};

export const CustomOpenAiModelConfig = RRecord({
  model: String,
  maxInputTokens: Number,
  maxOutputTokens: Optional(Number), // Not used by reasoning models
});
export type CustomOpenAiModelConfig = Static<typeof CustomOpenAiModelConfig>;

export const OpenAiModelParams = RRecord({
  frequencyPenalty: Number,
  n: Number,
  presencePenalty: Number,
  stopWords: Array(String),
  temperature: Optional(Number), // Not used by reasoning models
  topP: Number,
  reasoningEffort: Optional(
    Union(Literal("low"), Literal("medium"), Literal("high")),
  ), // ONLY used by reasoning models
  model: Union(OpenAiChatModelLiteral, CustomOpenAiModelConfig),
});
export type OpenAiModelParams = Static<typeof OpenAiModelParams>;

// Based on `ChatCompletionRequestMessageRoleEnum` from openai package
export const OpenAiChatMessageRoleLiteral = Union(
  Literal("system"),
  Literal("user"),
  Literal("assistant"),
);
export type OpenAiChatMessageRole = Static<typeof OpenAiChatMessageRoleLiteral>;
export const OpenAiChatMessageRole = getNormalEnum(
  OpenAiChatMessageRoleLiteral,
);

// based on `ChatCompletionContentPartImage` from openai package
const OpenAiChatContentPartImage = RRecord({
  type: Literal("image_url"),
  image_url: RRecord({
    url: String,
  }),
});

// Based on `ChatCompletionRequestMessage` from openai package
export const OpenAiChatMessage = Union(
  RRecord({
    content: String,
    role: Literal(OpenAiChatMessageRole.system),
  }),
  RRecord({
    content: String.Or(Array(OpenAiChatContentPartImage)),
    role: Literal(OpenAiChatMessageRole.user),
  }),
  RRecord({
    content: String.Or(Null),
    // OpenAI requires snake case
    function_call: Optional(RRecord({ arguments: String, name: String })),
    role: Literal(OpenAiChatMessageRole.assistant),
  }),
);
export type OpenAiChatMessage = Static<typeof OpenAiChatMessage>;

export const OpenAIConformerResult = RRecord({
  messages: Array(OpenAiChatMessage),
  totalContextTokens: Number,
  usedTokenCountApproximation: Boolean,
});
export type OpenAIConformerResult = Static<typeof OpenAIConformerResult>;
