Create TensorFlow Dataset from TFRecord files
Prototyping with YouTube 8M video-level features
Play around with Youtube 8M video-level dataset. The goal of this section is to create a tf.data.Dataset from a set of .tfrecords
file.
This code works with tensorflow 2.6.0.
import tensorflow as tf
print(tf.__version__)
The sample data were downloaded with
curl data.yt8m.org/download.py | shard=1,100 partition=2/video/train mirror=us python
per instruction available on the YouTube 8M dataset download page.
Import libraries and specify data_folder
.
import os
import glob
from tensorflow.data import TFRecordDataset
data_folder = "/home/default/video"
List .tfrecord
files to be loaded.
filenames = glob.glob(os.path.join(data_folder, "*.tfrecord"))
print(filenames[0]); print(filenames[-1])
Load .tfrecord
files into a raw (not parsed) dataset.
raw_dataset = tf.data.TFRecordDataset(filenames)
Create a funtion to parse the raw data. According to YouTube 8M dataset download section, the video-level data are stored as tensorflow.Example protocol buffers with the following text format:
features: {
feature: {
key : "id"
value: {
bytes_list: {
value: (Video id)
}
}
}
feature: {
key : "labels"
value: {
int64_list: {
value: [1, 522, 11, 172] # label list
}
}
}
feature: {
# Average of all 'rgb' features for the video
key : "mean_rgb"
value: {
float_list: {
value: [1024 float features]
}
}
}
feature: {
# Average of all 'audio' features for the video
key : "mean_audio"
value: {
float_list: {
value: [128 float features]
}
}
}
}
# Create a description of the features.
feature_description = {
'id': tf.io.FixedLenFeature([1], tf.string, default_value=''),
'labels': tf.io.FixedLenSequenceFeature([], tf.int64, default_value=0, allow_missing=True),
'mean_audio': tf.io.FixedLenFeature([128], tf.float32, default_value=[0.0] * 128),
'mean_rgb': tf.io.FixedLenFeature([1024], tf.float32, default_value=[0.0] * 1024),
}
def _parse_function(example_proto):
# Parse the input `tf.train.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
for parsed_record in parsed_dataset.take(1):
print(repr(parsed_record))