Skip to content

Commit 3266dfc

Browse files
authored
Refactor Gaussian Naive Bayes functions for clarity
1 parent a23f04d commit 3266dfc

1 file changed

Lines changed: 19 additions & 10 deletions

File tree

machine_learning/gaussian_naive_bayes.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def compute_mean_variance(values: list[float]) -> tuple[float, float]:
111111

112112
n = len(values)
113113
mean = sum(values) / n
114-
variance = sum((x - mean) ** 2 for x in values) / n
114+
variance = sum((value - mean) ** 2 for value in values) / n
115115
return mean, max(variance, 1e-9)
116116

117117

@@ -161,18 +161,20 @@ def train(
161161
return priors, summaries
162162

163163

164-
def gaussian_log_probability(x: float, mean: float, variance: float) -> float:
164+
def gaussian_log_probability(
165+
feature_value: float, mean: float, variance: float
166+
) -> float:
165167
"""
166168
Compute the log of the Gaussian probability density for a single value.
167169
168170
Uses the formula:
169-
log P(x | mean, var) = -0.5 * log(2 * pi * var)
170-
- 0.5 * ((x - mean)^2 / var)
171+
log P(feature_value | mean, var) = -0.5 * log(2 * pi * var)
172+
- 0.5 * ((feature_value - mean)^2 / var)
171173
172174
Args:
173-
x: The observed value.
174-
mean: Mean of the Gaussian distribution.
175-
variance: Variance of the Gaussian distribution (must be > 0).
175+
feature_value: The observed feature value.
176+
mean: Mean of the Gaussian distribution.
177+
variance: Variance of the Gaussian distribution (must be > 0).
176178
177179
Returns:
178180
Log probability density as a float.
@@ -191,7 +193,10 @@ def gaussian_log_probability(x: float, mean: float, variance: float) -> float:
191193
"""
192194
if variance <= 0:
193195
raise ValueError("Variance must be positive.")
194-
return -0.5 * math.log(2 * math.pi * variance) - 0.5 * ((x - mean) ** 2 / variance)
196+
return (
197+
-0.5 * math.log(2 * math.pi * variance)
198+
- 0.5 * ((feature_value - mean) ** 2 / variance)
199+
)
195200

196201

197202
def predict_single(
@@ -223,7 +228,9 @@ def predict_single(
223228

224229
for class_label, feature_summaries in summaries.items():
225230
score = priors[class_label]
226-
for feature_value, (mean, variance) in zip(feature_vector, feature_summaries):
231+
for feature_value, (mean, variance) in zip(
232+
feature_vector, feature_summaries
233+
):
227234
score += gaussian_log_probability(feature_value, mean, variance)
228235
if score > best_score:
229236
best_score = score
@@ -300,7 +307,9 @@ def accuracy(predictions: list[int], actual: list[int]) -> float:
300307
if not predictions:
301308
raise ValueError("Inputs must not be empty.")
302309
if len(predictions) != len(actual):
303-
raise ValueError("Predictions and actual labels must have the same length.")
310+
raise ValueError(
311+
"Predictions and actual labels must have the same length."
312+
)
304313
correct = sum(p == a for p, a in zip(predictions, actual))
305314
return correct / len(actual)
306315

0 commit comments

Comments
 (0)