|
1 | 1 | """ |
2 | 2 | Extract Invoke5 metadata from the raw metadata dictionary. |
3 | 3 | """ |
| 4 | + |
| 5 | +import itertools |
4 | 6 | import logging |
5 | 7 |
|
6 | 8 | from .invoke_metadata_abc import ( |
@@ -84,24 +86,32 @@ def _get_reference_images(self) -> list[ReferenceImage]: |
84 | 86 | This is called to get the reference image when the metadata contains a ref_images field. |
85 | 87 | """ |
86 | 88 | reference_images = self.raw_metadata.get("ref_images", []) |
87 | | - return [ |
88 | | - ReferenceImage( |
89 | | - model_name=image.get("config", {}) |
90 | | - .get("model", {}) |
91 | | - .get("name", "Unknown Model"), |
92 | | - image_name=image.get("config", {}) |
93 | | - .get("image", {}) |
94 | | - .get("image_name", "") |
95 | | - or image.get("config", {}) |
96 | | - .get("image", {}) |
97 | | - .get("original", {}) |
98 | | - .get("image", {}) |
99 | | - .get("image_name", "Unknown Image"), |
100 | | - weight=image.get("config", {}).get("weight", 1.0), |
| 89 | + # for some reason, ref_images can be a list of lists |
| 90 | + if any(isinstance(img, list) for img in reference_images): |
| 91 | + reference_images = list(itertools.chain.from_iterable(reference_images)) |
| 92 | + |
| 93 | + reference_image_list = [] |
| 94 | + for image in reference_images: |
| 95 | + if image.get("isEnabled", False) is False: |
| 96 | + continue |
| 97 | + model = image.get("config", {}).get("model", {}) or {} |
| 98 | + model_name = model.get("name", "N/A") |
| 99 | + image_name = image.get("config", {}).get("image", {}).get( |
| 100 | + "image_name", "" |
| 101 | + ) or image.get("config", {}).get("image", {}).get("original", {}).get( |
| 102 | + "image", {} |
| 103 | + ).get( |
| 104 | + "image_name", "Unknown Image" |
101 | 105 | ) |
102 | | - for image in reference_images |
103 | | - if image.get("isEnabled", False) |
104 | | - ] |
| 106 | + weight = image.get("config", {}).get("weight", 1.0) |
| 107 | + reference_image_list.append( |
| 108 | + ReferenceImage( |
| 109 | + model_name=model_name, |
| 110 | + image_name=image_name, |
| 111 | + weight=weight, |
| 112 | + ) |
| 113 | + ) |
| 114 | + return reference_image_list |
105 | 115 |
|
106 | 116 | def _get_reference_images_from_canvas_v2(self) -> list[ReferenceImage]: |
107 | 117 | """ |
|
0 commit comments