-
Notifications
You must be signed in to change notification settings - Fork 4.6k
feat(ml): add qdrant ingestion #38142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
538ccf8
daceae0
ee63bf9
aa139e8
4a74ef6
d32064f
60856fc
9238f33
952a715
f02c243
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,322 @@ | ||||||||
| # | ||||||||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||||||||
| # contributor license agreements. See the NOTICE file distributed with | ||||||||
| # this work for additional information regarding copyright ownership. | ||||||||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||||||||
| # (the "License"); you may not use this file except in compliance with | ||||||||
| # the License. You may obtain a copy of the License at | ||||||||
| # | ||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
| # | ||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
| # See the License for the specific language governing permissions and | ||||||||
| # limitations under the License. | ||||||||
|
|
||||||||
| import logging | ||||||||
| import time | ||||||||
| from collections.abc import Callable | ||||||||
| from dataclasses import dataclass, field | ||||||||
| from typing import Any, Optional | ||||||||
|
|
||||||||
| import grpc | ||||||||
| from objsize import get_deep_size | ||||||||
|
|
||||||||
| try: | ||||||||
| from qdrant_client import QdrantClient, models | ||||||||
| from qdrant_client.common.client_exceptions import ResourceExhaustedResponse | ||||||||
| from qdrant_client.http.exceptions import ResponseHandlingException, UnexpectedResponse | ||||||||
| except ImportError: | ||||||||
| logging.warning("Qdrant client library is not installed.") | ||||||||
|
|
||||||||
| import apache_beam as beam | ||||||||
| from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig | ||||||||
| from apache_beam.ml.rag.types import EmbeddableItem | ||||||||
|
|
||||||||
| DEFAULT_WRITE_BATCH_SIZE = 1000 | ||||||||
| DEFAULT_MAX_BATCH_BYTE_SIZE = 4 << 20 | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class QdrantConnectionParameters: | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add classmethod factories to make it clearer which combinations of parameters are valid? Something like |
||||||||
| """Configuration parameters for connecting to Qdrant service. | ||||||||
|
|
||||||||
| Either `location`, `url`, `host`, or `path` must be provided to establish | ||||||||
| a connection. | ||||||||
|
|
||||||||
| Args: | ||||||||
| location: | ||||||||
| If `str` - use it as a `url` parameter. | ||||||||
| If `None` - use default values for `host` and `port`. | ||||||||
| url: either host or str of "<scheme>//<host>:<port>/<prefix>". | ||||||||
| Default: `None` | ||||||||
| port: Port of the REST API interface. Default: 6333 | ||||||||
| grpc_port: Port of the gRPC interface. Default: 6334 | ||||||||
| prefer_grpc: If `true` - use gPRC interface whenever possible. | ||||||||
| https: If `true` - use HTTPS(SSL) protocol. Default: `None` | ||||||||
| api_key: API key for authentication in Qdrant Cloud. Default: `None` | ||||||||
| prefix: | ||||||||
| If not `None` - add `prefix` to the REST URL path. | ||||||||
| Example: `service/v1` will result in | ||||||||
| `http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API. | ||||||||
| Default: `None` | ||||||||
| timeout: | ||||||||
| Timeout for REST and gRPC API requests. | ||||||||
| Default: 5 seconds for REST and unlimited for gRPC | ||||||||
| host: | ||||||||
| Host name of Qdrant service. | ||||||||
| If url and host are None, set to 'localhost'. | ||||||||
| Default: `None` | ||||||||
| path: Persistence path for QdrantLocal. Default: `None` | ||||||||
| **kwargs: Additional arguments passed directly into client initialization | ||||||||
| """ | ||||||||
|
|
||||||||
| location: Optional[str] = None | ||||||||
| url: Optional[str] = None | ||||||||
| port: Optional[int] = 6333 | ||||||||
| grpc_port: int = 6334 | ||||||||
| prefer_grpc: bool = False | ||||||||
| https: Optional[bool] = None | ||||||||
| api_key: Optional[str] = None | ||||||||
| prefix: Optional[str] = None | ||||||||
| timeout: Optional[int] = None | ||||||||
| host: Optional[str] = None | ||||||||
| path: Optional[str] = None | ||||||||
| kwargs: dict[str, Any] = field(default_factory=dict) | ||||||||
|
|
||||||||
| def __post_init__(self): | ||||||||
| if not (self.location or self.url or self.host or self.path): | ||||||||
| raise ValueError( | ||||||||
| "One of location, url, host, or path must be provided for Qdrant") | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def for_cloud( | ||||||||
| cls, | ||||||||
| url: str, | ||||||||
| api_key: str, | ||||||||
| *, | ||||||||
| prefer_grpc: bool = False, | ||||||||
| timeout: Optional[int] = None, | ||||||||
| **kwargs: Any, | ||||||||
| ) -> "QdrantConnectionParameters": | ||||||||
| """Connect to Qdrant Cloud. Requires the cluster URL and an API key.""" | ||||||||
| return cls( | ||||||||
| url=url, | ||||||||
| api_key=api_key, | ||||||||
| https=True, | ||||||||
| prefer_grpc=prefer_grpc, | ||||||||
| timeout=timeout, | ||||||||
| kwargs=kwargs, | ||||||||
| ) | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def for_host( | ||||||||
| cls, | ||||||||
| host: str, | ||||||||
| port: int = 6333, | ||||||||
| *, | ||||||||
| grpc_port: int = 6334, | ||||||||
| prefer_grpc: bool = False, | ||||||||
| https: bool = False, | ||||||||
| api_key: Optional[str] = None, | ||||||||
| timeout: Optional[int] = None, | ||||||||
| **kwargs: Any, | ||||||||
| ) -> "QdrantConnectionParameters": | ||||||||
| """Connect to a self-hosted Qdrant instance by host and port.""" | ||||||||
| return cls( | ||||||||
| host=host, | ||||||||
| port=port, | ||||||||
| grpc_port=grpc_port, | ||||||||
| prefer_grpc=prefer_grpc, | ||||||||
| https=https, | ||||||||
| api_key=api_key, | ||||||||
| timeout=timeout, | ||||||||
| kwargs=kwargs, | ||||||||
| ) | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def for_url( | ||||||||
| cls, | ||||||||
| url: str, | ||||||||
| *, | ||||||||
| api_key: Optional[str] = None, | ||||||||
| prefer_grpc: bool = False, | ||||||||
| timeout: Optional[int] = None, | ||||||||
| **kwargs: Any, | ||||||||
| ) -> "QdrantConnectionParameters": | ||||||||
| """Connect using a full URL like 'https://my-qdrant.example.com:6333'.""" | ||||||||
| return cls( | ||||||||
| url=url, | ||||||||
| api_key=api_key, | ||||||||
| prefer_grpc=prefer_grpc, | ||||||||
| timeout=timeout, | ||||||||
| kwargs=kwargs) | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def local(cls, path: str) -> "QdrantConnectionParameters": | ||||||||
| """Use an embedded Qdrant instance persisted to the given path.""" | ||||||||
| return cls(path=path) | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def in_memory(cls) -> "QdrantConnectionParameters": | ||||||||
| """Use an embedded in-memory Qdrant instance. Useful for tests.""" | ||||||||
| return cls(location=":memory:") | ||||||||
|
|
||||||||
|
|
||||||||
| @dataclass | ||||||||
| class QdrantWriteConfig(VectorDatabaseWriteConfig): | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar thought here, a docstring should be provided here since this is the entrypoint for users to drop the qdrant write into their pipelines |
||||||||
| """Configuration for writing to Qdrant vector database. | ||||||||
|
|
||||||||
| This class defines the parameters needed to write data to a qdrant collection, | ||||||||
| including collection targeting, batching behavior, and operation timeouts. | ||||||||
|
|
||||||||
| Args: | ||||||||
| connection_params: QdrantConnectionParameters with connection settings. | ||||||||
| collection_name: Name of the Qdrant collection to write to. | ||||||||
| timeout: Optional timeout for write operations in seconds. Default is None. | ||||||||
| batch_size: Number of points to write in each batch. Default is 1000. | ||||||||
| kwargs: Additional keyword arguments to pass to the client's upsert method. | ||||||||
| dense_embedding_key: name for the dense vector in the qdrant collection. | ||||||||
| sparse_embedding_key: name for the sparse vector in the qdrant collection. | ||||||||
| """ | ||||||||
|
|
||||||||
| connection_params: QdrantConnectionParameters | ||||||||
| collection_name: str | ||||||||
| timeout: Optional[int] = None | ||||||||
| batch_size: int = DEFAULT_WRITE_BATCH_SIZE | ||||||||
| max_batch_byte_size: int = DEFAULT_MAX_BATCH_BYTE_SIZE | ||||||||
| kwargs: dict[str, Any] = field(default_factory=dict) | ||||||||
| dense_embedding_key: str = "dense" | ||||||||
| sparse_embedding_key: str = "sparse" | ||||||||
|
|
||||||||
| def __post_init__(self): | ||||||||
| if not self.collection_name: | ||||||||
| raise ValueError("Collection name must be provided") | ||||||||
|
MichaelGruschke marked this conversation as resolved.
|
||||||||
| if self.batch_size <= 0: | ||||||||
| raise ValueError("Batch size must be a positive integer") | ||||||||
|
|
||||||||
| def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]: | ||||||||
| return _QdrantWriteTransform(self) | ||||||||
|
|
||||||||
| def create_converter( | ||||||||
| self, | ||||||||
| ) -> Callable[[EmbeddableItem], "models.PointStruct"]: | ||||||||
| def convert(item: EmbeddableItem) -> "models.PointStruct": | ||||||||
| if item.dense_embedding is None and item.sparse_embedding is None: | ||||||||
| raise ValueError( | ||||||||
| "EmbeddableItem must have at least one embedding (dense or sparse)") | ||||||||
| vector = {} | ||||||||
| if item.dense_embedding is not None: | ||||||||
| vector[self.dense_embedding_key] = item.dense_embedding | ||||||||
| if item.sparse_embedding is not None: | ||||||||
| sparse_indices, sparse_values = item.sparse_embedding | ||||||||
| vector[self.sparse_embedding_key] = models.SparseVector( | ||||||||
| indices=sparse_indices, | ||||||||
| values=sparse_values, | ||||||||
| ) | ||||||||
| id = ( | ||||||||
| int(item.id) | ||||||||
| if isinstance(item.id, str) and item.id.isdigit() else item.id) | ||||||||
| return models.PointStruct( | ||||||||
| id=id, | ||||||||
| vector=vector, | ||||||||
| payload=item.metadata if item.metadata else None, | ||||||||
| ) | ||||||||
|
|
||||||||
| return convert | ||||||||
|
|
||||||||
|
|
||||||||
| class _QdrantWriteTransform(beam.PTransform): | ||||||||
| def __init__(self, config: QdrantWriteConfig): | ||||||||
| self.config = config | ||||||||
|
|
||||||||
| def expand(self, input_or_inputs: beam.PCollection[EmbeddableItem]): | ||||||||
| return ( | ||||||||
| input_or_inputs | ||||||||
| | "Convert to Records" >> beam.Map(self.config.create_converter()) | ||||||||
| | beam.ParDo(_QdrantWriteFn(self.config))) | ||||||||
|
|
||||||||
|
|
||||||||
| class _QdrantWriteFn(beam.DoFn): | ||||||||
| def __init__(self, config: QdrantWriteConfig): | ||||||||
| self.config = config | ||||||||
| self._client: "Optional[QdrantClient]" = None | ||||||||
|
MichaelGruschke marked this conversation as resolved.
|
||||||||
|
|
||||||||
| def start_bundle(self): | ||||||||
| self._batch = [] | ||||||||
| self._batch_byte_size = 0 | ||||||||
|
|
||||||||
| def process(self, element, *args, **kwargs): | ||||||||
| element_byte_size = get_deep_size(element) | ||||||||
| new_batch_byte_size = self._batch_byte_size + element_byte_size | ||||||||
|
|
||||||||
| is_batch_full = len(self._batch) >= self.config.batch_size | ||||||||
| is_batch_too_large = new_batch_byte_size > self.config.max_batch_byte_size | ||||||||
| if (is_batch_full or is_batch_too_large): | ||||||||
| self._flush() | ||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a byte size limit for individual batches, similar to BigQuery streaming inserts beam/sdks/python/apache_beam/io/gcp/bigquery.py Line 1655 in efe4e94
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added early flush when batch gets too big |
||||||||
| self._batch.append(element) | ||||||||
| self._batch_byte_size += element_byte_size | ||||||||
|
|
||||||||
| def setup(self): | ||||||||
| params = self.config.connection_params | ||||||||
| self._client = QdrantClient( | ||||||||
| location=params.location, | ||||||||
| url=params.url, | ||||||||
| port=params.port, | ||||||||
| grpc_port=params.grpc_port, | ||||||||
| prefer_grpc=params.prefer_grpc, | ||||||||
| https=params.https, | ||||||||
| api_key=params.api_key, | ||||||||
| prefix=params.prefix, | ||||||||
| timeout=params.timeout, | ||||||||
| host=params.host, | ||||||||
| path=params.path, | ||||||||
| check_compatibility=False, | ||||||||
| **params.kwargs, | ||||||||
|
Comment on lines
+275
to
+276
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Disabling the compatibility check (
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah this is a little bit tricky, check_compatibility would make it really restrictive as only two minor versions difference is allowed between client and server |
||||||||
| ) | ||||||||
|
|
||||||||
| def teardown(self): | ||||||||
| if self._client: | ||||||||
| self._client.close() | ||||||||
| self._client = None | ||||||||
|
|
||||||||
| def finish_bundle(self): | ||||||||
| self._flush() | ||||||||
|
|
||||||||
| def _flush(self): | ||||||||
| if not self._batch: | ||||||||
| return | ||||||||
| if not self._client: | ||||||||
| raise RuntimeError("Qdrant client is not initialized") | ||||||||
|
|
||||||||
| max_retries = 3 | ||||||||
| attempt = 1 | ||||||||
| while True: | ||||||||
| try: | ||||||||
| self._client.upsert( | ||||||||
| collection_name=self.config.collection_name, | ||||||||
| points=self._batch, | ||||||||
| timeout=self.config.timeout, | ||||||||
| **self.config.kwargs, | ||||||||
| ) | ||||||||
| break | ||||||||
| except ResourceExhaustedResponse as e: | ||||||||
| time.sleep(e.retry_after_s) | ||||||||
| # don't count rate-limit against max_retries | ||||||||
| continue | ||||||||
| except (UnexpectedResponse, ResponseHandlingException, | ||||||||
| grpc.RpcError) as e: | ||||||||
| if attempt > max_retries: | ||||||||
| raise | ||||||||
| time.sleep(2**attempt) | ||||||||
| attempt += 1 | ||||||||
| self._batch = [] | ||||||||
| self._batch_byte_size = 0 | ||||||||
|
|
||||||||
| def display_data(self): | ||||||||
| res = super().display_data() | ||||||||
| res["collection"] = self.config.collection_name | ||||||||
| res["batch_size"] = self.config.batch_size | ||||||||
| res["max_batch_byte_size"] = self.config.max_batch_byte_size | ||||||||
| return res | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A docstring outlining each field and the mandatory information required to create a valid set of parameters will make this much more user friendly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point! added docstring