2222
2323
2424class UnitTests (absltest .TestCase ):
25- def test_tuned_models_create (self ):
26- # [START tuned_models_create]
25+ @classmethod
26+ def setUpClass (cls ):
27+ # Code to run once before all tests in the class
28+ # [START tuned_models_create]
2729 import google .generativeai as genai
2830
2931 import time
@@ -53,7 +55,7 @@ def test_tuned_models_create(self):
5355 # You can use a tuned model here too. Set `source_model="tunedModels/..."`
5456 display_name = "increment" ,
5557 source_model = base_model ,
56- epoch_count = 20 ,
58+ epoch_count = 5 ,
5759 batch_size = 4 ,
5860 learning_rate = 0.001 ,
5961 training_data = training_data ,
@@ -62,22 +64,25 @@ def test_tuned_models_create(self):
6264 for status in operation .wait_bar ():
6365 time .sleep (10 )
6466
65- result = operation .result ()
66- print (result )
67+ tuned_model = operation .result ()
68+ print (tuned_model )
6769 # # You can plot the loss curve with:
6870 # snapshots = pd.DataFrame(result.tuning_task.snapshots)
6971 # sns.lineplot(data=snapshots, x='epoch', y='mean_loss')
7072
71- model = genai .GenerativeModel (model_name = result .name )
73+ model = genai .GenerativeModel (model_name = tuned_model .name )
7274 result = model .generate_content ("III" )
7375 print (result .text ) # IV
7476 # [END tuned_models_create]
77+
78+ cls .tuned_model_name = tuned_model_name = tuned_model .name
79+
7580
7681 def test_tuned_models_generate_content (self ):
7782 # [START tuned_models_generate_content]
7883 import google .generativeai as genai
7984
80- model = genai .GenerativeModel (model_name = "tunedModels/my-increment-model" )
85+ model = genai .GenerativeModel (model_name = self . tuned_model_name )
8186 result = model .generate_content ("III" )
8287 print (result .text ) # "IV"
8388 # [END tuned_models_generate_content]
@@ -86,7 +91,7 @@ def test_tuned_models_get(self):
8691 # [START tuned_models_get]
8792 import google .generativeai as genai
8893
89- model_info = genai .get_model ("tunedModels/my-increment-model" )
94+ model_info = genai .get_model (self . tuned_model_name )
9095 print (model_info )
9196 # [END tuned_models_get]
9297
@@ -100,6 +105,7 @@ def test_tuned_models_list(self):
100105
101106 def test_tuned_models_delete (self ):
102107 import time
108+ import google .generativeai as genai
103109
104110 base_model = "models/gemini-1.5-flash-001-tuning"
105111 training_data = samples / "increment_tuning_data.json"
@@ -109,7 +115,7 @@ def test_tuned_models_delete(self):
109115 # You can use a tuned model here too. Set `source_model="tunedModels/..."`
110116 display_name = "increment" ,
111117 source_model = base_model ,
112- epoch_count = 20 ,
118+ epoch_count = 5 ,
113119 batch_size = 4 ,
114120 learning_rate = 0.001 ,
115121 training_data = training_data ,
@@ -135,7 +141,7 @@ def test_tuned_models_permissions_create(self):
135141 # [START tuned_models_permissions_create]
136142 import google .generativeai as genai
137143
138- model_info = genai .get_model ("tunedModels/my-increment-model" )
144+ model_info = genai .get_model (self . tuned_model_name )
139145 # [START_EXCLUDE]
140146 for p in model_info .permissions .list ():
141147 if p .role .name != "OWNER" :
@@ -161,7 +167,7 @@ def test_tuned_models_permissions_list(self):
161167 # [START tuned_models_permissions_list]
162168 import google .generativeai as genai
163169
164- model_info = genai .get_model ("tunedModels/my-increment-model" )
170+ model_info = genai .get_model (self . tuned_model_name )
165171
166172 # [START_EXCLUDE]
167173 for p in model_info .permissions .list ():
@@ -190,7 +196,7 @@ def test_tuned_models_permissions_get(self):
190196 # [START tuned_models_permissions_get]
191197 import google .generativeai as genai
192198
193- model_info = genai .get_model ("tunedModels/my-increment-model" )
199+ model_info = genai .get_model (self . tuned_model_name )
194200
195201 # [START_EXCLUDE]
196202 for p in model_info .permissions .list ():
@@ -214,7 +220,7 @@ def test_tuned_models_permissions_update(self):
214220 # [START tuned_models_permissions_update]
215221 import google .generativeai as genai
216222
217- model_info = genai .get_model ("tunedModels/my-increment-model" )
223+ model_info = genai .get_model (self . tuned_model_name )
218224
219225 # [START_EXCLUDE]
220226 for p in model_info .permissions .list ():
@@ -235,7 +241,7 @@ def test_tuned_models_permission_delete(self):
235241 # [START tuned_models_permissions_delete]
236242 import google .generativeai as genai
237243
238- model_info = genai .get_model ("tunedModels/my-increment-model" )
244+ model_info = genai .get_model (self . tuned_model_name )
239245 # [START_EXCLUDE]
240246 for p in model_info .permissions .list ():
241247 if p .role .name != "OWNER" :
0 commit comments