Skip to content

Commit 17b01c6

Browse files
feat: [vertexai] add generateContentAsync methods to GenerativeModel (#10599)
PiperOrigin-RevId: 617951189 Co-authored-by: Jaycee Li <jayceeli@google.com>
1 parent 5c3d93e commit 17b01c6

File tree

5 files changed

+467
-163
lines changed

5 files changed

+467
-163
lines changed

‎java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java‎

Lines changed: 89 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package com.google.cloud.vertexai;
1818

19+
import static com.google.common.base.Preconditions.checkArgument;
20+
import static com.google.common.base.Preconditions.checkNotNull;
21+
1922
import com.google.api.core.InternalApi;
2023
import com.google.api.gax.core.CredentialsProvider;
2124
import com.google.api.gax.core.FixedCredentialsProvider;
@@ -28,8 +31,10 @@
2831
import com.google.cloud.vertexai.api.LlmUtilityServiceSettings;
2932
import com.google.cloud.vertexai.api.PredictionServiceClient;
3033
import com.google.cloud.vertexai.api.PredictionServiceSettings;
34+
import com.google.common.base.Strings;
3135
import java.io.IOException;
3236
import java.util.List;
37+
import java.util.concurrent.locks.ReentrantLock;
3338
import java.util.logging.Level;
3439
import java.util.logging.Logger;
3540

@@ -56,9 +61,8 @@ public class VertexAI implements AutoCloseable {
5661
private Transport transport = Transport.GRPC;
5762
// The clients will be instantiated lazily
5863
private PredictionServiceClient predictionServiceClient = null;
59-
private PredictionServiceClient predictionServiceRestClient = null;
6064
private LlmUtilityServiceClient llmUtilityClient = null;
61-
private LlmUtilityServiceClient llmUtilityRestClient = null;
65+
private final ReentrantLock lock = new ReentrantLock();
6266

6367
/**
6468
* Construct a VertexAI instance.
@@ -193,32 +197,35 @@ public Credentials getCredentials() throws IOException {
193197

194198
/** Sets the value for {@link #getTransport()}. */
195199
public void setTransport(Transport transport) {
200+
checkNotNull(transport, "Transport can't be null.");
201+
if (this.transport == transport) {
202+
return;
203+
}
204+
196205
this.transport = transport;
206+
resetClients();
197207
}
198208

199209
/** Sets the value for {@link #getApiEndpoint()}. */
200210
public void setApiEndpoint(String apiEndpoint) {
211+
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "Api endpoint can't be null or empty.");
212+
if (this.apiEndpoint == apiEndpoint) {
213+
return;
214+
}
201215
this.apiEndpoint = apiEndpoint;
216+
resetClients();
217+
}
202218

219+
private void resetClients() {
203220
if (this.predictionServiceClient != null) {
204221
this.predictionServiceClient.close();
205222
this.predictionServiceClient = null;
206223
}
207224

208-
if (this.predictionServiceRestClient != null) {
209-
this.predictionServiceRestClient.close();
210-
this.predictionServiceRestClient = null;
211-
}
212-
213225
if (this.llmUtilityClient != null) {
214226
this.llmUtilityClient.close();
215227
this.llmUtilityClient = null;
216228
}
217-
218-
if (this.llmUtilityRestClient != null) {
219-
this.llmUtilityRestClient.close();
220-
this.llmUtilityRestClient = null;
221-
}
222229
}
223230

224231
/**
@@ -230,78 +237,47 @@ public void setApiEndpoint(String apiEndpoint) {
230237
*/
231238
@InternalApi
232239
public PredictionServiceClient getPredictionServiceClient() throws IOException {
233-
if (this.transport == Transport.GRPC) {
234-
return getPredictionServiceGrpcClient();
235-
} else {
236-
return getPredictionServiceRestClient();
240+
if (predictionServiceClient != null) {
241+
return predictionServiceClient;
237242
}
238-
}
239-
240-
/**
241-
* Returns the {@link PredictionServiceClient} with GRPC. The client will be instantiated when the
242-
* first prediction API call is made.
243-
*
244-
* @return {@link PredictionServiceClient} that send GRPC requests to the backing service through
245-
* method calls that map to the API methods.
246-
*/
247-
private PredictionServiceClient getPredictionServiceGrpcClient() throws IOException {
248-
if (predictionServiceClient == null) {
249-
PredictionServiceSettings.Builder settingsBuilder = PredictionServiceSettings.newBuilder();
250-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
251-
if (this.credentialsProvider != null) {
252-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
243+
lock.lock();
244+
try {
245+
if (predictionServiceClient == null) {
246+
PredictionServiceSettings settings = getPredictionServiceSettings();
247+
// Disable the warning message logged in getApplicationDefault
248+
Logger defaultCredentialsProviderLogger =
249+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
250+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
251+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
252+
predictionServiceClient = PredictionServiceClient.create(settings);
253+
defaultCredentialsProviderLogger.setLevel(previousLevel);
253254
}
254-
HeaderProvider headerProvider =
255-
FixedHeaderProvider.create(
256-
"user-agent",
257-
String.format(
258-
"%s/%s",
259-
Constants.USER_AGENT_HEADER,
260-
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
261-
settingsBuilder.setHeaderProvider(headerProvider);
262-
// Disable the warning message logged in getApplicationDefault
263-
Logger defaultCredentialsProviderLogger =
264-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
265-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
266-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
267-
predictionServiceClient = PredictionServiceClient.create(settingsBuilder.build());
268-
defaultCredentialsProviderLogger.setLevel(previousLevel);
255+
return predictionServiceClient;
256+
} finally {
257+
lock.unlock();
269258
}
270-
return predictionServiceClient;
271259
}
272260

273-
/**
274-
* Returns the {@link PredictionServiceClient} with REST. The client will be instantiated when the
275-
* first prediction API call is made.
276-
*
277-
* @return {@link PredictionServiceClient} that send REST requests to the backing service through
278-
* method calls that map to the API methods.
279-
*/
280-
private PredictionServiceClient getPredictionServiceRestClient() throws IOException {
281-
if (predictionServiceRestClient == null) {
282-
PredictionServiceSettings.Builder settingsBuilder =
283-
PredictionServiceSettings.newHttpJsonBuilder();
284-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
285-
if (this.credentialsProvider != null) {
286-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
287-
}
288-
HeaderProvider headerProvider =
289-
FixedHeaderProvider.create(
290-
"user-agent",
291-
String.format(
292-
"%s/%s",
293-
Constants.USER_AGENT_HEADER,
294-
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
295-
settingsBuilder.setHeaderProvider(headerProvider);
296-
// Disable the warning message logged in getApplicationDefault
297-
Logger defaultCredentialsProviderLogger =
298-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
299-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
300-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
301-
predictionServiceRestClient = PredictionServiceClient.create(settingsBuilder.build());
302-
defaultCredentialsProviderLogger.setLevel(previousLevel);
261+
private PredictionServiceSettings getPredictionServiceSettings() throws IOException {
262+
PredictionServiceSettings.Builder builder;
263+
if (transport == Transport.REST) {
264+
builder = PredictionServiceSettings.newHttpJsonBuilder();
265+
} else {
266+
builder = PredictionServiceSettings.newBuilder();
267+
}
268+
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
269+
if (this.credentialsProvider != null) {
270+
builder.setCredentialsProvider(this.credentialsProvider);
303271
}
304-
return predictionServiceRestClient;
272+
HeaderProvider headerProvider =
273+
FixedHeaderProvider.create(
274+
"user-agent",
275+
String.format(
276+
"%s/%s",
277+
Constants.USER_AGENT_HEADER,
278+
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
279+
builder.setHeaderProvider(headerProvider);
280+
return builder.build();
305281
}
306282

307283
/**
@@ -313,78 +289,47 @@ private PredictionServiceClient getPredictionServiceRestClient() throws IOExcept
313289
*/
314290
@InternalApi
315291
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
316-
if (this.transport == Transport.GRPC) {
317-
return getLlmUtilityGrpcClient();
318-
} else {
319-
return getLlmUtilityRestClient();
292+
if (llmUtilityClient != null) {
293+
return llmUtilityClient;
320294
}
321-
}
322-
323-
/**
324-
* Returns the {@link LlmUtilityServiceClient} with GRPC. The client will be instantiated when the
325-
* first API call is made.
326-
*
327-
* @return {@link LlmUtilityServiceClient} that makes gRPC calls to the backing service through
328-
* method calls that map to the API methods.
329-
*/
330-
private LlmUtilityServiceClient getLlmUtilityGrpcClient() throws IOException {
331-
if (llmUtilityClient == null) {
332-
LlmUtilityServiceSettings.Builder settingsBuilder = LlmUtilityServiceSettings.newBuilder();
333-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
334-
if (this.credentialsProvider != null) {
335-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
295+
lock.lock();
296+
try {
297+
if (llmUtilityClient == null) {
298+
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
299+
// Disable the warning message logged in getApplicationDefault
300+
Logger defaultCredentialsProviderLogger =
301+
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
302+
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
303+
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
304+
llmUtilityClient = LlmUtilityServiceClient.create(settings);
305+
defaultCredentialsProviderLogger.setLevel(previousLevel);
336306
}
337-
HeaderProvider headerProvider =
338-
FixedHeaderProvider.create(
339-
"user-agent",
340-
String.format(
341-
"%s/%s",
342-
Constants.USER_AGENT_HEADER,
343-
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
344-
settingsBuilder.setHeaderProvider(headerProvider);
345-
// Disable the warning message logged in getApplicationDefault
346-
Logger defaultCredentialsProviderLogger =
347-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
348-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
349-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
350-
llmUtilityClient = LlmUtilityServiceClient.create(settingsBuilder.build());
351-
defaultCredentialsProviderLogger.setLevel(previousLevel);
307+
return llmUtilityClient;
308+
} finally {
309+
lock.unlock();
352310
}
353-
return llmUtilityClient;
354311
}
355312

356-
/**
357-
* Returns the {@link LlmUtilityServiceClient} with REST. The client will be instantiated when the
358-
* first API call is made.
359-
*
360-
* @return {@link LlmUtilityServiceClient} that makes REST requests to the backing service through
361-
* method calls that map to the API methods.
362-
*/
363-
private LlmUtilityServiceClient getLlmUtilityRestClient() throws IOException {
364-
if (llmUtilityRestClient == null) {
365-
LlmUtilityServiceSettings.Builder settingsBuilder =
366-
LlmUtilityServiceSettings.newHttpJsonBuilder();
367-
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
368-
if (this.credentialsProvider != null) {
369-
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
370-
}
371-
HeaderProvider headerProvider =
372-
FixedHeaderProvider.create(
373-
"user-agent",
374-
String.format(
375-
"%s/%s",
376-
Constants.USER_AGENT_HEADER,
377-
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
378-
settingsBuilder.setHeaderProvider(headerProvider);
379-
// Disable the warning message logged in getApplicationDefault
380-
Logger defaultCredentialsProviderLogger =
381-
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
382-
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
383-
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
384-
llmUtilityRestClient = LlmUtilityServiceClient.create(settingsBuilder.build());
385-
defaultCredentialsProviderLogger.setLevel(previousLevel);
313+
private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IOException {
314+
LlmUtilityServiceSettings.Builder settingsBuilder;
315+
if (transport == Transport.REST) {
316+
settingsBuilder = LlmUtilityServiceSettings.newHttpJsonBuilder();
317+
} else {
318+
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
319+
}
320+
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
321+
if (this.credentialsProvider != null) {
322+
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
386323
}
387-
return llmUtilityRestClient;
324+
HeaderProvider headerProvider =
325+
FixedHeaderProvider.create(
326+
"user-agent",
327+
String.format(
328+
"%s/%s",
329+
Constants.USER_AGENT_HEADER,
330+
GaxProperties.getLibraryVersion(LlmUtilityServiceSettings.class)));
331+
settingsBuilder.setHeaderProvider(headerProvider);
332+
return settingsBuilder.build();
388333
}
389334

390335
/** Closes the VertexAI instance together with all its instantiated clients. */
@@ -393,14 +338,8 @@ public void close() {
393338
if (predictionServiceClient != null) {
394339
predictionServiceClient.close();
395340
}
396-
if (predictionServiceRestClient != null) {
397-
predictionServiceRestClient.close();
398-
}
399341
if (llmUtilityClient != null) {
400342
llmUtilityClient.close();
401343
}
402-
if (llmUtilityRestClient != null) {
403-
llmUtilityRestClient.close();
404-
}
405344
}
406345
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.cloud.vertexai.generativeai;
17+
18+
import static com.google.common.base.Preconditions.checkArgument;
19+
import static com.google.common.base.Preconditions.checkNotNull;
20+
21+
import com.google.cloud.vertexai.api.FunctionDeclaration;
22+
import com.google.common.base.Strings;
23+
import com.google.gson.JsonObject;
24+
import com.google.protobuf.InvalidProtocolBufferException;
25+
import com.google.protobuf.util.JsonFormat;
26+
27+
/** Helper class to create {@link com.google.cloud.vertexai.api.FunctionDeclaration} */
28+
public final class FunctionDeclarationMaker {
29+
30+
/**
31+
* Creates a FunctionDeclaration from a JsonString
32+
*
33+
* @param jsonString A valid Json String that can be parsed to a FunctionDeclaration.
34+
* @return a {@link FunctionDeclaration} by parsing the input json String.
35+
* @throws InvalidProtocolBufferException if the String can't be parsed into a FunctionDeclaration
36+
* proto.
37+
*/
38+
public static FunctionDeclaration fromJsonString(String jsonString)
39+
throws InvalidProtocolBufferException {
40+
checkArgument(!Strings.isNullOrEmpty(jsonString), "Input String can't be null or empty.");
41+
FunctionDeclaration.Builder builder = FunctionDeclaration.newBuilder();
42+
JsonFormat.parser().merge(jsonString, builder);
43+
FunctionDeclaration declaration = builder.build();
44+
if (declaration.getName().isEmpty()) {
45+
throw new IllegalArgumentException("name field must be present.");
46+
}
47+
return declaration;
48+
}
49+
50+
/**
51+
* Creates a FunctionDeclaration from a JsonObject
52+
*
53+
* @param jsonObject A valid Json Object that can be parsed to a FunctionDeclaration.
54+
* @return a {@link FunctionDeclaration} by parsing the input json Object.
55+
* @throws InvalidProtocolBufferException if the jsonObject can't be parsed into a
56+
* FunctionDeclaration proto.
57+
*/
58+
public static FunctionDeclaration fromJsonObject(JsonObject jsonObject)
59+
throws InvalidProtocolBufferException {
60+
checkNotNull(jsonObject, "JsonObject can't be null.");
61+
return fromJsonString(jsonObject.toString());
62+
}
63+
}

0 commit comments

Comments
 (0)