Skip to content

escalante-bio/mosaic

Repository files navigation

Functional, multi-objective protein design using continuous relaxation.

Protein design tasks almost always involve multiple constraints or properies that must be satisfied or optimized. For instance, in binder design one may want to simultaneously ensure:

  • the chance of binding the intended target is high
  • the chance of binding to a similar off-target protein is low
  • the binder expresses well in bacteria
  • the binder is highly soluble.

There has been a recent explosion in the application of machine learning to protein property prediction, resulting in fairly accurate predictors for each of these properties. What is currently lacking is an efficient and flexible method for combining these different predictors into one design/filtering/ranking framework.


Installation

We recommend using uv, e.g. run uv sync --group jax-cuda after cloning the repo to install dependencies.

To run the example notebook try source .venv/bin/activate, marimo edit examples/example_notebook.py.

You may need to add various uv overrides for specific packages and your machine, take a look at pyproject.toml

You'll need a GPU or TPU-compatible version of JAX for structure prediction. You might need to install this manually, i.e. uv add jax[cuda12].

To automatically download the AF2 weights you'll need to install aria2: apt-get install aria2.

Introduction

This project combines two simple components to make a powerful protein design framework:

The key observation is that it's possible to use this continuous relaxation simultaneously with multiple learned objective terms 1.

This allows us to easily construct objective functions that are combinations of multiple learned potentials and optimize them efficiently, like so:

combined_loss = (
    Boltz1Loss(
        model=model,
        name="ART2B",
        loss=4 * sp.BinderTargetContact()
        + sp.RadiusOfGyration(target_radius=15.0)
        + sp.WithinBinderContact()
        + 0.3 * sp.HelixLoss()
        + ProteinMPNNLoss(mpnn, num_samples = 8),
        features=boltz_features,
        recycling_steps=0,
    )
    + 0.5 * esm_loss
    + trigram_ll
    + 0.1 * StabilityModel.from_pretrained(esm)
    + 0.5
    * Boltz1Loss(
        model=model,
        name="mono",
        loss=0.2 * sp.PLDDTLoss()
        + sp.RadiusOfGyration(target_radius=15.0)
        + 0.3 * sp.HelixLoss(),
        features=monomer_features,
        recycling_steps=0,
    )
)

_, logits_combined_objective = simplex_APGM(
    loss_function=combined_loss,
    n_steps=150,
    x=np.random.randn(binder_length, 20) * 0.1,
    stepsize=0.1,
)

Here we're using ~5 different models to construct a loss function: the Boltz-1 structure prediction model (which is used twice: once to predict the binder-target complex and once to predict the binder as a monomer), ESM2, ProteinMPNN, an n-gram model, and a stability model trained on the mega-scale dataset.

It's super easy to define additional loss terms, which are JIT-compatible callable pytrees, e.g.

class LogPCysteine(LossTerm):
    def __call__(self, soft_sequence: Float[Array, "N 20"], key = None):
        mean_log_p = jnp.log(soft_sequence[:, IDX_CYS] + 1E-8).mean()
        return mean_log_p, {"log_p_cys": mean_log_p}

There's no reason custom loss terms can't involve more expensive (differentiable) operations, e.g. running ProteinX, or an EVOLVEpro-style fitness predictor.

The marimo notebook gives a few examples of how this can work.

WARNING: ColabDesign, BindCraft, etc are well-tested and well-tuned methods for very specific problems. mosaic may require substantial hand-holding to work (tuning learning rates, etc), often produces proteins that fail simple in-silico tests, must be combined with standard filtering methods, etc. This is not for the faint of heart: the intent is to provide a framework in which to implement custom objective functions and optimization algorithms for your application.

It's very easy to swap in different optimizers. For instance, let's say we really wanted to try projected gradient descent on the hypercube $[0,1]^N$. We can implement that in a few lines of code:

def RSO_box(
    *,
    loss_function,
    x: Float[Array, "N 20"],
    n_steps: int,
    optim=optax.chain(optax.clip_by_global_norm(1.0), optax.sgd(1e-1)),
    key=None,
):
    if key is None:
        key = jax.random.PRNGKey(np.random.randint(0, 10000))

    opt_state = optim.init(x)
    
    for _iter in range(n_steps):
        (v, aux), g = _eval_loss_and_grad(
            x=x,
            loss_function=loss_function,
            key=key
        )
        updates, opt_state = optim.update(g, opt_state, x)
        x = optax.apply_updates(x, updates).clip(0,1)
        key = jax.random.fold_in(key, 0)
        _print_iter(_iter, aux, v)

    return x

Take a look at optimizers.py for a few examples of different optimizers.

Models and losses

Included models
Boltz-1
Boltz-2
AlphaFold2
Protenix (mini+tiny)
ProteinMPNN
ESM
stability
AbLang
trigram

Structure Prediction


We provide a simple interface in mosaic.structure_prediction and mosaic.models.* to five structure prediction models: Boltz1, Boltz2, AF2, ProtenixMini, and ProtenixTiny.

To make a prediction or design a binder, you'll need to make a list of mosaic.structure_prediction.TargetChain objects. These is a simple dataclasses that a protein (or DNA or RNA) sequence, a flag to tell the model if it should use MSAs (use_msa), and potentially a template structure.

For example, we can make a prediction with Protenix for IL7Ra like so:

import jax
from mosaic.structure_prediction import TargetChain
from mosaic.models.protenix import ProtenixMini


model = ProtenixMini()

target_sequence = "DYSFSCYSQLEVNGSQHSLTCAFEDPDVNTTNLEFEICGALVEVKCLNFRKLQEIYFIETKKFLLIGKSNICVKVGEKSLTCKKIDLTTIVKPEAPFDLSVVYREGANDFVVTFNTSHLQKKYVKVLMHDVAYRQEKDENKWTHVNLSSTKLTLLQRKLQPAAMYEIKVRSIPDHYFKGFWSEWSPSYYFRT"


# generate features and a "writer" object that turns model output into a prediction wrapper
target_only_features, target_only_structure = model.target_only_features(
    [TargetChain(target_sequence)]
)

prediction = model.predict(
    features=target_only_features,
    writer=target_only_structure,
    key=jax.random.key(0),
    recycling_steps=10,
)

# prediction contains useful properties like `prediction.st`, `prediction.pae` etc.

This interface is the same for all structure prediction models, so in theory we should be able to replace ProtenixMini above with Boltz2 by changing only a single line of code!

We also define a collection of (model agnostic!) structure prediction related losses here. It's super easy to define your own using the provided interface.

Internally we distinguish between three classes of losses: those that rely only on the trunk, structure module, or confidence module. For computational efficiency we only run the structure module or confidence module if required!

Continuing the example above, we can construct a loss and do design as follows:

import mosaic.losses.structure_prediction as sp

binder_length = 80

design_features, design_structure = protenix.binder_features(
    binder_length = binder_length, chains = [TargetChain(target_sequence)]
)

loss = protenix.build_loss(
    loss=sp.BinderTargetContact() + sp.WithinBinderContact(), features=design_features, recycling_steps = 3
)

PSSM = jax.nn.softmax(
    0.5
    * jax.random.gumbel(
        key=jax.random.key(np.random.randint(100000)),
        shape=(binder_length, 20),
    )
)


PSSM,_ = simplex_APGM(
    loss_function=loss,
    x=PSSM,
    n_steps=50,
    stepsize=0.15,
    momentum=0.3,
)

Every structure prediction model also supports a low-level loss/interface if you'd like to do something fancy (e.g. design a protein binder against a small molecule with Boltz or Protenix).

Protenix


See protenij.py for an example of how to use this family of models. This loss function supports some advanced features to speed up hallucination, namely "pre-cycling" (running multiple recycling iterations on the target alone before design) and "co-cycling" (running recycling and optimization steps in parallel), but can also be used analogously to Boltz or AF2.

ProteinMPNN


Load your prefered ProteinMPNN (soluble or vanilla) model using

from mosaic.proteinmpnn.mpnn import ProteinMPNN

mpnn = ProteinMPNN.from_pretrained()

In the simplest case we have a single-chain structure or complex where the protein we're designing occurs as the first chain (note this can be a prediction). We can then construct the (negative) log-likelihood of the designed sequence under ProteinMPNN as a loss term:

inverse_folding_LL = FixedStructureInverseFoldingLL.from_structure( gemmi.read_structure("scaffold.pdb"), mpnn)

This can then be added to whatever overall loss function you're constructing.

Note that it is often helpful to clip the loss using, e.g., ClippedLoss(inverse_folding_LL, 2, 100): over-optimizing ProteinMPNN likelihoods typically results in homopolymers.

ProteinMPNN + structure prediction


ProteinMPNN can also be combined with live structure predictions. Mathematically this is $-\log P_\theta(s | AF2(s)),$ the log-likelihood of the sequence $s$ under inverse folding of the predicted structure for that sequence. This loss term is ProteinMPNNLoss.

Another very useful loss term is InverseFoldingSequenceRecovery: a continuous relaxation of sequence recovery after sampling with ProteinMPNN (roughly $\langle s, -E_{z \sim p_\theta(\cdot | AF2(s))} [z] \rangle$). We've found this term often speeds up design and increases filter pass rates.

ESM


Warning: due to python issues, it's impossible to use both ESM2 and ESMC in the same environment.

Another useful loss term is the pseudolikelihood of the ESM2 protein language model (via esm2quinox); which is correlated with all kinds of useful properties (solubility, expressibility, etc).

This term can be constructed as follows:

import esm
import esm2quinox
torch_model, _ = esm.pretrained.esm2_t33_650M_UR50D()
ESM2PLL = ESM2PseudoLikelihood(esm2quinox.from_torch(torch_model))

In typical practice this loss should be clipped or squashed to avoid over-optimization (e.g. ClippedLoss(ESM2PLL, 2, 100)).

We also implement the corresponding loss for ESMC (via esmj).

from esmj import from_torch
from esm.models.esmc import ESMC as TORCH_ESMC

esm = from_torch(TORCH_ESMC.from_pretrained("esmc_300m").to("cpu"))
ESMCPLL = ESMCPseudoLikelihood(esm)

Stability


A simple delta G predictor trained on the megascale dataset. Might be a nice example of how to train and add a simple regression head on a small amount of data: train.py.

stability_loss = StabilityModel.from_pretrained(esm)

AbLang


AbLang, a family of antibody-specific language models.

import ablang
import jablang

heavy_ablang = ablang.pretrained("heavy")
heavy_ablang.freeze()

abpll = AbLangPseudoLikelihood(
    model=jablang.from_torch(heavy_ablang.AbLang),
    tokenizer=heavy_ablang.tokenizer,
    stop_grad=True,
)

Trigram


A trigram language model as in A high-level programming language for generative protein design.

trigram_ll = TrigramLL.from_pkl()

Optimizers and loss transformations


We include some standard [optimizers] in (src/mosaic/optimizers.py).

First, simplex_APGM, which is an accelerated proximal gradient algorithm on the probability simplex. One critical hyperparameter is the stepsize, a reasonable first guess is 0.1*np.sqrt(binder_length). Another useful keyword argument is scale, which corresponds to $\ell_2$ regularization. Values larger than 1.0 encorage sparse solutions; a typical binder design run might start with scale=1.0 to get an initial, soft solution and then ramp up to something higher to get a discrete solution.

simplex_APGM also accepts a keyword argument, logspace, if this is set to true we run the algorithm in logspace, which corresponds to an accelerated proximal bregman method. In this case scale corresponds to entropic regularization.

We also include a discrete optimization algorithm, gradient_MCMC, which is a variant of MCMC with a proposal distribution defined using a taylor approximation to the objective function (see Plug & Play Directed Evolution of Proteins with Gradient-based Discrete MCMC.) This algorithm is especially useful for finetuning either existing designs or the result of continuous optimization.

Loss transformations

We also provide a few common transformations of loss functions. Of note are ClippedLoss, which ... wraps and clips another loss term.

SetPositions and FixedPositionsPenalty are useful for fixing certain positions of an existing design.

ClippedGradient and NormedGradient respectively clip and normalize the gradients of individual loss terms, this can be useful when combining predictors with very different gradient norms, for example:

loss = ClippedGradient(inverse_folding_LL, 1.0)  
    + ClippedGradient(ablang_pll, 1.0)
    + 0.25 * ClippedGradient(ESMCPLL, 1.0)

Extensive theoretical discussion

Hallucination-based protein design workflows attempt to solve the following optimization problem:

$$\underset{s \in A^n}{\textrm{minimize}}~\ell(s).$$

Here $A$ is the set of amino acids, so the decision variable $s$ ranges over all protein sequences of length $n$. $~\ell: A^n \rightarrow \mathbf{R}$ is a loss functional that evaluates the quality of the protein $s$. In typical practice $\ell$ is some function of the output of a neural network; i.e. in ColabDesign $\ell$ might be (negative) average pLDDT from AlphaFold.

One challenge with naive approaches is that $A^n$ is extremely large and discrete optimization is difficult; while MCMC and other discrete algorithms have been used (see, e.g., Rives et al) they are often very slow.

ColabDesign, RSO, and BindCraft, among others, use the fact that $\ell$ has a particular structure that allows for a continuous relaxation of the original problem: almost every neural network first encodes the sequence $s$ into a one-hot matrix $P \in \mathbf{R}^{(n, c)}$. If we consider $\ell$ as a functional on $\mathbf{R}^{(n, c)}$ we can use automatic differentiation to do continuous optimization on either $\mathbf{R}^{(n, c)}$ or $\Delta_c^n$ ($n$ products of the probability simplex).

This is related to the classic optimization trick of optimizing over distributions rather than single points. First, $\underset{x}{\textrm{minimize }}f(x)$ is relaxed to $\underset{p \in \Delta}{\textrm{minimize }}E_p f(x)$. Next, if it makes sense to take the expectation of $x$ (as in the one-hot sequence case), we can interchange $f$ and $E$ to get the final relaxation: $$\underset{p \in \Delta}{\textrm{minimize }} f( E_p x) = \underset{p \in \Delta}{\textrm{minimize }} f(p).$$

Solutions to this relaxed optimization problem must then be translated into sequences; many different methods work here: RSO uses inverse folding of the predicted structure, BindCraft/ColabDesign uses a softmax with ramping temperature to encourage one-hot solutions, etc.

By default we use a generalized proximal gradient method (mirror descent with entropic regularization) to do optimization over the simplex and to encourage sparse solutions, though it is very easy to swap in other optimization algorithms (e.g. projected gradient descent or composition with a softmax as in ColabDesign).

Typically $\ell$ is formed by a single neural network (or an ensemble of the same architecture), but in practice we're interested in simultaneously optimizing different properties predicted by different neural networks. This has the added benefit of reducing the chance of finding so-called adversarial sequences.

This kind of modular implementation of loss terms is also useful with modern RL-based alignment of generative models approaches: these forms of alignment can often be seen as amortized optimization. Typically, they train a generative model to minimize some combination of KL divergence minus a loss function, which can be a combination of in-silico predictors. Another use case is to provide guidance to discrete diffusion or flow models.

Footnotes

  1. This requires us to treat neural networks as simple parametric functions that can be combined programatically; not as complicated software packages that require large libraries (e.g. PyTorch lightning), bash scripts, or containers as is common practice in BioML.

About

composite-objective protein design

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •