Protein design tasks almost always involve multiple constraints or properies that must be satisfied or optimized. For instance, in binder design one may want to simultaneously ensure:
- the chance of binding the intended target is high
- the chance of binding to a similar off-target protein is low
- the binder expresses well in bacteria
- the binder is highly soluble.
There has been a recent explosion in the application of machine learning to protein property prediction, resulting in fairly accurate predictors for each of these properties. What is currently lacking is an efficient and flexible method for combining these different predictors into one design/filtering/ranking framework.
We recommend using uv
, e.g. run uv sync --group jax-cuda
after cloning the repo to install dependencies.
To run the example notebook try source .venv/bin/activate
, marimo edit examples/example_notebook.py
.
You may need to add various
uv
overrides for specific packages and your machine, take a look at pyproject.toml
You'll need a GPU or TPU-compatible version of JAX for structure prediction. You might need to install this manually, i.e.
uv add jax[cuda12].
To automatically download the AF2 weights you'll need to install aria2
: apt-get install aria2
.
This project combines two simple components to make a powerful protein design framework:
- Gradient-based optimization over a continuous, relaxed sequence space (as in ColabDesign, RSO, BindCraft, etc)
- A functional, modular interface to easily combine multiple learned or hand-crafted loss terms and optimization algorithms (as in A high-level programming language for generative protein design etc)
The key observation is that it's possible to use this continuous relaxation simultaneously with multiple learned objective terms 1.
This allows us to easily construct objective functions that are combinations of multiple learned potentials and optimize them efficiently, like so:
combined_loss = (
Boltz1Loss(
model=model,
name="ART2B",
loss=4 * sp.BinderTargetContact()
+ sp.RadiusOfGyration(target_radius=15.0)
+ sp.WithinBinderContact()
+ 0.3 * sp.HelixLoss()
+ ProteinMPNNLoss(mpnn, num_samples = 8),
features=boltz_features,
recycling_steps=0,
)
+ 0.5 * esm_loss
+ trigram_ll
+ 0.1 * StabilityModel.from_pretrained(esm)
+ 0.5
* Boltz1Loss(
model=model,
name="mono",
loss=0.2 * sp.PLDDTLoss()
+ sp.RadiusOfGyration(target_radius=15.0)
+ 0.3 * sp.HelixLoss(),
features=monomer_features,
recycling_steps=0,
)
)
_, logits_combined_objective = simplex_APGM(
loss_function=combined_loss,
n_steps=150,
x=np.random.randn(binder_length, 20) * 0.1,
stepsize=0.1,
)
Here we're using ~5 different models to construct a loss function: the Boltz-1 structure prediction model (which is used twice: once to predict the binder-target complex and once to predict the binder as a monomer), ESM2, ProteinMPNN, an n-gram model, and a stability model trained on the mega-scale dataset.
It's super easy to define additional loss terms, which are JIT-compatible callable pytrees, e.g.
class LogPCysteine(LossTerm):
def __call__(self, soft_sequence: Float[Array, "N 20"], key = None):
mean_log_p = jnp.log(soft_sequence[:, IDX_CYS] + 1E-8).mean()
return mean_log_p, {"log_p_cys": mean_log_p}
There's no reason custom loss terms can't involve more expensive (differentiable) operations, e.g. running ProteinX, or an EVOLVEpro-style fitness predictor.
The marimo notebook gives a few examples of how this can work.
WARNING: ColabDesign, BindCraft, etc are well-tested and well-tuned methods for very specific problems.
mosaic
may require substantial hand-holding to work (tuning learning rates, etc), often produces proteins that fail simple in-silico tests, must be combined with standard filtering methods, etc. This is not for the faint of heart: the intent is to provide a framework in which to implement custom objective functions and optimization algorithms for your application.
It's very easy to swap in different optimizers. For instance, let's say we really wanted to try projected gradient descent on the hypercube
def RSO_box(
*,
loss_function,
x: Float[Array, "N 20"],
n_steps: int,
optim=optax.chain(optax.clip_by_global_norm(1.0), optax.sgd(1e-1)),
key=None,
):
if key is None:
key = jax.random.PRNGKey(np.random.randint(0, 10000))
opt_state = optim.init(x)
for _iter in range(n_steps):
(v, aux), g = _eval_loss_and_grad(
x=x,
loss_function=loss_function,
key=key
)
updates, opt_state = optim.update(g, opt_state, x)
x = optax.apply_updates(x, updates).clip(0,1)
key = jax.random.fold_in(key, 0)
_print_iter(_iter, aux, v)
return x
Take a look at optimizers.py for a few examples of different optimizers.
Included models |
---|
Boltz-1 |
Boltz-2 |
AlphaFold2 |
Protenix (mini+tiny) |
ProteinMPNN |
ESM |
stability |
AbLang |
trigram |
We provide a simple interface in mosaic.structure_prediction
and mosaic.models.*
to five structure prediction models: Boltz1
, Boltz2
, AF2
, ProtenixMini,
and ProtenixTiny.
To make a prediction or design a binder, you'll need to make a list of mosaic.structure_prediction.TargetChain
objects. These is a simple dataclasses that a protein (or DNA or RNA) sequence, a flag to tell the model if it should use MSAs (use_msa
), and potentially a template structure.
For example, we can make a prediction with Protenix for IL7Ra like so:
import jax
from mosaic.structure_prediction import TargetChain
from mosaic.models.protenix import ProtenixMini
model = ProtenixMini()
target_sequence = "DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKCLNFRKLQEIYFIETKKFLLIGKSNICVKVGEKSLTCKKIDLTTIVKPEAPFDLSVVYREGANDFVVTFNTSHLQKKYVKVLMHDVAYRQEKDENKWTHVNLSSTKLTLLQRKLQPAAMYEIKVRSIPDHYFKGFWSEWSPSYYFRT"
# generate features and a "writer" object that turns model output into a prediction wrapper
target_only_features, target_only_structure = model.target_only_features(
[TargetChain(target_sequence)]
)
prediction = model.predict(
features=target_only_features,
writer=target_only_structure,
key=jax.random.key(0),
recycling_steps=10,
)
# prediction contains useful properties like `prediction.st`, `prediction.pae` etc.
This interface is the same for all structure prediction models, so in theory we should be able to replace ProtenixMini
above with Boltz2
by changing only a single line of code!
We also define a collection of (model agnostic!) structure prediction related losses here. It's super easy to define your own using the provided interface.
Internally we distinguish between three classes of losses: those that rely only on the trunk, structure module, or confidence module. For computational efficiency we only run the structure module or confidence module if required!
Continuing the example above, we can construct a loss and do design as follows:
import mosaic.losses.structure_prediction as sp
binder_length = 80
design_features, design_structure = protenix.binder_features(
binder_length = binder_length, chains = [TargetChain(target_sequence)]
)
loss = protenix.build_loss(
loss=sp.BinderTargetContact() + sp.WithinBinderContact(), features=design_features, recycling_steps = 3
)
PSSM = jax.nn.softmax(
0.5
* jax.random.gumbel(
key=jax.random.key(np.random.randint(100000)),
shape=(binder_length, 20),
)
)
PSSM,_ = simplex_APGM(
loss_function=loss,
x=PSSM,
n_steps=50,
stepsize=0.15,
momentum=0.3,
)
Every structure prediction model also supports a low-level loss/interface if you'd like to do something fancy (e.g. design a protein binder against a small molecule with Boltz or Protenix).
See protenij.py for an example of how to use this family of models. This loss function supports some advanced features to speed up hallucination, namely "pre-cycling" (running multiple recycling iterations on the target alone before design) and "co-cycling" (running recycling and optimization steps in parallel), but can also be used analogously to Boltz or AF2.
Load your prefered ProteinMPNN (soluble or vanilla) model using
from mosaic.proteinmpnn.mpnn import ProteinMPNN
mpnn = ProteinMPNN.from_pretrained()
In the simplest case we have a single-chain structure or complex where the protein we're designing occurs as the first chain (note this can be a prediction). We can then construct the (negative) log-likelihood of the designed sequence under ProteinMPNN as a loss term:
inverse_folding_LL = FixedStructureInverseFoldingLL.from_structure( gemmi.read_structure("scaffold.pdb"), mpnn)
This can then be added to whatever overall loss function you're constructing.
Note that it is often helpful to clip the loss using, e.g., ClippedLoss(inverse_folding_LL, 2, 100)
: over-optimizing ProteinMPNN likelihoods typically results in homopolymers.
ProteinMPNN can also be combined with live structure predictions. Mathematically this is
ProteinMPNNLoss.
Another very useful loss term is InverseFoldingSequenceRecovery
: a continuous relaxation of sequence recovery after sampling with ProteinMPNN (roughly
Warning: due to python issues, it's impossible to use both ESM2 and ESMC in the same environment.
Another useful loss term is the pseudolikelihood of the ESM2 protein language model (via esm2quinox); which is correlated with all kinds of useful properties (solubility, expressibility, etc).
This term can be constructed as follows:
import esm
import esm2quinox
torch_model, _ = esm.pretrained.esm2_t33_650M_UR50D()
ESM2PLL = ESM2PseudoLikelihood(esm2quinox.from_torch(torch_model))
In typical practice this loss should be clipped or squashed to avoid over-optimization (e.g. ClippedLoss(ESM2PLL, 2, 100)
).
We also implement the corresponding loss for ESMC (via esmj).
from esmj import from_torch
from esm.models.esmc import ESMC as TORCH_ESMC
esm = from_torch(TORCH_ESMC.from_pretrained("esmc_300m").to("cpu"))
ESMCPLL = ESMCPseudoLikelihood(esm)
A simple delta G predictor trained on the megascale dataset. Might be a nice example of how to train and add a simple regression head on a small amount of data: train.py.
stability_loss = StabilityModel.from_pretrained(esm)
AbLang, a family of antibody-specific language models.
import ablang
import jablang
heavy_ablang = ablang.pretrained("heavy")
heavy_ablang.freeze()
abpll = AbLangPseudoLikelihood(
model=jablang.from_torch(heavy_ablang.AbLang),
tokenizer=heavy_ablang.tokenizer,
stop_grad=True,
)
A trigram language model as in A high-level programming language for generative protein design.
trigram_ll = TrigramLL.from_pkl()
We include some standard [optimizers] in (src/mosaic/optimizers.py).
First, simplex_APGM,
which is an accelerated proximal gradient algorithm on the probability simplex. One critical hyperparameter is the stepsize, a reasonable first guess is 0.1*np.sqrt(binder_length)
. Another useful keyword argument is scale
, which corresponds to 1.0
encorage sparse solutions; a typical binder design run might start with scale=1.0
to get an initial, soft solution and then ramp up to something higher to get a discrete solution.
simplex_APGM
also accepts a keyword argument, logspace,
if this is set to true we run the algorithm in logspace, which corresponds to an accelerated proximal bregman method. In this case scale
corresponds to entropic regularization.
We also include a discrete optimization algorithm, gradient_MCMC
, which is a variant of MCMC with a proposal distribution defined using a taylor approximation to the objective function (see Plug & Play Directed Evolution of Proteins with Gradient-based Discrete MCMC.) This algorithm is especially useful for finetuning either existing designs or the result of continuous optimization.
We also provide a few common transformations of loss functions. Of note are ClippedLoss
, which ... wraps and clips another loss term.
SetPositions
and FixedPositionsPenalty
are useful for fixing certain positions of an existing design.
ClippedGradient
and NormedGradient
respectively clip and normalize the gradients of individual loss terms, this can be useful when combining predictors with very different gradient norms, for example:
loss = ClippedGradient(inverse_folding_LL, 1.0)
+ ClippedGradient(ablang_pll, 1.0)
+ 0.25 * ClippedGradient(ESMCPLL, 1.0)
Hallucination-based protein design workflows attempt to solve the following optimization problem:
Here
One challenge with naive approaches is that
ColabDesign, RSO, and BindCraft, among others, use the fact that
This is related to the classic optimization trick of optimizing over distributions rather than single points. First,
$\underset{x}{\textrm{minimize }}f(x)$ is relaxed to$\underset{p \in \Delta}{\textrm{minimize }}E_p f(x)$ . Next, if it makes sense to take the expectation of$x$ (as in the one-hot sequence case), we can interchange$f$ and$E$ to get the final relaxation:$$\underset{p \in \Delta}{\textrm{minimize }} f( E_p x) = \underset{p \in \Delta}{\textrm{minimize }} f(p).$$
Solutions to this relaxed optimization problem must then be translated into sequences; many different methods work here: RSO uses inverse folding of the predicted structure, BindCraft/ColabDesign uses a softmax with ramping temperature to encourage one-hot solutions, etc.
By default we use a generalized proximal gradient method (mirror descent with entropic regularization) to do optimization over the simplex and to encourage sparse solutions, though it is very easy to swap in other optimization algorithms (e.g. projected gradient descent or composition with a softmax as in ColabDesign).
Typically
This kind of modular implementation of loss terms is also useful with modern RL-based alignment of generative models approaches: these forms of alignment can often be seen as amortized optimization. Typically, they train a generative model to minimize some combination of KL divergence minus a loss function, which can be a combination of in-silico predictors. Another use case is to provide guidance to discrete diffusion or flow models.
Footnotes
-
This requires us to treat neural networks as simple parametric functions that can be combined programatically; not as complicated software packages that require large libraries (e.g. PyTorch lightning), bash scripts, or containers as is common practice in BioML. ↩