class LabelingAgent:
COST_KEY = "Cost in $"
def __init__(
self,
config: Union[AutolabelConfig, str, dict],
cache: Optional[bool] = True,
example_selector: Optional[BaseExampleSelector] = None,
create_task: Optional[bool] = True,
console_output: Optional[bool] = True,
) -> None:
self.create_task = create_task
self.db = StateManager() if self.create_task else None
self.generation_cache = SQLAlchemyGenerationCache() if cache else None
self.transform_cache = SQLAlchemyTransformCache() if cache else None
self.console = Console(quiet=not console_output)
self.config = (
config if isinstance(config, AutolabelConfig) else AutolabelConfig(config)
)
self.task = TaskFactory.from_config(self.config)
self.llm: BaseModel = ModelFactory.from_config(
self.config, cache=self.generation_cache
)
self.confidence = ConfidenceCalculator(
score_type="logprob_average", llm=self.llm
)
self.example_selector = example_selector
# Only used if we don't use task management
self.all_annotations = []
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
def run(
self,
dataset: AutolabelDataset,
output_name: Optional[str] = None,
max_items: Optional[int] = None,
start_index: int = 0,
additional_metrics: Optional[List[BaseMetric]] = [],
skip_eval: Optional[bool] = False,
) -> Tuple[pd.Series, pd.DataFrame, List[MetricResult]]:
"""Labels data in a given dataset. Output written to new CSV file.
Args:
dataset: path to CSV dataset to be annotated
max_items: maximum items in dataset to be annotated
output_name: custom name of output CSV file
start_index: skips annotating [0, start_index)
"""
dataset = dataset.get_slice(max_items=max_items, start_index=start_index)
if self.create_task:
self.db.initialize()
self.dataset_obj = self.db.initialize_dataset(dataset.df, self.config)
self.task_object = self.db.initialize_task(self.config)
else:
self.all_annotations = []
if isinstance(dataset, str):
csv_file_name = (
output_name
if output_name
else f"{dataset.replace('.csv','')}_labeled.csv"
)
else:
csv_file_name = f"{self.config.task_name()}_labeled.csv"
if self.create_task:
# Initialize task run and check if it already exists
self.task_run = self.db.get_task_run(
self.task_object.id, self.dataset_obj.id
)
# Resume/Delete the task if it already exists or create a new task run
if self.task_run:
logger.info("Task run already exists.")
self.task_run = self.handle_existing_task_run(
self.task_run,
csv_file_name,
gt_labels=dataset.gt_labels,
additional_metrics=additional_metrics,
)
else:
self.task_run = self.db.create_task_run(
csv_file_name, self.task_object.id, self.dataset_obj.id
)
# Get the seed examples from the dataset config
seed_examples = self.config.few_shot_example_set()
# If this dataset config is a string, read the corrresponding csv file
if isinstance(seed_examples, str):
seed_loader = AutolabelDataset(seed_examples, self.config)
seed_examples = seed_loader.inputs
# Check explanations are present in data if explanation_column is passed in
if (
self.config.explanation_column()
and len(seed_examples) > 0
and self.config.explanation_column() not in list(seed_examples[0].keys())
):
raise ValueError(
f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file)."
)
if self.example_selector is None:
self.example_selector = ExampleSelectorFactory.initialize_selector(
self.config,
seed_examples,
dataset.df.keys().tolist(),
cache=self.generation_cache is not None,
)
current_index = self.task_run.current_index if self.create_task else 0
cost = 0.0
postfix_dict = {}
indices = range(current_index, len(dataset.inputs))
for current_index in track_with_stats(
indices,
postfix_dict,
total=len(dataset.inputs) - current_index,
console=self.console,
):
chunk = dataset.inputs[current_index]
if self.example_selector:
examples = self.example_selector.select_examples(chunk)
else:
examples = []
# Construct Prompt to pass to LLM
final_prompt = self.task.construct_prompt(chunk, examples)
response = self.llm.label([final_prompt])
for i, generations, error in zip(
range(len(response.generations)), response.generations, response.errors
):
if error is not None:
annotation = LLMAnnotation(
successfully_labeled=False,
label=self.task.NULL_LABEL_TOKEN,
raw_response="",
curr_sample=pickle.dumps(chunk),
prompt=final_prompt,
confidence_score=0,
error=error,
)
else:
annotations = []
for generation in generations:
annotation = self.task.parse_llm_response(
generation, chunk, final_prompt
)
if self.config.confidence():
annotation.confidence_score = self.confidence.calculate(
model_generation=annotation,
prompt=final_prompt,
)
annotations.append(annotation)
annotation = self.majority_annotation(annotations)
# Save the annotation in the database
self.save_annotation(annotation, current_index, i)
cost += sum(response.costs)
postfix_dict[self.COST_KEY] = f"{cost:.2f}"
# Evaluate the task every eval_every examples
if not skip_eval and (current_index + 1) % 100 == 0:
llm_labels = self.get_all_annotations()
if dataset.gt_labels:
eval_result = self.task.eval(
llm_labels,
dataset.gt_labels[: len(llm_labels)],
additional_metrics=additional_metrics,
)
for m in eval_result:
# This is a row wise metric
if isinstance(m.value, list):
continue
elif m.show_running:
postfix_dict[m.name] = (
f"{m.value:.4f}"
if isinstance(m.value, float)
else m.value
)
if self.create_task:
# Update task run state
self.task_run = self.save_task_run_state(
current_index=current_index + len(chunk)
)
llm_labels = self.get_all_annotations()
eval_result = None
table = {}
# if true labels are provided, evaluate accuracy of predictions
if not skip_eval and dataset.gt_labels:
eval_result = self.task.eval(
llm_labels,
dataset.gt_labels[: len(llm_labels)],
additional_metrics=additional_metrics,
)
# TODO: serialize and write to file
for m in eval_result:
if isinstance(m.value, list):
continue
elif m.show_running:
table[m.name] = m.value
else:
self.console.print(f"{m.name}:\n{m.value}")
# print cost
self.console.print(f"Actual Cost: {maybe_round(cost)}")
print_table(table, console=self.console, default_style=METRIC_TABLE_STYLE)
dataset.process_labels(llm_labels, eval_result)
# Only save to csv if output_name is provided or dataset is a string
if not output_name and isinstance(dataset, str):
output_name = (
dataset.rsplit(".", 1)[0] + "_labeled." + dataset.rsplit(".", 1)[1]
)
if output_name:
dataset.save(output_file_name=output_name)
return dataset
def plan(
self,
dataset: AutolabelDataset,
max_items: Optional[int] = None,
start_index: int = 0,
) -> None:
"""Calculates and prints the cost of calling autolabel.run() on a given dataset
Args:
dataset: path to a CSV dataset
"""
dataset = dataset.get_slice(max_items=max_items, start_index=start_index)
if self.config.confidence() and "REFUEL_API_KEY" not in os.environ:
raise ValueError(
"REFUEL_API_KEY environment variable must be set to compute confidence scores. You can request an API key at https://refuel-ai.typeform.com/llm-access."
)
prompt_list = []
total_cost = 0
# Get the seed examples from the dataset config
seed_examples = self.config.few_shot_example_set()
# If this dataset config is a string, read the corrresponding csv file
if isinstance(seed_examples, str):
seed_loader = AutolabelDataset(seed_examples, self.config)
seed_examples = seed_loader.inputs
# Check explanations are present in data if explanation_column is passed in
if (
self.config.explanation_column()
and len(seed_examples) > 0
and self.config.explanation_column() not in list(seed_examples[0].keys())
):
raise ValueError(
f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file)."
)
self.example_selector = ExampleSelectorFactory.initialize_selector(
self.config,
seed_examples,
dataset.df.keys().tolist(),
cache=self.generation_cache is not None,
)
input_limit = min(len(dataset.inputs), 100)
for input_i in track(
dataset.inputs[:input_limit],
description="Generating Prompts...",
console=self.console,
):
# TODO: Check if this needs to use the example selector
if self.example_selector:
examples = self.example_selector.select_examples(input_i)
else:
examples = []
final_prompt = self.task.construct_prompt(input_i, examples)
prompt_list.append(final_prompt)
# Calculate the number of tokens
curr_cost = self.llm.get_cost(prompt=final_prompt, label="")
total_cost += curr_cost
total_cost = total_cost * (len(dataset.inputs) / input_limit)
table = {
"Total Estimated Cost": f"${maybe_round(total_cost)}",
"Number of Examples": len(dataset.inputs),
"Average cost per example": f"${maybe_round(total_cost / len(dataset.inputs))}",
}
table = {"parameter": list(table.keys()), "value": list(table.values())}
print_table(
table, show_header=False, console=self.console, styles=COST_TABLE_STYLES
)
self.console.rule("Prompt Example")
self.console.print(f"{prompt_list[0]}")
self.console.rule()
async def async_run_transform(
self, transform: BaseTransform, dataset: AutolabelDataset
):
transform_outputs = [
transform.apply(input_dict) for input_dict in dataset.inputs
]
outputs = await gather_async_tasks_with_progress(
transform_outputs,
description=f"Running transform {transform.name()}...",
console=self.console,
)
output_df = pd.DataFrame.from_records(outputs)
final_df = pd.concat([dataset.df, output_df], axis=1)
dataset = AutolabelDataset(final_df, self.config)
return dataset
def transform(self, dataset: AutolabelDataset):
transforms = []
for transform_dict in self.config.transforms():
transforms.append(
TransformFactory.from_dict(transform_dict, cache=self.transform_cache)
)
for transform in transforms:
dataset = asyncio.run(self.async_run_transform(transform, dataset))
return dataset
def handle_existing_task_run(
self,
task_run: TaskRun,
csv_file_name: str,
gt_labels: List[str] = None,
additional_metrics: List[BaseMetric] = [],
) -> TaskRun:
"""
Allows for continuing an existing labeling task. The user will be asked whether they wish to continue from where the run previously left off, or restart from the beginning.
Args:
task_run: TaskRun to retry
csv_file_name: path to the dataset we wish to label (only used if user chooses to restart the task)
gt_labels: If ground truth labels are provided, performance metrics will be displayed, such as label accuracy
"""
self.console.print(
f"There is an existing task with following details: {task_run}"
)
llm_labels = self.get_all_annotations()
if gt_labels and len(llm_labels) > 0:
self.console.print("Evaluating the existing task...")
gt_labels = gt_labels[: len(llm_labels)]
eval_result = self.task.eval(
llm_labels, gt_labels, additional_metrics=additional_metrics
)
table = {}
for m in eval_result:
if isinstance(m.value, list):
continue
elif m.show_running:
table[m.name] = m.value
else:
self.console.print(f"{m.name}:\n{m.value}")
print_table(table, console=self.console, default_style=METRIC_TABLE_STYLE)
self.console.print(f"{task_run.current_index} examples labeled so far.")
if not Confirm.ask("Do you want to resume the task?"):
TaskRunModel.delete_by_id(self.db.session, task_run.id)
self.console.print("Deleted the existing task and starting a new one...")
task_run = self.db.create_task_run(
csv_file_name, self.task_object.id, self.dataset_obj.id
)
return task_run
def save_task_run_state(
self, current_index: int = None, status: TaskStatus = "", error: str = ""
) -> TaskRun:
"""Saves the current state of the Task being performed"""
# Save the current state of the task
if error:
self.task_run.error = error
if status:
self.task_run.status = status
if current_index:
self.task_run.current_index = current_index
return TaskRunModel.update(self.db.session, self.task_run)
def majority_annotation(
self, annotation_list: List[LLMAnnotation]
) -> LLMAnnotation:
labels = [a.label for a in annotation_list]
counts = {}
for ind, label in enumerate(labels):
# Needed for named entity recognition which outputs lists instead of strings
label = str(label)
if label not in counts:
counts[label] = (1, ind)
else:
counts[label] = (counts[label][0] + 1, counts[label][1])
max_label = max(counts, key=lambda x: counts[x][0])
return annotation_list[counts[max_label][1]]
def generate_explanations(
self,
seed_examples: Union[str, List[Dict]],
) -> List[Dict]:
"""Use LLM to generate explanations for why examples are labeled the way that they are."""
out_file = None
if isinstance(seed_examples, str):
out_file = seed_examples
seed_loader = AutolabelDataset(seed_examples, self.config)
seed_examples = seed_loader.inputs
explanation_column = self.config.explanation_column()
if not explanation_column:
raise ValueError(
"The explanation column needs to be specified in the dataset config."
)
for seed_example in track(
seed_examples,
description="Generating explanations",
console=self.console,
):
explanation_prompt = self.task.get_explanation_prompt(seed_example)
explanation = self.llm.label([explanation_prompt])
explanation = explanation.generations[0][0].text
seed_example["explanation"] = str(explanation) if explanation else ""
if out_file:
df = pd.DataFrame.from_records(seed_examples)
df.to_csv(out_file, index=False)
return seed_examples
def generate_synthetic_dataset(self) -> AutolabelDataset:
columns = get_format_variables(self.config.example_template())
df = pd.DataFrame(columns=columns)
for label in track(
self.config.labels_list(),
description="Generating dataset",
console=self.console,
):
prompt = self.task.get_generate_dataset_prompt(label)
result = self.llm.label([prompt])
if result.errors[0] is not None:
self.console.print(
f"Error generating rows for label {label}: {result.errors[0]}"
)
else:
response = result.generations[0][0].text.strip()
response = io.StringIO(response)
label_df = pd.read_csv(response, sep=self.config.delimiter())
label_df[self.config.label_column()] = label
df = pd.concat([df, label_df], axis=0, ignore_index=True)
return AutolabelDataset(df, self.config)
def clear_cache(self, use_ttl: bool = True):
"""
Clears the generation and transformation cache from autolabel.
Args:
use_ttl: If true, only clears the cache if the ttl has expired.
"""
self.generation_cache.clear(use_ttl=use_ttl)
self.transform_cache.clear(use_ttl=use_ttl)
def save_annotation(self, annotation: LLMAnnotation, current_index: int, i: int):
if self.create_task:
# Store the annotation in the database
AnnotationModel.create_from_llm_annotation(
self.db.session,
annotation,
current_index + i,
self.task_run.id,
)
else:
self.all_annotations.append(annotation)
def get_all_annotations(self):
if self.create_task:
db_result = AnnotationModel.get_annotations_by_task_run_id(
self.db.session, self.task_run.id
)
return [pickle.loads(a.llm_annotation) for a in db_result]
else:
return self.all_annotations