Iterating over data sets of different sizes (e.g. for semi-supervised learning)ΒΆ
There are some instances where we wish to draw samples from two data sets simultaneously, where the data sets
have different sizes. One example is in semi-supervised learning where we have a small dataset of labeled samples
lab_X
with ground truths lab_y
and a larger set of unlabeled samples unlab_X
. Lets say we want a
single epoch to consist of the entire unlabeled dataset while looping over the labeled dataset repeatly. The
data_source.CompositeDataSource
class can help us here.
Without using CompositeDataSource
we would need the following:
rng = np.random.RandomState(12345)
# Construct the data sources; the labeled data source will
# repeat infinitely
ds_lab = data_source.ArrayDataSource([lab_X, lab_y], repeats=-1)
ds_unlab = data_source.ArrayDataSource([unlab_X])
# Construct an iterator to get samples from our labeled data source:
lab_iter = ds_lab.batch_iterator(batch_size=64, shuffle=rng)
# Iterate over the unlabeled data set in the for-loop
for (batch_unlab_X,) in ds_unlab.batch_iterator(
batch_size=64, shuffle=rng):
# Extract batches from the labeled iterator ourselves
batch_lab_X, batch_lab_y = next(lab_iter)
# (we could also use `zip`)
# Process batches here...
We can use CompositeDataSource
to simplify the above code. It will drawn samples from both
ds_lab
and ds_unlab
and will shuffle the samples from these data source idependently:
# Construct the data sources; the labeled data source will
# repeat infinitely
ds_lab = data_source.ArrayDataSource([lab_X, lab_y], repeats=-1)
ds_unlab = data_source.ArrayDataSource([unlab_X])
# Combine with a `CompositeDataSource`
ds = data_source.CompositeDataSource([ds_lab, ds_unlab])
# Iterate over both the labeled and unlabeled samples:
for (batch_lab_X, batch_lab_y, batch_unlab_X) in ds.batch_iterator(
batch_size=64, shuffle=rng):
# Process batches here...
You can also have CompositeDataSource
generate structured mini-batches that reflect the
structure of the component data sources. The batches will have a nested tuple structure:
# Disable flattening with `flatten=False`
ds_struct = data_source.CompositeDataSource(
[ds_lab, ds_unlab], flatten=False)
# Iterate over both the labeled and unlabeled samples.
for ((batch_lab_X, batch_lab_y), (batch_unlab_X,)) in \
ds_struct.batch_iterator(batch_size=64, shuffle=rng):
# Process batches here...