Skip to content

escalante-bio/esmj

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Barebones translation of ESMC to JAX/equinox.

from esmj import from_torch
import equinox as eqx
import numpy as np

# load torch model
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig
client = ESMC.from_pretrained("esmc_300m").to("cpu")


# demo
prot_seq = "ESCALANTE"

# torch prediction
protein = ESMProtein(sequence=prot_seq)
protein_tensor = client.encode(protein)
torch_output = client.logits(
   protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)

# translate model to JAX
eqx_model = from_torch(client)
tokens = eqx_model.tokenize(prot_seq)
# jit the model
eqx_model = eqx.filter_jit(eqx_model)
# run it
output = eqx_model(tokens[None]) # add batch dimension

print(np.abs(output.logits - np.array(torch_output.logits.sequence)).max())
# close enough, maybe!

This project should be installable using uv.

About

jax translation of esmc

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages