Scikit - learn is a powerful library that simplifies the machine learning workflow. It follows a simple and consistent API design. The main steps in using Scikit - learn for model building are:
FastAPI is designed to be fast and easy to use. It uses Python type hints to validate input data, generate API documentation automatically, and ensure type safety. Key concepts in FastAPI include:
@app.get()
, @app.post()
, etc. These decorators map a specific HTTP method and URL path to a Python function.In e - commerce platforms, real - time prediction can be used for product recommendations. A machine learning model can be trained on historical user behavior data (such as purchase history, browsing history). FastAPI can be used to expose this model as an API. When a user visits the website, the application can send a request to the API with user - related features, and the API can return a list of recommended products in real - time.
In the healthcare industry, real - time prediction can assist in early disease detection. For example, a model can be trained on patient medical records, vital signs, and other relevant data. An API built with FastAPI can receive new patient data and predict the likelihood of a patient developing a certain disease.
In finance, real - time fraud detection is a critical application. A machine learning model can be trained on historical transaction data to identify patterns associated with fraudulent transactions. An API can receive new transaction details and quickly determine if the transaction is likely to be fraudulent.
pip install scikit - learn fastapi uvicorn pandas
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
# Load the iris dataset
iris = load_iris()
X = iris.data
y = iris.target
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train a logistic regression model
model = LogisticRegression()
model.fit(X_train, y_train)
# Evaluate the model
accuracy = model.score(X_test, y_test)
print(f"Model accuracy: {accuracy}")
from fastapi import FastAPI
import pandas as pd
from pydantic import BaseModel
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
import joblib
# Load the iris dataset and train the model (if not already trained)
iris = load_iris()
X = iris.data
y = iris.target
model = LogisticRegression()
model.fit(X, y)
# Save the model
joblib.dump(model, 'iris_model.pkl')
# Create a FastAPI app
app = FastAPI()
# Define the request model
class IrisRequest(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
# Load the saved model
loaded_model = joblib.load('iris_model.pkl')
# Define the prediction endpoint
@app.post("/predict/")
async def predict(request: IrisRequest):
data = [[request.sepal_length, request.sepal_width, request.petal_length, request.petal_width]]
prediction = loaded_model.predict(data)
return {"prediction": int(prediction[0])}
uvicorn main:app --reload
You can use tools like curl
or Postman to test the API. For example, using curl
:
curl -X POST "http://127.0.0.1:8000/predict/" -H "Content-Type: application/json" -d '{"sepal_length": 5.1, "sepal_width": 3.5, "petal_length": 1.4, "petal_width": 0.2}'
When saving and loading models using libraries like joblib
or pickle
, there can be compatibility issues between different versions of Scikit - learn or Python. It’s important to ensure that the same versions are used during model training and deployment.
FastAPI uses Pydantic for input validation, but it’s easy to overlook some edge cases. For example, if the input data has a different range or format than what the model expects, it can lead to incorrect predictions.
If the machine learning model is very large or complex, making predictions can be time - consuming. This can cause performance issues in the API, especially if there are a large number of concurrent requests.
Keep track of different versions of the machine learning model. This helps in reproducibility and makes it easier to roll back to a previous version if there are issues with the new model.
Thoroughly validate and sanitize all incoming data. In addition to using Pydantic for basic data type validation, also check for data integrity and range.
Use techniques like model compression, caching, and parallel processing to improve the performance of the prediction API. For example, if the same input data is likely to be received multiple times, cache the prediction results.
Combining Scikit - learn and FastAPI provides a powerful solution for building real - time prediction systems. Scikit - learn simplifies the machine learning model building process, while FastAPI makes it easy to expose these models as APIs. By understanding the core concepts, being aware of common pitfalls, and following best practices, developers can effectively build and deploy real - time prediction systems in various real - world scenarios.