Skip to content

Data Models

The Data Model classes are used to save the progress of AutoLabel jobs in an SQL database.

Saved data is stored in .autolabel.db

Every Data Model class implements its own "get" and "create" methods for accessing this saved data.

Bases: Base

Source code in autolabel/src/autolabel/data_models/annotation.py
class AnnotationModel(Base):
    __tablename__ = "annotations"

    id = Column(Integer, primary_key=True, autoincrement=True)
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    index = Column(Integer)
    llm_annotation = Column(TEXT)
    task_run_id = Column(Integer, ForeignKey("task_runs.id"))
    task_runs = relationship("TaskRunModel", back_populates="annotations")

    def __repr__(self):
        return f"<AnnotationModel(id={self.id}, index={self.index}, annotation={self.llm_annotation})"

    @classmethod
    def create_from_llm_annotation(
        cls, db, llm_annotation: LLMAnnotation, index: int, task_run_id: int
    ):
        db_object = cls(
            llm_annotation=pickle.dumps(llm_annotation),
            index=index,
            task_run_id=task_run_id,
        )
        db.add(db_object)
        db.commit()
        db_object = db.query(cls).order_by(cls.id.desc()).first()
        logger.debug(f"created new annotation: {db_object}")
        return db_object

    @classmethod
    def get_annotations_by_task_run_id(cls, db, task_run_id: int):
        annotations = (
            db.query(cls)
            .filter(cls.task_run_id == task_run_id)
            .order_by(cls.index)
            .all()
        )
        filtered_annotations = []
        ids = {}
        for annotation in annotations:
            if annotation.index not in ids:
                ids[annotation.index] = True
                filtered_annotations.append(annotation)
        return filtered_annotations

    @classmethod
    def from_pydantic(cls, annotation: BaseModel):
        return cls(**json.loads(annotation.json()))

    def delete(self, db):
        db.delete(self)
        db.commit()

rendering: show_root_heading: yes show_root_full_path: no

Bases: Base

an SQLAlchemy based Cache system for storing and retriving CacheEntries

Source code in autolabel/src/autolabel/data_models/generation_cache.py
class GenerationCacheEntryModel(Base):
    """an SQLAlchemy based Cache system for storing and retriving CacheEntries"""

    __tablename__ = "generation_cache"

    id = Column(Integer, primary_key=True)
    model_name = Column(String(50))
    prompt = Column(Text)
    model_params = Column(Text)
    generations = Column(JSON)
    creation_time_ms = Column(Integer)
    ttl_ms = Column(Integer)

    def __repr__(self):
        return f"<Cache(model_name={self.model_name},prompt={self.prompt},model_params={self.model_params},generations={self.generations})>"

    @classmethod
    def get(cls, db, cache_entry: GenerationCacheEntry):
        looked_up_entry = (
            db.query(cls)
            .filter(
                cls.model_name == cache_entry.model_name,
                cls.prompt == cache_entry.prompt,
                cls.model_params == cache_entry.model_params,
            )
            .first()
        )

        if not looked_up_entry:
            return None

        generations = json.loads(looked_up_entry.generations)["generations"]
        generations = [
            Generation(**gen) if gen["type"] == "Generation" else ChatGeneration(**gen)
            for gen in generations
        ]

        entry = GenerationCacheEntry(
            model_name=looked_up_entry.model_name,
            prompt=looked_up_entry.prompt,
            model_params=looked_up_entry.model_params,
            generations=generations,
            creation_time_ms=looked_up_entry.creation_time_ms,
            ttl_ms=looked_up_entry.ttl_ms,
        )
        return entry

    @classmethod
    def insert(cls, db, cache_entry: BaseModel):
        generations = {"generations": [gen.dict() for gen in cache_entry.generations]}
        db_object = cls(
            model_name=cache_entry.model_name,
            prompt=cache_entry.prompt,
            model_params=cache_entry.model_params,
            generations=json.dumps(generations),
            creation_time_ms=int(time.time() * 1000),
            ttl_ms=cache_entry.ttl_ms,
        )
        db.add(db_object)
        db.commit()
        return cache_entry

    @classmethod
    def clear(cls, db, use_ttl: bool = True) -> None:
        if use_ttl:
            current_time_ms = int(time.time() * 1000)
            db.query(cls).filter(
                current_time_ms - cls.creation_time_ms > cls.ttl_ms
            ).delete()
        else:
            db.query(cls).delete()
        db.commit()

rendering: show_root_heading: yes show_root_full_path: no

Bases: Base

an SQLAlchemy based Cache system for storing and retriving CacheEntries

Source code in autolabel/src/autolabel/data_models/transform_cache.py
class TransformCacheEntryModel(Base):
    """an SQLAlchemy based Cache system for storing and retriving CacheEntries"""

    __tablename__ = "transform_cache"

    id = Column(String, primary_key=True)
    transform_name = Column(String(50))
    transform_params = Column(TEXT)
    input = Column(TEXT)
    output = Column(TEXT)
    creation_time_ms = Column(Integer)
    ttl_ms = Column(Integer)

    def __repr__(self):
        return f"<TransformCache(id={self.id},transform_name={self.transform_name},transform_params={self.transform_params},input={self.input},output={self.output})>"

    @classmethod
    def get(cls, db, cache_entry: TransformCacheEntry) -> TransformCacheEntry:
        id = cache_entry.get_id()
        looked_up_entry = db.query(cls).filter(cls.id == id).first()

        if not looked_up_entry:
            return None

        entry = TransformCacheEntry(
            transform_name=looked_up_entry.transform_name,
            transform_params=pickle.loads(looked_up_entry.transform_params),
            input=pickle.loads(looked_up_entry.input),
            output=pickle.loads(looked_up_entry.output),
            creation_time_ms=looked_up_entry.creation_time_ms,
            ttl_ms=looked_up_entry.ttl_ms,
        )
        return entry

    @classmethod
    def insert(cls, db, cache_entry: TransformCacheEntry) -> None:
        db_object = cls(
            id=cache_entry.get_id(),
            transform_name=cache_entry.transform_name,
            transform_params=pickle.dumps(cache_entry.transform_params),
            input=pickle.dumps(cache_entry.input),
            output=pickle.dumps(cache_entry.output),
            creation_time_ms=int(time.time() * 1000),
            ttl_ms=cache_entry.ttl_ms,
        )
        db.add(db_object)
        db.commit()
        return db_object

    @classmethod
    def clear(cls, db, use_ttl: bool = True) -> None:
        if use_ttl:
            current_time_ms = int(time.time() * 1000)
            db.query(cls).filter(
                current_time_ms - cls.creation_time_ms > cls.ttl_ms
            ).delete()
        else:
            db.query(cls).delete()
        db.commit()

rendering: show_root_heading: yes show_root_full_path: no

Bases: Base

Source code in autolabel/src/autolabel/data_models/dataset.py
class DatasetModel(Base):
    __tablename__ = "datasets"

    id = Column(String(32), primary_key=True)
    input_file = Column(String(50))
    start_index = Column(Integer)
    end_index = Column(Integer)
    task_runs = relationship("TaskRunModel", back_populates="dataset")

    def __repr__(self):
        return f"<DatasetModel(id={self.id}, input_file={self.input_file}, start_index={self.start_index}, end_index={self.end_index})>"

    @classmethod
    def create(cls, db, dataset: BaseModel):
        db_object = cls(**json.loads(dataset.json()))
        db.add(db_object)
        db.commit()
        return db_object

    @classmethod
    def get_by_id(cls, db, id: int):
        return db.query(cls).filter(cls.id == id).first()

    @classmethod
    def get_by_input_file(cls, db, input_file: str):
        return db.query(cls).filter(cls.input_file == input_file).first()

    def delete(self, db):
        db.delete(self)
        db.commit()

rendering: show_root_heading: yes show_root_full_path: no

Bases: Base

Source code in autolabel/src/autolabel/data_models/task.py
class TaskModel(Base):
    __tablename__ = "tasks"

    id = Column(String(32), primary_key=True)
    task_type = Column(String(50))
    provider = Column(String(50))
    model_name = Column(String(50))
    config = Column(Text)
    task_runs = relationship("TaskRunModel", back_populates="task")

    def __repr__(self):
        return f"<TaskModel(id={self.id}, task_type={self.task_type}, provider={self.provider}, model_name={self.model_name})>"

    @classmethod
    def create(cls, db, task: BaseModel):
        db_object = cls(**json.loads(task.json()))
        db.add(db_object)
        db.commit()
        return db_object

    @classmethod
    def get_by_id(cls, db, id: int):
        return db.query(cls).filter(cls.id == id).first()

    def delete(self, db):
        db.delete(self)
        db.commit()

rendering: show_root_heading: yes show_root_full_path: no

Bases: Base

Source code in autolabel/src/autolabel/data_models/task_run.py
class TaskRunModel(Base):
    __tablename__ = "task_runs"

    id = Column(
        Integer,
        default=lambda: uuid.uuid4().int >> (128 - 32),
        primary_key=True,
    )
    task_id = Column(String(32), ForeignKey("tasks.id"))
    created_at = Column(DateTime(timezone=True), server_default=func.now())
    dataset_id = Column(String(32), ForeignKey("datasets.id"))
    current_index = Column(Integer)
    error = Column(String(256))
    metrics = Column(Text)
    output_file = Column(String(50))
    status = Column(String(50))
    task = relationship("TaskModel", back_populates="task_runs")
    dataset = relationship("DatasetModel", back_populates="task_runs")
    annotations = relationship("AnnotationModel", back_populates="task_runs")

    def __repr__(self):
        return f"<TaskRunModel(id={self.id}, created_at={str(self.created_at)}, task_id={self.task_id}, dataset_id={self.dataset_id}, output_file={self.output_file}, current_index={self.current_index}, status={self.status}, error={self.error}, metrics={self.metrics})"

    @classmethod
    def create(cls, db, task_run: BaseModel):
        logger.debug(f"creating new task: {task_run}")
        db_object = cls(**task_run.dict())
        db.add(db_object)
        db.commit()
        db.refresh(db_object)
        logger.debug(f"created new task: {db_object}")
        return db_object

    @classmethod
    def get(cls, db, task_id: str, dataset_id: str):
        return (
            db.query(cls)
            .filter(cls.task_id == task_id, cls.dataset_id == dataset_id)
            .first()
        )

    @classmethod
    def from_pydantic(cls, task_run: BaseModel):
        return cls(**json.loads(task_run.json()))

    @classmethod
    def update(cls, db, task_run: BaseModel):
        task_run_id = task_run.id
        task_run_orm = db.query(cls).filter(cls.id == task_run_id).first()
        logger.debug(f"updating task_run: {task_run}")
        for key, value in task_run.dict().items():
            setattr(task_run_orm, key, value)
        db.commit()
        logger.debug(f"task_run updated: {task_run}")
        return TaskRun.from_orm(task_run_orm)

    @classmethod
    def delete_by_id(cls, db, id: int):
        db.query(cls).filter(cls.id == id).delete()

    def delete(self, db):
        db.delete(self)
        db.commit()

rendering: show_root_heading: yes show_root_full_path: no

Source code in autolabel/src/autolabel/database/state_manager.py
class StateManager:
    def __init__(self):
        self.engine = create_db_engine()
        self.base = Base
        self.session = None

    def initialize(self):
        self.base.metadata.create_all(self.engine)
        self.session = sessionmaker(bind=self.engine)()

    def initialize_dataset(
        self,
        dataset: Union[str, pd.DataFrame],
        config: AutolabelConfig,
        start_index: int = 0,
        max_items: Optional[int] = None,
    ):
        # TODO: Check if this works for max_items = None

        dataset_id = Dataset.create_id(dataset, config, start_index, max_items)
        dataset_orm = DatasetModel.get_by_id(self.session, dataset_id)
        if dataset_orm:
            return Dataset.from_orm(dataset_orm)

        dataset = Dataset(
            id=dataset_id,
            input_file=dataset if isinstance(dataset, str) else "",
            start_index=start_index,
            end_index=start_index + max_items if max_items else -1,
        )
        return Dataset.from_orm(DatasetModel.create(self.session, dataset))

    def initialize_task(self, config: AutolabelConfig):
        task_id = Task.create_id(config)
        task_orm = TaskModel.get_by_id(self.session, task_id)
        if task_orm:
            return Task.from_orm(task_orm)

        task = Task(
            id=task_id,
            config=config.to_json(),
            task_type=config.task_type(),
            provider=config.provider(),
            model_name=config.model_name(),
        )
        return Task.from_orm(TaskModel.create(self.session, task))

    def get_task_run(self, task_id: str, dataset_id: str):
        task_run_orm = TaskRunModel.get(self.session, task_id, dataset_id)
        if task_run_orm:
            return TaskRun.from_orm(task_run_orm)
        else:
            return None

    def create_task_run(
        self, output_file: str, task_id: str, dataset_id: str
    ) -> TaskRun:
        logger.debug(f"creating new task_run")
        new_task_run = TaskRun(
            task_id=task_id,
            dataset_id=dataset_id,
            status=TaskStatus.ACTIVE,
            current_index=0,
            output_file=output_file,
            created_at=datetime.now(),
        )
        task_run_orm = TaskRunModel.create(self.session, new_task_run)
        return TaskRun.from_orm(task_run_orm)

rendering: show_root_heading: yes show_root_full_path: no

Source code in autolabel/src/autolabel/database/engine.py
def create_db_engine(db_path: Optional[str] = DB_PATH) -> Engine:
    global DB_ENGINE
    if DB_ENGINE is None:
        DB_ENGINE = create_engine(f"sqlite:///{db_path}", pool_size=0)
    return DB_ENGINE

rendering: show_root_heading: yes show_root_full_path: no