Skip to content

Instantly share code, notes, and snippets.

@ceceshao1
Last active April 25, 2019 21:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ceceshao1/cb834ed819628244093a2b61408c86e0 to your computer and use it in GitHub Desktop.
Save ceceshao1/cb834ed819628244093a2b61408c86e0 to your computer and use it in GitHub Desktop.
Fine-tuning Pretrained InceptionV3 model
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using a pre-trained model with transfer learning: InceptionV3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Install fruit classification data \n",
"\n",
"If you have not already done so, you can use T4 to access the subset of fruit data from Google Open Images "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#import Comet before all other imports \n",
"from comet_ml import Experiment \n",
"import t4"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# install the open fruit data to the directory of your choice (set via dest = \"DIRECTORY_HERE\")\n",
"t4.Package.install(\n",
" \"quilt/open_fruit\", \n",
" registry=\"s3://quilt-example\", \n",
" dest=\"./data\") # set your own directory here "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the Experiment object"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"COMET INFO: old comet version (1.0.45) detected. current: 1.0.46 please update your comet lib with command: `pip install --no-cache-dir --upgrade comet_ml`\n",
"COMET INFO: Experiment is live on comet.ml https://www.comet.ml/ceceshao1/comet-quilt-example/a5ec4a7953a54bcd817e3b7aa9c11a48\n",
"\n"
]
}
],
"source": [
"# Define your Comet Experiment object (you'll need to pass in your API Key - here we set the API as an environment variable)\n",
"experiment = Experiment(project_name=\"comet-quilt-example\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confirm GPU usage and Imports "
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"# confirm Keras is running on GPU \n",
"# If you are not using GPUs, skip this cell \n",
"import tensorflow as tf\n",
"sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# If you are not using GPUs, skip this cell \n",
"K.tensorflow_backend._get_available_gpus()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"from keras import applications\n",
"from keras.applications.inception_v3 import InceptionV3\n",
"from keras.preprocessing import image\n",
"from keras.models import Model, Sequential\n",
"from keras.layers import Dense, GlobalAveragePooling2D, Dropout, Flatten, Dense\n",
"from keras import backend as K\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras import optimizers\n",
"from keras.callbacks import EarlyStopping"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data Pre-processing and Parameters"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# dimensions of our images\n",
"img_width, img_height = 150, 150\n",
"\n",
"# set parameters\n",
"batch_size = 16\n",
"num_classes = 16\n",
"epochs = 50\n",
"activation = 'relu'\n",
"min_delta=0\n",
"patience=4\n",
"dropout=0.2\n",
"lr=0.0001\n",
"\n",
"train_samples = 27593\n",
"validation_samples = 6889"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"params={'batch_size':batch_size,\n",
" 'num_classes':num_classes,\n",
" 'epochs':epochs,\n",
" 'min_delta':min_delta,\n",
" 'patience':patience,\n",
" 'learning_rate':lr,\n",
" 'dropout':dropout\n",
"}\n",
"\n",
"experiment.log_parameters(params) #log these parameters as a dictionary to Comet. Adjust parameters in the cell above"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 27593 images belonging to 16 classes.\n",
"Found 6889 images belonging to 16 classes.\n"
]
}
],
"source": [
"from keras.applications.inception_v3 import preprocess_input\n",
"\n",
"train_datagen = ImageDataGenerator(\n",
" preprocessing_function=preprocess_input,\n",
" rotation_range=40,\n",
" width_shift_range=0.2,\n",
" height_shift_range=0.2,\n",
" shear_range=0.2,\n",
" zoom_range=0.2,\n",
" horizontal_flip=True,\n",
" fill_mode='nearest',\n",
" validation_split=0.2 #set the validation split \n",
")\n",
"\n",
"test_datagen = ImageDataGenerator(\n",
" rescale=1/255\n",
")\n",
"\n",
"train_generator = train_datagen.flow_from_directory(\n",
" './data/quilt/open_fruit/images_cropped',\n",
" target_size=(150, 150),\n",
" shuffle=True,\n",
" seed=20,\n",
" batch_size = batch_size,\n",
" class_mode='categorical',\n",
" subset=\"training\"\n",
")\n",
"\n",
"validation_generator = train_datagen.flow_from_directory(\n",
" './data/quilt/open_fruit/images_cropped',\n",
" target_size=(150, 150),\n",
" seed=20,\n",
" batch_size=batch_size,\n",
" class_mode='categorical',\n",
" subset = \"validation\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a hash for your training data - https://www.comet.ml/docs/python-sdk/Experiment/#experimentlog_dataset_hash\n",
"Experiment.log_dataset_hash(train_generator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Building the base model\n",
"\n",
"We are using the InceptionV3 model that has been pre-trained on ImageNet weights. Since we are going to adjust the classification task for our fruit classes and fine-tune the model, we're going to denote `include_top=False`."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /Users/ceceliashao/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n"
]
}
],
"source": [
"# create the base pre-trained model\n",
"base_model = InceptionV3(weights='imagenet', include_top=False,input_shape=(150,150))\n",
"\n",
"# add a global spatial average pooling layer\n",
"x = base_model.output\n",
"x = GlobalAveragePooling2D()(x)\n",
"# let's add a fully-connected layer\n",
"x = Dense(1024, activation='relu')(x)\n",
"# and a logistic layer -- we have 16 classes for the fruits \n",
"predictions = Dense(16, activation='softmax')(x)\n",
"\n",
"# this is the model we will train\n",
"model = Model(inputs=base_model.input, outputs=predictions)\n",
"\n",
"# first: train only the top layers (which were randomly initialized)\n",
"# i.e. freeze all convolutional InceptionV3 layers\n",
"for layer in base_model.layers:\n",
" layer.trainable = False\n",
"\n",
"# compile the model (should be done *after* setting layers to non-trainable)\n",
"model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i,layer in enumerate(base_model.layers):\n",
" print(i,layer.name)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import pathlib\n",
"sample_size = len(list(pathlib.Path('./data/quilt/open_fruit/images_cropped').rglob('./*')))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# train the model on the new data for a few epochs\n",
"model.fit_generator(\n",
" train_generator,\n",
" steps_per_epoch=sample_size // batch_size,\n",
" epochs=10,\n",
" validation_data=validation_generator,\n",
" validation_steps=validation_samples // batch_size,\n",
" callbacks=[EarlyStopping(monitor='val_loss', min_delta=min_delta, patience=patience)]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point, the top layers are well trained and we can start fine-tuning convolutional layers from inception V3. We will freeze the bottom N layers and train the remaining top layers."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 input_1\n",
"1 conv2d_1\n",
"2 batch_normalization_1\n",
"3 activation_1\n",
"4 conv2d_2\n",
"5 batch_normalization_2\n",
"6 activation_2\n",
"7 conv2d_3\n",
"8 batch_normalization_3\n",
"9 activation_3\n",
"10 max_pooling2d_1\n",
"11 conv2d_4\n",
"12 batch_normalization_4\n",
"13 activation_4\n",
"14 conv2d_5\n",
"15 batch_normalization_5\n",
"16 activation_5\n",
"17 max_pooling2d_2\n",
"18 conv2d_9\n",
"19 batch_normalization_9\n",
"20 activation_9\n",
"21 conv2d_7\n",
"22 conv2d_10\n",
"23 batch_normalization_7\n",
"24 batch_normalization_10\n",
"25 activation_7\n",
"26 activation_10\n",
"27 average_pooling2d_1\n",
"28 conv2d_6\n",
"29 conv2d_8\n",
"30 conv2d_11\n",
"31 conv2d_12\n",
"32 batch_normalization_6\n",
"33 batch_normalization_8\n",
"34 batch_normalization_11\n",
"35 batch_normalization_12\n",
"36 activation_6\n",
"37 activation_8\n",
"38 activation_11\n",
"39 activation_12\n",
"40 mixed0\n",
"41 conv2d_16\n",
"42 batch_normalization_16\n",
"43 activation_16\n",
"44 conv2d_14\n",
"45 conv2d_17\n",
"46 batch_normalization_14\n",
"47 batch_normalization_17\n",
"48 activation_14\n",
"49 activation_17\n",
"50 average_pooling2d_2\n",
"51 conv2d_13\n",
"52 conv2d_15\n",
"53 conv2d_18\n",
"54 conv2d_19\n",
"55 batch_normalization_13\n",
"56 batch_normalization_15\n",
"57 batch_normalization_18\n",
"58 batch_normalization_19\n",
"59 activation_13\n",
"60 activation_15\n",
"61 activation_18\n",
"62 activation_19\n",
"63 mixed1\n",
"64 conv2d_23\n",
"65 batch_normalization_23\n",
"66 activation_23\n",
"67 conv2d_21\n",
"68 conv2d_24\n",
"69 batch_normalization_21\n",
"70 batch_normalization_24\n",
"71 activation_21\n",
"72 activation_24\n",
"73 average_pooling2d_3\n",
"74 conv2d_20\n",
"75 conv2d_22\n",
"76 conv2d_25\n",
"77 conv2d_26\n",
"78 batch_normalization_20\n",
"79 batch_normalization_22\n",
"80 batch_normalization_25\n",
"81 batch_normalization_26\n",
"82 activation_20\n",
"83 activation_22\n",
"84 activation_25\n",
"85 activation_26\n",
"86 mixed2\n",
"87 conv2d_28\n",
"88 batch_normalization_28\n",
"89 activation_28\n",
"90 conv2d_29\n",
"91 batch_normalization_29\n",
"92 activation_29\n",
"93 conv2d_27\n",
"94 conv2d_30\n",
"95 batch_normalization_27\n",
"96 batch_normalization_30\n",
"97 activation_27\n",
"98 activation_30\n",
"99 max_pooling2d_3\n",
"100 mixed3\n",
"101 conv2d_35\n",
"102 batch_normalization_35\n",
"103 activation_35\n",
"104 conv2d_36\n",
"105 batch_normalization_36\n",
"106 activation_36\n",
"107 conv2d_32\n",
"108 conv2d_37\n",
"109 batch_normalization_32\n",
"110 batch_normalization_37\n",
"111 activation_32\n",
"112 activation_37\n",
"113 conv2d_33\n",
"114 conv2d_38\n",
"115 batch_normalization_33\n",
"116 batch_normalization_38\n",
"117 activation_33\n",
"118 activation_38\n",
"119 average_pooling2d_4\n",
"120 conv2d_31\n",
"121 conv2d_34\n",
"122 conv2d_39\n",
"123 conv2d_40\n",
"124 batch_normalization_31\n",
"125 batch_normalization_34\n",
"126 batch_normalization_39\n",
"127 batch_normalization_40\n",
"128 activation_31\n",
"129 activation_34\n",
"130 activation_39\n",
"131 activation_40\n",
"132 mixed4\n",
"133 conv2d_45\n",
"134 batch_normalization_45\n",
"135 activation_45\n",
"136 conv2d_46\n",
"137 batch_normalization_46\n",
"138 activation_46\n",
"139 conv2d_42\n",
"140 conv2d_47\n",
"141 batch_normalization_42\n",
"142 batch_normalization_47\n",
"143 activation_42\n",
"144 activation_47\n",
"145 conv2d_43\n",
"146 conv2d_48\n",
"147 batch_normalization_43\n",
"148 batch_normalization_48\n",
"149 activation_43\n",
"150 activation_48\n",
"151 average_pooling2d_5\n",
"152 conv2d_41\n",
"153 conv2d_44\n",
"154 conv2d_49\n",
"155 conv2d_50\n",
"156 batch_normalization_41\n",
"157 batch_normalization_44\n",
"158 batch_normalization_49\n",
"159 batch_normalization_50\n",
"160 activation_41\n",
"161 activation_44\n",
"162 activation_49\n",
"163 activation_50\n",
"164 mixed5\n",
"165 conv2d_55\n",
"166 batch_normalization_55\n",
"167 activation_55\n",
"168 conv2d_56\n",
"169 batch_normalization_56\n",
"170 activation_56\n",
"171 conv2d_52\n",
"172 conv2d_57\n",
"173 batch_normalization_52\n",
"174 batch_normalization_57\n",
"175 activation_52\n",
"176 activation_57\n",
"177 conv2d_53\n",
"178 conv2d_58\n",
"179 batch_normalization_53\n",
"180 batch_normalization_58\n",
"181 activation_53\n",
"182 activation_58\n",
"183 average_pooling2d_6\n",
"184 conv2d_51\n",
"185 conv2d_54\n",
"186 conv2d_59\n",
"187 conv2d_60\n",
"188 batch_normalization_51\n",
"189 batch_normalization_54\n",
"190 batch_normalization_59\n",
"191 batch_normalization_60\n",
"192 activation_51\n",
"193 activation_54\n",
"194 activation_59\n",
"195 activation_60\n",
"196 mixed6\n",
"197 conv2d_65\n",
"198 batch_normalization_65\n",
"199 activation_65\n",
"200 conv2d_66\n",
"201 batch_normalization_66\n",
"202 activation_66\n",
"203 conv2d_62\n",
"204 conv2d_67\n",
"205 batch_normalization_62\n",
"206 batch_normalization_67\n",
"207 activation_62\n",
"208 activation_67\n",
"209 conv2d_63\n",
"210 conv2d_68\n",
"211 batch_normalization_63\n",
"212 batch_normalization_68\n",
"213 activation_63\n",
"214 activation_68\n",
"215 average_pooling2d_7\n",
"216 conv2d_61\n",
"217 conv2d_64\n",
"218 conv2d_69\n",
"219 conv2d_70\n",
"220 batch_normalization_61\n",
"221 batch_normalization_64\n",
"222 batch_normalization_69\n",
"223 batch_normalization_70\n",
"224 activation_61\n",
"225 activation_64\n",
"226 activation_69\n",
"227 activation_70\n",
"228 mixed7\n",
"229 conv2d_73\n",
"230 batch_normalization_73\n",
"231 activation_73\n",
"232 conv2d_74\n",
"233 batch_normalization_74\n",
"234 activation_74\n",
"235 conv2d_71\n",
"236 conv2d_75\n",
"237 batch_normalization_71\n",
"238 batch_normalization_75\n",
"239 activation_71\n",
"240 activation_75\n",
"241 conv2d_72\n",
"242 conv2d_76\n",
"243 batch_normalization_72\n",
"244 batch_normalization_76\n",
"245 activation_72\n",
"246 activation_76\n",
"247 max_pooling2d_4\n",
"248 mixed8\n",
"249 conv2d_81\n",
"250 batch_normalization_81\n",
"251 activation_81\n",
"252 conv2d_78\n",
"253 conv2d_82\n",
"254 batch_normalization_78\n",
"255 batch_normalization_82\n",
"256 activation_78\n",
"257 activation_82\n",
"258 conv2d_79\n",
"259 conv2d_80\n",
"260 conv2d_83\n",
"261 conv2d_84\n",
"262 average_pooling2d_8\n",
"263 conv2d_77\n",
"264 batch_normalization_79\n",
"265 batch_normalization_80\n",
"266 batch_normalization_83\n",
"267 batch_normalization_84\n",
"268 conv2d_85\n",
"269 batch_normalization_77\n",
"270 activation_79\n",
"271 activation_80\n",
"272 activation_83\n",
"273 activation_84\n",
"274 batch_normalization_85\n",
"275 activation_77\n",
"276 mixed9_0\n",
"277 concatenate_1\n",
"278 activation_85\n",
"279 mixed9\n",
"280 conv2d_90\n",
"281 batch_normalization_90\n",
"282 activation_90\n",
"283 conv2d_87\n",
"284 conv2d_91\n",
"285 batch_normalization_87\n",
"286 batch_normalization_91\n",
"287 activation_87\n",
"288 activation_91\n",
"289 conv2d_88\n",
"290 conv2d_89\n",
"291 conv2d_92\n",
"292 conv2d_93\n",
"293 average_pooling2d_9\n",
"294 conv2d_86\n",
"295 batch_normalization_88\n",
"296 batch_normalization_89\n",
"297 batch_normalization_92\n",
"298 batch_normalization_93\n",
"299 conv2d_94\n",
"300 batch_normalization_86\n",
"301 activation_88\n",
"302 activation_89\n",
"303 activation_92\n",
"304 activation_93\n",
"305 batch_normalization_94\n",
"306 activation_86\n",
"307 mixed9_1\n",
"308 concatenate_2\n",
"309 activation_94\n",
"310 mixed10\n"
]
}
],
"source": [
"# let's visualize layer names and layer indices to see how many layers\n",
"# we should freeze:\n",
"for i, layer in enumerate(base_model.layers):\n",
" print(i, layer.name)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 input_1\n",
"1 conv2d_1\n",
"2 batch_normalization_1\n",
"3 activation_1\n",
"4 conv2d_2\n",
"5 batch_normalization_2\n",
"6 activation_2\n",
"7 conv2d_3\n",
"8 batch_normalization_3\n",
"9 activation_3\n",
"10 max_pooling2d_1\n",
"11 conv2d_4\n",
"12 batch_normalization_4\n",
"13 activation_4\n",
"14 conv2d_5\n",
"15 batch_normalization_5\n",
"16 activation_5\n",
"17 max_pooling2d_2\n",
"18 conv2d_9\n",
"19 batch_normalization_9\n",
"20 activation_9\n",
"21 conv2d_7\n",
"22 conv2d_10\n",
"23 batch_normalization_7\n",
"24 batch_normalization_10\n",
"25 activation_7\n",
"26 activation_10\n",
"27 average_pooling2d_1\n",
"28 conv2d_6\n",
"29 conv2d_8\n",
"30 conv2d_11\n",
"31 conv2d_12\n",
"32 batch_normalization_6\n",
"33 batch_normalization_8\n",
"34 batch_normalization_11\n",
"35 batch_normalization_12\n",
"36 activation_6\n",
"37 activation_8\n",
"38 activation_11\n",
"39 activation_12\n",
"40 mixed0\n",
"41 conv2d_16\n",
"42 batch_normalization_16\n",
"43 activation_16\n",
"44 conv2d_14\n",
"45 conv2d_17\n",
"46 batch_normalization_14\n",
"47 batch_normalization_17\n",
"48 activation_14\n",
"49 activation_17\n",
"50 average_pooling2d_2\n",
"51 conv2d_13\n",
"52 conv2d_15\n",
"53 conv2d_18\n",
"54 conv2d_19\n",
"55 batch_normalization_13\n",
"56 batch_normalization_15\n",
"57 batch_normalization_18\n",
"58 batch_normalization_19\n",
"59 activation_13\n",
"60 activation_15\n",
"61 activation_18\n",
"62 activation_19\n",
"63 mixed1\n",
"64 conv2d_23\n",
"65 batch_normalization_23\n",
"66 activation_23\n",
"67 conv2d_21\n",
"68 conv2d_24\n",
"69 batch_normalization_21\n",
"70 batch_normalization_24\n",
"71 activation_21\n",
"72 activation_24\n",
"73 average_pooling2d_3\n",
"74 conv2d_20\n",
"75 conv2d_22\n",
"76 conv2d_25\n",
"77 conv2d_26\n",
"78 batch_normalization_20\n",
"79 batch_normalization_22\n",
"80 batch_normalization_25\n",
"81 batch_normalization_26\n",
"82 activation_20\n",
"83 activation_22\n",
"84 activation_25\n",
"85 activation_26\n",
"86 mixed2\n",
"87 conv2d_28\n",
"88 batch_normalization_28\n",
"89 activation_28\n",
"90 conv2d_29\n",
"91 batch_normalization_29\n",
"92 activation_29\n",
"93 conv2d_27\n",
"94 conv2d_30\n",
"95 batch_normalization_27\n",
"96 batch_normalization_30\n",
"97 activation_27\n",
"98 activation_30\n",
"99 max_pooling2d_3\n",
"100 mixed3\n",
"101 conv2d_35\n",
"102 batch_normalization_35\n",
"103 activation_35\n",
"104 conv2d_36\n",
"105 batch_normalization_36\n",
"106 activation_36\n",
"107 conv2d_32\n",
"108 conv2d_37\n",
"109 batch_normalization_32\n",
"110 batch_normalization_37\n",
"111 activation_32\n",
"112 activation_37\n",
"113 conv2d_33\n",
"114 conv2d_38\n",
"115 batch_normalization_33\n",
"116 batch_normalization_38\n",
"117 activation_33\n",
"118 activation_38\n",
"119 average_pooling2d_4\n",
"120 conv2d_31\n",
"121 conv2d_34\n",
"122 conv2d_39\n",
"123 conv2d_40\n",
"124 batch_normalization_31\n",
"125 batch_normalization_34\n",
"126 batch_normalization_39\n",
"127 batch_normalization_40\n",
"128 activation_31\n",
"129 activation_34\n",
"130 activation_39\n",
"131 activation_40\n",
"132 mixed4\n",
"133 conv2d_45\n",
"134 batch_normalization_45\n",
"135 activation_45\n",
"136 conv2d_46\n",
"137 batch_normalization_46\n",
"138 activation_46\n",
"139 conv2d_42\n",
"140 conv2d_47\n",
"141 batch_normalization_42\n",
"142 batch_normalization_47\n",
"143 activation_42\n",
"144 activation_47\n",
"145 conv2d_43\n",
"146 conv2d_48\n",
"147 batch_normalization_43\n",
"148 batch_normalization_48\n",
"149 activation_43\n",
"150 activation_48\n",
"151 average_pooling2d_5\n",
"152 conv2d_41\n",
"153 conv2d_44\n",
"154 conv2d_49\n",
"155 conv2d_50\n",
"156 batch_normalization_41\n",
"157 batch_normalization_44\n",
"158 batch_normalization_49\n",
"159 batch_normalization_50\n",
"160 activation_41\n",
"161 activation_44\n",
"162 activation_49\n",
"163 activation_50\n",
"164 mixed5\n",
"165 conv2d_55\n",
"166 batch_normalization_55\n",
"167 activation_55\n",
"168 conv2d_56\n",
"169 batch_normalization_56\n",
"170 activation_56\n",
"171 conv2d_52\n",
"172 conv2d_57\n",
"173 batch_normalization_52\n",
"174 batch_normalization_57\n",
"175 activation_52\n",
"176 activation_57\n",
"177 conv2d_53\n",
"178 conv2d_58\n",
"179 batch_normalization_53\n",
"180 batch_normalization_58\n",
"181 activation_53\n",
"182 activation_58\n",
"183 average_pooling2d_6\n",
"184 conv2d_51\n",
"185 conv2d_54\n",
"186 conv2d_59\n",
"187 conv2d_60\n",
"188 batch_normalization_51\n",
"189 batch_normalization_54\n",
"190 batch_normalization_59\n",
"191 batch_normalization_60\n",
"192 activation_51\n",
"193 activation_54\n",
"194 activation_59\n",
"195 activation_60\n",
"196 mixed6\n",
"197 conv2d_65\n",
"198 batch_normalization_65\n",
"199 activation_65\n",
"200 conv2d_66\n",
"201 batch_normalization_66\n",
"202 activation_66\n",
"203 conv2d_62\n",
"204 conv2d_67\n",
"205 batch_normalization_62\n",
"206 batch_normalization_67\n",
"207 activation_62\n",
"208 activation_67\n",
"209 conv2d_63\n",
"210 conv2d_68\n",
"211 batch_normalization_63\n",
"212 batch_normalization_68\n",
"213 activation_63\n",
"214 activation_68\n",
"215 average_pooling2d_7\n",
"216 conv2d_61\n",
"217 conv2d_64\n",
"218 conv2d_69\n",
"219 conv2d_70\n",
"220 batch_normalization_61\n",
"221 batch_normalization_64\n",
"222 batch_normalization_69\n",
"223 batch_normalization_70\n",
"224 activation_61\n",
"225 activation_64\n",
"226 activation_69\n",
"227 activation_70\n",
"228 mixed7\n",
"229 conv2d_73\n",
"230 batch_normalization_73\n",
"231 activation_73\n",
"232 conv2d_74\n",
"233 batch_normalization_74\n",
"234 activation_74\n",
"235 conv2d_71\n",
"236 conv2d_75\n",
"237 batch_normalization_71\n",
"238 batch_normalization_75\n",
"239 activation_71\n",
"240 activation_75\n",
"241 conv2d_72\n",
"242 conv2d_76\n",
"243 batch_normalization_72\n",
"244 batch_normalization_76\n",
"245 activation_72\n",
"246 activation_76\n",
"247 max_pooling2d_4\n",
"248 mixed8\n",
"249 conv2d_81\n",
"250 batch_normalization_81\n",
"251 activation_81\n",
"252 conv2d_78\n",
"253 conv2d_82\n",
"254 batch_normalization_78\n",
"255 batch_normalization_82\n",
"256 activation_78\n",
"257 activation_82\n",
"258 conv2d_79\n",
"259 conv2d_80\n",
"260 conv2d_83\n",
"261 conv2d_84\n",
"262 average_pooling2d_8\n",
"263 conv2d_77\n",
"264 batch_normalization_79\n",
"265 batch_normalization_80\n",
"266 batch_normalization_83\n",
"267 batch_normalization_84\n",
"268 conv2d_85\n",
"269 batch_normalization_77\n",
"270 activation_79\n",
"271 activation_80\n",
"272 activation_83\n",
"273 activation_84\n",
"274 batch_normalization_85\n",
"275 activation_77\n",
"276 mixed9_0\n",
"277 concatenate_1\n",
"278 activation_85\n",
"279 mixed9\n",
"280 conv2d_90\n",
"281 batch_normalization_90\n",
"282 activation_90\n",
"283 conv2d_87\n",
"284 conv2d_91\n",
"285 batch_normalization_87\n",
"286 batch_normalization_91\n",
"287 activation_87\n",
"288 activation_91\n",
"289 conv2d_88\n",
"290 conv2d_89\n",
"291 conv2d_92\n",
"292 conv2d_93\n",
"293 average_pooling2d_9\n",
"294 conv2d_86\n",
"295 batch_normalization_88\n",
"296 batch_normalization_89\n",
"297 batch_normalization_92\n",
"298 batch_normalization_93\n",
"299 conv2d_94\n",
"300 batch_normalization_86\n",
"301 activation_88\n",
"302 activation_89\n",
"303 activation_92\n",
"304 activation_93\n",
"305 batch_normalization_94\n",
"306 activation_86\n",
"307 mixed9_1\n",
"308 concatenate_2\n",
"309 activation_94\n",
"310 mixed10\n",
"311 global_average_pooling2d_1\n",
"312 dense_1\n",
"313 dense_2\n"
]
}
],
"source": [
"for i, layer in enumerate(model.layers):\n",
" print(i, layer.name)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"# freeze the first 249 layers and unfreeze the rest:\n",
"for layer in model.layers[:310]:\n",
" layer.trainable = False\n",
"for layer in model.layers[310:]:\n",
" layer.trainable = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize training process in real-time\n",
"\n",
"Using the `experiment.display()` method, you can view live visualizations of the model's training progress. Once you run the cell below with `model.fit_generator()`, training data will begin reporting and the visualizations in the **Chart** will render"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Experiment.display()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we need to recompile the model for these modifications to take effect\n",
"# we use SGD with a low learning rate so we don't lose the value of the pre-trained model\n",
"from keras.optimizers import SGD\n",
"model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy',metrics=['accuracy'])\n",
"\n",
"# we train our model again (this time fine-tuning the top 2 inception blocks\n",
"# alongside the top Dense layers\n",
"model.fit_generator(\n",
" train_generator,1\n",
" steps_per_epoch=train_samples // batch_size,\n",
" epochs=epochs,\n",
" validation_data=validation_generator,\n",
" validation_steps=validation_samples// batch_size,\n",
" callbacks=[EarlyStopping(monitor='val_loss', min_delta=min_delta, patience=patience)]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save model weights"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'web': 'https://www.comet.ml/api/asset/download?assetId=a6c75ebcfd344c06a4934b97641ea87e&experimentKey=a5ec4a7953a54bcd817e3b7aa9c11a48',\n",
" 'api': 'https://www.comet.ml/api/rest/v1/asset/get-asset?assetId=a6c75ebcfd344c06a4934b97641ea87e&experimentKey=a5ec4a7953a54bcd817e3b7aa9c11a48'}"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#save locally\n",
"model.save_weights('inceptionv3_tuned.h5') \n",
"\n",
"#save to Comet Asset Tab\n",
"# you can retrieve these weights later via the REST API \n",
"experiment.log_asset(file_path='./inceptionv3_tuned.h5', file_name='inceptionv3_tuned.h5') "
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"COMET INFO: ----------------------------\n",
"COMET INFO: Comet.ml Experiment Summary:\n",
"COMET INFO: Metrics:\n",
"COMET INFO: acc: 0.7464916415897019\n",
"COMET INFO: batch: 1720\n",
"COMET INFO: epoch_end: 49\n",
"COMET INFO: loss: 0.7562066181474528\n",
"COMET INFO: size: 16\n",
"COMET INFO: step: 86200\n",
"COMET INFO: Other:\n",
"COMET INFO: trainable_params: 23917360\n",
"COMET INFO: Uploads:\n",
"COMET INFO: assets: 1\n",
"COMET INFO: figures: 0\n",
"COMET INFO: images: 0\n",
"COMET INFO: ----------------------------\n",
"COMET INFO: Uploading stats to Comet before program termination (may take several seconds)\n",
"COMET INFO: Still uploading\n"
]
}
],
"source": [
"# signal the end of the experiment \n",
"# retrieve summary statistics around metrics, assets, etc...\n",
"experiment.end()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Accessing the full model code from Comet \n",
"\n",
"If you run this experiment from a git directory, you or your teammates can retrieve this notebook using the `Reproduce` button from the Comet experiment view. \n",
"\n",
"When you press the `Reproduce` button (see screenshot below), you will see a prompt to download the notebook. The notebook will contain the cells in the order they were excecuted. \n",
"\n",
"![](../images/comet-reproduce-notebook.png)\n",
"\n",
"\n",
"**Retrieving a script:**\n",
"If you train a model with a script instead, you will be able to see bash commands to retreive the code (including untracked changes).\n",
"![](../images/comet reproduce script.png)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment