Land use and land cover classification by Euro SAT dataset using FastAI library
Currently I am following the fastai Practical DeepLearning for Coders (https://course.fast.ai) . As an attempt to use what I learn in first two lessions I try to fit a convolutional neural net of Euro SAT image to classify land usage .
Euro SAT Sentinel-2 dataset is comprised of 27000 images of 10 catagories of land use . The full dataset can be found here https://github.com/phelber/eurosat . A more detailed description can be found here https://arxiv.org/pdf/1709.00029.pdf .
First step to download the images and look into how it is organized. FastAI library provides various utility functions and download_data is one of them.Simple code looks like below -
path=download_data("http://madm.dfki.de/files/sentinel/EuroSAT.zip",ext='')
Dataset is organized in folders and each folder label is the class name . There is no test or train separate folder so that is what we need to derive .
Next step is to create a Databunch and load the data . Below is a simple function to do that work
tfms=get_transforms()
data = (ImageList.from_folder(path) --> path to the downloaded data
.split_by_rand_pct(0.2) --> split 80:20 train:validation
.label_from_folder() --> consider folder label class
.transform(tfms,size=224)--> Data augmentation
.databunch())
Important about the above snippet is the split_by_rand_pct
which split the items randomly and put the percentage as validation dateset whereas transform
add some more images to the dataset by randomly applying image modification(rotating, flipping,RGB adjust) this helps the model to generalize more and commonly known as Data Augmentation.
Now that data is ready it is time to create a CNN learner. It is worthwhile to note preparing the data is the most important step I have accepted most of the default that library is providing but it really depends on data that how the data needs to be prepared. To create the CNN I used RESNET 50 architecture, the paper also has provided a comparison of different models they have used and RESNET 50 seems to be the best fit among them by ~98.5% accuracy . Below code creates a CNN learner with resnet50
model and also show the error_rate
as we train the learner .
learner=cnn_learner(data,models.resnet50,metrics=error_rate)
To train the model on dataset fastai uses fit_one_cycle
method. More about the method can be found https://arxiv.org/pdf/1803.09820.pdf . In fastai it is used like below
learner.fit_one_cycle(4) --> 4 is the number of epoch
Here I have taken the pretrained resnet50 model and try to fit in on the dataset. In this approach only the weights of the last fully connected layer get updated as we rely on the lower layers as they are used for more generic features. After the training below is the result
Wow!! ~98% accuracy with all the defaults that the library is providing. Below is the confusion matrix
Next up let’s see what if instead of training only last fully connected layer , the model is trained from all the layer. To do that the first step is to unfreeze
the model like below and again train it -
learner.unfreeze()
learner.fit_one_cycle(4,max_lr=slice(3e-04,3e-03))
And now we have following result
A slight increase in accuracy which is in line with the paper of ~98.5% . But look at the lines of code that is needed in fastai to train the model.
Ok now the model is trained to use it for application model needs to be persisted some how for that
learn.export('euro-sat-stage1.pkl') --> name with .pkl extension
export
method serializes the model and can be retrieved to predict against it like below
img = open_image(<path to image>)
pred_class,pred_idx,outputs = learn.predict(img)
pred_class,outputs
That’s it the model can be used by applications. FastAI library is built upon PyTorch but abstracts way and gives a succinct way to create,train and predict using CNN . I am an absolute beginner and it really lowers down the barrier to start learning DeepLearning and use it .
Some notes-
- learning_rate is a major parameter while tuning the model. after traning the pretrained model to understand the learning_rate the below will help .
learner.lr_find()
learner.recorder.plot()
- Also having a very high accuracy can also mean that the model is overfitted on the given train set and may perform poorly on production data and thats why chosing a random validation set and proper data augmentation and cleansing plays a huge role on how generic the model is.
To wrap it up it was fun to able to train a model. In this process also I learned to setup a VM on Google cloud platform. The completed jupyter notebook can be found below github repo. So that’s it happy coding :)