Initial commit
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.png
|
||||||
|
.DS_Store
|
||||||
9
Dockerfile
Normal file
9
Dockerfile
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
FROM python:3.11
|
||||||
|
COPY ./app /source
|
||||||
|
RUN pip install --upgrade pip
|
||||||
|
RUN pip install --no-cache-dir -r /source/requirements.txt
|
||||||
|
EXPOSE 3000
|
||||||
|
RUN useradd app
|
||||||
|
USER app
|
||||||
|
WORKDIR source
|
||||||
|
CMD ["gunicorn", "-b", "0.0.0.0:3000", "app:app"]
|
||||||
52
README.md
Normal file
52
README.md
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# shapeAI
|
||||||
|
|
||||||
|
_Forked from [@quentinbkk's shape ai](https://github.com/quentinbkk/shapeAI)_
|
||||||
|
|
||||||
|
Shape AI is a web app utilizing a classifier model to identify user drawn geometric shapes. Currently, the model can identify drawn squares, rectangles, circles, and triangles.
|
||||||
|
|
||||||
|
**[Try it live ✏️](https://shapeai.craisin.tech)**
|
||||||
|
|
||||||
|
## Running the App 🏃
|
||||||
|
### Docker 🐋
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/craisined/shapeAI
|
||||||
|
cd shapeAI
|
||||||
|
docker build -t shapeAI .
|
||||||
|
docker run -p 3000:3000 shapeAI
|
||||||
|
```
|
||||||
|
### Manually 🖥️
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/craisined/shapeAI
|
||||||
|
cd shapeAI
|
||||||
|
python3 -m venv env
|
||||||
|
source env/bin/activate
|
||||||
|
cd app
|
||||||
|
pip install -r requirements.txt
|
||||||
|
gunicorn -b 0.0.0.0:3000 app:app
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Training 💪
|
||||||
|
OpenCV is used to synthetically generate training data in ```training/generate_shapes.py```.
|
||||||
|
|
||||||
|
Data is stored in ```training/data``` - add and modify the folder to add training cases.
|
||||||
|
|
||||||
|
Run ```training/train.py``` to train the model - exports to ```shape_model.keras```.
|
||||||
|
|
||||||
|
## Technical Overview 👨💻
|
||||||
|
### Abilities
|
||||||
|
1. High training accuracy - model consistently trains with accuracy > 99%
|
||||||
|
2. Fast speed - model has sub 50ms response times
|
||||||
|
3. Synthetic data and preproccessing - generates training data and sends user drawing from website to language model
|
||||||
|
### Frameworks
|
||||||
|
1. Model built with Tensorflow and Keras
|
||||||
|
2. Image manipulation built using OpenCV and Pillow
|
||||||
|
3. Backend built using Flask
|
||||||
|
4. Frontend built using vanilla HTML, CSS, JS
|
||||||
|
### Changes from original fork
|
||||||
|
1. Web UI and Flask backend added
|
||||||
|
2. Synthetic training data altered to produce a more human friendly model
|
||||||
|
### WIP
|
||||||
|
1. Low accuracy on certain cases - further improve synthetic shape generation
|
||||||
|
2. Add confidence for classification - do not display a result if confidence is low
|
||||||
|
3. Imporve mobile UI to further prevent scroll while drawing
|
||||||
|
4. Expand dataset to various alphanumerical characters
|
||||||
29
app/app.py
Normal file
29
app/app.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from base64 import b64decode
|
||||||
|
from flask import Flask, render_template, request
|
||||||
|
import io
|
||||||
|
from keras.preprocessing.image import img_to_array
|
||||||
|
import model
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
HOST="0.0.0.0"
|
||||||
|
PORT=3000
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def index():
|
||||||
|
return render_template("index.html")
|
||||||
|
|
||||||
|
@app.route("/shape_model")
|
||||||
|
def shape_model():
|
||||||
|
encoded_img = request.args["img"]
|
||||||
|
encoded_img = encoded_img.replace("data:image/png;base64,", "", 1)
|
||||||
|
img = b64decode(encoded_img)
|
||||||
|
img = Image.open(io.BytesIO(img))
|
||||||
|
img = img.convert("L")
|
||||||
|
img = img_to_array(img)
|
||||||
|
prediction = model.run_model(img)
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(HOST, port=PORT)
|
||||||
12
app/model.py
Normal file
12
app/model.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from keras import models, layers
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
model = models.load_model("model/shape_model.keras")
|
||||||
|
labels = ["circle ○", "rectangle ▭", "square □", "triangle △"]
|
||||||
|
def run_model(image):
|
||||||
|
img = np.expand_dims(image, axis=0)
|
||||||
|
prediction = np.argmax(model.predict(img))
|
||||||
|
return labels[prediction]
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
print(run_model(input("Image path: ")))
|
||||||
BIN
app/model/shape_model.keras
Normal file
BIN
app/model/shape_model.keras
Normal file
Binary file not shown.
43
app/requirements.txt
Normal file
43
app/requirements.txt
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
absl-py==2.3.1
|
||||||
|
astunparse==1.6.3
|
||||||
|
blinker==1.9.0
|
||||||
|
certifi==2025.10.5
|
||||||
|
charset-normalizer==3.4.4
|
||||||
|
click==8.3.0
|
||||||
|
Flask==3.1.2
|
||||||
|
flatbuffers==25.9.23
|
||||||
|
gast==0.6.0
|
||||||
|
google-pasta==0.2.0
|
||||||
|
grpcio==1.76.0
|
||||||
|
gunicorn==23.0.0
|
||||||
|
h5py==3.15.1
|
||||||
|
idna==3.11
|
||||||
|
itsdangerous==2.2.0
|
||||||
|
Jinja2==3.1.6
|
||||||
|
keras==3.11.3
|
||||||
|
libclang==18.1.1
|
||||||
|
Markdown==3.9
|
||||||
|
markdown-it-py==4.0.0
|
||||||
|
MarkupSafe==3.0.3
|
||||||
|
mdurl==0.1.2
|
||||||
|
ml_dtypes==0.5.3
|
||||||
|
namex==0.1.0
|
||||||
|
numpy==2.2.6
|
||||||
|
opencv-python==4.12.0.88
|
||||||
|
opt_einsum==3.4.0
|
||||||
|
optree==0.17.0
|
||||||
|
packaging==25.0
|
||||||
|
pillow==12.0.0
|
||||||
|
protobuf==6.33.0
|
||||||
|
Pygments==2.19.2
|
||||||
|
requests==2.32.5
|
||||||
|
rich==14.2.0
|
||||||
|
six==1.17.0
|
||||||
|
tensorboard==2.20.0
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
tensorflow==2.20.0
|
||||||
|
termcolor==3.1.0
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
urllib3==2.5.0
|
||||||
|
Werkzeug==3.1.3
|
||||||
|
wrapt==2.0.0
|
||||||
BIN
app/static/quentin.jpg
Normal file
BIN
app/static/quentin.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
57
app/static/script.js
Normal file
57
app/static/script.js
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
var c = document.getElementById("canvas");
|
||||||
|
var ctx = c.getContext("2d");
|
||||||
|
|
||||||
|
var aiBox = document.getElementById("shapeBox");
|
||||||
|
|
||||||
|
var isDragging = false;
|
||||||
|
|
||||||
|
function draw(e){
|
||||||
|
var canvas_width = 0.4 * document.documentElement.clientWidth;
|
||||||
|
if (document.documentElement.clientWidth <= 1000){
|
||||||
|
canvas_width = 0.8 * document.documentElement.clientWidth;
|
||||||
|
}
|
||||||
|
var rect = canvas.getBoundingClientRect();
|
||||||
|
|
||||||
|
if (e.type.includes(`touch`)) {
|
||||||
|
const { touches, changedTouches } = e.originalEvent ?? e;
|
||||||
|
const touch = touches[0] ?? changedTouches[0];
|
||||||
|
var posx = (touch.pageX - rect.left) * 64 / canvas_width;
|
||||||
|
var posy = (touch.pageY - rect.top) * 64 / canvas_width;
|
||||||
|
} else if (e.type.includes(`mouse`)) {
|
||||||
|
var posx = (e.clientX - rect.left) * 64 / canvas_width;
|
||||||
|
var posy = (e.clientY - rect.top) * 64 / canvas_width;
|
||||||
|
}
|
||||||
|
if (isDragging){
|
||||||
|
ctx.fillStyle = "#000000";
|
||||||
|
ctx.beginPath()
|
||||||
|
ctx.arc(posx, posy, 1, 0, 2*Math.PI);
|
||||||
|
ctx.fill();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function clear_canvas(){
|
||||||
|
ctx.fillStyle = "#FFFFFF";
|
||||||
|
ctx.beginPath();
|
||||||
|
ctx.fillRect(0, 0, 64, 64);
|
||||||
|
ctx.fill();
|
||||||
|
}
|
||||||
|
|
||||||
|
function send_image(){
|
||||||
|
var img = c.toDataURL();
|
||||||
|
const params = new URLSearchParams();
|
||||||
|
params.append("img", img);
|
||||||
|
fetch(`/shape_model?${params}`).then(
|
||||||
|
function (r) {return r.text();}
|
||||||
|
).then(
|
||||||
|
function (r) {aiBox.innerHTML = r;}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
clear_canvas();
|
||||||
|
setInterval(send_image, 1000);
|
||||||
|
c.addEventListener("mousemove", draw);
|
||||||
|
c.addEventListener('touchmove', draw);
|
||||||
|
c.addEventListener('mousedown', function(e){isDragging = true;});
|
||||||
|
c.addEventListener('touchstart', function(e){isDragging = true;});
|
||||||
|
c.addEventListener('mouseup', function(e){isDragging = false;});
|
||||||
|
c.addEventListener('touchend', function(e){isDragging = false;});
|
||||||
60
app/static/style.css
Normal file
60
app/static/style.css
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
html, body{
|
||||||
|
margin: 0px;
|
||||||
|
font-family: "Lexend", sans-serif;
|
||||||
|
font-weight: 300;
|
||||||
|
overscroll-behavior-y: contain;
|
||||||
|
overflow-x: hidden;
|
||||||
|
}
|
||||||
|
#textArea{
|
||||||
|
background: #f8f8ff;
|
||||||
|
width: 40vw;
|
||||||
|
height: 100vh;
|
||||||
|
float: left;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
#drawingArea{
|
||||||
|
width: 60vw;
|
||||||
|
float: right;
|
||||||
|
}
|
||||||
|
|
||||||
|
#canvas{
|
||||||
|
border-style: solid;
|
||||||
|
border-width: 4px;
|
||||||
|
margin-left: 10vw;
|
||||||
|
margin-right: 10vw;
|
||||||
|
margin-top: calc(50vh - 20vw);
|
||||||
|
width:40vw;
|
||||||
|
height:40vw;
|
||||||
|
}
|
||||||
|
#quentinImg{
|
||||||
|
display: inline-block;
|
||||||
|
height: 1em;
|
||||||
|
width: auto;
|
||||||
|
border-radius: 30%;
|
||||||
|
}
|
||||||
|
h1{
|
||||||
|
font-size: 64px;
|
||||||
|
}
|
||||||
|
h2{
|
||||||
|
font-size: 48px;
|
||||||
|
}
|
||||||
|
p{
|
||||||
|
font-size: 24px;
|
||||||
|
}
|
||||||
|
@media only screen and (max-width:1000px) {
|
||||||
|
#textArea{
|
||||||
|
width: 100vw;
|
||||||
|
height: 20vh;
|
||||||
|
}
|
||||||
|
#drawingArea{
|
||||||
|
width: 100vw;
|
||||||
|
height: 80vh;
|
||||||
|
}
|
||||||
|
#canvas{
|
||||||
|
width: 80vw;
|
||||||
|
height: 80vw;
|
||||||
|
margin-top: calc(40vh - 40vw);
|
||||||
|
}
|
||||||
|
}
|
||||||
24
app/templates/index.html
Normal file
24
app/templates/index.html
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>hello @quentinbkk i have found ur github :D</title>
|
||||||
|
<link href="/static/style.css" rel="stylesheet">
|
||||||
|
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||||
|
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||||
|
<link href="https://fonts.googleapis.com/css2?family=Lexend:wght@100..900&display=swap" rel="stylesheet">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="textArea">
|
||||||
|
<div>
|
||||||
|
<h1>Shape AI</h1>
|
||||||
|
<h2>I see a <span id="shapeBox">...</span></h2>
|
||||||
|
<p>Based on <a href="https://github.com/quentinbkk/shapeAI">ShapeAI</a> by <a href="https://github.com/quentinbkk">@quentinbkk</a> <img id="quentinImg" src="/static/quentin.jpg"></p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div id="drawingArea">
|
||||||
|
<canvas id="canvas" id="canvas", width=64px, height=64px></canvas><br>
|
||||||
|
<center><h3 onclick="clear_canvas()">Reset</h3></center>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
<script src="/static/script.js"></script>
|
||||||
|
</html>
|
||||||
65
training/generate_shapes.py
Normal file
65
training/generate_shapes.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
IMG_SIZE = 64
|
||||||
|
NUM_IMAGES = 250 # Number of images per class
|
||||||
|
OUTPUT_DIR = 'data'
|
||||||
|
|
||||||
|
# Ensure folders exist
|
||||||
|
shapes = ['circle', 'square', 'rectangle', 'triangle']
|
||||||
|
for shape in shapes:
|
||||||
|
os.makedirs(os.path.join(OUTPUT_DIR, shape), exist_ok=True)
|
||||||
|
|
||||||
|
def draw_circle():
|
||||||
|
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||||
|
center = (random.randint(8, 56), random.randint(8, 56))
|
||||||
|
radius = random.randint(4, min(60 - max(center), min(center)))
|
||||||
|
cv2.circle(img, center, radius, (0,), 2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def draw_square():
|
||||||
|
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||||
|
start = (random.randint(4, 56), random.randint(4, 56))
|
||||||
|
size = random.randint(4, 60 - max(start))
|
||||||
|
cv2.rectangle(img, start, (start[0]+size + random.randint(0, 3), start[1]+size + random.randint(0, 3)), (0,), 2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def draw_rectangle():
|
||||||
|
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||||
|
vertical = random.randint(0, 1)
|
||||||
|
long, short = (random.randint(4, 28), random.randint(33, 56))
|
||||||
|
if vertical:
|
||||||
|
start = (random.randint(4, 60 - short), random.randint(4, 60 - long))
|
||||||
|
width, height = short, long
|
||||||
|
else:
|
||||||
|
start = (random.randint(4, 60 - long), random.randint(4, 60 - short))
|
||||||
|
width, height = long, short
|
||||||
|
cv2.rectangle(img, start, (start[0]+width, start[1]+height), (0,), 2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def draw_triangle():
|
||||||
|
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||||
|
pt1 = (random.randint(4, 60), random.randint(4, 60))
|
||||||
|
pt2 = (random.randint(4, 60), random.randint(4, 60))
|
||||||
|
pt3 = (random.randint(4, 60), random.randint(4, 60))
|
||||||
|
points = np.array([pt1, pt2, pt3])
|
||||||
|
cv2.drawContours(img, [points], 0, (0,), 2)
|
||||||
|
return img
|
||||||
|
|
||||||
|
draw_functions = {
|
||||||
|
'circle': draw_circle,
|
||||||
|
'square': draw_square,
|
||||||
|
'rectangle': draw_rectangle,
|
||||||
|
'triangle': draw_triangle
|
||||||
|
}
|
||||||
|
|
||||||
|
# ----- Generate images -----
|
||||||
|
for shape in shapes:
|
||||||
|
for i in range(NUM_IMAGES):
|
||||||
|
img = draw_functions[shape]()
|
||||||
|
filename = os.path.join(OUTPUT_DIR, shape, f"{shape}_{i}.png")
|
||||||
|
cv2.imwrite(filename, img)
|
||||||
|
|
||||||
|
print("Images generated successfully!")
|
||||||
BIN
training/shape_model.keras
Normal file
BIN
training/shape_model.keras
Normal file
Binary file not shown.
68
training/train.py
Normal file
68
training/train.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
from keras import layers, models
|
||||||
|
from keras.preprocessing import image
|
||||||
|
import pathlib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
IMG_SIZE = 64
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
NUM_CLASSES = 4
|
||||||
|
EPOCHS = 10
|
||||||
|
|
||||||
|
data_dir = pathlib.Path("data")
|
||||||
|
|
||||||
|
train_ds = tf.keras.utils.image_dataset_from_directory(
|
||||||
|
data_dir,
|
||||||
|
labels='inferred',
|
||||||
|
label_mode='categorical',
|
||||||
|
color_mode='grayscale',
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
image_size=(IMG_SIZE, IMG_SIZE),
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="training",
|
||||||
|
seed=123
|
||||||
|
)
|
||||||
|
|
||||||
|
val_ds = tf.keras.utils.image_dataset_from_directory(
|
||||||
|
data_dir,
|
||||||
|
labels='inferred',
|
||||||
|
label_mode='categorical',
|
||||||
|
color_mode='grayscale',
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
image_size=(IMG_SIZE, IMG_SIZE),
|
||||||
|
validation_split=0.2,
|
||||||
|
subset="validation",
|
||||||
|
seed=123
|
||||||
|
)
|
||||||
|
|
||||||
|
AUTOTUNE = tf.data.AUTOTUNE
|
||||||
|
train_ds = train_ds.cache().shuffle(100).prefetch(buffer_size=AUTOTUNE)
|
||||||
|
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
|
||||||
|
|
||||||
|
model = models.Sequential([
|
||||||
|
layers.Rescaling(1/255, input_shape=(IMG_SIZE, IMG_SIZE, 1)),
|
||||||
|
layers.Conv2D(32, (3,3), activation='relu'),
|
||||||
|
layers.MaxPooling2D(2,2),
|
||||||
|
|
||||||
|
layers.Conv2D(64, (3,3), activation='relu'),
|
||||||
|
layers.MaxPooling2D(2,2),
|
||||||
|
|
||||||
|
layers.Flatten(),
|
||||||
|
layers.Dense(64, activation='relu'),
|
||||||
|
layers.Dense(NUM_CLASSES, activation='softmax')
|
||||||
|
])
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer='adam',
|
||||||
|
loss='categorical_crossentropy',
|
||||||
|
metrics=['accuracy']
|
||||||
|
)
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
train_ds,
|
||||||
|
validation_data=val_ds,
|
||||||
|
epochs=EPOCHS
|
||||||
|
)
|
||||||
|
|
||||||
|
model.save("shape_model.keras")
|
||||||
|
print("Saved model")
|
||||||
Reference in New Issue
Block a user