Example Project Template: Serve a Scikit-learn Model via a Flask API
Last updated:Just looking for the project? Grab the code here: queirozfcom/flask-sklearn-seed
This is a full template for building a simple flask-based API and server that serves a trained Scikit-learn model.
It is not meant for production, just for development purposes
Includes:
- API and code-level tests
- Logging
- Error handling
- CLI for training the model
- Input validation using JSON Schema
Quickstart
Clone the project
$ git clone git@github.com:queirozfcom/flask-sklearn-seed.git Cloning into 'flask-sklearn-seed'...
create Python 3 virtualenv, activate virtualenv
$ cd flask-sklearn-seed $ virtualenv -p python3 venv3 $ source venv3/bin/activate
install requirements-dev
$ pip install -r requirements-dev.txt
train the model using the dummy data:
$ python -m app.models.train_model data/raw/training.csv v0 Will train model v0 using the file at: /home/felipe/flask-sklearn-seed/data/raw/training.csv training set has 7500 rows validation set has 2500 rows 0.985957111012551 Successfully saved model at /home/felipe/flask-sklearn-seed/trained-models/trained-model-v0.p
start the server
$ python -m app.app * Serving Flask app "app" (lazy loading) * Environment: production WARNING: Do not use the development server in a production environment. Use a production WSGI server instead. * Debug mode: off * Running on http://0.0.0.0:8080/ (Press CTRL+C to quit)
Using the app
Training via the CLI
- To train a model:
$ python -m app.models.train_model <path/to/training_set.csv> <version-number>
- To train a model:
Tests
- To run utils tests:
$ python -m tests.utils_tests
- To run API tests:
$ python -m tests.web_tests
- To run utils tests:
Starting the server
$ python -m app.app * Running on http://0.0.0.0:8080/ (Press CTRL+C to quit)
Code Organization
This is how this project's code is structured.
Loosely based on Queirozf.com: How to Structure Software Projects: Python Examples and Cookie Cutter Data Science
.
│
├── README.md <----- this file
│
├── app
│ ├── app.py <----- main project file. contains routes and initialization code
│ │
│ ├── settings.py
│ │
│ ├── helpers <----- helpers contain helper code that is SPECIFIC to this application
│ │ ├── features.py they are placed here so as not to overly pollute the business logic
│ │ ├── files.py with scaffolding code.
│ │ └── validation.py
│ │
│ ├── models <----- code for training models
│ │ └── train_model.py
│ │
│ └── utils <----- utils contain helper code that is NOT SPECIFIC to this application,
│ └── files.py i.e. it could be extracted and used elsewhere
│
├── data <----- data files, intermediate representation, if needed.
│ ├── interim
│ ├── processed
│ └── raw
│ └── training_set.csv
│
├── logs <----- logs folder
│ └─ application.log
│
├── notebooks <----- jupyter notebooks for data exploration and analyses
│ └── view-data.ipynb
│
├── requirements-dev.txt <----- packages required to DEVELOP this project (train model, notebooks, tests, CLI commands)
├── requirements-prod.txt <----- packages required to DEPLOY this project (only serves the API)
│
├── tests <----- test code
│ ├── utils_tests.py
│ └── web_tests.py
│
├── trained-models <------ trained models (serialized) are kept here
│ ├── trained-model-v0.p
│ ├── trained-model-v1.p
│ └── ...
│
└── venv3 <------ python virtualenv
API Docs
Healthcheck
A simple healthcheck, to be used for monitoring (e.g. in AWS Elastic Beanstalk) a given model version.
Example: Correct Request, valid version
REQUEST
GET /v0/healthcheck
RESPONSE 200
OK
Example: Correct Request, invalid version
REQUEST
GET /v31254/healthcheck
RESPONSE 200
Not OK
Predict
Returns a prediction, calculated by a previously trained model, whose version is <version>
.
Example: Correct Request
REQUEST
POST /v0/predict
{
"id": "2",
"x_1": -2.0,
"x_2": -0.414120,
"x_3": 0.2131,
"x_4": -1.2
}
RESPONSE 200
{
"id": "2",
"prediction": 0.8077
}
Example: Model version not found
REQUEST
POST /v43287/predict
{
"id": "19826478126",
"x_1": 1.0,
"x_2": -0.414120,
"x_3": 0.2131,
"x_4": -1.2
}
RESPONSE 404
{
"message": "Trained model version 'v43287' was not found."
}
Example: Invalid request arguments
REQUEST
POST /v0/predict
{
"id": "126",
"x_1": 1.0,
"x_2": -2.2
}
RESPONSE 400
{
"message": "Missing keys: 'x_3', 'x_4'"
}
Logging
Logging is needed to keep track of how people use your app (collect usage metrics) and to help diagnose errors in case something goes wrong.
I've used an external package (concurrent-log-handler
) because the default RotatingFileHandler
does not support compression of old log files. This is to make sure logging itself doesn't cause problems due to lack of disk space.
Caching
There are a couple of caching mechanisms for flask (e.g. https://github.com/sh4nks/flask-caching) but, since Logistic Regression is an eager learning method (i.e. inference is quite fast because most of the work is done at training time), it didn't seem to be worth the extra complexity.
Maybe if you are using lazy methods (such as k-NN), caching would be more useful.