Skip to content

Commit 3e40b33

Browse files
committed
refactor : precommit 적용
1 parent 4e3c06d commit 3e40b33

8 files changed

Lines changed: 100 additions & 31 deletions

File tree

docs/BaseComponent_ko.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,15 @@ retriever = FunctionalComponent(my_retriever, name="MyRetriever", hook=hook)
190190

191191
```python
192192
from lang2sql.core.hooks import MemoryHook
193+
from lang2sql.flows.baseline import SequentialFlow
194+
193195
hook = MemoryHook()
194196

195-
flow = BaselineFlow(steps=[...], hook=hook) # 또는 컴포넌트마다 hook 주입
196-
out = flow.run_query("지난달 매출")
197+
flow = SequentialFlow(steps=[...], hook=hook) # 또는 컴포넌트마다 hook 주입
198+
out = flow.run("지난달 매출")
197199

198200
# 이벤트 확인
199-
for e in hook.events:
201+
for e in hook.snapshot():
200202
print(e.phase, e.component, e.duration_ms, e.error)
201203
```
202204

docs/Hook_and_exception_ko.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,16 @@ class MemoryHook:
111111

112112
#### MemoryHook 사용 예시
113113

114-
```py
114+
```python
115115
from lang2sql.core.hooks import MemoryHook
116-
from lang2sql.flows.baseline import BaselineFlow
116+
from lang2sql.flows.baseline import SequentialFlow
117117

118118
hook = MemoryHook()
119-
flow = BaselineFlow(steps=[...], hook=hook)
119+
flow = SequentialFlow(steps=[...], hook=hook)
120120

121-
out = flow.run_query("지난달 매출")
121+
out = flow.run("지난달 매출")
122122

123-
for e in hook.events:
123+
for e in hook.snapshot():
124124
print(e.name, e.phase, e.component, e.duration_ms, e.error)
125125
```
126126

docs/tutorials/getting-started-without-datahub.md

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,53 @@ print(f"FAISS index saved to: {OUTPUT_DIR}/catalog.faiss")
122122

123123
### 4) 실행
124124

125+
v2 CLI는 외부 벡터 인덱스 경로를 인수로 받지 않습니다.
126+
앞서 생성한 FAISS 인덱스를 활용하려면 Python API로 파이프라인을 직접 구성합니다.
127+
128+
```python
129+
# run_query.py
130+
import os
131+
from dotenv import load_dotenv
132+
from lang2sql import CatalogChunker, VectorRetriever
133+
from lang2sql.integrations.db import SQLAlchemyDB
134+
from lang2sql.integrations.embedding import OpenAIEmbedding
135+
from lang2sql.integrations.llm import OpenAILLM
136+
from lang2sql.integrations.vectorstore import FAISSVectorStore
137+
from lang2sql.flows.hybrid import HybridNL2SQL
138+
139+
load_dotenv()
140+
141+
INDEX_DIR = "./dev/table_info_db"
142+
embedding = OpenAIEmbedding(
143+
model=os.getenv("OPEN_AI_EMBEDDING_MODEL", "text-embedding-3-large"),
144+
api_key=os.getenv("OPEN_AI_KEY"),
145+
)
146+
147+
# FAISS 인덱스 로드 후 파이프라인 구성
148+
store = FAISSVectorStore.load(f"{INDEX_DIR}/catalog.faiss")
149+
150+
pipeline = HybridNL2SQL(
151+
catalog=[], # FAISS에 이미 인덱싱돼 있으므로 빈 리스트
152+
llm=OpenAILLM(model=os.getenv("OPEN_AI_LLM_MODEL", "gpt-4o"), api_key=os.getenv("OPEN_AI_KEY")),
153+
db=SQLAlchemyDB(os.getenv("DB_URL", "sqlite:///sample.db")),
154+
embedding=embedding,
155+
db_dialect=os.getenv("DB_TYPE", "sqlite"),
156+
)
157+
158+
rows = pipeline.run("주문 수를 집계하는 SQL을 만들어줘")
159+
print(rows)
160+
```
161+
162+
Streamlit UI:
163+
125164
```bash
126-
# Streamlit UI
127165
lang2sql run-streamlit
166+
```
128167

129-
# CLI 예시 (FAISS 인덱스 사용)
130-
lang2sql query "주문 수를 집계하는 SQL을 만들어줘" \
131-
--vectordb-type faiss \
132-
--vectordb-location ./dev/table_info_db
168+
CLI (카탈로그 없이 baseline만 가능):
133169

134-
# CLI 예시 (pgvector)
135-
lang2sql query "주문 수를 집계하는 SQL을 만들어줘" \
136-
--vectordb-type pgvector \
137-
--vectordb-location "postgresql://pgvector:pgvector@localhost:5432/postgres"
170+
```bash
171+
lang2sql query "주문 수를 집계해줘" --flow baseline --dialect sqlite
138172
```
139173

140174
### 5) (선택) pgvector로 적재하기
@@ -229,4 +263,3 @@ VectorRetriever.from_chunks(
229263
print(f"pgvector collection populated: {TABLE}")
230264
```
231265

232-
주의: FAISS 디렉토리 또는 pgvector 컬렉션이 없으면 현재 코드는 DataHub에서 메타데이터를 가져와 인덱스를 생성하려고 시도합니다. DataHub를 사용하지 않는 경우 위 절차로 사전에 VectorDB를 만들어 두세요.

src/lang2sql/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from .factory import build_db_from_env, build_embedding_from_env, build_explorer_from_url, build_llm_from_env
1+
from .factory import (
2+
build_db_from_env,
3+
build_embedding_from_env,
4+
build_explorer_from_url,
5+
build_llm_from_env,
6+
)
27
from .components.enrichment.context_enricher import ContextEnricher
38
from .components.enrichment.question_profiler import QuestionProfiler
49
from .components.execution.sql_executor import SQLExecutor
@@ -50,6 +55,7 @@
5055
from .integrations.llm.gemini_ import GeminiLLM
5156
from .integrations.llm.huggingface_ import HuggingFaceLLM
5257
from .integrations.llm.ollama_ import OllamaLLM
58+
5359
__all__ = [
5460
# Data types
5561
"CatalogEntry",
@@ -132,15 +138,16 @@
132138
# ---------------------------------------------------------------------------
133139
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
134140
"DataHubCatalogLoader": (".integrations.catalog.datahub_", "DataHubCatalogLoader"),
135-
"FAISSVectorStore": (".integrations.vectorstore.faiss_", "FAISSVectorStore"),
136-
"PGVectorStore": (".integrations.vectorstore.pgvector_", "PGVectorStore"),
141+
"FAISSVectorStore": (".integrations.vectorstore.faiss_", "FAISSVectorStore"),
142+
"PGVectorStore": (".integrations.vectorstore.pgvector_", "PGVectorStore"),
137143
}
138144

139145

140146
def __getattr__(name: str):
141147
if name in _LAZY_IMPORTS:
142148
module_path, attr = _LAZY_IMPORTS[name]
143149
import importlib
150+
144151
obj = getattr(importlib.import_module(module_path, package=__name__), attr)
145152
# Cache in module globals so subsequent accesses skip __getattr__
146153
globals()[name] = obj

src/lang2sql/core/ports.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def list_tables(self, schema: str | None = None) -> list[str]: ...
7878

7979
def get_ddl(self, table: str, *, schema: str | None = None) -> str: ...
8080

81-
def sample_data(self, table: str, *, limit: int = 5, schema: str | None = None) -> list[dict]: ...
81+
def sample_data(
82+
self, table: str, *, limit: int = 5, schema: str | None = None
83+
) -> list[dict]: ...
8284

8385
def execute_read_only(self, sql: str) -> list[dict]: ...

src/lang2sql/integrations/db/sqlalchemy_.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,17 @@ def execute(self, sql: str) -> list[dict[str, Any]]:
3232

3333

3434
_WRITE_PREFIXES = frozenset(
35-
{"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE", "MERGE"}
35+
{
36+
"INSERT",
37+
"UPDATE",
38+
"DELETE",
39+
"DROP",
40+
"ALTER",
41+
"CREATE",
42+
"TRUNCATE",
43+
"REPLACE",
44+
"MERGE",
45+
}
3646
)
3747

3848

@@ -51,7 +61,9 @@ def __init__(self, url: str, *, schema: str | None = None) -> None:
5161
self._schema = schema
5262

5363
@classmethod
54-
def from_engine(cls, engine: "Engine", *, schema: str | None = None) -> "SQLAlchemyExplorer":
64+
def from_engine(
65+
cls, engine: "Engine", *, schema: str | None = None
66+
) -> "SQLAlchemyExplorer":
5567
"""기존 engine 공유용. 연결 풀 중복 방지."""
5668
instance = cls.__new__(cls)
5769
instance._engine = engine
@@ -86,7 +98,9 @@ def get_ddl(self, table: str, *, schema: str | None = None) -> str:
8698
t = SATable(table, metadata, autoload_with=self._engine, schema=resolved_schema)
8799
return str(CreateTable(t).compile(self._engine))
88100

89-
def sample_data(self, table: str, *, limit: int = 5, schema: str | None = None) -> list[dict]:
101+
def sample_data(
102+
self, table: str, *, limit: int = 5, schema: str | None = None
103+
) -> list[dict]:
90104
"""실제 샘플 데이터 반환.
91105
92106
f-string SQL 금지 — SQLAlchemy ORM select()로 identifier quoting 위임.

tests/test_components_vector_retriever.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,9 @@ def test_save_and_load_returns_same_results(tmp_path):
535535

536536
store = FAISSVectorStore(index_path=path + ".faiss")
537537
chunks = CatalogChunker().split(CATALOG)
538-
original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store)
538+
original = VectorRetriever.from_chunks(
539+
chunks, embedding=embedding, vectorstore=store
540+
)
539541
original.save(path)
540542

541543
loaded_store = FAISSVectorStore.load(path)
@@ -555,7 +557,9 @@ def test_load_registry_intact(tmp_path):
555557

556558
store = FAISSVectorStore(index_path=path + ".faiss")
557559
chunks = CatalogChunker().split(CATALOG)
558-
original = VectorRetriever.from_chunks(chunks, embedding=embedding, vectorstore=store)
560+
original = VectorRetriever.from_chunks(
561+
chunks, embedding=embedding, vectorstore=store
562+
)
559563
original.save(path)
560564

561565
loaded_store = FAISSVectorStore.load(path)
@@ -571,7 +575,9 @@ def test_save_raises_for_inmemory():
571575
"""InMemoryVectorStore는 save()를 지원하지 않아 NotImplementedError가 발생한다."""
572576
embedding = FakeEmbeddingFAISS()
573577
chunks = CatalogChunker().split(CATALOG)
574-
retriever = VectorRetriever.from_chunks(chunks, embedding=embedding) # InMemory 기본값
578+
retriever = VectorRetriever.from_chunks(
579+
chunks, embedding=embedding
580+
) # InMemory 기본값
575581

576582
with pytest.raises(NotImplementedError, match="does not support save"):
577583
retriever.save("/tmp/should_not_exist")

tests/test_integrations_sqlalchemy_explorer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import pytest
66
from sqlalchemy import create_engine, text
77

8-
98
# ---------------------------------------------------------------------------
109
# Fixture: SQLite in-memory DB with FK schema
1110
# ---------------------------------------------------------------------------
1211

12+
1313
@pytest.fixture()
1414
def engine():
1515
eng = create_engine("sqlite:///:memory:")
@@ -29,7 +29,9 @@ def engine():
2929
status TEXT DEFAULT 'pending'
3030
)
3131
"""))
32-
conn.execute(text("INSERT INTO customers VALUES (1, 'Alice', 'alice@example.com')"))
32+
conn.execute(
33+
text("INSERT INTO customers VALUES (1, 'Alice', 'alice@example.com')")
34+
)
3335
conn.execute(text("INSERT INTO customers VALUES (2, 'Bob', 'bob@example.com')"))
3436
conn.execute(text("INSERT INTO orders VALUES (1, 1, 99.9, 'shipped')"))
3537
conn.execute(text("INSERT INTO orders VALUES (2, 2, 42.0, 'pending')"))
@@ -48,6 +50,7 @@ def explorer(engine):
4850
# Tests
4951
# ---------------------------------------------------------------------------
5052

53+
5154
def test_list_tables(explorer):
5255
tables = explorer.list_tables()
5356
assert set(tables) == {"customers", "orders"}
@@ -98,7 +101,9 @@ def test_execute_read_only_select(explorer):
98101

99102
def test_execute_read_only_rejects_insert(explorer):
100103
with pytest.raises(ValueError, match="Write operations not allowed"):
101-
explorer.execute_read_only("INSERT INTO customers VALUES (3, 'Eve', 'eve@x.com')")
104+
explorer.execute_read_only(
105+
"INSERT INTO customers VALUES (3, 'Eve', 'eve@x.com')"
106+
)
102107

103108

104109
def test_execute_read_only_rejects_drop(explorer):

0 commit comments

Comments
 (0)