Skip to content

Commit 428813f

Browse files
author
Rishal Hurbans
committed
Chapter 12 fix: Added support for running scripts from different paths and referencing model files correctly.
1 parent b6b4bce commit 428813f

4 files changed

Lines changed: 34 additions & 18 deletions

File tree

ch12-generative_image_models/numpy_u_net/toy_numpy_u_net_generate.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import numpy as np
2-
import matplotlib.pyplot as plt
31
import math
42
import pickle
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
57

68
# Step 0: Redefine the NumPy Layers and UNet Class
79
# We must have the exact same class definitions as the training script
@@ -160,13 +162,14 @@ def set_params(self, params):
160162
# Step 1: Load Model and Define Helpers
161163
print("\nStep 1: Loading Model and Defining Helpers")
162164
model = NumPyUNet()
165+
MODEL_PATH = Path(__file__).resolve().parent / "trained_numpy_unet_model.pkl"
163166
try:
164-
with open('trained_numpy_unet_model.pkl', 'rb') as f:
167+
with MODEL_PATH.open('rb') as f:
165168
params = pickle.load(f)
166169
model.set_params(params)
167170
print("Successfully loaded trained NumPy U-Net model weights.")
168171
except FileNotFoundError:
169-
print("Error: 'trained_numpy_unet_model.pkl' not found.")
172+
print(f"Error: '{MODEL_PATH}' not found.")
170173
print("Please run the NumPy training script first.")
171174
exit()
172175

ch12-generative_image_models/numpy_u_net/toy_numpy_u_net_train.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import numpy as np
2-
import matplotlib.pyplot as plt
31
import math
42
import pickle
3+
from pathlib import Path
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
57

68
# Seed randomness for reproducibility
79
np.random.seed(42)
@@ -288,6 +290,9 @@ def set_params(self, params):
288290
if f'{name}_gamma' in params: layer.gamma = params[f'{name}_gamma']
289291
if f'{name}_beta' in params: layer.beta = params[f'{name}_beta']
290292

293+
# Where to persist the trained model so the script works regardless of CWD
294+
MODEL_PATH = Path(__file__).resolve().parent / "trained_numpy_unet_model.pkl"
295+
291296
# Step 1: Define The Training Data
292297
print("\nStep 1: Defining The Training Data")
293298
training_data = [
@@ -577,6 +582,6 @@ def create_timestep_embedding(t, embedding_dim):
577582
plt.show()
578583

579584
print("\n--- Saving Model Weights ---")
580-
with open('trained_numpy_unet_model.pkl', 'wb') as f:
585+
with MODEL_PATH.open('wb') as f:
581586
pickle.dump(model.get_params(), f)
582-
print("Model weights saved to 'trained_numpy_unet_model.pkl'")
587+
print(f"Model weights saved to '{MODEL_PATH}'")

ch12-generative_image_models/pytorch-u_net/toy_pytorch_u_net_generate.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import math
2+
from pathlib import Path
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
16
import torch
27
import torch.nn as nn
3-
import numpy as np
4-
import matplotlib.pyplot as plt
5-
import math
68

79
# Step 1: Setup and Model Definition
810
print("Step 1: Setting up environment and model definition")
@@ -54,11 +56,12 @@ def forward(self, x, t_emb, c_emb):
5456
print("\nStep 2: Loading Model and Defining Helpers")
5557
# Initialize the model and load the saved weights
5658
model = UNet().to(device)
59+
MODEL_PATH = Path(__file__).resolve().parent / "trained_pytorch_unet_model.pth"
5760
try:
58-
model.load_state_dict(torch.load('/trained_pytorch_unet_model.pth', map_location=device))
61+
model.load_state_dict(torch.load(str(MODEL_PATH), map_location=device))
5962
print("Successfully loaded trained U-Net model weights.")
6063
except FileNotFoundError:
61-
print("Error: 'trained_unet_model.pth' not found.")
64+
print(f"Error: '{MODEL_PATH}' not found.")
6265
print("Please run the training script first to create the model file.")
6366
exit()
6467
model.eval() # Set the model to evaluation mode

ch12-generative_image_models/pytorch-u_net/toy_pytorch_u_net_train.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1+
import math
2+
from pathlib import Path
3+
4+
import matplotlib.pyplot as plt
5+
import numpy as np
16
import torch
27
import torch.nn as nn
3-
import numpy as np
4-
import matplotlib.pyplot as plt
5-
import math
68

79
# Use a GPU if available, otherwise use the CPU
810
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
911
print(f"Using device: {device}")
1012

13+
# Where to persist the trained model so the script works regardless of CWD
14+
MODEL_PATH = Path(__file__).resolve().parent / "trained_pytorch_unet_model.pth"
15+
1116
# Step 1: Define the Training Data
1217
print("\nStep 1: Defining The Training Data")
1318
training_data = [
@@ -225,5 +230,5 @@ def forward(self, x, t_emb, c_emb):
225230

226231
# Save the Trained Model Weights
227232
print("\nSaving Model Weights")
228-
torch.save(model.state_dict(), '/trained_pytorch_unet_model.pth')
229-
print("Model weights saved to 'trained_unet_model.pth'")
233+
torch.save(model.state_dict(), str(MODEL_PATH))
234+
print(f"Model weights saved to '{MODEL_PATH}'")

0 commit comments

Comments
 (0)