Skip to content

Example Selector

Bases: BaseExampleSelector, BaseModel

Example selector to handle the case of fixed few-shot context i.e. every input prompt to the labeling model has the same few-shot examples

Source code in autolabel/src/autolabel/few_shot/fixed_example_selector.py
class FixedExampleSelector(BaseExampleSelector, BaseModel):
    """Example selector to handle the case of fixed few-shot context
    i.e. every input prompt to the labeling model has the same few-shot examples
    """

    examples: List[dict]
    """A list of the examples that the prompt template expects."""

    k: int = 4
    """Number of examples to select"""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    def add_example(self, example: Dict[str, str]) -> None:
        self.examples.append(example)

    def select_examples(
        self,
        input_variables: Dict[str, str],
        **kwargs,
    ) -> List[dict]:
        """Select which examples to use based on the input lengths."""
        label_column = kwargs.get("label_column")
        selected_labels = kwargs.get("selected_labels")

        if not selected_labels:
            return self.examples[: self.k]

        if not label_column:
            print("No label column provided, returning all examples")
            return self.examples[: self.k]

        # get the examples where label matches the selected labels
        valid_examples = [
            example
            for example in self.examples
            if example.get(label_column) in selected_labels
        ]
        return valid_examples[: min(self.k, len(valid_examples))]

    @classmethod
    def from_examples(
        cls,
        examples: List,
        k: int = 4,
    ) -> FixedExampleSelector:
        """Create pass-through example selector using example list

        Returns:
            The FixedExampleSelector instantiated
        """

        return cls(examples=examples, k=k)

examples: List[dict] instance-attribute

A list of the examples that the prompt template expects.

k: int = 4 class-attribute instance-attribute

Number of examples to select

Config

Configuration for this pydantic object.

Source code in autolabel/src/autolabel/few_shot/fixed_example_selector.py
class Config:
    """Configuration for this pydantic object."""

    extra = Extra.forbid
    arbitrary_types_allowed = True

from_examples(examples, k=4) classmethod

Create pass-through example selector using example list

Returns:

Type Description
FixedExampleSelector

The FixedExampleSelector instantiated

Source code in autolabel/src/autolabel/few_shot/fixed_example_selector.py
@classmethod
def from_examples(
    cls,
    examples: List,
    k: int = 4,
) -> FixedExampleSelector:
    """Create pass-through example selector using example list

    Returns:
        The FixedExampleSelector instantiated
    """

    return cls(examples=examples, k=k)

select_examples(input_variables, **kwargs)

Select which examples to use based on the input lengths.

Source code in autolabel/src/autolabel/few_shot/fixed_example_selector.py
def select_examples(
    self,
    input_variables: Dict[str, str],
    **kwargs,
) -> List[dict]:
    """Select which examples to use based on the input lengths."""
    label_column = kwargs.get("label_column")
    selected_labels = kwargs.get("selected_labels")

    if not selected_labels:
        return self.examples[: self.k]

    if not label_column:
        print("No label column provided, returning all examples")
        return self.examples[: self.k]

    # get the examples where label matches the selected labels
    valid_examples = [
        example
        for example in self.examples
        if example.get(label_column) in selected_labels
    ]
    return valid_examples[: min(self.k, len(valid_examples))]

VectorStoreWrapper

Bases: VectorStore

Source code in autolabel/src/autolabel/few_shot/vector_store.py
class VectorStoreWrapper(VectorStore):
    def __init__(
        self,
        embedding_function: Optional[Embeddings] = None,
        corpus_embeddings: Optional[Tensor] = None,
        texts: Optional[List[str]] = None,
        metadatas: Optional[List[Dict[str, str]]] = None,
        cache: bool = True,
    ) -> None:
        self._embedding_function = embedding_function
        self._corpus_embeddings = corpus_embeddings
        self._texts = texts
        self._metadatas = metadatas
        if cache:
            self._db_engine = create_db_engine()
            with self._db_engine.connect() as conn:
                query = f"CREATE TABLE IF NOT EXISTS {EMBEDDINGS_TABLE} (embedding_function TEXT, text TEXT, embedding BLOB)"
                conn.execute(sql_text(query))
                conn.commit()
        else:
            self._db_engine = None

    def _get_embeddings(self, texts: Iterable[str]) -> List[List[float]]:
        """Get embeddings from the database. If not found, compute them and add them to the database.

        If no database is used, compute the embeddings and return them.

        Args:
            texts (Iterable[str]): Iterable of texts to embed.
        Returns:
            List[List[float]]: List of embeddings.
        """
        if self._db_engine:
            with self._db_engine.connect() as conn:
                embeddings = []
                uncached_texts = []
                uncached_texts_indices = []
                for idx, text in enumerate(texts):
                    query = sql_text(
                        f"SELECT embedding FROM {EMBEDDINGS_TABLE} WHERE embedding_function = :x AND text = :y",
                    )
                    params = {
                        "x": (
                            self._embedding_function.model
                            if self._embedding_function.__class__.__name__
                            not in ["HuggingFaceEmbeddings", "VertexAIEmbeddings"]
                            else self._embedding_function.model_name
                        ),
                        "y": text,
                    }
                    result = conn.execute(query, params).fetchone()
                    if result:
                        embeddings.append(pickle.loads(result[0]))
                    else:
                        embeddings.append(None)
                        uncached_texts.append(text)
                        uncached_texts_indices.append(idx)

                uncached_embeddings = self._embedding_function.embed_documents(
                    uncached_texts
                )
                self._add_embeddings_to_cache(uncached_texts, uncached_embeddings)
                for idx, embedding in zip(uncached_texts_indices, uncached_embeddings):
                    embeddings[idx] = embedding

                return embeddings
        else:
            return self._embedding_function.embed_documents(list(texts))

    def _add_embeddings_to_cache(
        self, texts: Iterable[str], embeddings: List[List[float]]
    ) -> None:
        """Save embeddings to the database. If self._db_engine is None, do nothing.
        Args:
            texts (Iterable[str]): Iterable of texts.
            embeddings (List[List[float]]): List of embeddings.
        """
        if self._db_engine:
            with self._db_engine.connect() as conn:
                for text, embedding in zip(texts, embeddings):
                    query = sql_text(
                        f"INSERT INTO {EMBEDDINGS_TABLE} (embedding_function, text, embedding) VALUES (:x, :y, :z)"
                    )
                    params = {
                        "x": (
                            self._embedding_function.model
                            if self._embedding_function.__class__.__name__
                            not in ["HuggingFaceEmbeddings", "VertexAIEmbeddings"]
                            else self._embedding_function.model_name
                        ),
                        "y": text,
                        "z": pickle.dumps(embedding),
                    }
                    conn.execute(query, params)
                    conn.commit()

    def add_texts(
        self,
        texts: Iterable[str],
        metadatas: Optional[List[Dict[str, str]]] = None,
    ) -> List[str]:
        """Run texts through the embeddings and add to the vectorstore. Currently, the vectorstore is reinitialized each time, because we do not require a persistent vector store for example selection.
        Args:
            texts (Iterable[str]): Texts to add to the vectorstore.
            metadatas (Optional[List[dict]], optional): Optional list of metadatas.
        Returns:
            List[str]: List of IDs of the added texts.
        """
        if self._embedding_function is not None:
            embeddings = self._get_embeddings(texts)

        self._corpus_embeddings = torch.tensor(embeddings)
        self._texts = texts
        self._metadatas = metadatas
        return metadatas

    def similarity_search(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Run semantic similarity search.
        Args:
            query (str): Query text to search for.
            k (int): Number of results to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
        Returns:
            List[Document]: List of documents most similar to the query text.
        """
        docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
        return [doc for doc, _ in docs_and_scores]

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 4,
        filter: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Run semantic similarity search and retrieve distances.
        Args:
            query (str): Query text to search for.
            k (int): Number of results to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
        Returns:
            List[Tuple[Document, float]]: List of documents most similar to the query
                text with distance in float.
        """
        query_embeddings = torch.tensor([self._get_embeddings([query])[0]])
        result_ids_and_scores = semantic_search(
            corpus_embeddings=self._corpus_embeddings,
            query_embeddings=query_embeddings,
            top_k=k,
        )
        result_ids = [result["corpus_id"] for result in result_ids_and_scores[0]]
        scores = [result["score"] for result in result_ids_and_scores[0]]
        results = {}
        results["documents"] = [[self._texts[index] for index in result_ids]]
        results["distances"] = [scores]
        results["metadatas"] = [[self._metadatas[index] for index in result_ids]]
        return _results_to_docs_and_scores(results)

    def label_diversity_similarity_search(
        self,
        query: str,
        label_key: str,
        k: int = 4,
        filter: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Run semantic similarity search.
        Args:
            query (str): Query text to search for.
            k (int): Number of results to return per label.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
        Returns:
            List[Document]: List of documents most similar to the query text.
        """
        docs_and_scores = self.label_diversity_similarity_search_with_score(
            query, label_key, k, filter=filter
        )
        return [doc for doc, _ in docs_and_scores]

    def label_diversity_similarity_search_with_score(
        self,
        query: str,
        label_key: str,
        k: int = 4,
        filter: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        """Run semantic similarity search and retrieve distances.
        Args:
            query (str): Query text to search for.
            k (int): Number of results to return. Defaults to 4.
            filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
        Returns:
            List[Tuple[Document, float]]: List of documents most similar to the query
                text with distance in float.
        """
        query_embeddings = torch.tensor([self._get_embeddings([query])[0]])
        data = []
        data = zip(self._corpus_embeddings, self._texts, self._metadatas)
        sorted_data = sorted(data, key=lambda item: item[2].get(label_key))

        documents = []
        scores = []
        metadatas = []
        for label, label_examples in groupby(
            sorted_data, key=lambda item: item[2].get(label_key)
        ):
            label_examples_list = list(label_examples)
            label_embeddings = list(
                map(lambda label_example: label_example[0], label_examples_list)
            )
            label_texts = list(
                map(lambda label_example: label_example[1], label_examples_list)
            )
            label_metadatas = list(
                map(lambda label_example: label_example[2], label_examples_list)
            )

            result_ids_and_scores = semantic_search(
                corpus_embeddings=label_embeddings,
                query_embeddings=query_embeddings,
                top_k=k,
            )
            result_ids = [result["corpus_id"] for result in result_ids_and_scores[0]]
            documents.extend([label_texts[index] for index in result_ids])
            metadatas.extend([label_metadatas[index] for index in result_ids])
            scores.extend([result["score"] for result in result_ids_and_scores[0]])
        results = {}

        results["documents"] = [documents]
        results["distances"] = [scores]
        results["metadatas"] = [metadatas]

        return _results_to_docs_and_scores(results)

    def max_marginal_relevance_search_by_vector(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        **kwargs: Any,
    ) -> List[Document]:
        query_embedding = self._get_embeddings([query])[0]
        query_embeddings = torch.tensor([query_embedding])
        result_ids_and_scores = semantic_search(
            corpus_embeddings=self._corpus_embeddings,
            query_embeddings=query_embeddings,
            top_k=fetch_k,
        )
        result_ids = [result["corpus_id"] for result in result_ids_and_scores[0]]
        scores = [result["score"] for result in result_ids_and_scores[0]]

        fetched_embeddings = torch.index_select(
            input=self._corpus_embeddings, dim=0, index=torch.tensor(result_ids)
        ).tolist()
        mmr_selected = maximal_marginal_relevance(
            np.array([query_embedding], dtype=np.float32),
            fetched_embeddings,
            k=k,
            lambda_mult=lambda_mult,
        )
        selected_result_ids = [result_ids[i] for i in mmr_selected]
        selected_scores = [scores[i] for i in mmr_selected]
        results = {}
        results["documents"] = [[self._texts[index] for index in selected_result_ids]]
        results["distances"] = [selected_scores]
        results["metadatas"] = [
            [self._metadatas[index] for index in selected_result_ids]
        ]

        return _results_to_docs_and_scores(results)

    def max_marginal_relevance_search(
        self,
        query: str,
        k: int = 4,
        fetch_k: int = 20,
        lambda_mult: float = 0.5,
        **kwargs: Any,
    ) -> List[Document]:
        docs_and_scores = self.max_marginal_relevance_search_by_vector(
            query, k, fetch_k, lambda_mult=lambda_mult
        )
        return [doc for doc, _ in docs_and_scores]

    @classmethod
    def from_texts(
        cls: Type[VectorStoreWrapper],
        texts: List[str],
        embedding: Optional[Embeddings] = None,
        metadatas: Optional[List[dict]] = None,
        cache: bool = True,
        **kwargs: Any,
    ) -> VectorStoreWrapper:
        """Create a vectorstore from raw text.
        The data will be ephemeral in-memory.
        Args:
            texts (List[str]): List of texts to add to the collection.
            embedding (Optional[Embeddings]): Embedding function. Defaults to None.
            metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
            cache (bool): Whether to cache the embeddings. Defaults to True.
        Returns:
            vector_store: Vectorstore with seedset embeddings
        """
        vector_store = cls(
            embedding_function=embedding,
            corpus_embeddings=None,
            texts=None,
            cache=cache,
            **kwargs,
        )
        vector_store.add_texts(texts=texts, metadatas=metadatas)
        return vector_store

add_texts(texts, metadatas=None)

Run texts through the embeddings and add to the vectorstore. Currently, the vectorstore is reinitialized each time, because we do not require a persistent vector store for example selection. Args: texts (Iterable[str]): Texts to add to the vectorstore. metadatas (Optional[List[dict]], optional): Optional list of metadatas. Returns: List[str]: List of IDs of the added texts.

Source code in autolabel/src/autolabel/few_shot/vector_store.py
def add_texts(
    self,
    texts: Iterable[str],
    metadatas: Optional[List[Dict[str, str]]] = None,
) -> List[str]:
    """Run texts through the embeddings and add to the vectorstore. Currently, the vectorstore is reinitialized each time, because we do not require a persistent vector store for example selection.
    Args:
        texts (Iterable[str]): Texts to add to the vectorstore.
        metadatas (Optional[List[dict]], optional): Optional list of metadatas.
    Returns:
        List[str]: List of IDs of the added texts.
    """
    if self._embedding_function is not None:
        embeddings = self._get_embeddings(texts)

    self._corpus_embeddings = torch.tensor(embeddings)
    self._texts = texts
    self._metadatas = metadatas
    return metadatas

from_texts(texts, embedding=None, metadatas=None, cache=True, **kwargs) classmethod

Create a vectorstore from raw text. The data will be ephemeral in-memory. Args: texts (List[str]): List of texts to add to the collection. embedding (Optional[Embeddings]): Embedding function. Defaults to None. metadatas (Optional[List[dict]]): List of metadatas. Defaults to None. cache (bool): Whether to cache the embeddings. Defaults to True. Returns: vector_store: Vectorstore with seedset embeddings

Source code in autolabel/src/autolabel/few_shot/vector_store.py
@classmethod
def from_texts(
    cls: Type[VectorStoreWrapper],
    texts: List[str],
    embedding: Optional[Embeddings] = None,
    metadatas: Optional[List[dict]] = None,
    cache: bool = True,
    **kwargs: Any,
) -> VectorStoreWrapper:
    """Create a vectorstore from raw text.
    The data will be ephemeral in-memory.
    Args:
        texts (List[str]): List of texts to add to the collection.
        embedding (Optional[Embeddings]): Embedding function. Defaults to None.
        metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
        cache (bool): Whether to cache the embeddings. Defaults to True.
    Returns:
        vector_store: Vectorstore with seedset embeddings
    """
    vector_store = cls(
        embedding_function=embedding,
        corpus_embeddings=None,
        texts=None,
        cache=cache,
        **kwargs,
    )
    vector_store.add_texts(texts=texts, metadatas=metadatas)
    return vector_store

Run semantic similarity search. Args: query (str): Query text to search for. k (int): Number of results to return per label. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List[Document]: List of documents most similar to the query text.

Source code in autolabel/src/autolabel/few_shot/vector_store.py
def label_diversity_similarity_search(
    self,
    query: str,
    label_key: str,
    k: int = 4,
    filter: Optional[Dict[str, str]] = None,
    **kwargs: Any,
) -> List[Document]:
    """Run semantic similarity search.
    Args:
        query (str): Query text to search for.
        k (int): Number of results to return per label.
        filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
    Returns:
        List[Document]: List of documents most similar to the query text.
    """
    docs_and_scores = self.label_diversity_similarity_search_with_score(
        query, label_key, k, filter=filter
    )
    return [doc for doc, _ in docs_and_scores]

label_diversity_similarity_search_with_score(query, label_key, k=4, filter=None, **kwargs)

Run semantic similarity search and retrieve distances. Args: query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List[Tuple[Document, float]]: List of documents most similar to the query text with distance in float.

Source code in autolabel/src/autolabel/few_shot/vector_store.py
def label_diversity_similarity_search_with_score(
    self,
    query: str,
    label_key: str,
    k: int = 4,
    filter: Optional[Dict[str, str]] = None,
    **kwargs: Any,
) -> List[Tuple[Document, float]]:
    """Run semantic similarity search and retrieve distances.
    Args:
        query (str): Query text to search for.
        k (int): Number of results to return. Defaults to 4.
        filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
    Returns:
        List[Tuple[Document, float]]: List of documents most similar to the query
            text with distance in float.
    """
    query_embeddings = torch.tensor([self._get_embeddings([query])[0]])
    data = []
    data = zip(self._corpus_embeddings, self._texts, self._metadatas)
    sorted_data = sorted(data, key=lambda item: item[2].get(label_key))

    documents = []
    scores = []
    metadatas = []
    for label, label_examples in groupby(
        sorted_data, key=lambda item: item[2].get(label_key)
    ):
        label_examples_list = list(label_examples)
        label_embeddings = list(
            map(lambda label_example: label_example[0], label_examples_list)
        )
        label_texts = list(
            map(lambda label_example: label_example[1], label_examples_list)
        )
        label_metadatas = list(
            map(lambda label_example: label_example[2], label_examples_list)
        )

        result_ids_and_scores = semantic_search(
            corpus_embeddings=label_embeddings,
            query_embeddings=query_embeddings,
            top_k=k,
        )
        result_ids = [result["corpus_id"] for result in result_ids_and_scores[0]]
        documents.extend([label_texts[index] for index in result_ids])
        metadatas.extend([label_metadatas[index] for index in result_ids])
        scores.extend([result["score"] for result in result_ids_and_scores[0]])
    results = {}

    results["documents"] = [documents]
    results["distances"] = [scores]
    results["metadatas"] = [metadatas]

    return _results_to_docs_and_scores(results)

Run semantic similarity search. Args: query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List[Document]: List of documents most similar to the query text.

Source code in autolabel/src/autolabel/few_shot/vector_store.py
def similarity_search(
    self,
    query: str,
    k: int = 4,
    filter: Optional[Dict[str, str]] = None,
    **kwargs: Any,
) -> List[Document]:
    """Run semantic similarity search.
    Args:
        query (str): Query text to search for.
        k (int): Number of results to return. Defaults to 4.
        filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
    Returns:
        List[Document]: List of documents most similar to the query text.
    """
    docs_and_scores = self.similarity_search_with_score(query, k, filter=filter)
    return [doc for doc, _ in docs_and_scores]

similarity_search_with_score(query, k=4, filter=None, **kwargs)

Run semantic similarity search and retrieve distances. Args: query (str): Query text to search for. k (int): Number of results to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List[Tuple[Document, float]]: List of documents most similar to the query text with distance in float.

Source code in autolabel/src/autolabel/few_shot/vector_store.py
def similarity_search_with_score(
    self,
    query: str,
    k: int = 4,
    filter: Optional[Dict[str, str]] = None,
    **kwargs: Any,
) -> List[Tuple[Document, float]]:
    """Run semantic similarity search and retrieve distances.
    Args:
        query (str): Query text to search for.
        k (int): Number of results to return. Defaults to 4.
        filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
    Returns:
        List[Tuple[Document, float]]: List of documents most similar to the query
            text with distance in float.
    """
    query_embeddings = torch.tensor([self._get_embeddings([query])[0]])
    result_ids_and_scores = semantic_search(
        corpus_embeddings=self._corpus_embeddings,
        query_embeddings=query_embeddings,
        top_k=k,
    )
    result_ids = [result["corpus_id"] for result in result_ids_and_scores[0]]
    scores = [result["score"] for result in result_ids_and_scores[0]]
    results = {}
    results["documents"] = [[self._texts[index] for index in result_ids]]
    results["distances"] = [scores]
    results["metadatas"] = [[self._metadatas[index] for index in result_ids]]
    return _results_to_docs_and_scores(results)

cos_sim(a, b)

Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. Returns: cos_sim: Matrix with res(i)(j) = cos_sim(a[i], b[j])

Source code in autolabel/src/autolabel/few_shot/vector_store.py
def cos_sim(a: Tensor, b: Tensor) -> Tensor:
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    Returns:
        cos_sim: Matrix with res(i)(j) = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.mm(a_norm, b_norm.transpose(0, 1))

Semantic similarity search based on cosine similarity score. Implementation from this project: https://github.com/UKPLab/sentence-transformers

Source code in autolabel/src/autolabel/few_shot/vector_store.py
def semantic_search(
    query_embeddings: Tensor,
    corpus_embeddings: Tensor,
    query_chunk_size: int = 100,
    corpus_chunk_size: int = 500000,
    top_k: int = 10,
    score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim,
):
    """
    Semantic similarity search based on cosine similarity score. Implementation from this project: https://github.com/UKPLab/sentence-transformers
    """

    if isinstance(query_embeddings, (np.ndarray, np.generic)):
        query_embeddings = torch.from_numpy(query_embeddings)
    elif isinstance(query_embeddings, list):
        query_embeddings = torch.stack(query_embeddings)

    if len(query_embeddings.shape) == 1:
        query_embeddings = query_embeddings.unsqueeze(0)

    if isinstance(corpus_embeddings, (np.ndarray, np.generic)):
        corpus_embeddings = torch.from_numpy(corpus_embeddings)
    elif isinstance(corpus_embeddings, list):
        corpus_embeddings = torch.stack(corpus_embeddings)

    # Check that corpus and queries are on the same device
    if corpus_embeddings.device != query_embeddings.device:
        query_embeddings = query_embeddings.to(corpus_embeddings.device)

    queries_result_list = [[] for _ in range(len(query_embeddings))]

    for query_start_idx in range(0, len(query_embeddings), query_chunk_size):
        # Iterate over chunks of the corpus
        for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size):
            # Compute cosine similarities
            cos_scores = score_function(
                query_embeddings[query_start_idx : query_start_idx + query_chunk_size],
                corpus_embeddings[
                    corpus_start_idx : corpus_start_idx + corpus_chunk_size
                ],
            )

            # Get top-k scores
            cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
                cos_scores,
                min(top_k, len(cos_scores[0])),
                dim=1,
                largest=True,
                sorted=False,
            )
            cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
            cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()

            for query_itr in range(len(cos_scores)):
                for sub_corpus_id, score in zip(
                    cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]
                ):
                    corpus_id = corpus_start_idx + sub_corpus_id
                    query_id = query_start_idx + query_itr
                    if len(queries_result_list[query_id]) < top_k:
                        heapq.heappush(
                            queries_result_list[query_id], (score, corpus_id)
                        )  # heaqp tracks the quantity of the first element in the tuple
                    else:
                        heapq.heappushpop(
                            queries_result_list[query_id], (score, corpus_id)
                        )

    # change the data format and sort
    for query_id in range(len(queries_result_list)):
        for doc_itr in range(len(queries_result_list[query_id])):
            score, corpus_id = queries_result_list[query_id][doc_itr]
            queries_result_list[query_id][doc_itr] = {
                "corpus_id": corpus_id,
                "score": score,
            }
        queries_result_list[query_id] = sorted(
            queries_result_list[query_id], key=lambda x: x["score"], reverse=True
        )
    return queries_result_list