16
16
17
17
package com .google .cloud .vertexai ;
18
18
19
+ import static com .google .common .base .Preconditions .checkArgument ;
20
+ import static com .google .common .base .Preconditions .checkNotNull ;
21
+
19
22
import com .google .api .core .InternalApi ;
20
23
import com .google .api .gax .core .CredentialsProvider ;
21
24
import com .google .api .gax .core .FixedCredentialsProvider ;
28
31
import com .google .cloud .vertexai .api .LlmUtilityServiceSettings ;
29
32
import com .google .cloud .vertexai .api .PredictionServiceClient ;
30
33
import com .google .cloud .vertexai .api .PredictionServiceSettings ;
34
+ import com .google .common .base .Strings ;
31
35
import java .io .IOException ;
32
36
import java .util .List ;
37
+ import java .util .concurrent .locks .ReentrantLock ;
33
38
import java .util .logging .Level ;
34
39
import java .util .logging .Logger ;
35
40
@@ -56,9 +61,8 @@ public class VertexAI implements AutoCloseable {
56
61
private Transport transport = Transport .GRPC ;
57
62
// The clients will be instantiated lazily
58
63
private PredictionServiceClient predictionServiceClient = null ;
59
- private PredictionServiceClient predictionServiceRestClient = null ;
60
64
private LlmUtilityServiceClient llmUtilityClient = null ;
61
- private LlmUtilityServiceClient llmUtilityRestClient = null ;
65
+ private final ReentrantLock lock = new ReentrantLock () ;
62
66
63
67
/**
64
68
* Construct a VertexAI instance.
@@ -193,32 +197,35 @@ public Credentials getCredentials() throws IOException {
193
197
194
198
/** Sets the value for {@link #getTransport()}. */
195
199
public void setTransport (Transport transport ) {
200
+ checkNotNull (transport , "Transport can't be null." );
201
+ if (this .transport == transport ) {
202
+ return ;
203
+ }
204
+
196
205
this .transport = transport ;
206
+ resetClients ();
197
207
}
198
208
199
209
/** Sets the value for {@link #getApiEndpoint()}. */
200
210
public void setApiEndpoint (String apiEndpoint ) {
211
+ checkArgument (!Strings .isNullOrEmpty (apiEndpoint ), "Api endpoint can't be null or empty." );
212
+ if (this .apiEndpoint == apiEndpoint ) {
213
+ return ;
214
+ }
201
215
this .apiEndpoint = apiEndpoint ;
216
+ resetClients ();
217
+ }
202
218
219
+ private void resetClients () {
203
220
if (this .predictionServiceClient != null ) {
204
221
this .predictionServiceClient .close ();
205
222
this .predictionServiceClient = null ;
206
223
}
207
224
208
- if (this .predictionServiceRestClient != null ) {
209
- this .predictionServiceRestClient .close ();
210
- this .predictionServiceRestClient = null ;
211
- }
212
-
213
225
if (this .llmUtilityClient != null ) {
214
226
this .llmUtilityClient .close ();
215
227
this .llmUtilityClient = null ;
216
228
}
217
-
218
- if (this .llmUtilityRestClient != null ) {
219
- this .llmUtilityRestClient .close ();
220
- this .llmUtilityRestClient = null ;
221
- }
222
229
}
223
230
224
231
/**
@@ -230,78 +237,47 @@ public void setApiEndpoint(String apiEndpoint) {
230
237
*/
231
238
@ InternalApi
232
239
public PredictionServiceClient getPredictionServiceClient () throws IOException {
233
- if (this .transport == Transport .GRPC ) {
234
- return getPredictionServiceGrpcClient ();
235
- } else {
236
- return getPredictionServiceRestClient ();
240
+ if (predictionServiceClient != null ) {
241
+ return predictionServiceClient ;
237
242
}
238
- }
239
-
240
- /**
241
- * Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242
- * first prediction API call is made.
243
- *
244
- * @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245
- * method calls that map to the API methods.
246
- */
247
- private PredictionServiceClient getPredictionServiceGrpcClient () throws IOException {
248
- if (predictionServiceClient == null ) {
249
- PredictionServiceSettings .Builder settingsBuilder = PredictionServiceSettings .newBuilder ();
250
- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
251
- if (this .credentialsProvider != null ) {
252
- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
243
+ lock .lock ();
244
+ try {
245
+ if (predictionServiceClient == null ) {
246
+ PredictionServiceSettings settings = getPredictionServiceSettings ();
247
+ // Disable the warning message logged in getApplicationDefault
248
+ Logger defaultCredentialsProviderLogger =
249
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
250
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
251
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
252
+ predictionServiceClient = PredictionServiceClient .create (settings );
253
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
253
254
}
254
- HeaderProvider headerProvider =
255
- FixedHeaderProvider .create (
256
- "user-agent" ,
257
- String .format (
258
- "%s/%s" ,
259
- Constants .USER_AGENT_HEADER ,
260
- GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
261
- settingsBuilder .setHeaderProvider (headerProvider );
262
- // Disable the warning message logged in getApplicationDefault
263
- Logger defaultCredentialsProviderLogger =
264
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
265
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
266
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
267
- predictionServiceClient = PredictionServiceClient .create (settingsBuilder .build ());
268
- defaultCredentialsProviderLogger .setLevel (previousLevel );
255
+ return predictionServiceClient ;
256
+ } finally {
257
+ lock .unlock ();
269
258
}
270
- return predictionServiceClient ;
271
259
}
272
260
273
- /**
274
- * Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275
- * first prediction API call is made.
276
- *
277
- * @return {@link PredictionServiceClient} that send REST requests to the backing service through
278
- * method calls that map to the API methods.
279
- */
280
- private PredictionServiceClient getPredictionServiceRestClient () throws IOException {
281
- if (predictionServiceRestClient == null ) {
282
- PredictionServiceSettings .Builder settingsBuilder =
283
- PredictionServiceSettings .newHttpJsonBuilder ();
284
- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
285
- if (this .credentialsProvider != null ) {
286
- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
287
- }
288
- HeaderProvider headerProvider =
289
- FixedHeaderProvider .create (
290
- "user-agent" ,
291
- String .format (
292
- "%s/%s" ,
293
- Constants .USER_AGENT_HEADER ,
294
- GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
295
- settingsBuilder .setHeaderProvider (headerProvider );
296
- // Disable the warning message logged in getApplicationDefault
297
- Logger defaultCredentialsProviderLogger =
298
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
299
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
300
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
301
- predictionServiceRestClient = PredictionServiceClient .create (settingsBuilder .build ());
302
- defaultCredentialsProviderLogger .setLevel (previousLevel );
261
+ private PredictionServiceSettings getPredictionServiceSettings () throws IOException {
262
+ PredictionServiceSettings .Builder builder ;
263
+ if (transport == Transport .REST ) {
264
+ builder = PredictionServiceSettings .newHttpJsonBuilder ();
265
+ } else {
266
+ builder = PredictionServiceSettings .newBuilder ();
267
+ }
268
+ builder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
269
+ if (this .credentialsProvider != null ) {
270
+ builder .setCredentialsProvider (this .credentialsProvider );
303
271
}
304
- return predictionServiceRestClient ;
272
+ HeaderProvider headerProvider =
273
+ FixedHeaderProvider .create (
274
+ "user-agent" ,
275
+ String .format (
276
+ "%s/%s" ,
277
+ Constants .USER_AGENT_HEADER ,
278
+ GaxProperties .getLibraryVersion (PredictionServiceSettings .class )));
279
+ builder .setHeaderProvider (headerProvider );
280
+ return builder .build ();
305
281
}
306
282
307
283
/**
@@ -313,78 +289,47 @@ private PredictionServiceClient getPredictionServiceRestClient() throws IOExcept
313
289
*/
314
290
@ InternalApi
315
291
public LlmUtilityServiceClient getLlmUtilityClient () throws IOException {
316
- if (this .transport == Transport .GRPC ) {
317
- return getLlmUtilityGrpcClient ();
318
- } else {
319
- return getLlmUtilityRestClient ();
292
+ if (llmUtilityClient != null ) {
293
+ return llmUtilityClient ;
320
294
}
321
- }
322
-
323
- /**
324
- * Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325
- * first API call is made.
326
- *
327
- * @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328
- * method calls that map to the API methods.
329
- */
330
- private LlmUtilityServiceClient getLlmUtilityGrpcClient () throws IOException {
331
- if (llmUtilityClient == null ) {
332
- LlmUtilityServiceSettings .Builder settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
333
- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
334
- if (this .credentialsProvider != null ) {
335
- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
295
+ lock .lock ();
296
+ try {
297
+ if (llmUtilityClient == null ) {
298
+ LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings ();
299
+ // Disable the warning message logged in getApplicationDefault
300
+ Logger defaultCredentialsProviderLogger =
301
+ Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
302
+ Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
303
+ defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
304
+ llmUtilityClient = LlmUtilityServiceClient .create (settings );
305
+ defaultCredentialsProviderLogger .setLevel (previousLevel );
336
306
}
337
- HeaderProvider headerProvider =
338
- FixedHeaderProvider .create (
339
- "user-agent" ,
340
- String .format (
341
- "%s/%s" ,
342
- Constants .USER_AGENT_HEADER ,
343
- GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
344
- settingsBuilder .setHeaderProvider (headerProvider );
345
- // Disable the warning message logged in getApplicationDefault
346
- Logger defaultCredentialsProviderLogger =
347
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
348
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
349
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
350
- llmUtilityClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
351
- defaultCredentialsProviderLogger .setLevel (previousLevel );
307
+ return llmUtilityClient ;
308
+ } finally {
309
+ lock .unlock ();
352
310
}
353
- return llmUtilityClient ;
354
311
}
355
312
356
- /**
357
- * Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358
- * first API call is made.
359
- *
360
- * @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361
- * method calls that map to the API methods.
362
- */
363
- private LlmUtilityServiceClient getLlmUtilityRestClient () throws IOException {
364
- if (llmUtilityRestClient == null ) {
365
- LlmUtilityServiceSettings .Builder settingsBuilder =
366
- LlmUtilityServiceSettings .newHttpJsonBuilder ();
367
- settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
368
- if (this .credentialsProvider != null ) {
369
- settingsBuilder .setCredentialsProvider (this .credentialsProvider );
370
- }
371
- HeaderProvider headerProvider =
372
- FixedHeaderProvider .create (
373
- "user-agent" ,
374
- String .format (
375
- "%s/%s" ,
376
- Constants .USER_AGENT_HEADER ,
377
- GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
378
- settingsBuilder .setHeaderProvider (headerProvider );
379
- // Disable the warning message logged in getApplicationDefault
380
- Logger defaultCredentialsProviderLogger =
381
- Logger .getLogger ("com.google.auth.oauth2.DefaultCredentialsProvider" );
382
- Level previousLevel = defaultCredentialsProviderLogger .getLevel ();
383
- defaultCredentialsProviderLogger .setLevel (Level .SEVERE );
384
- llmUtilityRestClient = LlmUtilityServiceClient .create (settingsBuilder .build ());
385
- defaultCredentialsProviderLogger .setLevel (previousLevel );
313
+ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings () throws IOException {
314
+ LlmUtilityServiceSettings .Builder settingsBuilder ;
315
+ if (transport == Transport .REST ) {
316
+ settingsBuilder = LlmUtilityServiceSettings .newHttpJsonBuilder ();
317
+ } else {
318
+ settingsBuilder = LlmUtilityServiceSettings .newBuilder ();
319
+ }
320
+ settingsBuilder .setEndpoint (String .format ("%s:443" , this .apiEndpoint ));
321
+ if (this .credentialsProvider != null ) {
322
+ settingsBuilder .setCredentialsProvider (this .credentialsProvider );
386
323
}
387
- return llmUtilityRestClient ;
324
+ HeaderProvider headerProvider =
325
+ FixedHeaderProvider .create (
326
+ "user-agent" ,
327
+ String .format (
328
+ "%s/%s" ,
329
+ Constants .USER_AGENT_HEADER ,
330
+ GaxProperties .getLibraryVersion (LlmUtilityServiceSettings .class )));
331
+ settingsBuilder .setHeaderProvider (headerProvider );
332
+ return settingsBuilder .build ();
388
333
}
389
334
390
335
/** Closes the VertexAI instance together with all its instantiated clients. */
@@ -393,14 +338,8 @@ public void close() {
393
338
if (predictionServiceClient != null ) {
394
339
predictionServiceClient .close ();
395
340
}
396
- if (predictionServiceRestClient != null ) {
397
- predictionServiceRestClient .close ();
398
- }
399
341
if (llmUtilityClient != null ) {
400
342
llmUtilityClient .close ();
401
343
}
402
- if (llmUtilityRestClient != null ) {
403
- llmUtilityRestClient .close ();
404
- }
405
344
}
406
345
}
0 commit comments