Data Modeling [04]: Keras and Tensorflow

19 minute read

Published:

Workflow of Keras (Tensorflow). Tutorial from Keras on the Kaggle Cats vs Dogs binary classification dataset.


Datasets Loading

Download and unzip dataset.

!curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip
!unzip -q kagglecatsanddogs_3367a.zip

!ls
kagglecatsanddogs_3367a.zip
kearas.py
MSR-LA - 3467.docx
PetImages
readme[1].txt

!ls PetImages
Cat
Dog

We now have a PetImages folder containing two subfolders Cat and Dog. Next we delete the corrupted images:

import os

num_skipped = 0
for folder_name in ("Cat", "Dog"):
    folder_path = os.path.join("PetImages", folder_name)
    for fname in os.listdir(folder_path):
        fpath = os.path.join(folder_path, fname)
        try:
            fobj = open(fpath, "rb")  # read mode 'r' with binary I/O 'b'.
            is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
            # JPG files contain the string "JFIF" at the beginning of the file, encoded as bytes
            # fobj.peek(10) returns the first 10 baytes of the file
        finally:
            fobj.close()

        if not is_jfif:
            num_skipped += 1
            # Delete corrupted image
            os.remove(fpath)

print("Deleted %d images" % num_skipped)
Deleted 1578 images

Training & Testing Dataset Generation

image_size = (180, 180)
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="training",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="validation",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
Found 23422 files belonging to 2 classes.
Using 18738 files for training.

Found 23422 files belonging to 2 classes.
Using 4684 files for validation.

train_ds
Out[11]: <BatchDataset element_spec=(TensorSpec(shape=(None, 180, 180, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>

The type of the data in the dataset is tensor tuple. The first element is a 3D tensor storing the information of the color images. The second element stores the labels with ‘1’ being ‘dog’ and ‘0’ for ‘cat’.

type(next(iter(train_ds)))
Out[16]: tuple

next(iter(train_ds))
Out[15]: 
(<tf.Tensor: shape=(32, 180, 180, 3), dtype=float32, numpy=
 array([[[[160.11111  , 155.11111  , 126.111115 ],
          [159.33333  , 154.33333  , 125.333336 ],
          [158.74692  , 153.74692  , 124.74692  ],
          ...,
          [158.43056  , 150.43056  , 127.43056  ],
          [155.76387  , 148.4305   , 125.09718  ],
          [156.43056  , 150.43056  , 126.43056  ]],
 
         [[161.       , 156.       , 127.       ],
          [160.09721  , 155.09721  , 126.09722  ],
          [159.       , 154.       , 125.       ],
          ...,
          [160.29167  , 152.29167  , 129.29167  ],
          [157.19446  , 149.86108  , 126.52776  ],
          [157.       , 151.       , 127.       ]],
 
         [[162.       , 157.       , 128.       ],
          [161.       , 156.       , 127.       ],
          [160.       , 155.       , 126.       ],
          ...,
          [161.47067  , 153.47067  , 130.47067  ],
          [158.56483  , 151.23146  , 127.89815  ],
          [158.       , 152.       , 128.       ]],
 
         ...,
 
         [[189.73611  , 180.73611  , 149.43057  ],
          [187.05092  , 178.53702  , 153.49077  ],
          [189.58334  , 182.29167  , 165.83492  ],
          ...,
          [241.69446  , 238.       , 239.       ],
          [241.31023  , 237.61577  , 238.61577  ],
          [240.54169  , 236.84723  , 237.84723  ]],
 
         [[187.4722   , 176.34717  , 149.4722   ],
          [190.45836  , 181.08336  , 161.8195   ],
          [205.1992   , 196.88437  , 186.8196   ],
          ...,
          [241.29166  , 237.29166  , 236.87497  ],
          [240.95834  , 236.95834  , 236.54166  ],
          [240.29166  , 236.29166  , 235.87497  ]],
 
         [[188.0278   , 176.0278   , 153.16672  ],
          [197.42134  , 187.18523  , 172.70378  ],
          [224.15906  , 215.79486  , 210.6313   ],
          ...,
          [238.86108  , 234.86108  , 233.86108  ],
          [238.71758  , 234.71758  , 233.71758  ],
          [238.43054  , 234.43054  , 233.43054  ]]],
 
 
        [[[ 12.382964 ,  13.017286 ,  16.294075 ],
          [ 31.420378 ,  32.509266 ,  34.14667  ],
          [ 31.04445  ,  32.04445  ,  30.772844 ],
          ...,
          [ 30.949389 ,  31.99692  ,  30.582722 ],
          [ 30.335556 ,  32.30223  ,  30.213333 ],
          [ 12.977433 ,  14.572387 ,  12.343099 ]],
 
         [[ 82.092964 ,  82.66371  ,  84.87778  ],
          [251.55112  , 253.16333  , 251.84222  ],
          [252.10556  , 253.69629  , 249.37407  ],
          ...,
          [253.50742  , 254.75371  , 250.14075  ],
          [247.13661  , 249.10329  , 245.42772  ],
          [ 81.839966 ,  82.839966 ,  81.51776  ]],
 
         [[ 82.108025 ,  83.108025 ,  82.82716  ],
          [254.       , 255.       , 251.84444  ],
          [143.85802  , 146.70679  , 139.26234  ],
          ...,
          [129.9105   , 132.13272  , 124.688286 ],
          [250.61664  , 251.99257  , 247.05182  ],
          [ 81.839966 ,  83.25417  ,  81.123955 ]],
 
         ...,
 
         [[ 82.44257  ,  82.44257  ,  82.44257  ],
          [253.39626  , 253.39626  , 251.46292  ],
          [110.7682   , 110.7682   , 108.7682   ],
          ...,
          [ 98.0678   ,  99.305466 ,  97.203575 ],
          [252.87965  , 254.87965  , 251.87965  ],
          [ 81.18265  ,  83.18265  ,  80.44624  ]],
 
         [[ 81.81224  ,  81.81224  ,  81.81224  ],
          [252.32114  , 252.32114  , 250.3878   ],
          [254.57224  , 254.57224  , 251.79814  ],
          ...,
          [253.0815   , 254.55186  , 251.63336  ],
          [253.5967   , 254.96335  , 252.6389   ],
          [ 81.36742  ,  82.73407  ,  81.45335  ]],
 
         [[ 14.206032 ,  14.206032 ,  14.206032 ],
          [ 32.62375  ,  32.62375  ,  30.690414 ],
          [ 32.869766 ,  33.406204 ,  31.793259 ],
          ...,
          [ 31.529648 ,  33.1136   ,  30.871002 ],
          [ 34.434906 ,  36.283447 ,  33.472305 ],
          [ 11.786362 ,  13.381325 ,  12.4207   ]]],
 
 
        [[[ 23.       ,  15.       ,  12.       ],
          [ 23.       ,  15.       ,  12.       ],
          [ 23.       ,  15.       ,  12.       ],
          ...,
          [ 94.194496 , 100.194496 ,  98.194496 ],
          [ 88.87489  ,  95.87483  ,  94.208145 ],
          [ 88.23617  ,  97.23617  ,  96.23617  ]],
 
         [[ 23.       ,  15.       ,  12.       ],
          [ 23.       ,  15.       ,  12.       ],
          [ 23.       ,  15.       ,  12.       ],
          ...,
          [102.43056  , 108.43056  , 106.43056  ],
          [ 90.8332   ,  97.83314  ,  96.16645  ],
          [ 91.65283  , 100.65283  ,  99.65283  ]],
 
         [[ 23.       ,  15.       ,  12.       ],
          [ 23.       ,  15.       ,  12.       ],
          [ 23.       ,  15.       ,  12.       ],
          ...,
          [ 85.13889  ,  91.13889  ,  89.13889  ],
          [ 91.20825  ,  98.20819  ,  96.541504 ],
          [ 92.18058  , 101.18058  , 100.18058  ]],
 
         ...,
 
         [[ 22.88889  ,  23.88889  ,  17.88889  ],
          [ 23.666666 ,  24.666666 ,  18.666666 ],
          [ 24.444445 ,  25.444445 ,  19.444445 ],
          ...,
          [ 21.       ,  13.       ,  10.       ],
          [ 20.194443 ,  13.194383 ,  10.527757 ],
          [ 20.       ,  15.       ,  10.416687 ]],
 
         [[ 22.88889  ,  23.88889  ,  17.88889  ],
          [ 23.666666 ,  24.666666 ,  18.666666 ],
          [ 24.444445 ,  25.444445 ,  19.444445 ],
          ...,
          [ 21.375    ,  13.375    ,  10.375    ],
          [ 20.916695 ,  13.916634 ,  10.583321 ],
          [ 20.       ,  15.       ,   9.       ]],
 
         [[ 22.88889  ,  23.88889  ,  17.88889  ],
          [ 23.666666 ,  24.666666 ,  18.666666 ],
          [ 24.444445 ,  25.444445 ,  19.444445 ],
          ...,
          [ 22.       ,  14.       ,  11.       ],
          [ 21.638926 ,  14.6388645,  10.638926 ],
          [ 20.       ,  15.458313 ,   8.083374 ]]],
 
 
        ...,
 
 
        [[[179.45139  , 163.59723  , 147.13426  ],
          [168.74652  , 158.46529  , 143.15973  ],
          [153.13889  , 146.2604   , 131.50462  ],
          ...,
          [169.28354  ,  85.74996  ,  39.51964  ],
          [179.93735  ,  99.152596 ,  50.919983 ],
          [205.23584  , 124.38401  ,  64.75918  ]],
 
         [[174.07985  , 160.04514  , 146.65625  ],
          [163.1875   , 156.08333  , 139.40625  ],
          [149.75     , 144.       , 123.208336 ],
          ...,
          [207.09029  , 129.24649  ,  81.9722   ],
          [206.65623  , 133.60416  ,  87.14579  ],
          [205.53476  , 134.25346  ,  78.75018  ]],
 
         [[163.60301  , 154.29745  , 143.83333  ],
          [151.3646   , 147.00694  , 133.2639   ],
          [148.22223  , 141.32986  , 125.3588   ],
          ...,
          [194.9248   , 120.857666 ,  72.093765 ],
          [192.89922  , 120.54506  ,  75.61791  ],
          [200.35065  , 127.52542  ,  84.99767  ]],
 
         ...,
 
         [[252.2442   , 215.329    , 170.30817  ],
          [254.1771   , 221.55212  , 184.96884  ],
          [253.70262  , 170.11118  , 130.9874   ],
          ...,
          [254.32887  , 220.88106  , 183.56517  ],
          [242.63867  , 182.59323  , 139.37448  ],
          [238.10547  , 185.81163  , 140.59872  ]],
 
         [[251.78473  , 164.71875  , 122.5      ],
          [251.10417  , 165.64583  , 122.010414 ],
          [252.9618   , 173.84027  , 126.802086 ],
          ...,
          [241.7326   , 171.10411  , 122.78468  ],
          [247.79161  , 182.0416   , 133.06238  ],
          [248.2813   , 185.01738  , 140.53123  ]],
 
         [[253.21645  , 181.70467  , 127.34817  ],
          [252.8993   , 179.89238  , 128.44455  ],
          [253.0382   , 177.73972  , 128.71077  ],
          ...,
          [243.54869  , 178.37743  , 122.54984  ],
          [234.27773  , 164.17004  , 115.9755   ],
          [240.52554  , 172.87396  , 129.46648  ]]],
 
 
        [[[ 29.949074 ,  79.94907  , 138.94907  ],
          [ 30.847221 ,  80.84722  , 139.84723  ],
          [ 31.541666 ,  81.541664 , 140.54167  ],
          ...,
          [ 27.       ,  78.       , 141.       ],
          [ 27.208353 ,  76.541725 , 139.20842  ],
          [ 28.       ,  74.       , 136.       ]],
 
         [[ 30.625    ,  80.625    , 139.625    ],
          [ 31.875    ,  81.875    , 140.875    ],
          [ 32.625    ,  82.625    , 141.625    ],
          ...,
          [ 27.791672 ,  78.79167  , 141.79167  ],
          [ 27.875008 ,  77.20838  , 139.87506  ],
          [ 29.555573 ,  75.55557  , 137.55557  ]],
 
         [[ 32.708332 ,  82.708336 , 141.70833  ],
          [ 33.708332 ,  83.708336 , 142.70833  ],
          [ 34.708332 ,  84.708336 , 143.70833  ],
          ...,
          [ 28.837967 ,  79.83797  , 142.83797  ],
          [ 28.76386  ,  78.56947  , 141.23616  ],
          [ 28.708332 ,  76.708336 , 138.70833  ]],
 
         ...,
 
         [[142.96751  , 153.96751  , 156.55083  ],
          [151.70827  , 162.70827  , 167.09714  ],
          [155.49063  , 166.19897  , 173.07394  ],
          ...,
          [126.06967  , 137.75485  , 154.06046  ],
          [121.819405 , 134.15273  , 149.76382  ],
          [111.319305 , 124.319305 , 140.3193   ]],
 
         [[144.95833  , 155.95833  , 160.70833  ],
          [153.83333  , 164.58333  , 170.91667  ],
          [165.25     , 175.25     , 184.25     ],
          ...,
          [139.7777   , 151.98602  , 163.91661  ],
          [126.62484  , 139.08315  , 152.87479  ],
          [117.49997  , 130.49997  , 146.49997  ]],
 
         [[135.6944   , 146.1759   , 153.78233  ],
          [145.43053  , 155.43053  , 164.88884  ],
          [164.25917  , 174.25917  , 183.71748  ],
          ...,
          [133.21252  , 146.21252  , 154.75421  ],
          [128.98642  , 142.31973  , 152.65298  ],
          [111.6572   , 124.6572   , 140.65721  ]]],
 
 
        [[[191.825    , 171.825    , 110.825    ],
          [185.17375  , 164.32376  , 106.72375  ],
          [148.75     , 127.       ,  75.25     ],
          ...,
          [ 28.29375  ,  28.29375  ,  16.79375  ],
          [ 26.025005 ,  26.025005 ,  16.025005 ],
          [ 25.175    ,  25.175    ,  15.175    ]],
 
         [[191.       , 171.       , 110.       ],
          [185.05     , 164.2      , 106.6      ],
          [148.75     , 127.       ,  75.25     ],
          ...,
          [ 28.36875  ,  28.63125  ,  17.       ],
          [ 26.325006 ,  27.296259 ,  17.0075   ],
          [ 25.475    ,  26.       ,  17.05     ]],
 
         [[191.       , 171.       , 110.       ],
          [185.79375  , 165.6875   , 106.6      ],
          [149.625    , 128.09375  ,  75.90625  ],
          ...,
          [ 28.03125  ,  28.3125   ,  18.53125  ],
          [ 24.843761 ,  25.950012 ,  17.63126  ],
          [ 23.25     ,  24.25     ,  16.25     ]],
 
         ...,
 
         [[123.       , 101.875    ,  56.375    ],
          [130.54375  , 109.41875  ,  63.91875  ],
          [128.875    , 107.9375   ,  62.34375  ],
          ...,
          [203.125    , 183.4375   , 152.625    ],
          [199.48752  , 179.98752  , 147.07501  ],
          [198.       , 178.5      , 145.375    ]],
 
         [[130.95001  , 108.95001  ,  59.475006 ],
          [133.81879  , 111.81879  ,  62.34378  ],
          [143.96257  , 123.46257  ,  73.237564 ],
          ...,
          [192.81876  , 167.25     , 136.2375   ],
          [189.36124  , 165.07872  , 132.37497  ],
          [188.14996  , 164.67496  , 130.67496  ]],
 
         [[133.65     , 111.649994 ,  61.649994 ],
          [142.95747  , 120.95747  ,  70.95747  ],
          [159.69371  , 139.19371  ,  88.44371  ],
          ...,
          [197.09373  , 170.34373  , 139.59373  ],
          [189.53003  , 163.83002  , 130.43254  ],
          [184.175    , 160.175    , 124.52501  ]]]], dtype=float32)>,
 <tf.Tensor: shape=(32,), dtype=int32, numpy=
 array([0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1])>)

Visualization

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1): # take the first batch only
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        label = 'dog' if labels[i] else 'cat'
        plt.title(label)
        plt.axis("off")

drawing


Dataset Preprocessing

Firstly, it’s a good practice to artificially introduce sample diversity by applying random yet realistic transformations to the training images, such as random horizontal flipping or small random rotations.

data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
    ]
)

Secondly, we should standardize the pixel values to be in the ‘[0, 1]’ by using a Rescaling layer at the start of the model.

For GPU

inputs = keras.Input(shape=input_shape)
X = data_augmentation(inputs)
X = layers.Rescaling(1./255)(X)
  • With this option, data augmentation will happen on device, synchronously with the rest of the model execution.
  • Input samples will only be augmented during fit(), not when calling evaluate() or predict().

For CPU

augmented_train_ds = train_ds.map(
  lambda x, y: (data_augmentation(x, training=True), y))
  • With this option, data augmentation will happen on CPU asynchronously, and will be buffered before going into the model.

Model Building

Prefetching configuration:

train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)

Xception neural network:

def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    
    # Image augmentation block
    x = data_augmentation(inputs) # flip and rotation
    
    # Normalization
    x = layers.Rescaling(1.0 / 255)(x)
    
    # Entry block - 2 Conv layers
    x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    # Each Block: 2 Conv + 1 Pool + 1 Res
    for size in [128, 256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        # Separable convolutions consist of first performing a depthwise spatial convolution
        # followed by a pointwise convolution which mixes the resulting output channels.
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        activation = "sigmoid"
        units = 1
    else:
        activation = "softmax"
        units = num_classes

    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(units, activation=activation)(x)
    
    return keras.Model(inputs, outputs)

model = make_model(input_shape=image_size + (3,), num_classes=2)
dot_img_file = 'model_xception.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

Model summary:

model.summary()
Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                                 Output Shape           Param #     Connected to                     
==================================================================================================
 input_5 (InputLayer)                         [(None, 180, 180, 3)]  0           []                                                                                                                                                                                                                              
 sequential_4 (Sequential)                    (None, 180, 180, 3)    0           ['input_5[0][0]']                
                                                                                                  
 rescaling_4 (Rescaling)                      (None, 180, 180, 3)    0           ['sequential_4[0][0]']           
                                                                                                  
 conv2d_24 (Conv2D)                           (None, 90, 90, 32)     896         ['rescaling_4[0][0]']            
                                                                                                  
 batch_normalization_44 (BatchNormalization)  (None, 90, 90, 32)     128         ['conv2d_24[0][0]']                                                                                                                                                                                                   
 activation_44 (Activation)                   (None, 90, 90, 32)     0           ['batch_normalization_44[0][0]']                                                                                                
 conv2d_25 (Conv2D)                           (None, 90, 90, 64)     18496       ['activation_44[0][0]']          
                                                                                                  
 batch_normalization_45 (BatchNormalization)  (None, 90, 90, 64)     256         ['conv2d_25[0][0]']                                                                                                                                                                                                   
 activation_45 (Activation)                   (None, 90, 90, 64)     0           ['batch_normalization_45[0][0]'] 
                                                                                                  
 activation_46 (Activation)                   (None, 90, 90, 64)     0           ['activation_45[0][0]']          
                                                                                                  
 separable_conv2d_36 (SeparableConv2D)        (None, 90, 90, 128)    8896        ['activation_46[0][0]']          
                                                                                          
                                                                                                  
 batch_normalization_46 (BatchNormalization)  (None, 90, 90, 128)    512         ['separable_conv2d_36[0][0]']    
                                                                                     
                                                                                                  
 activation_47 (Activation)                   (None, 90, 90, 128)    0           ['batch_normalization_46[0][0]'] 
                                                                                                  
 separable_conv2d_37 (SeparableConv2D)        (None, 90, 90, 128)    17664       ['activation_47[0][0]']                                                                                                                                                                                                     
 batch_normalization_47 (BatchNormalization)  (None, 90, 90, 128)    512         ['separable_conv2d_37[0][0]']                                                                                                                                                                                          
 max_pooling2d_16 (MaxPooling2D)              (None, 45, 45, 128)    0           ['batch_normalization_47[0][0]']                                                                                                                                                                                                 
 conv2d_26 (Conv2D)                           (None, 45, 45, 128)    8320        ['activation_45[0][0]']          
                                                                                                  
 add_16 (Add)                                 (None, 45, 45, 128)    0           ['max_pooling2d_16[0][0]', 'conv2d_26[0][0]']              
                                                                                                  
 activation_48 (Activation)                   (None, 45, 45, 128)    0           ['add_16[0][0]']                 
                                                                                                  
 separable_conv2d_38 (SeparableConv2D)        (None, 45, 45, 256)    34176       ['activation_48[0][0]']                                                                                                                                                                                                    
 batch_normalization_48 (BatchNormalization)  (None, 45, 45, 256)    1024        ['separable_conv2d_38[0][0]']                                                                                                                                                                                           
 activation_49 (Activation)                   (None, 45, 45, 256)    0           ['batch_normalization_48[0][0]'] 
                                                                                                  
 separable_conv2d_39 (SeparableConv2D)        (None, 45, 45, 256)    68096       ['activation_49[0][0]']                                                                                                                                                                                                    
 batch_normalization_49 (BatchNormalization)  (None, 45, 45, 256)    1024        ['separable_conv2d_39[0][0]']                                                                                                                                                                                          
 max_pooling2d_17 (MaxPooling2D)              (None, 23, 23, 256)    0           ['batch_normalization_49[0][0]']                                                                                                                                                                                                  
 conv2d_27 (Conv2D)                           (None, 23, 23, 256)    33024       ['add_16[0][0]']                 
                                                                                                  
 add_17 (Add)                                 (None, 23, 23, 256)    0           ['max_pooling2d_17[0][0]', 'conv2d_27[0][0]']              
                                                                                                  
 activation_50 (Activation)                   (None, 23, 23, 256)    0           ['add_17[0][0]']                 
                                                                                                  
 separable_conv2d_40 (SeparableConv2D)        (None, 23, 23, 512)    133888      ['activation_50[0][0]']                                                                                                                                                                                                      
 batch_normalization_50 (BatchNormalization)  (None, 23, 23, 512)    2048        ['separable_conv2d_40[0][0]']                                                                                                                                                                                          
 activation_51 (Activation)                   (None, 23, 23, 512)    0           ['batch_normalization_50[0][0]'] 
                                                                                                  
 separable_conv2d_41 (SeparableConv2D)        (None, 23, 23, 512)    267264      ['activation_51[0][0]']                                                                                                                                                                                                       
 batch_normalization_51 (BatchNormalization)  (None, 23, 23, 512)    2048        ['separable_conv2d_41[0][0]']                                                                                                                                                                                        
 max_pooling2d_18 (MaxPooling2D)              (None, 12, 12, 512)    0           ['batch_normalization_51[0][0]']                                                                                                                                                                                                   
 conv2d_28 (Conv2D)                           (None, 12, 12, 512)    131584      ['add_17[0][0]']                 
                                                                                                  
 add_18 (Add)                                 (None, 12, 12, 512)    0           ['max_pooling2d_18[0][0]', 'conv2d_28[0][0]']              
                                                                                                  
 activation_52 (Activation)                   (None, 12, 12, 512)    0           ['add_18[0][0]']                 
                                                                                                  
 separable_conv2d_42 (SeparableConv2D)        (None, 12, 12, 728)    378072      ['activation_52[0][0]']                                                                                                                                                                                                       
 batch_normalization_52 (BatchNormalization)  (None, 12, 12, 728)    2912        ['separable_conv2d_42[0][0]']                                                                                                                                                                                        
 activation_53 (Activation)                   (None, 12, 12, 728)    0           ['batch_normalization_52[0][0]'] 
                                                                                                  
 separable_conv2d_43 (SeparableConv2D)        (None, 12, 12, 728)    537264      ['activation_53[0][0]']                                                                                                                                                                                                     
 batch_normalization_53 (BatchNormalization)  (None, 12, 12, 728)    2912        ['separable_conv2d_43[0][0]']                                                                                                                                                                                        
 max_pooling2d_19 (MaxPooling2D)              (None, 6, 6, 728)      0           ['batch_normalization_53[0][0]']                                                                                                                                                                                               
 conv2d_29 (Conv2D)                           (None, 6, 6, 728)      373464      ['add_18[0][0]']                 
                                                                                                  
 add_19 (Add)                                 (None, 6, 6, 728)      0           ['max_pooling2d_19[0][0]', 'conv2d_29[0][0]']              
                                                                                                  
 separable_conv2d_44 (SeparableConv2D)        (None, 6, 6, 1024)     753048      ['add_19[0][0]']                                                                                                                                                                                                             
 batch_normalization_54 (BatchNormalization)  (None, 6, 6, 1024)     4096        ['separable_conv2d_44[0][0]']                                                                                                                                                                                        
 activation_54 (Activation)                   (None, 6, 6, 1024)     0           ['batch_normalization_54[0][0]'] 
                                                                                                  
 global_average_pooling2d_4 (GlobalAveragePooling2D)  (None, 1024)   0           ['activation_54[0][0]']                                                                                                                                                                                       
 dropout_4 (Dropout)                          (None, 1024)           0           ['global_average_pooling2d_4[0][0]']                              
                                                                                                  
 dense_4 (Dense)                              (None, 1)              1025        ['dropout_4[0][0]']              
                                                                                                  
==================================================================================================
Total params: 2,782,649
Trainable params: 2,773,913
Non-trainable params: 8,736
__________________________________________________________________________________________________

drawing


Training

epochs = 50

callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5"),
]
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
history = model.fit(
    train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
)

It took about 7645.4 seconds to finish 50 epochs, and we obtain an accuracy of 96.3% on the validation dataset.

Epoch 50/50
loss: 0.0495 - accuracy: 0.9814 - val_loss: 0.1077 - val_accuracy: 0.9633
Total time cost: 7645.395040399999 s

Plot the results:

figure = plt.figure()
ax = figure.add_subplot(1,1,1)
ax.plot(history.history['loss'],linewidth=2,label='train_loss')
ax.plot(history.history['accuracy'],'--',linewidth=2,label='train accuracy')
ax.plot(history.history['val_loss'],'-.',linewidth=2,label='validation loss')
ax.plot(history.history['val_accuracy'],':',linewidth=2,label='validation accuracy')
ax.set_xlabel('epochs')
ax.set_ylabel('loss/accuracy')
ax.legend()

drawing


Testing

img = keras.preprocessing.image.load_img(
    "PetImages/Cat/6779.jpg", target_size=image_size
)
plt.imshow(img)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create batch axis

predictions = model.predict(img_array)
score = predictions[0]
print(
    "This image is %.2f percent cat and %.2f percent dog."
    % (100 * (1 - score), 100 * score)
)

Result:

This image is 96.00 percent cat and 4.00 percent dog.

drawing


Complete Code

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os

# =============================================================================
# Load Dataset
# =============================================================================
# =============================================================================
# !curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip
# 
# !unzip -q kagglecatsanddogs_3367a.zip
# !ls
# !ls PetImages
# =============================================================================
num_skipped = 0
for folder_name in ("Cat", "Dog"):
    folder_path = os.path.join("PetImages", folder_name)
    for fname in os.listdir(folder_path):
        fpath = os.path.join(folder_path, fname)
        try:
            fobj = open(fpath, "rb")
            is_jfif = tf.compat.as_bytes("JFIF") in fobj.peek(10)
        finally:
            fobj.close()

        if not is_jfif:
            num_skipped += 1
            # Delete corrupted image
            os.remove(fpath)

print("Deleted %d images" % num_skipped)

# =============================================================================
# Training and Validation Dataset
# =============================================================================
image_size = (180, 180)
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="training",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="validation",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)

# =============================================================================
data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
    ]
)

# =============================================================================
# Model
# =============================================================================
def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    
    # Image augmentation block
    x = data_augmentation(inputs) # flip and rotation

    # Entry block
    x = layers.Rescaling(1.0 / 255)(x)
    x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [128, 256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        activation = "sigmoid"
        units = 1
    else:
        activation = "softmax"
        units = num_classes

    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(units, activation=activation)(x)
    
    return keras.Model(inputs, outputs)

model = make_model(input_shape=image_size + (3,), num_classes=2)

# =============================================================================
# Train
# =============================================================================
epochs = 50

callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.h5"),
]
model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
history = model.fit(
    train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
)

# =============================================================================
# Test
# =============================================================================
img = keras.preprocessing.image.load_img(
    "PetImages/Cat/6779.jpg", target_size=image_size
)
plt.imshow(img)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create batch axis

predictions = model.predict(img_array)
score = predictions[0]
print(
    "This image is %.2f percent cat and %.2f percent dog."
    % (100 * (1 - score), 100 * score)
)

Cuda Configuaration

Using cuda, it took about 156 s to run a single epoch:

 98/586 [====>.........................] - ETA: 2:01 - loss: 0.7277 - accuracy: 0.5842 
142/586 [======>.......................] - ETA: 1:50 - loss: 0.7136 - accuracy: 0.5900
293/586 [==============>...............] - ETA: 1:13 - loss: 0.6683 - accuracy: 0.6272
325/586 [===============>..............] - ETA: 1:05 - loss: 0.6631 - accuracy: 0.6319
464/586 [======================>.......] - ETA: 30s - loss: 0.6369 - accuracy: 0.6544 
544/586 [==========================>...] - ETA: 10s - loss: 0.6276 - accuracy: 0.6612
586/586 [==============================] - ETA: 0s - loss: 0.6202 - accuracy: 0.6674 

586/586 [==============================] - 156s 264ms/step - loss: 0.6202 - accuracy: 0.6674 - val_loss: 0.7241 - val_accuracy: 0.6008

Comments