|
1 | 1 | from typing import Any, List |
2 | 2 |
|
| 3 | +import psutil |
3 | 4 | import torch |
4 | 5 | import torch.nn as nn |
5 | 6 | import torch.nn.functional as F |
@@ -84,7 +85,9 @@ def forward(self, x): |
84 | 85 | ) |
85 | 86 |
|
86 | 87 | partitioned_module = resource_partition( |
87 | | - partitioned_module, cpu_memory_budget=2 * 1024 * 1024 * 1024 # 2GB, |
| 88 | + partitioned_module, |
| 89 | + cpu_memory_budget=0.89 * 1024 * 1024 * 1024 |
| 90 | + + psutil.Process().memory_info().rss, # 0.89GB + current memory usage, |
88 | 91 | ) |
89 | 92 |
|
90 | 93 | self.assertEqual( |
@@ -166,7 +169,9 @@ def forward(self, x): |
166 | 169 | ) |
167 | 170 |
|
168 | 171 | partitioned_module = resource_partition( |
169 | | - partitioned_module, cpu_memory_budget=1.4 * 1024 * 1024 * 1024 # 1.4GB, |
| 172 | + partitioned_module, |
| 173 | + cpu_memory_budget=0.39 * 1024 * 1024 * 1024 |
| 174 | + + psutil.Process().memory_info().rss, # 0.39GB + current memory usage, |
170 | 175 | ) |
171 | 176 |
|
172 | 177 | assert ( |
@@ -298,7 +303,9 @@ def forward(self, x): |
298 | 303 | ) |
299 | 304 |
|
300 | 305 | partitioned_module = resource_partition( |
301 | | - partitioned_module, cpu_memory_budget=1.4 * 1024 * 1024 * 1024 # 1.4GB, |
| 306 | + partitioned_module, |
| 307 | + cpu_memory_budget=0.39 * 1024 * 1024 * 1024 |
| 308 | + + psutil.Process().memory_info().rss, # 0.89GB + current memory usage, |
302 | 309 | ) |
303 | 310 |
|
304 | 311 | assert ( |
@@ -396,7 +403,9 @@ def forward(self, x): |
396 | 403 | ) |
397 | 404 |
|
398 | 405 | partitioned_module = resource_partition( |
399 | | - partitioned_module, cpu_memory_budget=1.4 * 1024 * 1024 * 1024 # 1.4GB, |
| 406 | + partitioned_module, |
| 407 | + cpu_memory_budget=0.39 * 1024 * 1024 * 1024 |
| 408 | + + psutil.Process().memory_info().rss, # 0.89GB + current memory usage, |
400 | 409 | ) |
401 | 410 |
|
402 | 411 | assert ( |
|
0 commit comments