Image classification web application with Flask and Keras Print

  • machine learning, keras, classification
  • 30

Below we will show you how to spin up a webpage for classifying images on-demand. Users will be able to provide the URL to an image, and the application will predict its contents.

Keras

Keras is a machine learning library for training deep neural networks that runs on top of TensorFlow, CNTK or Theano.

Flask

Flask is a web-application microframework for Python.

Code

from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

from flask import Flask
from flask import request
from flask import jsonify
from flask import render_template
from flask_cors import CORS

import os
import urllib.request
import imghdr
import base64

app = Flask(__name__)
CORS(app)

model = ResNet50(weights='imagenet')

@app.route('/', methods=['GET'])
def base():
    return '''<html>
        <body>
        <form id="img-prediction-form">
            <label>Image URL:</label>
            <input type="text" name="imgurl" />
            <input type="submit" value="Submit" name="submit">
        </form>

        <table>
            <thead>
                <th>Name</th>
                <th>Accuracy</th>
            </thead>
            <tbody id="prediction-results"></tbody>
        </table>

        <script
            src="https://code.jquery.com/jquery-3.2.1.min.js"
            integrity="sha256-hwg4gsxgFZhOsEEamdOYGBf13FyQuiTwlAQgxVSNgt4="
            crossorigin="anonymous">
        </script>
        <script type="text/javascript">
            $(document).on("submit","form",function(e) {
                e.preventDefault();

                $("#prediction-results").html("");

                $.ajax({
                    url: "/predict",
                    method: "POST",
                    data: $("#img-prediction-form").serialize(),
                    success: function(result) {
                        $.each(result, function(x,y) {
                            $("#prediction-results").append("<tr><td>"+x+"</td><td>"+y+"</td></tr>");
                        });
                    }
                });
            });
        </script>
        </body>
        </html>'''

@app.route('/predict', methods=['POST'])
def predit():
    response = {}
    if request.method != 'POST':
        response['status'] = False
        response['message'] = 'Invalid request'
    else:
        imgurl = request.form.get('imgurl')
        imgpth = "/tmp/{}".format(base64.b64encode(imgurl.encode('utf-8')))
        urllib.request.urlretrieve(imgurl, imgpth)

        if not os.path.isfile(imgpth):
            response['status'] = False
            response['message'] = 'Could not obtain the specified image.'
        else:
            if imghdr.what(imgpth) is None:
                response['status'] = False
                response['message'] = 'Invalid image provided'
            else:
                img = image.load_img(imgpth, target_size=(224, 224))
                x = image.img_to_array(img)
                x = np.expand_dims(x, axis=0)
                x = preprocess_input(x)

                preds = model.predict(x)

                for x in decode_predictions(preds, top=3)[0]:
                    response[x[1]] = format(x[2], "g")

    return jsonify(response)

Execute

$ export FLASK_APP=filename.py
$ python -m flask run --host=[IP_ADDRESS]

Replace [IP_ADDRESS] with your server's IP address, and replace filename.py with the filename of the above.


Was this answer helpful?

« Back

Powered by WHMCompleteSolution