When people first learn about Jaxon, a common question is how we are able to train a model to produce data that will train a better model. Isn’t that first model already the model we want if it can label data? Which comes first, the chicken or the egg? (I ordered one of each from Amazon; I’ll let you know…)
Let’s take a look at one of the seminal approaches to training with sparse data. A weak labeling paradigm like Snorkel combines heuristics via a generative model. This generative model aggregates predictions (and abstentions) provided by the heuristics, and makes a joint decision. Intuitively, it takes a vote among the heuristics and decides what the group’s prediction will be. It may often choose the majority opinion, but like the team captain in Family Feud, it is free to weigh the group’s opinions on a case-by-case basis.
This generative model is then used to label a dataset, which is used in turn to train a discriminative model – typically a deep neural net outfitted with noise-aware loss functions – which is then used as the final predictor. And herein lies the central theme of this post: if the generative model is producing the training labels, then why use the discriminative model at all?
One reason is compression. All of those weak labelers feeding the generative model may carry I/O overhead, processing inefficiencies, etc. The relevant decisions can be encoded into a single, efficient neural network. However, that’s the less interesting motivation.
In the case of an unsupervised generative model consuming only heuristics, generalization is a concern. Hand-coded heuristics are based only on user-defined primitives, and are likely to suffer from a lack of coverage of the latent feature space; overall recall may be poor. A well-designed discriminative neural model will encapsulate a representation of the feature space such that similar inputs – even those that may not trigger one of the initial heuristics – have similar representations, and therefore can be predicted with some level of confidence.
Jaxon embraces a semi-supervised approach to sparse training, and this allows us to utilize machine learning models as weak labelers, with heuristics supplementing. Unlike the heuristics, these models can carry a rich representation, particularly when techniques such as pretraining are utilized.
Further, Jaxon’s training platform is able to generate synthetically augmented data – more examples, not just synthetic labels – internally during training. This has the effect of further enriching the data representation and therefore generalization capability.
Ensembling these models with heuristics bears additional fruit. Models-as-labelers provide rich representations across different modalities. Heuristics add domain knowledge that transcends examples that happen to be represented in the (sparsely-)labeled portion of the training data.
Iterating beyond a generative model and training a single discriminative model consolidates the representations available to the different labelers and unifies the benefits. Augmentations, custom training schedules, confidence-awareness, and other training platform capabilities ensure that the result is greater than the sum of the parts. You can have your egg and your chicken!
– Greg Harman, CTO