Skip to content

Models

BaseModel

Bases: ABC

Source code in src/autolabel/models/base.py
class BaseModel(ABC):
    TTL_MS = 60 * 60 * 24 * 365 * 3 * 1000  # 3 years

    def __init__(self, config: AutolabelConfig, cache: BaseCache) -> None:
        self.config = config
        self.cache = cache
        self.model_params = config.model_params()
        # Specific classes that implement this interface should run initialization steps here
        # E.g. initializing the LLM model with required parameters from ModelConfig

    def label(self, prompts: List[str]) -> RefuelLLMResult:
        """Label a list of prompts."""
        existing_prompts = {}
        missing_prompt_idxs = list(range(len(prompts)))
        missing_prompts = prompts
        costs = []
        errors = [None for i in range(len(prompts))]
        if self.cache:
            (
                existing_prompts,
                missing_prompt_idxs,
                missing_prompts,
            ) = self.get_cached_prompts(prompts)

        # label missing prompts
        if len(missing_prompts) > 0:
            new_results = self._label(missing_prompts)
            for ind, prompt in enumerate(missing_prompts):
                costs.append(
                    self.get_cost(prompt, label=new_results.generations[ind][0].text)
                )

            # Set the existing prompts to the new results
            for i, result, error in zip(
                missing_prompt_idxs, new_results.generations, new_results.errors
            ):
                existing_prompts[i] = result
                errors[i] = error

            if self.cache:
                self.update_cache(missing_prompt_idxs, new_results, prompts)

        generations = [existing_prompts[i] for i in range(len(prompts))]
        return RefuelLLMResult(generations=generations, costs=costs, errors=errors)

    def _label_individually(self, prompts: List[str]) -> RefuelLLMResult:
        """Label each prompt individually. Should be used only after trying as a batch first.

        Args:
            prompts (List[str]): List of prompts to label

        Returns:
            LLMResult: LLMResult object with generations
            List[LabelingError]: List of errors encountered while labeling
        """
        generations = []
        errors = []
        for prompt in prompts:
            try:
                response = self.llm.generate([prompt])
                generations.append(response.generations[0])
                errors.append(None)
            except Exception as e:
                print(f"Error generating from LLM: {e}")
                generations.append([Generation(text="")])
                errors.append(
                    LabelingError(
                        error_type=ErrorType.LLM_PROVIDER_ERROR, error_message=str(e)
                    )
                )

        return RefuelLLMResult(generations=generations, errors=errors)

    @abstractmethod
    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        # TODO: change return type to do parsing in the Model class
        pass

    @abstractmethod
    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        pass

    def get_cached_prompts(self, prompts: List[str]) -> Optional[str]:
        """Get prompts that are already cached."""
        model_params_string = str(
            sorted([(k, v) for k, v in self.model_params.items()])
        )
        missing_prompts = []
        missing_prompt_idxs = []
        existing_prompts = {}
        for i, prompt in enumerate(prompts):
            cache_entry = GenerationCacheEntry(
                prompt=prompt,
                model_name=self.model_name,
                model_params=model_params_string,
            )
            cache_val = self.cache.lookup(cache_entry)
            if cache_val:
                existing_prompts[i] = cache_val
            else:
                missing_prompts.append(prompt)
                missing_prompt_idxs.append(i)
        return (
            existing_prompts,
            missing_prompt_idxs,
            missing_prompts,
        )

    def update_cache(self, missing_prompt_idxs, new_results, prompts):
        """Update the cache with new results."""
        model_params_string = str(
            sorted([(k, v) for k, v in self.model_params.items()])
        )

        for i, result, error in zip(
            missing_prompt_idxs, new_results.generations, new_results.errors
        ):
            # If there was an error, don't cache the result
            if error is not None:
                continue

            cache_entry = GenerationCacheEntry(
                prompt=prompts[i],
                model_name=self.model_name,
                model_params=model_params_string,
                generations=result,
                ttl_ms=self.TTL_MS,
            )
            self.cache.update(cache_entry)

    @abstractmethod
    def returns_token_probs(self) -> bool:
        """Whether the LLM supports returning logprobs of generated tokens

        Returns:
            bool: whether the LLM returns supports returning logprobs of generated tokens
        """
        pass

get_cached_prompts(prompts)

Get prompts that are already cached.

Source code in src/autolabel/models/base.py
def get_cached_prompts(self, prompts: List[str]) -> Optional[str]:
    """Get prompts that are already cached."""
    model_params_string = str(
        sorted([(k, v) for k, v in self.model_params.items()])
    )
    missing_prompts = []
    missing_prompt_idxs = []
    existing_prompts = {}
    for i, prompt in enumerate(prompts):
        cache_entry = GenerationCacheEntry(
            prompt=prompt,
            model_name=self.model_name,
            model_params=model_params_string,
        )
        cache_val = self.cache.lookup(cache_entry)
        if cache_val:
            existing_prompts[i] = cache_val
        else:
            missing_prompts.append(prompt)
            missing_prompt_idxs.append(i)
    return (
        existing_prompts,
        missing_prompt_idxs,
        missing_prompts,
    )

label(prompts)

Label a list of prompts.

Source code in src/autolabel/models/base.py
def label(self, prompts: List[str]) -> RefuelLLMResult:
    """Label a list of prompts."""
    existing_prompts = {}
    missing_prompt_idxs = list(range(len(prompts)))
    missing_prompts = prompts
    costs = []
    errors = [None for i in range(len(prompts))]
    if self.cache:
        (
            existing_prompts,
            missing_prompt_idxs,
            missing_prompts,
        ) = self.get_cached_prompts(prompts)

    # label missing prompts
    if len(missing_prompts) > 0:
        new_results = self._label(missing_prompts)
        for ind, prompt in enumerate(missing_prompts):
            costs.append(
                self.get_cost(prompt, label=new_results.generations[ind][0].text)
            )

        # Set the existing prompts to the new results
        for i, result, error in zip(
            missing_prompt_idxs, new_results.generations, new_results.errors
        ):
            existing_prompts[i] = result
            errors[i] = error

        if self.cache:
            self.update_cache(missing_prompt_idxs, new_results, prompts)

    generations = [existing_prompts[i] for i in range(len(prompts))]
    return RefuelLLMResult(generations=generations, costs=costs, errors=errors)

returns_token_probs() abstractmethod

Whether the LLM supports returning logprobs of generated tokens

Returns:

Name Type Description
bool bool

whether the LLM returns supports returning logprobs of generated tokens

Source code in src/autolabel/models/base.py
@abstractmethod
def returns_token_probs(self) -> bool:
    """Whether the LLM supports returning logprobs of generated tokens

    Returns:
        bool: whether the LLM returns supports returning logprobs of generated tokens
    """
    pass

update_cache(missing_prompt_idxs, new_results, prompts)

Update the cache with new results.

Source code in src/autolabel/models/base.py
def update_cache(self, missing_prompt_idxs, new_results, prompts):
    """Update the cache with new results."""
    model_params_string = str(
        sorted([(k, v) for k, v in self.model_params.items()])
    )

    for i, result, error in zip(
        missing_prompt_idxs, new_results.generations, new_results.errors
    ):
        # If there was an error, don't cache the result
        if error is not None:
            continue

        cache_entry = GenerationCacheEntry(
            prompt=prompts[i],
            model_name=self.model_name,
            model_params=model_params_string,
            generations=result,
            ttl_ms=self.TTL_MS,
        )
        self.cache.update(cache_entry)

ModelFactory

The ModelFactory class is used to create a BaseModel object from the given AutoLabelConfig configuration.

Source code in src/autolabel/models/__init__.py
class ModelFactory:
    """The ModelFactory class is used to create a BaseModel object from the given AutoLabelConfig configuration."""

    @staticmethod
    def from_config(config: AutolabelConfig, cache: BaseCache = None) -> BaseModel:
        """
        Returns a BaseModel object configured with the settings found in the provided AutolabelConfig.
        Args:
            config: AutolabelConfig object containing project settings
            cache: cache allows for saving results in between labeling runs for future use
        Returns:
            model: a fully configured BaseModel object
        """
        provider = ModelProvider(config.provider())
        try:
            model_cls = MODEL_REGISTRY[provider]
            model_obj = model_cls(config=config, cache=cache)
            # The below ensures that users should based off of the BaseModel
            # when creating/registering custom models.
            assert isinstance(
                model_obj, BaseModel
            ), f"{model_obj} should inherit from autolabel.models.BaseModel"
        except KeyError as e:
            # We should never get here as the config should have already
            # been validated by the pydantic model.
            logger.error(
                f"{config.provider()} is not in the list of supported providers: \
                {list(ModelProvider.__members__.keys())}"
            )
            raise e

        return model_obj

from_config(config, cache=None) staticmethod

Returns a BaseModel object configured with the settings found in the provided AutolabelConfig.

Parameters:

Name Type Description Default
config AutolabelConfig

AutolabelConfig object containing project settings

required
cache BaseCache

cache allows for saving results in between labeling runs for future use

None

Returns:

Name Type Description
model BaseModel

a fully configured BaseModel object

Source code in src/autolabel/models/__init__.py
@staticmethod
def from_config(config: AutolabelConfig, cache: BaseCache = None) -> BaseModel:
    """
    Returns a BaseModel object configured with the settings found in the provided AutolabelConfig.
    Args:
        config: AutolabelConfig object containing project settings
        cache: cache allows for saving results in between labeling runs for future use
    Returns:
        model: a fully configured BaseModel object
    """
    provider = ModelProvider(config.provider())
    try:
        model_cls = MODEL_REGISTRY[provider]
        model_obj = model_cls(config=config, cache=cache)
        # The below ensures that users should based off of the BaseModel
        # when creating/registering custom models.
        assert isinstance(
            model_obj, BaseModel
        ), f"{model_obj} should inherit from autolabel.models.BaseModel"
    except KeyError as e:
        # We should never get here as the config should have already
        # been validated by the pydantic model.
        logger.error(
            f"{config.provider()} is not in the list of supported providers: \
            {list(ModelProvider.__members__.keys())}"
        )
        raise e

    return model_obj

register_model(name, model_cls)

Register Model class

Source code in src/autolabel/models/__init__.py
def register_model(name, model_cls):
    """Register Model class"""
    MODEL_REGISTRY[name] = model_cls

AnthropicLLM

Bases: BaseModel

Source code in src/autolabel/models/anthropic.py
class AnthropicLLM(BaseModel):
    DEFAULT_MODEL = "claude-instant-v1"
    DEFAULT_PARAMS = {
        "max_tokens_to_sample": 1000,
        "temperature": 0.0,
    }

    # Reference: https://cdn2.assets-servd.host/anthropic-website/production/images/apr-pricing-tokens.pdf
    COST_PER_PROMPT_TOKEN = {
        # $11.02 per million tokens
        "claude-v1": (11.02 / 1000000),
        "claude-instant-v1": (1.63 / 1000000),
    }
    COST_PER_COMPLETION_TOKEN = {
        # $32.68 per million tokens
        "claude-v1": (32.68 / 1000000),
        "claude-instant-v1": (5.51 / 1000000),
    }

    def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
        super().__init__(config, cache)

        try:
            from langchain.chat_models import ChatAnthropic
            from anthropic import tokenizer
        except ImportError:
            raise ImportError(
                "anthropic is required to use the anthropic LLM. Please install it with the following command: pip install 'refuel-autolabel[anthropic]'"
            )

        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL
        # populate model params
        model_params = config.model_params()
        self.model_params = {**self.DEFAULT_PARAMS, **model_params}
        # initialize LLM
        self.llm = ChatAnthropic(model=self.model_name, **self.model_params)

        self.tokenizer = tokenizer

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        prompts = [[HumanMessage(content=prompt)] for prompt in prompts]
        try:
            result = self.llm.generate(prompts)
            return RefuelLLMResult(
                generations=result.generations, errors=[None] * len(result.generations)
            )
        except Exception as e:
            return self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        num_prompt_toks = self.tokenizer.count_tokens(prompt)
        if label:
            num_label_toks = self.tokenizer.count_tokens(label)
        else:
            # get an upper bound
            num_label_toks = self.model_params["max_tokens_to_sample"]

        cost_per_prompt_token = self.COST_PER_PROMPT_TOKEN[self.model_name]
        cost_per_completion_token = self.COST_PER_COMPLETION_TOKEN[self.model_name]
        return (num_prompt_toks * cost_per_prompt_token) + (
            num_label_toks * cost_per_completion_token
        )

    def returns_token_probs(self) -> bool:
        return False

HFPipelineLLM

Bases: BaseModel

Source code in src/autolabel/models/hf_pipeline.py
class HFPipelineLLM(BaseModel):
    DEFAULT_MODEL = "google/flan-t5-xxl"
    DEFAULT_PARAMS = {"temperature": 0.0, "quantize": 8}

    def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
        super().__init__(config, cache)

        from langchain.llms import HuggingFacePipeline

        try:
            from transformers import (
                AutoConfig,
                AutoModelForSeq2SeqLM,
                AutoModelForCausalLM,
                AutoTokenizer,
                pipeline,
            )
            from transformers.models.auto.modeling_auto import (
                MODEL_FOR_CAUSAL_LM_MAPPING,
                MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
            )
        except ImportError:
            raise ValueError(
                "Could not import transformers python package. "
                "Please it install it with `pip install transformers`."
            )

        try:
            import torch
        except ImportError:
            raise ValueError(
                "Could not import torch package. "
                "Please it install it with `pip install torch`."
            )
        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL

        # populate model params
        model_params = config.model_params()
        self.model_params = {**self.DEFAULT_PARAMS, **model_params}
        if config.logit_bias() != 0:
            self.model_params = {
                **self._generate_sequence_bias(),
                **self.model_params,
            }

        # initialize HF pipeline
        tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, use_fast=False, add_prefix_space=True
        )
        quantize_bits = self.model_params["quantize"]
        model_config = AutoConfig.from_pretrained(self.model_name)
        if isinstance(model_config, tuple(MODEL_FOR_CAUSAL_LM_MAPPING)):
            AutoModel = AutoModelForCausalLM
        elif isinstance(model_config, tuple(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)):
            AutoModel = AutoModelForSeq2SeqLM
        else:
            raise ValueError(
                "model_name is neither a causal LM nor a seq2seq LM. Please check the model_name."
            )

        if not torch.cuda.is_available():
            model = AutoModel.from_pretrained(self.model_name)
        elif quantize_bits == 8:
            model = AutoModel.from_pretrained(
                self.model_name, load_in_8bit=True, device_map="auto"
            )
        elif quantize_bits == "16":
            model = AutoModel.from_pretrained(
                self.model_name, torch_dtype=torch.float16, device_map="auto"
            )
        else:
            model = AutoModel.from_pretrained(self.model_name, device_map="auto")

        model_kwargs = dict(self.model_params)  # make a copy of the model params
        model_kwargs.pop("quantize", None)  # remove quantize from the model params
        pipe = pipeline(
            "text2text-generation",
            model=model,
            tokenizer=tokenizer,
            **model_kwargs,
        )

        # initialize LLM
        self.llm = HuggingFacePipeline(pipeline=pipe, model_kwargs=model_kwargs)

    def _generate_sequence_bias(self) -> Dict:
        """Generates sequence bias dict to add to the config for the labels specified

        Returns:
            Dict: sequence bias, max new tokens, and num beams
        """
        if len(self.config.labels_list()) == 0:
            logger.warning(
                "No labels specified in the config. Skipping logit bias generation."
            )
            return {}
        try:
            from transformers import AutoTokenizer
        except ImportError:
            raise ValueError(
                "Could not import transformers python package. "
                "Please it install it with `pip install transformers`."
            )
        tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, use_fast=False, add_prefix_space=True
        )
        sequence_bias = {tuple([tokenizer.eos_token_id]): self.config.logit_bias()}
        max_new_tokens = 0
        for label in self.config.labels_list():
            tokens = tuple(tokenizer([label], add_special_tokens=False).input_ids[0])
            for token in tokens:
                sequence_bias[tuple([token])] = self.config.logit_bias()
            max_new_tokens = max(max_new_tokens, len(tokens))

        return {
            "sequence_bias": sequence_bias,
            "max_new_tokens": max_new_tokens,
        }

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        try:
            result = self.llm.generate(prompts)
            return RefuelLLMResult(
                generations=result.generations, errors=[None] * len(result.generations)
            )
        except Exception as e:
            return self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        # Model inference for this model is being run locally
        # Revisit this in the future when we support HF inference endpoints
        return 0.0

    def returns_token_probs(self) -> bool:
        return False

OpenAILLM

Bases: BaseModel

Source code in src/autolabel/models/openai.py
class OpenAILLM(BaseModel):
    CHAT_ENGINE_MODELS = [
        "gpt-3.5-turbo",
        "gpt-3.5-turbo-0301",
        "gpt-3.5-turbo-0613",
        "gpt-3.5-turbo-16k",
        "gpt-3.5-turbo-16k-0613",
        "gpt-4",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k",
        "gpt-4-32k-0613",
    ]
    MODELS_WITH_TOKEN_PROBS = ["text-curie-001", "text-davinci-003"]

    # Default parameters for OpenAILLM
    DEFAULT_MODEL = "gpt-3.5-turbo"
    DEFAULT_PARAMS_COMPLETION_ENGINE = {
        "max_tokens": 1000,
        "temperature": 0.0,
        "model_kwargs": {"logprobs": 1},
    }
    DEFAULT_PARAMS_CHAT_ENGINE = {
        "max_tokens": 1000,
        "temperature": 0.0,
    }

    # Reference: https://openai.com/pricing
    COST_PER_PROMPT_TOKEN = {
        "text-davinci-003": 0.02 / 1000,
        "text-curie-001": 0.002 / 1000,
        "gpt-3.5-turbo": 0.0015 / 1000,
        "gpt-3.5-turbo-0301": 0.0015 / 1000,
        "gpt-3.5-turbo-0613": 0.0015 / 1000,
        "gpt-3.5-turbo-16k": 0.003 / 1000,
        "gpt-3.5-turbo-16k-0613": 0.003 / 1000,
        "gpt-4": 0.03 / 1000,
        "gpt-4-0613": 0.03 / 1000,
        "gpt-4-32k": 0.06 / 1000,
        "gpt-4-32k-0613": 0.06 / 1000,
        "gpt-4-0314": 0.03 / 1000,
        "gpt-4-32k-0314": 0.06 / 1000,
    }
    COST_PER_COMPLETION_TOKEN = {
        "text-davinci-003": 0.02 / 1000,
        "text-curie-001": 0.002 / 1000,
        "gpt-3.5-turbo": 0.002 / 1000,
        "gpt-3.5-turbo-0301": 0.002 / 1000,
        "gpt-3.5-turbo-0613": 0.002 / 1000,
        "gpt-3.5-turbo-16k": 0.004 / 1000,
        "gpt-3.5-turbo-16k-0613": 0.004 / 1000,
        "gpt-4": 0.06 / 1000,
        "gpt-4-0613": 0.06 / 1000,
        "gpt-4-32k": 0.12 / 1000,
        "gpt-4-32k-0613": 0.12 / 1000,
        "gpt-4-0314": 0.06 / 1000,
        "gpt-4-32k-0314": 0.12 / 1000,
    }

    @cached_property
    def _engine(self) -> str:
        if self.model_name is not None and self.model_name in self.CHAT_ENGINE_MODELS:
            return "chat"
        else:
            return "completion"

    def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
        super().__init__(config, cache)
        try:
            from langchain.chat_models import ChatOpenAI
            from langchain.llms import OpenAI
            import tiktoken
        except ImportError:
            raise ImportError(
                "anthropic is required to use the anthropic LLM. Please install it with the following command: pip install 'refuel-autolabel[openai]'"
            )

        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL

        if os.getenv("OPENAI_API_KEY") is None:
            raise ValueError("OPENAI_API_KEY environment variable not set")

        # populate model params and initialize the LLM
        model_params = config.model_params()
        if config.logit_bias() != 0:
            model_params = {
                **self._generate_logit_bias(),
                **model_params,
            }

        if self._engine == "chat":
            self.model_params = {**self.DEFAULT_PARAMS_CHAT_ENGINE, **model_params}
            self.llm = ChatOpenAI(model_name=self.model_name, **self.model_params)
        else:
            self.model_params = {
                **self.DEFAULT_PARAMS_COMPLETION_ENGINE,
                **model_params,
            }
            self.llm = OpenAI(model_name=self.model_name, **self.model_params)

        self.tiktoken = tiktoken

    def _generate_logit_bias(self) -> None:
        """Generates logit bias for the labels specified in the config

        Returns:
            Dict: logit bias and max tokens
        """
        if len(self.config.labels_list()) == 0:
            logger.warning(
                "No labels specified in the config. Skipping logit bias generation."
            )
            return {}
        encoding = self.tiktoken.encoding_for_model(self.model_name)
        logit_bias = {}
        max_tokens = 0
        for label in self.config.labels_list():
            if label not in logit_bias:
                tokens = encoding.encode(label)
                for token in tokens:
                    logit_bias[token] = self.config.logit_bias()
                max_tokens = max(max_tokens, len(tokens))

        return {"logit_bias": logit_bias, "max_tokens": max_tokens}

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        if self._engine == "chat":
            # Need to convert list[prompts] -> list[messages]
            # Currently the entire prompt is stuck into the "human message"
            # We might consider breaking this up into human vs system message in future
            prompts = [[HumanMessage(content=prompt)] for prompt in prompts]
        try:
            result = self.llm.generate(prompts)
            return RefuelLLMResult(
                generations=result.generations, errors=[None] * len(result.generations)
            )
        except Exception as e:
            return self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        encoding = self.tiktoken.encoding_for_model(self.model_name)
        num_prompt_toks = len(encoding.encode(prompt))
        if label:
            num_label_toks = len(encoding.encode(label))
        else:
            # get an upper bound
            num_label_toks = self.model_params["max_tokens"]

        cost_per_prompt_token = self.COST_PER_PROMPT_TOKEN[self.model_name]
        cost_per_completion_token = self.COST_PER_COMPLETION_TOKEN[self.model_name]
        return (num_prompt_toks * cost_per_prompt_token) + (
            num_label_toks * cost_per_completion_token
        )

    def returns_token_probs(self) -> bool:
        return (
            self.model_name is not None
            and self.model_name in self.MODELS_WITH_TOKEN_PROBS
        )

PaLMLLM

Bases: BaseModel

Source code in src/autolabel/models/palm.py
class PaLMLLM(BaseModel):
    SEP_REPLACEMENT_TOKEN = "@@"
    CHAT_ENGINE_MODELS = ["chat-bison@001"]

    DEFAULT_MODEL = "text-bison@001"
    # Reference: https://developers.generativeai.google/guide/concepts#model_parameters for "A token is approximately 4 characters"
    DEFAULT_PARAMS = {"temperature": 0, "max_output_tokens": 1000}

    # Reference: https://cloud.google.com/vertex-ai/pricing
    COST_PER_CHARACTER = {
        "text-bison@001": 0.001 / 1000,
        "chat-bison@001": 0.0005 / 1000,
        "textembedding-gecko@001": 0.0001 / 1000,
    }

    @cached_property
    def _engine(self) -> str:
        if self.model_name is not None and self.model_name in self.CHAT_ENGINE_MODELS:
            return "chat"
        else:
            return "completion"

    def __init__(
        self,
        config: AutolabelConfig,
        cache: BaseCache = None,
    ) -> None:
        super().__init__(config, cache)
        try:
            from langchain.chat_models import ChatVertexAI
            from langchain.llms import VertexAI
        except ImportError:
            raise ImportError(
                "palm is required to use the Palm LLM. Please install it with the following command: pip install 'refuel-autolabel[google]'"
            )

        # populate model name
        self.model_name = config.model_name() or self.DEFAULT_MODEL

        # populate model params and initialize the LLM
        model_params = config.model_params()
        self.model_params = {
            **self.DEFAULT_PARAMS,
            **model_params,
        }
        if self._engine == "chat":
            self.llm = ChatVertexAI(model_name=self.model_name, **self.model_params)
        else:
            self.llm = VertexAI(model_name=self.model_name, **self.model_params)

    @retry(
        reraise=True,
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )
    def _label_with_retry(self, prompts: List[str]) -> LLMResult:
        return self.llm.generate(prompts)

    def _label_individually(self, prompts: List[str]) -> LLMResult:
        """Label each prompt individually. Should be used only after trying as a batch first.

        Args:
            prompts (List[str]): List of prompts to label

        Returns:
            LLMResult: LLMResult object with generations
        """
        generations = []
        for i, prompt in enumerate(prompts):
            try:
                response = self._label_with_retry([prompt])
                for generation in response.generations[0]:
                    generation.text = generation.text.replace(
                        self.SEP_REPLACEMENT_TOKEN, "\n"
                    )
                generations.append(response.generations[0])
            except Exception as e:
                print(f"Error generating from LLM: {e}, returning empty generation")
                generations.append([Generation(text="")])

        return LLMResult(generations=generations)

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        for prompt in prompts:
            if self.SEP_REPLACEMENT_TOKEN in prompt:
                logger.warning(
                    f"""Current prompt contains {self.SEP_REPLACEMENT_TOKEN} 
                                which is currently used as a separator token by refuel
                                llm. It is highly recommended to avoid having any
                                occurences of this substring in the prompt.
                            """
                )
        prompts = [
            prompt.replace("\n", self.SEP_REPLACEMENT_TOKEN) for prompt in prompts
        ]
        if self._engine == "chat":
            # Need to convert list[prompts] -> list[messages]
            # Currently the entire prompt is stuck into the "human message"
            # We might consider breaking this up into human vs system message in future
            prompts = [[HumanMessage(content=prompt)] for prompt in prompts]

        try:
            result = self._label_with_retry(prompts)
            for generations in result.generations:
                for generation in generations:
                    generation.text = generation.text.replace(
                        self.SEP_REPLACEMENT_TOKEN, "\n"
                    )
            return RefuelLLMResult(
                generations=result.generations, errors=[None] * len(result.generations)
            )
        except Exception as e:
            self._label_individually(prompts)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        if self.model_name is None:
            return 0.0
        cost_per_char = self.COST_PER_CHARACTER.get(self.model_name, 0.0)
        return cost_per_char * len(prompt) + cost_per_char * (
            len(label) if label else 4 * self.model_params["max_output_tokens"]
        )

    def returns_token_probs(self) -> bool:
        return False

RefuelLLM

Bases: BaseModel

Source code in src/autolabel/models/refuel.py
class RefuelLLM(BaseModel):
    DEFAULT_PARAMS = {
        "max_new_tokens": 128,
        "temperature": 0.0,
    }

    def __init__(
        self,
        config: AutolabelConfig,
        cache: BaseCache = None,
    ) -> None:
        super().__init__(config, cache)

        # populate model name
        # This is unused today, but in the future could
        # be used to decide which refuel model is queried
        self.model_name = config.model_name()
        model_params = config.model_params()
        self.model_params = {**self.DEFAULT_PARAMS, **model_params}

        # initialize runtime
        self.BASE_API = "https://refuel-llm.refuel.ai/"
        self.SEP_REPLACEMENT_TOKEN = "@@"
        self.REFUEL_API_ENV = "REFUEL_API_KEY"
        if self.REFUEL_API_ENV in os.environ and os.environ[self.REFUEL_API_ENV]:
            self.REFUEL_API_KEY = os.environ[self.REFUEL_API_ENV]
        else:
            raise ValueError(
                f"Did not find {self.REFUEL_API_ENV}, please add an environment variable"
                f" `{self.REFUEL_API_ENV}` which contains it"
            )

    @retry(
        reraise=True,
        stop=stop_after_attempt(5),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )
    def _label_with_retry(self, prompt: str) -> requests.Response:
        payload = {
            "data": {"model_input": prompt, "model_params": {**self.model_params}},
            "task": "generate",
        }
        headers = {"refuel_api_key": self.REFUEL_API_KEY}
        response = requests.post(self.BASE_API, json=payload, headers=headers)
        # raise Exception if status != 200
        response.raise_for_status()
        return response

    def _label(self, prompts: List[str]) -> RefuelLLMResult:
        generations = []
        errors = []
        for prompt in prompts:
            try:
                if self.SEP_REPLACEMENT_TOKEN in prompt:
                    logger.warning(
                        f"""Current prompt contains {self.SEP_REPLACEMENT_TOKEN} 
                            which is currently used as a separator token by refuel
                            llm. It is highly recommended to avoid having any
                            occurences of this substring in the prompt.
                        """
                    )
                separated_prompt = prompt.replace("\n", self.SEP_REPLACEMENT_TOKEN)
                response = self._label_with_retry(separated_prompt)
                response = json.loads(response.json()["body"]).replace(
                    self.SEP_REPLACEMENT_TOKEN, "\n"
                )
                generations.append([Generation(text=response)])
                errors.append(None)
            except Exception as e:
                # This signifies an error in generating the response using RefuelLLm
                logger.error(
                    f"Unable to generate prediction: {e}",
                )
                generations.append([Generation(text="")])
                errors.append(
                    LabelingError(error_type=ErrorType.LLM_PROVIDER_ERROR, error=e)
                )
        return RefuelLLMResult(generations=generations, errors=errors)

    def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
        return 0

    def returns_token_probs(self) -> bool:
        return False