Covariate Drift & SmartSplit Technology

As we go through the process of training machine learning models and start thinking about how we’re going to prepare our training data, one of the first questions that always comes to mind is “How am I going to split off a test set so that I can measure the models that I produce?” There’s a rule of thumb that most of us use most of the time without thinking too awfully hard about it, which is to take our data, split off 20% (that’s most common – sometimes we’ll have a big dataset and split off 10%), and we’ll use a random split to do it. 

One of the first questions we have to ask when we do this split: Is just taking a rule of thumb like that always the right answer? Always should be a bit of a hair raising word to anybody because there’s just about nothing that always works. On the one hand, if you make your test set too big you’ve wasted examples that you could have used in your training set that could have made it perform just a little bit better. Conversely, if you have too small a test set, it’s really going to be a poor measurement of how well your model will actually perform once it rolls out into the field. You simply haven’t covered all the potential cases, so the measurement made – and you really should think of any model metrics that you compute, things like accuracy, your F1, so forth – they’re single measurements out of a distribution you can’t see. And you really have to think about if you are getting the distribution towards the middle of that mean (and how big is that distribution), or are you taking something that’s maybe not representative of something you’re more likely to see in the field – and that’s really what we worry about. 

Before we even talk about how to split it – if we still assume a random sampling – how much should we split off? Well I think the answer, if you want to get a little more certain, has to do with being inspired by survey statistics and techniques there wherein you have to establish a confidence bound. You can’t have a magic number that’s always right, but you can say that, within a certain degree of error, you’re willing to tolerate a given measurement within a particular confidence bound. Typical choices might be that you’re willing to accept a 5% error within a confidence of say 95% (eg, 95% of the time, my answer will be within that error). 

When you start to assemble for a given dataset and a given model, you first ask yourself “How many examples do I need in my test set to be sure I meet those conditioned bounds?” And it turns out that there’s one more dependent variable, which is the predicted accuracy of your model. And while I understand you won’t know that until you’ve trained it and measured it, generally speaking you’re going to have at least some idea of the neighborhood that it’s in. And intuitively, the more accurate your model is, the fewer examples you’ll need to ensure it simply because if you have a model that’s supposed to be 99.999% accurate, one wrong answer and you’ve pretty much thrown that out the window, whereas if your model is 50% (50/50), you’re going to have to measure an awful lot of examples before statistically you’re really sure that it is or is not at that level of accuracy. 

Alright so, we’ve talked about how to split off your data and how big your test set needs to be, but that sort of under/over splitting is not the only thing that can go wrong when it comes to splitting off your data and making sure that your model evaluations are accurate. You know, other things happen – data changes. There’s all sorts of drift. The notion of the environments of your data – which is sort of a subtle way to describe the nature of your data (eg what are the examples like, are they changing over time, are they changing with context as you measure it in different systems, are the predicted values changing?). Sometimes it might be the master set, the range of values you’re predicting that change. And sometimes it may simply be the distribution. You’re just naturally starting to see a few more dogs than cats and maybe that’s a drift error. Maybe there actually are more dogs than cats in the particular pile of images you happen to be looking at. Or maybe just dogs are outperforming cats in the world because, as we all know, dogs are better. And conversely, if you have a very, very small sample, you can get bias – particularly for test sets, but that’s also true for training sets and other sample tests.

So one of the notions of the way that data may change over time is the idea of concept drift. I’m going to tell a story originally told by the people over at Google in a paper called Invariant Risk MinimizationThe key notion was that some researchers may create an image classifier and they want to separate out cows from camels. They bring up a great classifier and roll it out to the field and it starts misclassifying when they start taking live pictures. The question becomes: Why? 

Well, it turned out in this particular case that the training data they had contained a bias in it that nobody thought of – and no matter what, there are going to be hidden biases, you’re not going to manually be able to think of them all or move them all out – and it turned out that all the pictures of camels that they had to train from were taken in a desert on a brown background and all the pictures of cows were in pastures with a green background. 

And so as soon as the classifier ran across a picture of a cow standing on dirt or a camel standing on grass it would misclassify. Essentially, what they had actually trained was a background classifier. So one of the primary ways to deal with this causal drift is this invariant risk minimization technique. So the idea here is to introduce an intermediate stage into your data. Instead of predicting your outputs (cow vs camel), you find what we call an invariant representation. So what you’re actually searching for is some function such that when we transform a variable with this function it will still make the predictions you want and it becomes invariant even as things change. 

So another type of drift that we think about is the notion of probability drift. Here we are looking at the same variables, but the distributions change. Maybe cows are getting more popular and camels are getting less popular. Another example of this would be trying to model housing prices in a particular town over time. You can build a very accurate and effective model for housing prices in 1950. Fast forward it to today, 70 years later, and that model isn’t going to help you at all. The houses may be the same – some of the houses in Santa Clara haven’t changed since the 50s, which is true in possibly many areas of the country – but the prices have gone up, and they haven’t necessarily gone up proportionately because there are other latent factors like land or proximity to this and that in the neighborhood. So this is something where time is the primary dimension in which things are going to drift given that the data hasn’t changed. The primary way to deal with this is to simply monitor and adjust. There’s no such thing as a model that you can roll into production and stay as accurate as it had in overall time. These are living things, so to speak, and they need to be continually monitored and fed with new, updated examples – perhaps samples out of the actual live data stream – and adjusted as biases start changing over time. Eventually the models need to be updated or just simply replaced with newer models trained on newer weighted data. 

Alright, so this takes us to covariate drift. Now we get into the part of the world that Jaxon’s SmartSplit technology is really designed to address. In the case of covariate drift, the probability of a variable – the distribution of whether it’s a cow or a camel or median house prices – isn’t going to change, but the probability of what we’re going to see in the actual data does. So if you think about cases like language usages: we haven’t used ‘thou’ since about the 17th century. Let’s say you want to figure out if Shakespeare could have written a certain passage. Does it have the word ‘thou’ in it? Then maybe. Does it not have ‘thou’ in it? Well he probably used that given he was doing his thing before the 17th century. And this sort of covariate drift where the distribution of the latent features of a variable are changing over time but the distribution of our output does not change over time – or if it does not change over time we use these other techniques – is really the core of where the SmartSplit technology helps us. 

So SmartSplit at a high level has three steps. And again, just to frame this, our slated task here is that we have a dataset that we want to divide up by some proportion. Hopefully we’ve used survey statistics to make a wise decision in terms of what percentage of those are the training set and what percentage of those are the testing set and if we’re doing a third validation set then the same holds. But here we want to actually figure out which specific examples from the overall dataset should we put into the training and test set and can I do better than random sampling. So the reason random sampling fails is because all the examples are not created equal. Some of them will be duplicates. Some of them will expose more than one latent feature that the system might be able to find signal in. And some of them aren’t that useful at all. We want to make sure that we are very choosy about exactly which representations go into which bucket such that they are as alight as absolutely possible, even in the dimensions that we haven’t been able to see as data scientists and analysts eyeballing the data.

So the three high-level steps of SmartSplit are:

  • First create a representation. This could be something inspired by the invariant risk minimization that we talked about, but it doesn’t have to be. I don’t want to confound these two areas of research and technology, although there is an opportunity to marry things up.
  • Once we have projected our dataset into this representation space, we’re going to create a strategic clustering chain over that and we’re going to further subdivide those examples into a lot of very strategically crafted clusters, each of which starts to get at a particular topical area or latent feature subarea.
  • And once we’ve done that, now when we do our split we’re going to sample those clusters, and we’re going to sample them to fit such that everything comes back out with the expected distribution. So for each cluster, if we said we wanted 80/20 split, we’re going to take 4 examples/cluster for the training set and take our final example and put it into the test set and keep doing that. 

So that’s the technique on a high level of how SmartSplit works. So now let’s dig in a little deeper and see the results. There are two primary benefits of using SmartSplit: 

  • The first one, and the reason we started looking into this, is the notion of test variants. When you take one test set that we split off of the training data to evaluate a model, you’re getting a single measurement of that model’s theoretical distribution over all the documents in the wild that it might have to predict in production. There’s actually a bit of a distribution of what your model’s actual accuracy is and you’re getting to sample a single measure in that test set so you get one observation against what your model’s actual performance distribution is. So naturally if we can reduce the variance of that curve then we can make sure that the measurement we made is more likely to be closer to that model’s true performance output.
  • The additional piece here is that how can we divide things up such that the measurement we make on a model is more likely to reflect what we’re going to see in production with that model with a brand new example that we’ve never seen than the traditional disappointment of what looks great in the lab immediately falls off in terms of performance when it hits the production realm.

So we had some very encouraging success on this. I have two graphs below of a couple of different benchmarks we ran, and each of these represents a different dataset and modeling problem, and we ran the same problem including random splits on the left side and SmartSplits on the right side using 100 different kinds so we could get a proper distribution and see how does the same modeling perform as we give it different permutations of the data used for training and testing. If you look, you’ll notice in both cases the SmartSplit distribution on the right side is narrower than the random split on the left side and empirically when we add up the numbers we seem to be getting about a 3X reduction in variance when we use SmartSplit. That seems to hold up fairly well across datasets. And we consider that to be generally a win.

But wait there’s more. 

So it turns out that the models are better too. So if you look at the image below, not only is the distribution band much smaller for the split compared to the right, but the mean/median/average output of the model (represented by the red lines) are also higher. We’re actually not only getting models that are more robust and our predictions are more robust, but it turns out that the models are better. We’re getting an average of 15% reduction in model error, which we hypothesize has to do with the fact that the model is seeing accurate proportions of the different types of latent information, latent feature spaces available in the training data. And if that wasn’t enough, things seem to converge just a little bit faster as well. So our models are achieving convergence faster. 

– Greg Harman, CTO (from Jaxon’s SmartSplit Webinar)