Skip to content

Models

Bases: ABC

Source code in autolabel/src/autolabel/models/base.py
class BaseModel(ABC):
    TTL_MS = 60 * 60 * 24 * 7 * 1000  # 1 week
    DEFAULT_CONTEXT_LENGTH = None

    def __init__(self, config: AutolabelConfig, cache: BaseCache) -> None:
        self.config = config
        self.cache = cache
        self.model_params = config.model_params()
        # Specific classes that implement this interface should run initialization steps here
        # E.g. initializing the LLM model with required parameters from ModelConfig

    async def label(self, prompts: List[str]) -> RefuelLLMResult:
        """Label a list of prompts."""
        existing_prompts = {}
        missing_prompt_idxs = list(range(len(prompts)))
        missing_prompts = prompts
        costs = []
        errors = [None for i in range(len(prompts))]
        latencies = [0 for i in range(len(prompts))]
        if self.cache:
            (
                existing_prompts,
                missing_prompt_idxs,
                missing_prompts,
            ) = self.get_cached_prompts(prompts)
        # label missing prompts
        if len(missing_prompts) > 0:
            if hasattr(self, "_alabel"):
                new_results = await self._alabel(missing_prompts)
            else:
                new_results = self._label(missing_prompts)
            for ind, prompt in enumerate(missing_prompts):
                costs.append(
                    self.get_cost(prompt, label=new_results.generations[ind][0].text)
                )

            # Set the existing prompts to the new results
            for i, result, error, latency in zip(
                missing_prompt_idxs,
                new_results.generations,
                new_results.errors,
                new_results.latencies,
            ):
                existing_prompts[i] = result
                errors[i] = error
                latencies[i] = latency

            if self.cache:
                self.update_cache(missing_prompt_idxs, new_results, prompts)
        generations = [existing_prompts[i] for i in range(len(prompts))]
        return RefuelLLMResult(
            generations=generations, costs=costs, errors=errors, latencies=latencies
        )

    async def _alabel_individually(self, prompts: List[str]) -> RefuelLLMResult:
        """Label each prompt individually. Should be used only after trying as a batch first.

        Args:
            prompts (List[str]): List of prompts to label

        Returns:
            LLMResult: LLMResult object with generations
            List[LabelingError]: List of errors encountered while labeling
        """
        generations = []
        errors = []
        latencies = []
        for prompt in prompts:
            try:
                start_time = time()
                response = await self.llm.agenerate([prompt])
                generations.append(response.generations[0])
                errors.append(None)
                latencies.append(time() - start_time)
            except Exception as e:
                print(f"Error generating from LLM: {e}")
                generations.append([Generation(text="")])
                errors.append(
                    LabelingError(
                        error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
                    )
                )
                latencies.append(0)

        return RefuelLLMResult(
            generations=generations, errors=errors, latencies=latencies
        )

    def _label_individually(self, prompts: List[str]) -> RefuelLLMResult:
        """Label each prompt individually. Should be used only after trying as a batch first.

        Args:
            prompts (List[str]): List of prompts to label

        Returns:
            LLMResult: LLMResult object with generations
            List[LabelingError]: List of errors encountered while labeling
        """
        generations = []
        errors = []
        latencies = []
        for prompt in prompts:
            try:
                start_time = time()
                response = self.llm.generate([prompt])
                generations.append(response.generations[0])
                errors.append(None)
                latencies.append(time() - start_time)
            except Exception as e:
                print(f"Error generating from LLM: {e}")
                generations.append([Generation(text="")])
                errors.append(
                    LabelingError(
                        error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
                    )
                )
                latencies.append(0)

        return RefuelLLMResult(
            generations=generations, errors=errors, latencies=latencies
        )

    @abstractmethod
    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        # TODO: change return type to do parsing in the Model class
        pass

    @abstractmethod
    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        pass

    def get_cached_prompts(self, prompts: List[str]) -> Optional[str]:
        """Get prompts that are already cached."""
        model_params_string = str(
            sorted([(k, v) for k, v in self.model_params.items()])
        )
        missing_prompts = []
        missing_prompt_idxs = []
        existing_prompts = {}
        for i, prompt in enumerate(prompts):
            cache_entry = GenerationCacheEntry(
                prompt=prompt,
                model_name=self.model_name,
                model_params=model_params_string,
            )
            cache_val = self.cache.lookup(cache_entry)
            if cache_val:
                existing_prompts[i] = cache_val
            else:
                missing_prompts.append(prompt)
                missing_prompt_idxs.append(i)
        return (
            existing_prompts,
            missing_prompt_idxs,
            missing_prompts,
        )

    def update_cache(self, missing_prompt_idxs, new_results, prompts):
        """Update the cache with new results."""
        model_params_string = str(
            sorted([(k, v) for k, v in self.model_params.items()])
        )

        for i, result, error in zip(
            missing_prompt_idxs, new_results.generations, new_results.errors
        ):
            # If there was an error, don't cache the result
            if error is not None:
                continue

            cache_entry = GenerationCacheEntry(
                prompt=prompts[i],
                model_name=self.model_name,
                model_params=model_params_string,
                generations=result,
                ttl_ms=self.TTL_MS,
            )
            self.cache.update(cache_entry)

    @abstractmethod
    def returns_token_probs(self) -> bool:
        """Whether the LLM supports returning logprobs of generated tokens

        Returns:
            bool: whether the LLM returns supports returning logprobs of generated tokens
        """
        pass

    @abstractmethod
    def get_num_tokens(self, prompt: str) -> int:
        """
        Get the number of tokens in the prompt"""
        pass

get_cached_prompts(prompts)

Get prompts that are already cached.

Source code in autolabel/src/autolabel/models/base.py
def get_cached_prompts(self, prompts: List[str]) -> Optional[str]:
    """Get prompts that are already cached."""
    model_params_string = str(
        sorted([(k, v) for k, v in self.model_params.items()])
    )
    missing_prompts = []
    missing_prompt_idxs = []
    existing_prompts = {}
    for i, prompt in enumerate(prompts):
        cache_entry = GenerationCacheEntry(
            prompt=prompt,
            model_name=self.model_name,
            model_params=model_params_string,
        )
        cache_val = self.cache.lookup(cache_entry)
        if cache_val:
            existing_prompts[i] = cache_val
        else:
            missing_prompts.append(prompt)
            missing_prompt_idxs.append(i)
    return (
        existing_prompts,
        missing_prompt_idxs,
        missing_prompts,
    )

get_num_tokens(prompt) abstractmethod

Get the number of tokens in the prompt

Source code in autolabel/src/autolabel/models/base.py
@abstractmethod
def get_num_tokens(self, prompt: str) -> int:
    """
    Get the number of tokens in the prompt"""
    pass

label(prompts) async

Label a list of prompts.

Source code in autolabel/src/autolabel/models/base.py
async def label(self, prompts: List[str]) -> RefuelLLMResult:
    """Label a list of prompts."""
    existing_prompts = {}
    missing_prompt_idxs = list(range(len(prompts)))
    missing_prompts = prompts
    costs = []
    errors = [None for i in range(len(prompts))]
    latencies = [0 for i in range(len(prompts))]
    if self.cache:
        (
            existing_prompts,
            missing_prompt_idxs,
            missing_prompts,
        ) = self.get_cached_prompts(prompts)
    # label missing prompts
    if len(missing_prompts) > 0:
        if hasattr(self, "_alabel"):
            new_results = await self._alabel(missing_prompts)
        else:
            new_results = self._label(missing_prompts)
        for ind, prompt in enumerate(missing_prompts):
            costs.append(
                self.get_cost(prompt, label=new_results.generations[ind][0].text)
            )

        # Set the existing prompts to the new results
        for i, result, error, latency in zip(
            missing_prompt_idxs,
            new_results.generations,
            new_results.errors,
            new_results.latencies,
        ):
            existing_prompts[i] = result
            errors[i] = error
            latencies[i] = latency

        if self.cache:
            self.update_cache(missing_prompt_idxs, new_results, prompts)
    generations = [existing_prompts[i] for i in range(len(prompts))]
    return RefuelLLMResult(
        generations=generations, costs=costs, errors=errors, latencies=latencies
    )

returns_token_probs() abstractmethod

Whether the LLM supports returning logprobs of generated tokens

Returns:

Name Type Description
bool bool

whether the LLM returns supports returning logprobs of generated tokens

Source code in autolabel/src/autolabel/models/base.py
@abstractmethod
def returns_token_probs(self) -> bool:
    """Whether the LLM supports returning logprobs of generated tokens

    Returns:
        bool: whether the LLM returns supports returning logprobs of generated tokens
    """
    pass

update_cache(missing_prompt_idxs, new_results, prompts)

Update the cache with new results.

Source code in autolabel/src/autolabel/models/base.py
def update_cache(self, missing_prompt_idxs, new_results, prompts):
    """Update the cache with new results."""
    model_params_string = str(
        sorted([(k, v) for k, v in self.model_params.items()])
    )

    for i, result, error in zip(
        missing_prompt_idxs, new_results.generations, new_results.errors
    ):
        # If there was an error, don't cache the result
        if error is not None:
            continue

        cache_entry = GenerationCacheEntry(
            prompt=prompts[i],
            model_name=self.model_name,
            model_params=model_params_string,
            generations=result,
            ttl_ms=self.TTL_MS,
        )
        self.cache.update(cache_entry)

ModelFactory

The ModelFactory class is used to create a BaseModel object from the given AutoLabelConfig configuration.

Source code in autolabel/src/autolabel/models/__init__.py
class ModelFactory:
    """The ModelFactory class is used to create a BaseModel object from the given AutoLabelConfig configuration."""

    @staticmethod
    def from_config(config: AutolabelConfig, cache: BaseCache = None) -> BaseModel:
        """
        Returns a BaseModel object configured with the settings found in the provided AutolabelConfig.
        Args:
            config: AutolabelConfig object containing project settings
            cache: cache allows for saving results in between labeling runs for future use
        Returns:
            model: a fully configured BaseModel object
        """
        provider = ModelProvider(config.provider())
        try:
            model_cls = MODEL_REGISTRY[provider]
            model_obj = model_cls(config=config, cache=cache)
            # The below ensures that users should based off of the BaseModel
            # when creating/registering custom models.
            assert isinstance(
                model_obj, BaseModel
            ), f"{model_obj} should inherit from autolabel.models.BaseModel"
        except KeyError as e:
            # We should never get here as the config should have already
            # been validated by the pydantic model.
            logger.error(
                f"{config.provider()} is not in the list of supported providers: \
                {list(ModelProvider.__members__.keys())}"
            )
            raise e

        return model_obj

from_config(config, cache=None) staticmethod

Returns a BaseModel object configured with the settings found in the provided AutolabelConfig. Args: config: AutolabelConfig object containing project settings cache: cache allows for saving results in between labeling runs for future use Returns: model: a fully configured BaseModel object

Source code in autolabel/src/autolabel/models/__init__.py
@staticmethod
def from_config(config: AutolabelConfig, cache: BaseCache = None) -> BaseModel:
    """
    Returns a BaseModel object configured with the settings found in the provided AutolabelConfig.
    Args:
        config: AutolabelConfig object containing project settings
        cache: cache allows for saving results in between labeling runs for future use
    Returns:
        model: a fully configured BaseModel object
    """
    provider = ModelProvider(config.provider())
    try:
        model_cls = MODEL_REGISTRY[provider]
        model_obj = model_cls(config=config, cache=cache)
        # The below ensures that users should based off of the BaseModel
        # when creating/registering custom models.
        assert isinstance(
            model_obj, BaseModel
        ), f"{model_obj} should inherit from autolabel.models.BaseModel"
    except KeyError as e:
        # We should never get here as the config should have already
        # been validated by the pydantic model.
        logger.error(
            f"{config.provider()} is not in the list of supported providers: \
            {list(ModelProvider.__members__.keys())}"
        )
        raise e

    return model_obj

register_model(name, model_cls)

Register Model class

Source code in autolabel/src/autolabel/models/__init__.py
def register_model(name, model_cls):
    """Register Model class"""
    MODEL_REGISTRY[name] = model_cls

Bases: BaseModel

Source code in autolabel/src/autolabel/models/anthropic.py
class AnthropicLLM(BaseModel):
    DEFAULT_MODEL = "claude-instant-v1"
    DEFAULT_PARAMS = {
        "max_tokens_to_sample": 1000,
        "temperature": 0.0,
    }

    # Reference: https://cdn2.assets-servd.host/anthropic-website/production/images/apr-pricing-tokens.pdf
    COST_PER_PROMPT_TOKEN = {
        # $11.02 per million tokens
        "claude-v1": (11.02 / 1000000),
        "claude-instant-v1": (1.63 / 1000000),
    }
    COST_PER_COMPLETION_TOKEN = {
        # $32.68 per million tokens
        "claude-v1": (32.68 / 1000000),
        "claude-instant-v1": (5.51 / 1000000),
    }

    def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
        super().__init__(config, cache)

        try:
            from langchain.chat_models import ChatAnthropic
            from anthropic._tokenizers import sync_get_tokenizer
        except ImportError:
            raise ImportError(
                "anthropic is required to use the anthropic LLM. Please install it with the following command: pip install 'refuel-autolabel[anthropic]'"
            )

        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL
        # populate model params
        model_params = config.model_params()
        self.model_params = {**self.DEFAULT_PARAMS, **model_params}
        # initialize LLM
        self.llm = ChatAnthropic(model=self.model_name, **self.model_params)

        self.tokenizer = sync_get_tokenizer()

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        prompts = [[HumanMessage(content=prompt)] for prompt in prompts]
        try:
            start_time = time()
            result = self.llm.generate(prompts)
            end_time = time()
            return RefuelLLMResult(
                generations=result.generations,
                errors=[None] * len(result.generations),
                latencies=[end_time - start_time] * len(result.generations),
            )
        except Exception as e:
            return self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        num_prompt_toks = len(self.tokenizer.encode(prompt).ids)
        if label:
            num_label_toks = len(self.tokenizer.encode(label).ids)
        else:
            # get an upper bound
            num_label_toks = self.model_params["max_tokens_to_sample"]

        cost_per_prompt_token = self.COST_PER_PROMPT_TOKEN[self.model_name]
        cost_per_completion_token = self.COST_PER_COMPLETION_TOKEN[self.model_name]
        return (num_prompt_toks * cost_per_prompt_token) + (
            num_label_toks * cost_per_completion_token
        )

    def returns_token_probs(self) -> bool:
        return False

    def get_num_tokens(self, prompt: str) -> int:
        return len(self.tokenizer.encode(prompt).ids)

Bases: BaseModel

Source code in autolabel/src/autolabel/models/hf_pipeline.py
class HFPipelineLLM(BaseModel):
    DEFAULT_MODEL = "google/flan-t5-xxl"
    DEFAULT_PARAMS = {"temperature": 0.0, "quantize": 8}

    def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
        super().__init__(config, cache)

        from langchain.llms import HuggingFacePipeline

        try:
            from transformers import (
                AutoConfig,
                AutoModelForSeq2SeqLM,
                AutoModelForCausalLM,
                AutoTokenizer,
                pipeline,
            )
            from transformers.models.auto.modeling_auto import (
                MODEL_FOR_CAUSAL_LM_MAPPING,
                MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
            )
        except ImportError:
            raise ValueError(
                "Could not import transformers python package. "
                "Please it install it with `pip install transformers`."
            )

        try:
            import torch
        except ImportError:
            raise ValueError(
                "Could not import torch package. "
                "Please it install it with `pip install torch`."
            )
        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL

        # populate model params
        model_params = config.model_params()
        self.model_params = {**self.DEFAULT_PARAMS, **model_params}
        if config.logit_bias() != 0:
            self.model_params = {
                **self._generate_sequence_bias(),
                **self.model_params,
            }

        # initialize HF pipeline
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, use_fast=False, add_prefix_space=True
        )
        quantize_bits = self.model_params["quantize"]
        model_config = AutoConfig.from_pretrained(self.model_name)
        if isinstance(model_config, tuple(MODEL_FOR_CAUSAL_LM_MAPPING)):
            AutoModel = AutoModelForCausalLM
        elif isinstance(model_config, tuple(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)):
            AutoModel = AutoModelForSeq2SeqLM
        else:
            raise ValueError(
                "model_name is neither a causal LM nor a seq2seq LM. Please check the model_name."
            )

        if not torch.cuda.is_available():
            model = AutoModel.from_pretrained(self.model_name)
        elif quantize_bits == 8:
            model = AutoModel.from_pretrained(
                self.model_name, load_in_8bit=True, device_map="auto"
            )
        elif quantize_bits == "16":
            model = AutoModel.from_pretrained(
                self.model_name, torch_dtype=torch.float16, device_map="auto"
            )
        else:
            model = AutoModel.from_pretrained(self.model_name, device_map="auto")

        model_kwargs = dict(self.model_params)  # make a copy of the model params
        model_kwargs.pop("quantize", None)  # remove quantize from the model params
        pipe = pipeline(
            "text2text-generation",
            model=model,
            tokenizer=self.tokenizer,
            **model_kwargs,
        )

        # initialize LLM
        self.llm = HuggingFacePipeline(pipeline=pipe, model_kwargs=model_kwargs)

    def _generate_sequence_bias(self) -> Dict:
        """Generates sequence bias dict to add to the config for the labels specified

        Returns:
            Dict: sequence bias, max new tokens, and num beams
        """
        if len(self.config.labels_list()) == 0:
            logger.warning(
                "No labels specified in the config. Skipping logit bias generation."
            )
            return {}
        try:
            from transformers import AutoTokenizer
        except ImportError:
            raise ValueError(
                "Could not import transformers python package. "
                "Please it install it with `pip install transformers`."
            )
        sequence_bias = {tuple([self.tokenizer.eos_token_id]): self.config.logit_bias()}
        max_new_tokens = 0
        for label in self.config.labels_list():
            tokens = tuple(
                self.tokenizer([label], add_special_tokens=False).input_ids[0]
            )
            for token in tokens:
                sequence_bias[tuple([token])] = self.config.logit_bias()
            max_new_tokens = max(max_new_tokens, len(tokens))

        return {
            "sequence_bias": sequence_bias,
            "max_new_tokens": max_new_tokens,
        }

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        try:
            start_time = time()
            result = self.llm.generate(prompts)
            end_time = time()
            return RefuelLLMResult(
                generations=result.generations,
                errors=[None] * len(result.generations),
                latencies=[end_time - start_time] * len(result.generations),
            )
        except Exception as e:
            return self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        # Model inference for this model is being run locally
        # Revisit this in the future when we support HF inference endpoints
        return 0.0

    def returns_token_probs(self) -> bool:
        return False

    def get_num_tokens(self, prompt: str) -> int:
        return len(self.tokenizer.encode(prompt))

Bases: BaseModel

Source code in autolabel/src/autolabel/models/openai.py
class OpenAILLM(BaseModel):
    CHAT_ENGINE_MODELS = [
        "gpt-3.5-turbo",
        "gpt-3.5-turbo-0301",
        "gpt-3.5-turbo-0613",
        "gpt-3.5-turbo-16k",
        "gpt-3.5-turbo-16k-0613",
        "gpt-4",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k",
        "gpt-4-32k-0613",
        "gpt-4-1106-preview",
        "gpt-4-0125-preview",
    ]
    MODELS_WITH_TOKEN_PROBS = [
        "text-curie-001",
        "text-davinci-003",
        "gpt-3.5-turbo",
        "gpt-3.5-turbo-0301",
        "gpt-3.5-turbo-0613",
        "gpt-3.5-turbo-16k",
        "gpt-3.5-turbo-16k-0613",
        "gpt-4",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k",
        "gpt-4-32k-0613",
        "gpt-4-1106-preview",
        "gpt-4-0125-preview",
    ]
    JSON_MODE_MODELS = [
        "gpt-3.5-turbo-0125",
        "gpt-3.5-turbo",
        "gpt-4-0125-preview",
        "gpt-4-1106-preview",
        "gpt-4-turbo-preview",
    ]

    # Default parameters for OpenAILLM
    DEFAULT_MODEL = "gpt-3.5-turbo"
    DEFAULT_PARAMS_COMPLETION_ENGINE = {
        "max_tokens": 1000,
        "temperature": 0.0,
        "model_kwargs": {"logprobs": 1},
        "request_timeout": 30,
    }
    DEFAULT_PARAMS_CHAT_ENGINE = {
        "max_tokens": 1000,
        "temperature": 0.0,
        "request_timeout": 30,
    }
    DEFAULT_QUERY_PARAMS_CHAT_ENGINE = {"logprobs": True, "top_logprobs": 1}

    # Reference: https://openai.com/pricing
    COST_PER_PROMPT_TOKEN = {
        "text-davinci-003": 0.02 / 1000,
        "text-curie-001": 0.002 / 1000,
        "gpt-3.5-turbo": 0.0015 / 1000,
        "gpt-3.5-turbo-0301": 0.0015 / 1000,
        "gpt-3.5-turbo-0613": 0.0015 / 1000,
        "gpt-3.5-turbo-16k": 0.003 / 1000,
        "gpt-3.5-turbo-16k-0613": 0.003 / 1000,
        "gpt-4": 0.03 / 1000,
        "gpt-4-0613": 0.03 / 1000,
        "gpt-4-32k": 0.06 / 1000,
        "gpt-4-32k-0613": 0.06 / 1000,
        "gpt-4-0314": 0.03 / 1000,
        "gpt-4-32k-0314": 0.06 / 1000,
        "gpt-4-1106-preview": 0.01 / 1000,
        "gpt-4-0125-preview": 0.01 / 1000,
    }
    COST_PER_COMPLETION_TOKEN = {
        "text-davinci-003": 0.02 / 1000,
        "text-curie-001": 0.002 / 1000,
        "gpt-3.5-turbo": 0.002 / 1000,
        "gpt-3.5-turbo-0301": 0.002 / 1000,
        "gpt-3.5-turbo-0613": 0.002 / 1000,
        "gpt-3.5-turbo-16k": 0.004 / 1000,
        "gpt-3.5-turbo-16k-0613": 0.004 / 1000,
        "gpt-4": 0.06 / 1000,
        "gpt-4-0613": 0.06 / 1000,
        "gpt-4-32k": 0.12 / 1000,
        "gpt-4-32k-0613": 0.12 / 1000,
        "gpt-4-0314": 0.06 / 1000,
        "gpt-4-32k-0314": 0.12 / 1000,
        "gpt-4-1106-preview": 0.03 / 1000,
        "gpt-4-0125-preview": 0.03 / 1000,
    }

    @cached_property
    def _engine(self) -> str:
        if self.model_name is not None and self.model_name in self.CHAT_ENGINE_MODELS:
            return "chat"
        else:
            return "completion"

    def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
        super().__init__(config, cache)
        try:
            import tiktoken
            from langchain.chat_models import ChatOpenAI
            from langchain.llms import OpenAI
        except ImportError:
            raise ImportError(
                "openai is required to use the OpenAILLM. Please install it with the following command: pip install 'refuel-autolabel[openai]'"
            )
        self.tiktoken = tiktoken
        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL

        if os.getenv("OPENAI_API_KEY") is None:
            raise ValueError("OPENAI_API_KEY environment variable not set")

        # populate model params and initialize the LLM
        model_params = config.model_params()
        if config.logit_bias() != 0:
            model_params = {
                **model_params,
                **self._generate_logit_bias(),
            }

        if self._engine == "chat":
            self.model_params = {**self.DEFAULT_PARAMS_CHAT_ENGINE, **model_params}
            self.query_params = self.DEFAULT_QUERY_PARAMS_CHAT_ENGINE
            self.llm = ChatOpenAI(
                model_name=self.model_name, verbose=False, **self.model_params
            )
            if config.json_mode():
                if self.model_name not in self.JSON_MODE_MODELS:
                    logger.warning(
                        f"json_mode is not supported for model {self.model_name}. Disabling json_mode."
                    )
                else:
                    self.query_params["response_format"] = {"type": "json_object"}
        else:
            self.model_params = {
                **self.DEFAULT_PARAMS_COMPLETION_ENGINE,
                **model_params,
            }
            self.llm = OpenAI(
                model_name=self.model_name, verbose=False, **self.model_params
            )

    def _generate_logit_bias(self) -> None:
        """Generates logit bias for the labels specified in the config

        Returns:
            Dict: logit bias and max tokens
        """
        if len(self.config.labels_list()) == 0:
            logger.warning(
                "No labels specified in the config. Skipping logit bias generation."
            )
            return {}
        encoding = self.tiktoken.encoding_for_model(self.model_name)
        logit_bias = {}
        max_tokens = 0
        for label in self.config.labels_list():
            if label not in logit_bias:
                tokens = encoding.encode(label)
                for token in tokens:
                    logit_bias[token] = self.config.logit_bias()
                max_tokens = max(max_tokens, len(tokens))
        logit_bias[encoding.eot_token] = self.config.logit_bias()
        return {"logit_bias": logit_bias, "max_tokens": max_tokens}

    def _chat_backward_compatibility(
        self, generations: List[LLMResult]
    ) -> List[LLMResult]:
        for generation_options in generations:
            for curr_generation in generation_options:
                generation_info = curr_generation.generation_info
                new_logprobs = {"top_logprobs": []}
                for curr_token in generation_info["logprobs"]["content"]:
                    new_logprobs["top_logprobs"].append(
                        {curr_token["token"]: curr_token["logprob"]}
                    )
                curr_generation.generation_info["logprobs"] = new_logprobs
        return generations

    async def _alabel(self, prompts: List[str]) -> RefuelLLMResult:
        try:
            start_time = time()
            if self._engine == "chat":
                prompts = [[HumanMessage(content=prompt)] for prompt in prompts]
                result = await self.llm.agenerate(prompts, **self.query_params)
                generations = self._chat_backward_compatibility(result.generations)
            else:
                result = await self.llm.agenerate(prompts)
                generations = result.generations
            end_time = time()
            return RefuelLLMResult(
                generations=generations,
                errors=[None] * len(generations),
                latencies=[end_time - start_time] * len(generations),
            )
        except Exception as e:
            return await self._alabel_individually(prompts)

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        try:
            start_time = time()
            if self._engine == "chat":
                prompts = [[HumanMessage(content=prompt)] for prompt in prompts]
                result = self.llm.generate(prompts, **self.query_params)
                generations = self._chat_backward_compatibility(result.generations)
            else:
                result = self.llm.generate(prompts)
                generations = result.generations
            end_time = time()
            return RefuelLLMResult(
                generations=generations,
                errors=[None] * len(generations),
                latencies=[end_time - start_time] * len(generations),
            )
        except Exception as e:
            return self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        encoding = self.tiktoken.encoding_for_model(self.model_name)
        num_prompt_toks = len(encoding.encode(prompt))
        if label:
            num_label_toks = len(encoding.encode(label))
        else:
            # get an upper bound
            num_label_toks = self.model_params["max_tokens"]

        cost_per_prompt_token = self.COST_PER_PROMPT_TOKEN[self.model_name]
        cost_per_completion_token = self.COST_PER_COMPLETION_TOKEN[self.model_name]
        return (num_prompt_toks * cost_per_prompt_token) + (
            num_label_toks * cost_per_completion_token
        )

    def returns_token_probs(self) -> bool:
        return (
            self.model_name is not None
            and self.model_name in self.MODELS_WITH_TOKEN_PROBS
        )

    def get_num_tokens(self, prompt: str) -> int:
        encoding = self.tiktoken.encoding_for_model(self.model_name)
        return len(encoding.encode(prompt))

Bases: BaseModel

Source code in autolabel/src/autolabel/models/palm.py
class PaLMLLM(BaseModel):
    SEP_REPLACEMENT_TOKEN = "@@"
    CHAT_ENGINE_MODELS = ["chat-bison@001"]

    DEFAULT_MODEL = "text-bison@001"
    # Reference: https://developers.generativeai.google/guide/concepts#model_parameters for "A token is approximately 4 characters"
    DEFAULT_PARAMS = {"temperature": 0, "max_output_tokens": 1000}

    # Reference: https://cloud.google.com/vertex-ai/pricing
    COST_PER_CHARACTER = {
        "text-bison@001": 0.001 / 1000,
        "chat-bison@001": 0.0005 / 1000,
        "textembedding-gecko@001": 0.0001 / 1000,
    }

    @cached_property
    def _engine(self) -> str:
        if self.model_name is not None and self.model_name in self.CHAT_ENGINE_MODELS:
            return "chat"
        else:
            return "completion"

    def __init__(
        self,
        config: AutolabelConfig,
        cache: BaseCache = None,
    ) -> None:
        super().__init__(config, cache)
        try:
            from langchain.chat_models import ChatVertexAI
            from langchain.llms import VertexAI
            import tiktoken
        except ImportError:
            raise ImportError(
                "palm is required to use the Palm LLM. Please install it with the following command: pip install 'refuel-autolabel[google]'"
            )

        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL

        # populate model params and initialize the LLM
        model_params = config.model_params()
        self.model_params = {
            **self.DEFAULT_PARAMS,
            **model_params,
        }
        if self._engine == "chat":
            self.llm = ChatVertexAI(model_name=self.model_name, **self.model_params)
        else:
            self.llm = VertexAI(model_name=self.model_name, **self.model_params)
        self.tiktoken = tiktoken

    @retry(
        reraise=True,
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )
    def _label_with_retry(self, prompts: List[str]) -> LLMResult:
        start_time = time()
        response = self.llm.generate(prompts)
        return response, time() - start_time

    def _label_individually(self, prompts: List[str]) -> RefuelLLMResult:
        """Label each prompt individually. Should be used only after trying as a batch first.

        Args:
            prompts (List[str]): List of prompts to label

        Returns:
            RefuelLLMResult: RefuelLLMResult object
        """
        generations = []
        errors = []
        latencies = []
        for i, prompt in enumerate(prompts):
            try:
                response, latency = self._label_with_retry([prompt])
                for generation in response.generations[0]:
                    generation.text = generation.text.replace(
                        self.SEP_REPLACEMENT_TOKEN, "\n"
                    )
                generations.append(response.generations[0])
                errors.append(None)
                latencies.append(latency)
            except Exception as e:
                print(f"Error generating from LLM: {e}, returning empty generation")
                generations.append([Generation(text="")])
                errors.append(
                    LabelingError(
                        error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
                    )
                )
                latencies.append(0)

        return RefuelLLMResult(
            generations=generations, errors=errors, latencies=latencies
        )

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        for prompt in prompts:
            if self.SEP_REPLACEMENT_TOKEN in prompt:
                logger.warning(
                    f"""Current prompt contains {self.SEP_REPLACEMENT_TOKEN} 
                                which is currently used as a separator token by refuel
                                llm. It is highly recommended to avoid having any
                                occurences of this substring in the prompt.
                            """
                )
        prompts = [
            prompt.replace("\n", self.SEP_REPLACEMENT_TOKEN) for prompt in prompts
        ]
        if self._engine == "chat":
            # Need to convert list[prompts] -> list[messages]
            # Currently the entire prompt is stuck into the "human message"
            # We might consider breaking this up into human vs system message in future
            prompts = [[HumanMessage(content=prompt)] for prompt in prompts]

        try:
            start_time = time()
            result = self._label_with_retry(prompts)
            end_time = time()
            for generations in result.generations:
                for generation in generations:
                    generation.text = generation.text.replace(
                        self.SEP_REPLACEMENT_TOKEN, "\n"
                    )
            return RefuelLLMResult(
                generations=result.generations,
                errors=[None] * len(result.generations),
                latencies=[end_time - start_time] * len(result.generations),
            )
        except Exception as e:
            return self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        if self.model_name is None:
            return 0.0
        cost_per_char = self.COST_PER_CHARACTER.get(self.model_name, 0.0)
        return cost_per_char * len(prompt) + cost_per_char * (
            len(label) if label else 4 * self.model_params["max_output_tokens"]
        )

    def returns_token_probs(self) -> bool:
        return False

    def get_num_tokens(self, prompt: str) -> int:
        # TODO(dhruva): Replace with actual tokenizer once that is available
        encoding = self.tiktoken.encoding_for_model("gpt2")
        return len(encoding.encode(prompt))

Bases: BaseModel

Source code in autolabel/src/autolabel/models/refuel.py
class RefuelLLM(BaseModel):
    DEFAULT_TOKENIZATION_MODEL = "NousResearch/Llama-2-13b-chat-hf"
    DEFAULT_CONTEXT_LENGTH = 3250
    DEFAULT_CONNECT_TIMEOUT = 10
    DEFAULT_READ_TIMEOUT = 120
    DEFAULT_PARAMS = {
        "max_new_tokens": 128,
    }

    def __init__(
        self,
        config: AutolabelConfig,
        cache: BaseCache = None,
    ) -> None:
        super().__init__(config, cache)
        try:
            from transformers import AutoTokenizer
        except Exception as e:
            raise Exception(
                "Unable to import transformers. Please install transformers to use RefuelLLM"
            )

        # populate model name
        # This is unused today, but in the future could
        # be used to decide which refuel model is queried
        self.model_name = config.model_name()
        model_params = config.model_params()
        self.model_params = {**self.DEFAULT_PARAMS, **model_params}
        self.tokenizer = AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZATION_MODEL)

        # initialize runtime
        self.BASE_API = f"https://llm.refuel.ai/models/{self.model_name}/generate"
        self.REFUEL_API_ENV = "REFUEL_API_KEY"
        if self.REFUEL_API_ENV in os.environ and os.environ[self.REFUEL_API_ENV]:
            self.REFUEL_API_KEY = os.environ[self.REFUEL_API_ENV]
        else:
            raise ValueError(
                f"Did not find {self.REFUEL_API_ENV}, please add an environment variable"
                f" `{self.REFUEL_API_ENV}` which contains it"
            )

    @retry(
        reraise=True,
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        before_sleep=before_sleep_log(logger, logging.WARNING),
        retry=retry_if_not_exception_type(UnretryableError),
    )
    def _label_with_retry(self, prompt: str) -> Tuple[requests.Response, float]:
        payload = {
            "input": prompt,
            "params": {**self.model_params},
            "confidence": self.config.confidence(),
        }
        headers = {"refuel_api_key": self.REFUEL_API_KEY}
        start_time = time()
        response = requests.post(
            self.BASE_API,
            json=payload,
            headers=headers,
            timeout=(self.DEFAULT_CONNECT_TIMEOUT, self.DEFAULT_READ_TIMEOUT),
        )
        end_time = time()
        # raise Exception if status != 200
        if response.status_code != 200:
            if response.status_code in UNRETRYABLE_ERROR_CODES:
                # This is a bad request, and we should not retry
                raise UnretryableError(
                    f"NonRetryable Error: Received status code {response.status_code} from Refuel API. Response: {response.text}"
                )

            logger.warning(
                f"Received status code {response.status_code} from Refuel API. Response: {response.text}"
            )
            response.raise_for_status()
        return response, end_time - start_time

    @retry(
        reraise=True,
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        before_sleep=before_sleep_log(logger, logging.WARNING),
        retry=retry_if_not_exception_type(UnretryableError),
    )
    async def _alabel_with_retry(self, prompt: str) -> Tuple[requests.Response, float]:
        payload = {
            "input": prompt,
            "params": {**self.model_params},
            "confidence": self.config.confidence(),
        }
        headers = {"refuel_api_key": self.REFUEL_API_KEY}
        async with httpx.AsyncClient() as client:
            timeout = httpx.Timeout(
                self.DEFAULT_CONNECT_TIMEOUT, read=self.DEFAULT_READ_TIMEOUT
            )
            start_time = time()
            response = await client.post(
                self.BASE_API, json=payload, headers=headers, timeout=timeout
            )
            end_time = time()
            # raise Exception if status != 200
            if response.status_code != 200:
                if response.status_code in UNRETRYABLE_ERROR_CODES:
                    # This is a bad request, and we should not retry
                    raise UnretryableError(
                        f"NonRetryable Error: Received status code {response.status_code} from Refuel API. Response: {response.text}"
                    )

                logger.warning(
                    f"Received status code {response.status_code} from Refuel API. Response: {response.text}"
                )
                response.raise_for_status()
            return response, end_time - start_time

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        generations = []
        errors = []
        latencies = []
        for prompt in prompts:
            try:
                response, latency = self._label_with_retry(prompt)
                response = json.loads(response.json())
                generations.append(
                    [
                        Generation(
                            text=response["generated_text"],
                            generation_info=(
                                {"logprobs": {"top_logprobs": response["logprobs"]}}
                                if self.config.confidence()
                                else None
                            ),
                        )
                    ]
                )
                errors.append(None)
                latencies.append(latency)
            except Exception as e:
                # This signifies an error in generating the response using RefuelLLm
                logger.error(
                    f"Unable to generate prediction: {e}",
                )
                generations.append([Generation(text="")])
                errors.append(
                    LabelingError(
                        error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
                    )
                )
                latencies.append(0)
        return RefuelLLMResult(
            generations=generations, errors=errors, latencies=latencies
        )

    async def _alabel(self, prompts: List[str]) -> RefuelLLMResult:
        generations = []
        errors = []
        latencies = []
        try:
            requests = [self._alabel_with_retry(prompt) for prompt in prompts]
            responses = await asyncio.gather(*requests)
            for response, latency in responses:
                response = json.loads(response.json())
                generations.append(
                    [
                        Generation(
                            text=response["generated_text"],
                            generation_info=(
                                {"logprobs": {"top_logprobs": response["logprobs"]}}
                                if self.config.confidence()
                                else None
                            ),
                        )
                    ]
                )
                errors.append(None)
                latencies.append(latency)
        except Exception as e:
            # This signifies an error in generating the response using RefuelLLm
            logger.error(
                f"Unable to generate prediction: {e}",
            )
            generations.append([Generation(text="")])
            errors.append(
                LabelingError(
                    error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
                )
            )
            latencies.append(0)
        return RefuelLLMResult(
            generations=generations, errors=errors, latencies=latencies
        )

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        return 0

    def returns_token_probs(self) -> bool:
        return True

    def get_num_tokens(self, prompt: str) -> int:
        return len(self.tokenizer.encode(prompt))