Skip to content

Commit b3012ee

Browse files
authored
fix: multiword broken (#317)
* fix: multiword broken * increase version
1 parent 226316c commit b3012ee

4 files changed

Lines changed: 25 additions & 9 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ def distill_from_model(
8888

8989
# Create the vocabulary in the new tokenizer.
9090
tokenizer_model = clean_and_create_vocabulary(tokenizer_model, vocabulary, token_remove_regex=token_remove_regex)
91-
# Remove the post processor, this is not necessary.
92-
tokenizer_model.post_processor = None
93-
# Prune again now that the post processor is gone.
94-
# We can't do this before because we need the post processor and associated
95-
# tokens before to add eos/bos.
96-
tokenizer_model = tokenizer_model.prune_added_tokens()
9791

9892
# All tokens in a single list.
9993
all_tokens = tokenizer_model.sorted_vocabulary

model2vec/tokenizer/tokenizer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,18 @@ def clean_and_create_vocabulary(
1313
vocabulary_to_add: list[str],
1414
token_remove_regex: re.Pattern[str] | None,
1515
) -> TokenizerModel:
16-
"""Clean a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
16+
"""
17+
Clean a vocabulary by removing duplicates and tokens that were already in the vocabulary.
18+
19+
This function removes duplicate tokens and tokens that are already in the model's vocabulary.
20+
It also removes the tokenizer's post-processor, which we do not use anyway.
21+
22+
:param model: The tokenizer model to clean.
23+
:param vocabulary_to_add: The vocabulary to add to the model. Any tokens in this vocabulary that
24+
are split according to the pretokenizer are added as multiword tokens.
25+
:param token_remove_regex: A regex pattern to remove tokens from the vocabulary.
26+
:return: The cleaned tokenizer model.
27+
"""
1728
seen_tokens = set()
1829

1930
n_duplicate = 0
@@ -39,7 +50,9 @@ def clean_and_create_vocabulary(
3950
if len(preprocessed) > 1:
4051
tokens_as_str = [f"'{subword}'" for subword in preprocessed]
4152
split_into = ",".join(tokens_as_str)
42-
logger.warning(f"Token '{token}' was split into multiple tokens after preprocessing: [{split_into}]")
53+
logger.warning(
54+
f"Token '{token}' was split into multiple tokens after preprocessing: [{split_into}], adding it as a multi-word token."
55+
)
4356
added_tokens_to_add.append(token)
4457
continue
4558
token = preprocessed[0]
@@ -54,6 +67,8 @@ def clean_and_create_vocabulary(
5467
seen_tokens.add(token)
5568
tokens_to_add.append(token)
5669

70+
model.post_processor = None
71+
model = model.prune_added_tokens()
5772
model = model.add_tokens_to_vocabulary(tokens_to_add, preprocess_tokens=True)
5873
model = model.add_addedtokens(added_tokens_to_add, is_special=False, single_word=False, normalized=True)
5974

model2vec/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version_triple__ = (0, 8, 1)
1+
__version_triple__ = (0, 8, 2)
22
__version__ = ".".join(map(str, __version_triple__))

tests/test_distillation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
(None, 1024, None), # Subword, PCA set high, SIF off
3838
(None, None, 1e-4), # No PCA, SIF on
3939
(None, 0.9, 1e-4), # PCA as float (variance), SIF on
40+
(["star wars"], 8, None), # Multiword vocabulary
4041
],
4142
)
4243
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@@ -79,6 +80,12 @@ def test_distill_from_model(
7980
assert json.loads(static_model.tokenizer.to_str()) == json.loads(static_model2.tokenizer.to_str())
8081
assert static_model.base_model_name == static_model2.base_model_name
8182

83+
for token in vocabulary or []:
84+
# Normalized tokens are for single-word tokens.
85+
# Other tokens are added as addedtokens, as is.
86+
normalized = static_model.tokenizer.normalizer.normalize_str(token)
87+
assert token in static_model.tokens or normalized in static_model.tokens
88+
8289

8390
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
8491
@patch("transformers.AutoModel.from_pretrained")

0 commit comments

Comments
 (0)