Use PyReader to read training and test data

Besides Python Reader, we provide PyReader. The performance of PyReader is better than 同步数据读取 , because the process of loading data is asynchronous with the process of training model when PyReader is in use. And PyReader can coordinate with double_buffer_reader to improve the performance of reading data. What’s more, double_buffer_reader can achieve the transformation from CPU Tensor to GPU Tensor, which improve the efficiency of reading data to some extent.

Create PyReader Object

You can create PyReader object as follows:

import paddle.fluid as fluid

py_reader = fluid.layers.py_reader(capacity=64,
                                   shapes=[(-1,784), (-1,1)],
                                   dtypes=['float32', 'int64'],
                                   name='py_reader',
                                   use_double_buffer=True)

In the code, capacity is buffer size of PyReader; shapes is the size of parameters in the batch (such as image and label in picture classification task); dtypes is data type of parameters in the batch; name is name of PyReader instance; use_double_buffer is True by default, which means double_buffer_reader is used.

Attention: If you want to create multiple PyReader objects(such as two different PyReader in training and inference period respectively), you have to appoint different names for different PyReader objects,since PaddlePaddle uses different names to distinguish different variables, and Program.clone() (reference to api_fluid_Program_clone )can’t copy PyReader objects.

import paddle.fluid as fluid

train_py_reader = fluid.layers.py_reader(capacity=64,
                                         shapes=[(-1,784), (-1,1)],
                                         dtypes=['float32', 'int64'],
                                         name='train',
                                         use_double_buffer=True)

test_py_reader = fluid.layers.py_reader(capacity=64,
                                        shapes=[(-1,3,224,224), (-1,1)],
                                        dtypes=['float32', 'int64'],
                                        name='test',
                                        use_double_buffer=True)

While using PyReader, if you need to share the model parameters of training and test periods, you can use fluid.unique_name.guard() . Notes: Paddle use different names to distinguish different variables, and the names are generated by the counter in unique_name module. By the way, the counts rise by one every time a variable name is generated. fluid.unique_name.guard() aims to reset the counter in unique_name module, in order to ensure that the variable names are the same when calling fluid.unique_name.guard() repeatedly, so that parameters can be shared.

An example of configuring networks during the training and test periods by PyReader is as follows:

import paddle
import paddle.fluid as fluid
import paddle.dataset.mnist as mnist

import numpy

def network(is_train):
    # Create py_reader object and give different names
    # when is_train = True and is_train = False
    reader = fluid.layers.py_reader(
        capacity=10,
        shapes=((-1, 784), (-1, 1)),
        dtypes=('float32', 'int64'),
        name="train_reader" if is_train else "test_reader",
        use_double_buffer=True)

    # Use read_file() method to read out the data from py_reader
    img, label = fluid.layers.read_file(reader)
    ...
    # Here, we omitted the definition of loss of the model
    return loss , reader

# Create main program and startup program for training
train_prog = fluid.Program()
train_startup = fluid.Program()

with fluid.program_guard(train_prog, train_startup):
    # Use fluid.unique_name.guard() to share parameters with test network
    with fluid.unique_name.guard():
        train_loss, train_reader = network(True)
        adam = fluid.optimizer.Adam(learning_rate=0.01)
        adam.minimize(train_loss)

# Create main program and startup program for testing
test_prog = fluid.Program()
test_startup = fluid.Program()
with fluid.program_guard(test_prog, test_startup):
    # Use fluid.unique_name.guard() to share parameters with train network
    with fluid.unique_name.guard():
        test_loss, test_reader = network(False)

Configure data source of PyReader objects

PyReader object sets the data source by decorate_paddle_reader() or decorate_tensor_provider() decorate_paddle_reader() and decorate_tensor_provider() both receive the Python generator generator as parameters. generator generates a batch of data every time by yield ways inside.

The differences of decorate_paddle_reader() and decorate_tensor_provider() ways are:

  • generator of decorate_paddle_reader() should return data of Numpy Array type, but generator of decorate_tensor_provider() should return LoDTensor type.

  • decorate_tensor_provider() requires that the returned data type and size of LoDTensor of generator have to match the appointed dtypes and shapes parameters while configuring py_reader, but decorate_paddle_reader() doesn’t have the requirements, since the data type and size can transform inside.

Specific ways are as follows:

 import paddle.fluid as fluid
 import numpy as np

 BATCH_SIZE = 32

 # Case 1: Use decorate_paddle_reader() method to set the data source of py_reader
 # The generator yields Numpy-typed batched data
 def fake_random_numpy_reader():
     image = np.random.random(size=(BATCH_SIZE, 784))
     label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9)
     yield image, label

 py_reader1 = fluid.layers.py_reader(
     capacity=10,
     shapes=((-1, 784), (-1, 1)),
     dtypes=('float32', 'int64'),
     name='py_reader1',
     use_double_buffer=True)

py_reader1.decorate_paddle_reader(fake_random_reader)


# Case 2: Use decorate_tensor_provider() method to set the data source of py_reader
 # The generator yields Tensor-typed batched data
 def fake_random_tensor_provider():
     image = np.random.random(size=(BATCH_SIZE, 784)).astype('float32')
     label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9).astype('int64')

     image_tensor = fluid.LoDTensor()
     image_tensor.set(image, fluid.CPUPlace())

     label_tensor = fluid.LoDTensor()
     label_tensor.set(label, fluid.CPUPlace())
     yield image_tensor, label_tensor

 py_reader2 = fluid.layers.py_reader(
     capacity=10,
     shapes=((-1, 784), (-1, 1)),
     dtypes=('float32', 'int64'),
     name='py_reader2',
     use_double_buffer=True)

 py_reader2.decorate_tensor_provider(fake_random_tensor_provider)

example usage:

import paddle.batch
import paddle.fluid as fluid
import numpy as np

BATCH_SIZE = 32

# Case 1: Use decorate_paddle_reader() method to set the data source of py_reader
# The generator yields Numpy-typed batched data
def fake_random_numpy_reader():
    image = np.random.random(size=(784, ))
    label = np.random.random_integers(size=(1, ), low=0, high=9)
    yield image, label

py_reader1 = fluid.layers.py_reader(
    capacity=10,
    shapes=((-1, 784), (-1, 1)),
    dtypes=('float32', 'int64'),
    name='py_reader1',
    use_double_buffer=True)

py_reader1.decorate_paddle_reader(paddle.batch(fake_random_numpy_reader, batch_size=BATCH_SIZE))


# Case 2: Use decorate_tensor_provider() method to set the data source of py_reader
# The generator yields Tensor-typed batched data
def fake_random_tensor_provider():
    image = np.random.random(size=(BATCH_SIZE, 784)).astype('float32')
    label = np.random.random_integers(size=(BATCH_SIZE, 1), low=0, high=9).astype('int64')
    yield image_tensor, label_tensor

py_reader2 = fluid.layers.py_reader(
    capacity=10,
    shapes=((-1, 784), (-1, 1)),
    dtypes=('float32', 'int64'),
    name='py_reader2',
    use_double_buffer=True)

py_reader2.decorate_tensor_provider(fake_random_tensor_provider)

Train and test model with PyReader

Examples by using PyReader to train models and test are as follows:

import paddle
 import paddle.fluid as fluid
 import paddle.dataset.mnist as mnist
 import six

 def network(is_train):
     # Create py_reader object and give different names
     # when is_train = True and is_train = False
     reader = fluid.layers.py_reader(
         capacity=10,
         shapes=((-1, 784), (-1, 1)),
         dtypes=('float32', 'int64'),
         name="train_reader" if is_train else "test_reader",
         use_double_buffer=True)
     img, label = fluid.layers.read_file(reader)
     ...
     # Here, we omitted the definition of loss of the model
     return loss , reader

 # Create main program and startup program for training
 train_prog = fluid.Program()
 train_startup = fluid.Program()

 # Define train network
 with fluid.program_guard(train_prog, train_startup):
     # Use fluid.unique_name.guard() to share parameters with test network
     with fluid.unique_name.guard():
         train_loss, train_reader = network(True)
         adam = fluid.optimizer.Adam(learning_rate=0.01)
         adam.minimize(train_loss)

 # Create main program and startup program for testing
 test_prog = fluid.Program()
 test_startup = fluid.Program()

 # Define test network
 with fluid.program_guard(test_prog, test_startup):
     # Use fluid.unique_name.guard() to share parameters with train network
     with fluid.unique_name.guard():
         test_loss, test_reader = network(False)


place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)

# Run startup program
exe.run(train_startup)
exe.run(test_startup)

# Compile programs
train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(loss_name=train_loss.name)
test_prog = fluid.CompiledProgram(test_prog).with_data_parallel(share_vars_from=train_prog)

# Set the data source of py_reader using decorate_paddle_reader() method
train_reader.decorate_paddle_reader(
    paddle.reader.shuffle(paddle.batch(mnist.train(), 512), buf_size=8192))

test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))

for epoch_id in six.moves.range(10):
    train_reader.start()
    try:
        while True:
            loss = exe.run(program=train_prog, fetch_list=[train_loss])
            print 'train_loss', loss
    except fluid.core.EOFException:
        print 'End of epoch', epoch_id
        train_reader.reset()

    test_reader.start()
    try:
        while True:
            loss = exe.run(program=test_prog, fetch_list=[test_loss])
            print 'test loss', loss
    except fluid.core.EOFException:
        print 'End of testing'
        test_reader.reset()

Specific steps are as follows:

  1. Before the start of every epoch, call start() to invoke PyReader;

  2. At the end of every epoch, read_file throws exception fluid.core.EOFException . Call reset() after catching up exception to reset the state of PyReader in order to start next epoch.