The goal of this blog post is to show how to use the TensorFlow API to create a multi-label logistic classification model that takes multiple inputs. The focus is not on the results as we will use just a sample dataset, but on the API itself. This post builds on a previous blog post that shows how to create a TensorFlow Dataset for the YouTube 8M video-level dataset.

Requirements

This code works with tensorflow 2.6.0.

import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
2.6.0

Load parsed dataset

The parsed dataset created in the previous blog post was saved using tf.data.experimental.save:

tf.data.experimental.save(parsed_dataset, os.path.join(data_folder, "dataset"))

Load the parsed dataset:

parsed_dataset = tf.data.experimental.load(os.path.join(os.environ["DATA_FOLDER"], "dataset"))
for parsed_record in parsed_dataset.take(1):
  print(repr(parsed_record))
{'id': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'eACj'], dtype=object)>, 'labels': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([180, 304])>, 'mean_rgb': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([ 0.34214902,  1.0072957 , -0.28980112, ..., -0.38452676,
        0.07256398, -0.9404775 ], dtype=float32)>, 'mean_audio': <tf.Tensor: shape=(128,), dtype=float32, numpy=
array([-1.5312055 , -1.0285152 ,  0.15257615, -1.3953794 , -0.5539142 ,
        1.066028  , -1.8354464 ,  0.3552817 , -0.7087098 ,  0.95269775,
       -0.35108703, -1.0913819 , -0.43328798, -0.13257357,  0.9500226 ,
        1.6974918 ,  1.8891319 , -0.3803924 , -1.9713941 ,  1.7584128 ,
       -0.551239  ,  0.13044512, -0.04392789, -1.3871107 , -1.3588997 ,
       -0.08746034,  0.98711026,  0.00665731, -0.3661653 , -0.92649364,
        0.11269166,  1.5400211 ,  0.5915486 , -1.6733549 , -0.5325128 ,
       -0.9271016 , -1.7089834 ,  0.76628643, -1.054659  ,  0.4481834 ,
       -0.21100494,  0.12168999, -0.22766402, -1.0156257 , -1.2115217 ,
        0.42374197,  0.5706336 ,  0.06538964,  0.33071873, -0.04344149,
        0.15525132, -1.0446879 , -0.78811395, -0.4171153 , -0.52485204,
        0.4324971 , -0.6081474 ,  0.45110175, -0.13913992, -0.4041042 ,
        0.2465722 ,  0.34263542, -1.3624262 , -0.04867025,  0.42751154,
        0.3208692 ,  0.12728354, -1.0325279 ,  1.2633833 , -0.4146833 ,
       -1.2371792 ,  0.17993592, -1.0372703 , -0.7702389 , -1.0303392 ,
       -0.83468634,  1.5831887 , -0.43401757,  0.1333635 ,  1.0003645 ,
        0.72445637,  0.50229496, -0.67599964,  0.96339846, -0.14984064,
        1.27834   ,  1.491503  , -0.4544462 , -0.04380629, -0.7911539 ,
       -0.16260853, -0.12308885,  0.5622433 , -0.909713  ,  0.7098645 ,
       -1.4420735 ,  0.30895248, -1.7302632 , -0.14376068,  1.0689464 ,
       -0.13062799,  0.5123877 ,  0.6601305 ,  1.0648121 ,  0.99878377,
        0.43930665, -0.05961416, -0.0680045 , -0.37261006, -1.1954707 ,
       -0.9128745 ,  0.6335003 , -1.6354159 ,  0.00629251,  1.1883566 ,
        1.6427722 ,  0.48028553,  0.5267363 , -0.80732656, -0.8823532 ,
       -1.0776412 ,  1.2457514 ,  0.06478164,  0.05408093, -0.54844224,
        0.1208388 , -0.78628993,  0.5823071 ], dtype=float32)>}

Logistic regression

Builds a multi-label logistic classification model that takes the image and audio vectors as input.

Number of classes

According to the YouTube 8M video-level dataset there are 3862 classes. We can check if our sample data has at most 3862 different labels. It is a good opportunity to use the tf.Dataset.reduce method.

def tf_reduce_unique_values(old, new):
    concat_tensor = tf.concat([old, new["labels"]], axis = 0)
    y, _ = tf.unique(concat_tensor)
    return y

unique_labels = tf.sort(parsed_dataset.reduce(
    np.array([], dtype=np.int64), tf_reduce_unique_values
))
unique_labels
<tf.Tensor: shape=(3687,), dtype=int64, numpy=array([   0,    1,    2, ..., 3859, 3860, 3861])>
assert unique_labels[-1] <= 3861 # The dataset has a total of 3862 classes

We can then define the number of classes to be 3862:

number_classes = 3862

Define the model

Use keras functional API to define a multiple inputs model:

mean_rgb = keras.Input(name="mean_rgb", shape=(1024,))
mean_audio = keras.Input(name="mean_audio", shape=(128,))
x = keras.layers.concatenate([mean_rgb, mean_audio])
x = keras.layers.Dense(activation="sigmoid", units=number_classes)(x)
model = keras.Model(inputs=[mean_rgb, mean_audio], outputs=[x])

Compile the model

Since each video can belong to more than one class, we need to build a multi-label classification model. We can then use the binary crossentropy loss function and the binary accuracy metric for reasons discussed in this blog post.

model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[keras.metrics.BinaryAccuracy()],
)

Prepare the dataset

The keras training API accepts a tf.Dataset as input but it expects a tuple containing (features, labels). We need then to preprocess our parsed_dataset to turn it into a train_dataset with appropriate output format. We also need to transform the labels from a list of integers to a multi-hot encoding as desccribed in this blog post.

def training_preprocessing(data, number_classes):
    features = {"mean_rgb": data["mean_rgb"], "mean_audio": data["mean_audio"]}
    one_hot = tf.one_hot(indices=data["labels"], depth=number_classes)
    label = tf.reduce_max(one_hot, axis = 0)
    return (features, label)
train_dataset = parsed_dataset.map(lambda x: training_preprocessing(x, number_classes=number_classes))
for data in train_dataset.take(1):
  print(repr(data))
({'mean_rgb': <tf.Tensor: shape=(1024,), dtype=float32, numpy=
array([ 0.34214902,  1.0072957 , -0.28980112, ..., -0.38452676,
        0.07256398, -0.9404775 ], dtype=float32)>, 'mean_audio': <tf.Tensor: shape=(128,), dtype=float32, numpy=
array([-1.5312055 , -1.0285152 ,  0.15257615, -1.3953794 , -0.5539142 ,
        1.066028  , -1.8354464 ,  0.3552817 , -0.7087098 ,  0.95269775,
       -0.35108703, -1.0913819 , -0.43328798, -0.13257357,  0.9500226 ,
        1.6974918 ,  1.8891319 , -0.3803924 , -1.9713941 ,  1.7584128 ,
       -0.551239  ,  0.13044512, -0.04392789, -1.3871107 , -1.3588997 ,
       -0.08746034,  0.98711026,  0.00665731, -0.3661653 , -0.92649364,
        0.11269166,  1.5400211 ,  0.5915486 , -1.6733549 , -0.5325128 ,
       -0.9271016 , -1.7089834 ,  0.76628643, -1.054659  ,  0.4481834 ,
       -0.21100494,  0.12168999, -0.22766402, -1.0156257 , -1.2115217 ,
        0.42374197,  0.5706336 ,  0.06538964,  0.33071873, -0.04344149,
        0.15525132, -1.0446879 , -0.78811395, -0.4171153 , -0.52485204,
        0.4324971 , -0.6081474 ,  0.45110175, -0.13913992, -0.4041042 ,
        0.2465722 ,  0.34263542, -1.3624262 , -0.04867025,  0.42751154,
        0.3208692 ,  0.12728354, -1.0325279 ,  1.2633833 , -0.4146833 ,
       -1.2371792 ,  0.17993592, -1.0372703 , -0.7702389 , -1.0303392 ,
       -0.83468634,  1.5831887 , -0.43401757,  0.1333635 ,  1.0003645 ,
        0.72445637,  0.50229496, -0.67599964,  0.96339846, -0.14984064,
        1.27834   ,  1.491503  , -0.4544462 , -0.04380629, -0.7911539 ,
       -0.16260853, -0.12308885,  0.5622433 , -0.909713  ,  0.7098645 ,
       -1.4420735 ,  0.30895248, -1.7302632 , -0.14376068,  1.0689464 ,
       -0.13062799,  0.5123877 ,  0.6601305 ,  1.0648121 ,  0.99878377,
        0.43930665, -0.05961416, -0.0680045 , -0.37261006, -1.1954707 ,
       -0.9128745 ,  0.6335003 , -1.6354159 ,  0.00629251,  1.1883566 ,
        1.6427722 ,  0.48028553,  0.5267363 , -0.80732656, -0.8823532 ,
       -1.0776412 ,  1.2457514 ,  0.06478164,  0.05408093, -0.54844224,
        0.1208388 , -0.78628993,  0.5823071 ], dtype=float32)>}, <tf.Tensor: shape=(3862,), dtype=float32, numpy=array([0., 0., 0., ..., 0., 0., 0.], dtype=float32)>)

Fit the model

We can then use the fit method with the train_dataset that we created above.

model.fit(train_dataset.batch(32), epochs=3)
Epoch 1/3
1294/1294 [==============================] - 55s 42ms/step - loss: 0.4018 - binary_accuracy: 0.8924
Epoch 2/3
1294/1294 [==============================] - 55s 43ms/step - loss: 0.1287 - binary_accuracy: 0.9953
Epoch 3/3
1294/1294 [==============================] - 54s 42ms/step - loss: 0.0399 - binary_accuracy: 0.9994
<keras.callbacks.History at 0x17eb4c640>