def download(directory, filename):
"""Download a file from the MNIST dataset if not already done."""
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
_, zipped_filepath = tempfile.mkstemp(suffix='.gz')
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
tf.gfile.Copy(zipped_filepath, filepath)
os.remove(zipped_filepath)
return filepath
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(f):
"""
Extract the images into a 4D uint8 numpy array.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' %
(magic, f.name))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
def extract_labels(f, one_hot=False, num_classes=10):
"""
Extract the labels into a 1D uint8 numpy array.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST label file: %s' %
(magic, f.name))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
return labels
def load_dataset(directory, images_file, labels_file):
"""Download and parse MNIST dataset."""
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)
with tf.gfile.Open(images_file, 'rb') as f:
images = extract_images(f)
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
with tf.gfile.Open(labels_file, 'rb') as f:
labels = extract_labels(f)
return images, labels
Preparing MNIST data for Distributed DL
This notebook uses MNIST as an example to show how to load image data to Spark TensorFrames and save to TFRecord files.
To run this notebook, you may edit the data downloading and TFRecord save locations.