Molecular Generators
In this tutorial, we will finally focus on the central part of the GenUI platform. However, it is recommended that you still check the QSAR and Compounds tutorials first because we will expand upon some of the concepts presented there.
A lot of the most promising contemporary generators are based on generative deep neural networks, which are a modern class of machine learning algorithms. Therefore, in this tutorial we will mainly be concerned with the integration of such an approach. However, it should be noted that any generator can be added to the platform even if it is not strictly based on machine learning. You can skip to chapter Using the Trained Generator if your generator is not based on a machine learning algorithm or does not require training. However, you might still find Defining URLs, Defining Views and Models & Serializers useful if you are not familiar with Django and Django REST frameworks.
We will not describe the development of a generative extension in exhausting detail here, but you should regard this tutorial more as a case study which highlights important GenUI concepts and features. The subject of our case study will be the
genui.generators.extensions.genuidrugex
package, which integrates a deep learning
generative algorithm called DrugEx. This approach is based on training a recurrent neural network with a reinforcement learning loop
where a QSAR model provides the environment for the agent. It is implemented on
top of the PyTorch machine
learning library which can take advantage of GPU hardware. Therefore, DrugEx is
a good non-trivial example of a contemporary molecular generator that might be
of interest to users seeking to use the GenUI platform to discover new chemistry.
Just like any other extension, genuidrugex
is also a Django application so it has all the expected modules
and functionality that you are probably already familiar with if
you read the previous tutorials. However,
in order to achieve our goals, we will have to do a little bit more work
and customization than we saw previously. Most of this work will focus on creating a customized model builder (see Implementing a Model Builder), which
has a similar purpose as the instances of MolSetInitializer
that we
saw in Defining Compound Set Initializer. However, we will also
learn more about other features of GenUI that we have not described, yet.
Defining URLs
We will take a look at the genuidrugex.urls
module first since it clearly showcases the components of the extension
that we will be dealing with in our case study. The file is not long so we can show it here:
"""
urls
Created by: Martin Sicho
On: 5/3/20, 6:51 PM
"""
from django.urls import path, include
from rest_framework import routers
from genui.utils.extensions.tasks.views import ModelTasksView
from genui.models.views import ModelFileView, ModelPerformanceListView
from . import models
from . import views
router = routers.DefaultRouter()
router.register(r'drugex/networks', views.DrugExNetViewSet, basename='drugex_net')
router.register(r'drugex/agents', views.DrugExAgentViewSet, basename='drugex_agent')
router.register(r'drugex/environments', views.EnvironmentViewSet, basename='drugex_env')
# scoring methods
router.register(r'drugex/scorers/methods/all', views.ScoringMethodViewSet, basename='drugex_scoremethods_all')
router.register(r'drugex/scorers/methods/genuimodels', views.QSARScorerViewSet, basename='drugex_scoremethods_genuimodels')
router.register(r'drugex/scorers/methods/properties', views.PropertyScorerViewSet, basename='drugex_scoremethods_properties')
# modifiers
router.register(r'drugex/scorers/modifiers/all', views.ModifierViewSet, basename='drugex_modifiers_all')
router.register(r'drugex/scorers/modifiers/clipped', views.ClippedViewSet, basename='drugex_modifiers_clipped')
router.register(r'drugex/scorers/modifiers/hump', views.HumpViewSet, basename='drugex_modifiers_hump')
# scorers
router.register(r'drugex/scorers', views.ScorerViewSet, basename='drugex_scorers')
# generators
router.register(r'drugex/generators', views.GeneratorViewSet, basename='drugex_generators')
routes = [
# networks
path('drugex/networks/<int:pk>/tasks/all/', ModelTasksView.as_view(model_class=models.DrugExNet))
, path('drugex/networks/<int:pk>/tasks/started/', ModelTasksView.as_view(started_only=True, model_class=models.DrugExNet))
, path('drugex/networks/<int:pk>/performance/', ModelPerformanceListView.as_view(), name="drugex_net_perf_view")
, path('drugex/networks/<int:pk>/files/', ModelFileView.as_view(model_class=models.DrugExNet), name="drugex-net-model-files-list")
] + [
# agents
path('drugex/agents/<int:pk>/tasks/all/', ModelTasksView.as_view(model_class=models.DrugExAgent))
, path('drugex/agents/<int:pk>/tasks/started/', ModelTasksView.as_view(started_only=True, model_class=models.DrugExAgent))
, path('drugex/agents/<int:pk>/performance/', ModelPerformanceListView.as_view(), name="drugex_agent_perf_view")
, path('drugex/agents/<int:pk>/files/', ModelFileView.as_view(model_class=models.DrugExAgent), name="drugex-agent-model-files-list")
]
urlpatterns = [
path('', include(routes)),
path('', include(router.urls)),
]
Before training the generator itself (the agent), the DrugEx approach requires
that an exploration and exploitation networks are build first. That is why we see
two viewsets (DrugExNetViewSet
and DrugExAgentViewSet
) registered for the router in the above code. Each viewset handles
creation and training of either one of the networks or the agent itself.
The routes
list is simply a list of URLs for API endpoints
that we would like to have available under the root URLs /generators/drugex/networks/
and /generators/drugex/agents/
.
You can see that we are using some class views that come from the genui.models
application:
ModelTasksView
: Shows all or only started and running asynchronous Celery tasks attached to an instance ofModel
(determined by thepk
URL argument).
ModelPerformanceListView
: Shows machine learning model performance for a specified instance ofModel
. See Adding Performance Metrics for more info.
ModelFileView
:Model
instances can have files attached to it. This endpoint handles them.
You do not have to have these views if you do not need them, but it is good to know about them so that you do not have to implement your own.
Defining Views
Lets take a closer look at DrugExNetViewSet
and DrugExAgentViewSet
in genuidrugex.views
:
import traceback
from django.conf import settings
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from genui import celery_app
from genui.accounts.serializers import FilterToUserMixIn
from genui.projects.serializers import FilterToProjectMixIn
from genui.utils.extensions.tasks.utils import runTask
from . import models
from . import serializers
from .genuimodels import builders
from .tasks import buildDrugExModel, calculateEnvironment
from genui.models.views import ModelViewSet
class DrugExNetViewSet(ModelViewSet):
queryset = models.DrugExNet.objects.order_by('-created')
serializer_class = serializers.DrugExNetSerializer
init_serializer_class = serializers.DrugExNetInitSerializer
builder_class = builders.DrugExNetBuilder
build_task = buildDrugExModel
def get_builder_kwargs(self):
return {"model_class" : models.DrugExNet.__name__}
class DrugExAgentViewSet(ModelViewSet):
queryset = models.DrugExAgent.objects.order_by('-created')
serializer_class = serializers.DrugExAgentSerializer
init_serializer_class = serializers.DrugExAgentInitSerializer
builder_class = builders.DrugExAgentBuilder
build_task = buildDrugExModel
def get_builder_kwargs(self):
return {"model_class" : models.DrugExAgent.__name__}
class EnvironmentViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.DrugExEnvironment.objects.order_by('-created')
serializer_class = serializers.DrugExEnvironmentSerializer
calculation_serializer_class = serializers.DrugExEnvironmentCalculationSerializer
owner_relation = "project__owner"
@action(detail=True, methods=['post'])
def calculate(self, request, pk=None):
environment = self.queryset.get(pk=pk)
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
data = serializer.validated_data
task = None
try:
task, task_id = runTask(
calculateEnvironment,
eager=hasattr(settings, 'CELERY_TASK_ALWAYS_EAGER') and settings.CELERY_TASK_ALWAYS_EAGER,
args=(
environment.pk,
data["molsets"],
data["useModifiers"]
),
)
data["taskID"] = task_id
return Response(data, status=status.HTTP_201_CREATED)
except Exception as exp:
traceback.print_exc()
if task and task.id:
celery_app.control.revoke(task_id=task.id, terminate=True)
return Response({"error" : repr(exp)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.action == 'calculate':
return self.calculation_serializer_class
else:
return self.serializer_class
class ScoringMethodViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.ScoringMethod.objects.order_by('-created')
serializer_class = serializers.ScoringFunctionSerializer
owner_relation = "project__owner"
# http_method_names = ['get']
class QSARScorerViewSet(ScoringMethodViewSet):
queryset = models.GenUIModelScorer.objects.order_by('-created')
serializer_class = serializers.QSARScorerSerializer
class PropertyScorerViewSet(ScoringMethodViewSet):
queryset = models.PropertyScorer.objects.order_by('-created')
serializer_class = serializers.PropertyScorerSerializer
class ModifierViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.ScoreModifier.objects.order_by('-created')
serializer_class = serializers.ModifierSerializer
test_serializer_class = serializers.ModifierTestSerializer
owner_relation = "project__owner"
permission_classes = [IsAuthenticated]
@action(detail=False, methods=['post'])
def test(self, request):
modifier = self.queryset.model
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
data = serializer.validated_data
results = modifier.test(data['inputs'], **(data['params']))
data["results"] = results
return Response(data)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.action == 'test':
return self.test_serializer_class
else:
return self.serializer_class
class ClippedViewSet(ModifierViewSet):
queryset = models.ClippedScore.objects.order_by('-created')
serializer_class = serializers.ClippedSerializer
class HumpViewSet(ModifierViewSet):
queryset = models.SmoothHump.objects.order_by('-created')
serializer_class = serializers.SmoothHumpSerializer
class ScorerViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.DrugExScorer.objects.order_by('-created')
serializer_class = serializers.DrugExScorerSerializer
owner_relation = "project__owner"
class GeneratorViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.DrugEx.objects.order_by('-created')
serializer_class = serializers.DrugExGeneratorSerializer
owner_relation = "project__owner"
GenUI already has a viewset defined for model creation, ModelViewSet
. Actually, the same viewset is used for QSAR modelling
as well (see genui.qsar.views
) and it turns out
that we do not actually need to modify it much. There is just a few class attributes
that we have to set:
queryset
: Defines the objects shown in the list and their order.
serializer_class
andinit_serializer_class
: These are serializer classes used to represent and create new instances. They can be identical, but in the case of models they are often slightly different. Hence, the distinction and two parameters.
builder_class
: This is a specific attribute for models and is similar to theinitializer_class
found ingenui.compounds
(see Defining Compound Set Initializer). We will tak about in more detail later.
build_task
: We have a specific Celery task defined for DrugEx (buildDrugExModel()
ingenuidrugex.tasks
). Mainly because this extension depends on GPU hardware and as such the build tasks should be submitted to thegpu
queue, which should be consumed by workers with GPUs.
Models & Serializers
Since we are implementing a completely new kind of model,
we have to define the models and serializers we will need.
This part is very specific to the DrugEx extension,
but there are a few model and serializer classes
in the genui.models
application that you should be
aware of when attempting something similar.
Model Classes
Model
A polymorphic model class used to save data
about a machine learning model. In genuidrugex.models
, we
can add data that is saved to this class by creating a subclass.
For example, the final DrugEx agent can be described like so:
class DrugExAgent(Model):
environment = models.ForeignKey(QSARModel, on_delete=models.CASCADE, null=False, related_name='drugexEnviron')
explorationNet = models.ForeignKey(DrugExNet, on_delete=models.CASCADE, null=False, related_name='drugexExplore')
exploitationNet = models.ForeignKey(DrugExNet, on_delete=models.CASCADE, null=False, related_name='drugexExploit')
def getGenerator(self):
return self.generator.all().get() if self.generator.all().exists() else None
Here we add foreign keys to the explorationNet
and exploitationNet
models (represented by DrugExNet
, also a subclass of Model
). We also
add the environment
, which is used for the policy gradient calculation in the reinforcement learning loop of the DrugEx algorithm. It is simply a given QSAR model, also defined
as a subclass of Model
(QSARModel
).
TrainingStrategy
and ValidationStrategy
When a model is created in GenUI several parameters have to be specified
depending on whether the model is to be trained and validated by GenUI or imported
from an external source. Two most important Model
attributes are trainingStrategy
and
validationStrategy
. These should point to database
model instances of TrainingStrategy
and
ValidationStrategy
.
TrainingStrategy
contains data
required for training of a model. Most importantly this is information about the algorithm used and the chosen training parameters, which are tied to this instance with the ModelParameterValue
model.
On the other hand, ValidationStrategy
defines how the model should be validated after it is trained. It holds information
about the performance metrics that should be calculated
during validation (see ModelPerformanceMetric
).
Just like the Model
class, TrainingStrategy
and
ValidationStrategy
are polymorphic as well and you can add new information by subclassing them.
For example, BasicValidationStrategy
saves information about the
size of the external validation set chosen randomly from the training
data (defined as a fraction of instances) and the number of folds in cross-validation. It is defined
in genui.models.models
as follows:
class CV(ValidationStrategy):
cvFolds = models.IntegerField(blank=False)
class Meta:
abstract = True
class ValidationSet(ValidationStrategy):
validSetSize = models.FloatField(blank=False) # as
class Meta:
abstract = True
class BasicValidationStrategy(ValidationSet, CV):
pass
In the genuidrugex
extension we have the following custom validation
strategy:
class DrugExValidationStrategy(ValidationStrategy):
validSetSize = models.IntegerField(default=512, null=True)
It simply defines the size of the test set that will be used to measure performance after processing one batch of data.
Serializer Classes
Classes defined in genuidrugex.serializers
just describe how the models in genuidrugex.serializers
are transformed to JSON format. This is
a little bit more involved and you are encouraged to study the source code of this
module more closely. We will just follow up on the example above and show how the serializer looks like for DrugExValidationStrategy
:
class DrugExValidationStrategySerializer(ValidationStrategySerializer):
"""
This is used for GET requests to convert a DrugExValidationStrategy model
to JSON.
"""
class Meta:
model = models.DrugExValidationStrategy
fields = ValidationStrategySerializer.Meta.fields + ("validSetSize",)
class DrugExValidationStrategyInitSerializer(DrugExValidationStrategySerializer):
"""
This is the serializer used for POST requests when creating a new instance of
DrugExValidationStrategy.
"""
metrics = serializers.PrimaryKeyRelatedField(many=True, queryset=ModelPerformanceMetric.objects.all(), required=False)
class Meta:
model = models.DrugExValidationStrategy
fields = DrugExValidationStrategySerializer.Meta.fields + ("validSetSize",)
Many serializers are already defined in genui.models.serializers
for
Model
, TrainingStrategy
, ValidationStrategy
(ModelSerializer
, TrainingStrategySerializer
, ValidationStrategySerializer
) and other classes related to machine learning models. You should always derive from these when building
a customized model.
Implementing a Model Builder
Now it is time to cover the most essential attribute of ModelViewSet
, builder_class
. Each GenUI model needs a builder, which is initialized when the Celery build task is executed. QSAR models have their own builder (QSARModelBuilder
) and chemical space
maps are also models with their own MapBuilder
. In this chapter, we will cover
how builder can be implemented for a molecular generator.
Lets just briefly remind ourselves of the views we defined earlier:
import traceback
from django.conf import settings
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from genui import celery_app
from genui.accounts.serializers import FilterToUserMixIn
from genui.projects.serializers import FilterToProjectMixIn
from genui.utils.extensions.tasks.utils import runTask
from . import models
from . import serializers
from .genuimodels import builders
from .tasks import buildDrugExModel, calculateEnvironment
from genui.models.views import ModelViewSet
class DrugExNetViewSet(ModelViewSet):
queryset = models.DrugExNet.objects.order_by('-created')
serializer_class = serializers.DrugExNetSerializer
init_serializer_class = serializers.DrugExNetInitSerializer
builder_class = builders.DrugExNetBuilder
build_task = buildDrugExModel
def get_builder_kwargs(self):
return {"model_class" : models.DrugExNet.__name__}
class DrugExAgentViewSet(ModelViewSet):
queryset = models.DrugExAgent.objects.order_by('-created')
serializer_class = serializers.DrugExAgentSerializer
init_serializer_class = serializers.DrugExAgentInitSerializer
builder_class = builders.DrugExAgentBuilder
build_task = buildDrugExModel
def get_builder_kwargs(self):
return {"model_class" : models.DrugExAgent.__name__}
class EnvironmentViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.DrugExEnvironment.objects.order_by('-created')
serializer_class = serializers.DrugExEnvironmentSerializer
calculation_serializer_class = serializers.DrugExEnvironmentCalculationSerializer
owner_relation = "project__owner"
@action(detail=True, methods=['post'])
def calculate(self, request, pk=None):
environment = self.queryset.get(pk=pk)
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
data = serializer.validated_data
task = None
try:
task, task_id = runTask(
calculateEnvironment,
eager=hasattr(settings, 'CELERY_TASK_ALWAYS_EAGER') and settings.CELERY_TASK_ALWAYS_EAGER,
args=(
environment.pk,
data["molsets"],
data["useModifiers"]
),
)
data["taskID"] = task_id
return Response(data, status=status.HTTP_201_CREATED)
except Exception as exp:
traceback.print_exc()
if task and task.id:
celery_app.control.revoke(task_id=task.id, terminate=True)
return Response({"error" : repr(exp)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.action == 'calculate':
return self.calculation_serializer_class
else:
return self.serializer_class
class ScoringMethodViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.ScoringMethod.objects.order_by('-created')
serializer_class = serializers.ScoringFunctionSerializer
owner_relation = "project__owner"
# http_method_names = ['get']
class QSARScorerViewSet(ScoringMethodViewSet):
queryset = models.GenUIModelScorer.objects.order_by('-created')
serializer_class = serializers.QSARScorerSerializer
class PropertyScorerViewSet(ScoringMethodViewSet):
queryset = models.PropertyScorer.objects.order_by('-created')
serializer_class = serializers.PropertyScorerSerializer
class ModifierViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.ScoreModifier.objects.order_by('-created')
serializer_class = serializers.ModifierSerializer
test_serializer_class = serializers.ModifierTestSerializer
owner_relation = "project__owner"
permission_classes = [IsAuthenticated]
@action(detail=False, methods=['post'])
def test(self, request):
modifier = self.queryset.model
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
data = serializer.validated_data
results = modifier.test(data['inputs'], **(data['params']))
data["results"] = results
return Response(data)
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_serializer_class(self):
if self.action == 'test':
return self.test_serializer_class
else:
return self.serializer_class
class ClippedViewSet(ModifierViewSet):
queryset = models.ClippedScore.objects.order_by('-created')
serializer_class = serializers.ClippedSerializer
class HumpViewSet(ModifierViewSet):
queryset = models.SmoothHump.objects.order_by('-created')
serializer_class = serializers.SmoothHumpSerializer
class ScorerViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.DrugExScorer.objects.order_by('-created')
serializer_class = serializers.DrugExScorerSerializer
owner_relation = "project__owner"
class GeneratorViewSet(FilterToProjectMixIn, FilterToUserMixIn, viewsets.ModelViewSet):
queryset = models.DrugEx.objects.order_by('-created')
serializer_class = serializers.DrugExGeneratorSerializer
owner_relation = "project__owner"
In this code, we should now understand everything except the builder_class
and build_task
attributes and the get_builder_kwargs()
method.
When the create view of the DrugExNetViewSet
or DrugExAgentViewSet
receives a POST request with
data defining a new DrugEx model, the POST data is validated according to the
appropriate serializers and the create
method of DrugExNetInitSerializer
or
DrugExAgentInitSerializer
is called. This method is responsible for the creation of
a new DrugExNet
or
DrugExAgent
instance from the
supplied data. An asynchronous Celery task (buildDrugExModel()
defined in genuidrugex.tasks
) is then added to queue and passed the ID of the created instance. The fully qualified name of the
builder class and the model class (as specified by the get_builder_kwargs()
method of ModelViewSet
)
are passed as well.
We can look at the code of buildDrugExModel()
to explain what happens once the task is eligible for execution
by a Celery worker:
"""
tasks
Created by: Martin Sicho
On: 28-01-20, 13:52
"""
from celery import shared_task
from genui.compounds.models import MolSet, ActivityTypes, Activity
from genui.utils.extensions.tasks.progress import ProgressRecorder
from genui.utils.inspection import getObjectAndModuleFromFullName
from . import models
from .models import DrugExEnvironment, DrugExEnvironmentScores
from .torchutils import cleanup
@shared_task(name="BuildDrugExModel", bind=True, queue='gpu')
def buildDrugExModel(self, model_id, builder_class, model_class):
# get the builder
model_class = getattr(models, model_class)
instance = model_class.objects.get(pk=model_id)
builder_class = getObjectAndModuleFromFullName(builder_class)[0]
recorder = ProgressRecorder(self)
if hasattr(instance, 'parent'):
builder = builder_class(
instance,
instance.parent,
progress=recorder
)
else:
builder = builder_class(
instance,
progress=recorder
)
# build the model
try:
builder.build()
except Exception as exp:
raise exp
cleanup()
return {
"errors" : [repr(x) for x in builder.errors],
"DrExModelName" : instance.name,
"DrExModelID" : instance.id,
}
@shared_task(name="calculateEnvironment", bind=True)
def calculateEnvironment(self, environment_id, molsets_ids, use_modifiers):
environment = DrugExEnvironment.objects.get(pk=environment_id)
instance = environment.getInstance(use_modifiers=use_modifiers)
activity_sets = []
for molset_id in molsets_ids:
molset = MolSet.objects.get(pk=molset_id)
molecules = molset.molecules.all()
scores = instance(molset.allSmiles)
assert len(scores) == len(molecules)
activity_set = DrugExEnvironmentScores(
molecules=molset,
environment=environment,
modifiersOn=use_modifiers,
project=environment.project,
name=f"{environment.name} ({molset.name})",
description=f"Activity set created from DrugEx environment: {environment.name} for compound set: {molset.name}."
)
activity_set.save()
activity_sets.append(activity_set.id)
drops = ['VALID']
if not use_modifiers:
drops.append('DESIRE')
scores.drop(drops, axis=1, inplace=True)
columns = scores.columns
for idx, row in scores.iterrows():
molecule = molecules[idx]
for col in columns:
value = row[col]
if value:
atype = ActivityTypes.objects.get_or_create(value=f"{environment.name}_{col}{'_MOD' if use_modifiers else ''}")[0]
activity = Activity(
value=value,
type=atype,
source=activity_set,
molecule=molecule
)
activity.save()
return {
"environment" : environment.id,
"activitySets" : activity_sets,
}
Once the task is eligible for execution, the function above is executed on the worker and the
information outlined above is passed to it as parameters. You should always make
sure that whatever information you pass from get_builder_kwargs()
is serializable as JSON because that is how this information
is passed from the server to the worker.
Note
We also specify gpu
as the queue for this task. This indicates
that this task prefers to be executed on a worker node with access to GPUs.
Therefore, looking at the code above this is roughly what happens on the worker:
We import the model class from
genuidrugex.models
.We use
getObjectAndModuleFromFullName()
fromgenui.utils.inspection
to import the correct builder.We create an instance of the correct builder (
DrugExNetBuilder
orDrugExAgentBuilder
).We run the build method of the builder to build the model.
Therefore, in order to build models we have to implement a builder, which means
subclassing the ModelBuilder
abstract class and adding the necessary methods. You can find the DrugEx model builders in genuidrugex.genuimodels.builders
, which is a standard location recognized by genuisetup
.
You should always define your builders in {your_extension}.genuimodels.builders
.
Note that there are a few mix-in classes defined in genui.models.genuimodels.bases
that you can take advantage of when creating your own model builders. For example, in the DrugEx extension we use ProgressMixIn
that adds methods and attributes to record task progress stages more easily.
The purpose of a model builder is to
prepare model training data with the getX()
and getY()
methods. Values returned by these methods are then fed to the implementation of fit()
of an Algorithm
instance (accessible from the model
attribute of ModelBuilder
). This is exactly what the
reference implementation of build()
in ModelBuilder
does:
def build(self) -> models.Model:
"""
Build method of the ModelBuilder abstract class.
"""
self.model.fit(self.getX(), self.getY())
self.saveFile() # calls self.model.serialize(path_to_model_snapshot)
return self.instance
In the case of BasicQSARModelBuilder
, getX()
returns the matrix of
descriptor values for each compound and getY()
returns the class labels. In the DrugEx extension,
we return the parsed data set corpus from getX()
.
Also note the saveFile
method of the builder. This ensures that the fitted
model is serialized on disk after training. Each subclass of Algorithm
is responsible for implementing the
correct serialize()
method
so that a proper save file can be generated.
You can see the defined DrugEx algorithms by exploring
the genuidrugex.genuimodels.algorithms
module. You will notice that
the algorithms implement a few more methods in addition to those described
in our discussion of QSAR model algorithms (see QSAR).
We have to define new file formats for model serialization with the DrugExAlgorithm.getFileFormats()
method,
make sure that our algorithm lists only as a generator by providing the correct
mode with DrugExAlgorithm.getModes()
and we also have to provide a new serializer and deserializer (DrugExAlgorithm.getSerializer()
and DrugExAlgorithm.getDeserializer()
)
for the state of the model (used by Algorithm.serialize()
).
Using the Trained Generator
So far we have described a possible implementation of a machine learning based generator on the example of the DrugEx extension,
but we have not yet described how to use the trained generator for the creation of new compound sets. You might have
noticed the definition of the DrugEx
class in genuidrugex.models
:
class DrugEx(Generator):
agent = models.ForeignKey(Model, on_delete=models.CASCADE, null=False, related_name="generator")
def get(self, n_samples):
import genui.generators.extensions.genuidrugex.genuimodels.builders as builders
builder_class = getattr(builders, self.agent.builder.name)
builder = builder_class(Model.objects.get(pk=self.agent.id))
samples, valids = builder.sample(n_samples)
return [x for idx, x in enumerate(samples) if bool(valids[idx])]
We have not discussed this Django model yet because it is not central to our discussion
of training and saving the DrugEx networks, but an instance of this class is
created whenever an instance of DrugExNet
or DrugExAgent
is saved
and in fact it is all we need to register a new generator with GenUI and
generate new compounds with it.
Looking at the implementation above, we see that
the definition of a generator in GenUI is
quite general and straightforward. It is simply any instance of the Generator
class. It should always implement the get()
method which takes only one
argument, the maximum number of compounds to generate. In the case of DrugEx,
the implementation of get()
we see above is only a question of importing the correct
DrugEx builder, which we equipped with the sample()
method that can use the
trained neural network to generate compounds.
After implementing your Generator
Django model, you should be able to see
it as an option when creating new compound sets with the
API endpoints of the genui.compounds.extensions.generated
extension.
Note that this is really all you need to define
a generator and if you do not require to train a machine learning model,
your Django application could be really simple and reduced to just implementing
this Django model, creating an appropriate serializer and hooking it up with
a simple Django REST Framework view or viewset.
Conclusion
This was a short tour through how the genuidrugex
extension was created
and hopefully it is a bit more clear how you can integrate your own generator
in GenUI. The genui.models
package has much more interesting features
than covered here so you are encouraged to check its own documentation
pages for some guidance on how to implement the functionality that you want.