Initial commit
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.png
|
||||
.DS_Store
|
||||
9
Dockerfile
Normal file
9
Dockerfile
Normal file
@@ -0,0 +1,9 @@
|
||||
FROM python:3.11
|
||||
COPY ./app /source
|
||||
RUN pip install --upgrade pip
|
||||
RUN pip install --no-cache-dir -r /source/requirements.txt
|
||||
EXPOSE 3000
|
||||
RUN useradd app
|
||||
USER app
|
||||
WORKDIR source
|
||||
CMD ["gunicorn", "-b", "0.0.0.0:3000", "app:app"]
|
||||
52
README.md
Normal file
52
README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# shapeAI
|
||||
|
||||
_Forked from [@quentinbkk's shape ai](https://github.com/quentinbkk/shapeAI)_
|
||||
|
||||
Shape AI is a web app utilizing a classifier model to identify user drawn geometric shapes. Currently, the model can identify drawn squares, rectangles, circles, and triangles.
|
||||
|
||||
**[Try it live ✏️](https://shapeai.craisin.tech)**
|
||||
|
||||
## Running the App 🏃
|
||||
### Docker 🐋
|
||||
```bash
|
||||
git clone https://github.com/craisined/shapeAI
|
||||
cd shapeAI
|
||||
docker build -t shapeAI .
|
||||
docker run -p 3000:3000 shapeAI
|
||||
```
|
||||
### Manually 🖥️
|
||||
```bash
|
||||
git clone https://github.com/craisined/shapeAI
|
||||
cd shapeAI
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
cd app
|
||||
pip install -r requirements.txt
|
||||
gunicorn -b 0.0.0.0:3000 app:app
|
||||
```
|
||||
|
||||
## Model Training 💪
|
||||
OpenCV is used to synthetically generate training data in ```training/generate_shapes.py```.
|
||||
|
||||
Data is stored in ```training/data``` - add and modify the folder to add training cases.
|
||||
|
||||
Run ```training/train.py``` to train the model - exports to ```shape_model.keras```.
|
||||
|
||||
## Technical Overview 👨💻
|
||||
### Abilities
|
||||
1. High training accuracy - model consistently trains with accuracy > 99%
|
||||
2. Fast speed - model has sub 50ms response times
|
||||
3. Synthetic data and preproccessing - generates training data and sends user drawing from website to language model
|
||||
### Frameworks
|
||||
1. Model built with Tensorflow and Keras
|
||||
2. Image manipulation built using OpenCV and Pillow
|
||||
3. Backend built using Flask
|
||||
4. Frontend built using vanilla HTML, CSS, JS
|
||||
### Changes from original fork
|
||||
1. Web UI and Flask backend added
|
||||
2. Synthetic training data altered to produce a more human friendly model
|
||||
### WIP
|
||||
1. Low accuracy on certain cases - further improve synthetic shape generation
|
||||
2. Add confidence for classification - do not display a result if confidence is low
|
||||
3. Imporve mobile UI to further prevent scroll while drawing
|
||||
4. Expand dataset to various alphanumerical characters
|
||||
29
app/app.py
Normal file
29
app/app.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from base64 import b64decode
|
||||
from flask import Flask, render_template, request
|
||||
import io
|
||||
from keras.preprocessing.image import img_to_array
|
||||
import model
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
app = Flask(__name__)
|
||||
HOST="0.0.0.0"
|
||||
PORT=3000
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template("index.html")
|
||||
|
||||
@app.route("/shape_model")
|
||||
def shape_model():
|
||||
encoded_img = request.args["img"]
|
||||
encoded_img = encoded_img.replace("data:image/png;base64,", "", 1)
|
||||
img = b64decode(encoded_img)
|
||||
img = Image.open(io.BytesIO(img))
|
||||
img = img.convert("L")
|
||||
img = img_to_array(img)
|
||||
prediction = model.run_model(img)
|
||||
return prediction
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(HOST, port=PORT)
|
||||
12
app/model.py
Normal file
12
app/model.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from keras import models, layers
|
||||
import numpy as np
|
||||
|
||||
model = models.load_model("model/shape_model.keras")
|
||||
labels = ["circle ○", "rectangle ▭", "square □", "triangle △"]
|
||||
def run_model(image):
|
||||
img = np.expand_dims(image, axis=0)
|
||||
prediction = np.argmax(model.predict(img))
|
||||
return labels[prediction]
|
||||
|
||||
if __name__=="__main__":
|
||||
print(run_model(input("Image path: ")))
|
||||
BIN
app/model/shape_model.keras
Normal file
BIN
app/model/shape_model.keras
Normal file
Binary file not shown.
43
app/requirements.txt
Normal file
43
app/requirements.txt
Normal file
@@ -0,0 +1,43 @@
|
||||
absl-py==2.3.1
|
||||
astunparse==1.6.3
|
||||
blinker==1.9.0
|
||||
certifi==2025.10.5
|
||||
charset-normalizer==3.4.4
|
||||
click==8.3.0
|
||||
Flask==3.1.2
|
||||
flatbuffers==25.9.23
|
||||
gast==0.6.0
|
||||
google-pasta==0.2.0
|
||||
grpcio==1.76.0
|
||||
gunicorn==23.0.0
|
||||
h5py==3.15.1
|
||||
idna==3.11
|
||||
itsdangerous==2.2.0
|
||||
Jinja2==3.1.6
|
||||
keras==3.11.3
|
||||
libclang==18.1.1
|
||||
Markdown==3.9
|
||||
markdown-it-py==4.0.0
|
||||
MarkupSafe==3.0.3
|
||||
mdurl==0.1.2
|
||||
ml_dtypes==0.5.3
|
||||
namex==0.1.0
|
||||
numpy==2.2.6
|
||||
opencv-python==4.12.0.88
|
||||
opt_einsum==3.4.0
|
||||
optree==0.17.0
|
||||
packaging==25.0
|
||||
pillow==12.0.0
|
||||
protobuf==6.33.0
|
||||
Pygments==2.19.2
|
||||
requests==2.32.5
|
||||
rich==14.2.0
|
||||
six==1.17.0
|
||||
tensorboard==2.20.0
|
||||
tensorboard-data-server==0.7.2
|
||||
tensorflow==2.20.0
|
||||
termcolor==3.1.0
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.5.0
|
||||
Werkzeug==3.1.3
|
||||
wrapt==2.0.0
|
||||
BIN
app/static/quentin.jpg
Normal file
BIN
app/static/quentin.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
57
app/static/script.js
Normal file
57
app/static/script.js
Normal file
@@ -0,0 +1,57 @@
|
||||
var c = document.getElementById("canvas");
|
||||
var ctx = c.getContext("2d");
|
||||
|
||||
var aiBox = document.getElementById("shapeBox");
|
||||
|
||||
var isDragging = false;
|
||||
|
||||
function draw(e){
|
||||
var canvas_width = 0.4 * document.documentElement.clientWidth;
|
||||
if (document.documentElement.clientWidth <= 1000){
|
||||
canvas_width = 0.8 * document.documentElement.clientWidth;
|
||||
}
|
||||
var rect = canvas.getBoundingClientRect();
|
||||
|
||||
if (e.type.includes(`touch`)) {
|
||||
const { touches, changedTouches } = e.originalEvent ?? e;
|
||||
const touch = touches[0] ?? changedTouches[0];
|
||||
var posx = (touch.pageX - rect.left) * 64 / canvas_width;
|
||||
var posy = (touch.pageY - rect.top) * 64 / canvas_width;
|
||||
} else if (e.type.includes(`mouse`)) {
|
||||
var posx = (e.clientX - rect.left) * 64 / canvas_width;
|
||||
var posy = (e.clientY - rect.top) * 64 / canvas_width;
|
||||
}
|
||||
if (isDragging){
|
||||
ctx.fillStyle = "#000000";
|
||||
ctx.beginPath()
|
||||
ctx.arc(posx, posy, 1, 0, 2*Math.PI);
|
||||
ctx.fill();
|
||||
}
|
||||
}
|
||||
|
||||
function clear_canvas(){
|
||||
ctx.fillStyle = "#FFFFFF";
|
||||
ctx.beginPath();
|
||||
ctx.fillRect(0, 0, 64, 64);
|
||||
ctx.fill();
|
||||
}
|
||||
|
||||
function send_image(){
|
||||
var img = c.toDataURL();
|
||||
const params = new URLSearchParams();
|
||||
params.append("img", img);
|
||||
fetch(`/shape_model?${params}`).then(
|
||||
function (r) {return r.text();}
|
||||
).then(
|
||||
function (r) {aiBox.innerHTML = r;}
|
||||
);
|
||||
}
|
||||
|
||||
clear_canvas();
|
||||
setInterval(send_image, 1000);
|
||||
c.addEventListener("mousemove", draw);
|
||||
c.addEventListener('touchmove', draw);
|
||||
c.addEventListener('mousedown', function(e){isDragging = true;});
|
||||
c.addEventListener('touchstart', function(e){isDragging = true;});
|
||||
c.addEventListener('mouseup', function(e){isDragging = false;});
|
||||
c.addEventListener('touchend', function(e){isDragging = false;});
|
||||
60
app/static/style.css
Normal file
60
app/static/style.css
Normal file
@@ -0,0 +1,60 @@
|
||||
html, body{
|
||||
margin: 0px;
|
||||
font-family: "Lexend", sans-serif;
|
||||
font-weight: 300;
|
||||
overscroll-behavior-y: contain;
|
||||
overflow-x: hidden;
|
||||
}
|
||||
#textArea{
|
||||
background: #f8f8ff;
|
||||
width: 40vw;
|
||||
height: 100vh;
|
||||
float: left;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
#drawingArea{
|
||||
width: 60vw;
|
||||
float: right;
|
||||
}
|
||||
|
||||
#canvas{
|
||||
border-style: solid;
|
||||
border-width: 4px;
|
||||
margin-left: 10vw;
|
||||
margin-right: 10vw;
|
||||
margin-top: calc(50vh - 20vw);
|
||||
width:40vw;
|
||||
height:40vw;
|
||||
}
|
||||
#quentinImg{
|
||||
display: inline-block;
|
||||
height: 1em;
|
||||
width: auto;
|
||||
border-radius: 30%;
|
||||
}
|
||||
h1{
|
||||
font-size: 64px;
|
||||
}
|
||||
h2{
|
||||
font-size: 48px;
|
||||
}
|
||||
p{
|
||||
font-size: 24px;
|
||||
}
|
||||
@media only screen and (max-width:1000px) {
|
||||
#textArea{
|
||||
width: 100vw;
|
||||
height: 20vh;
|
||||
}
|
||||
#drawingArea{
|
||||
width: 100vw;
|
||||
height: 80vh;
|
||||
}
|
||||
#canvas{
|
||||
width: 80vw;
|
||||
height: 80vw;
|
||||
margin-top: calc(40vh - 40vw);
|
||||
}
|
||||
}
|
||||
24
app/templates/index.html
Normal file
24
app/templates/index.html
Normal file
@@ -0,0 +1,24 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>hello @quentinbkk i have found ur github :D</title>
|
||||
<link href="/static/style.css" rel="stylesheet">
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Lexend:wght@100..900&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<div id="textArea">
|
||||
<div>
|
||||
<h1>Shape AI</h1>
|
||||
<h2>I see a <span id="shapeBox">...</span></h2>
|
||||
<p>Based on <a href="https://github.com/quentinbkk/shapeAI">ShapeAI</a> by <a href="https://github.com/quentinbkk">@quentinbkk</a> <img id="quentinImg" src="/static/quentin.jpg"></p>
|
||||
</div>
|
||||
</div>
|
||||
<div id="drawingArea">
|
||||
<canvas id="canvas" id="canvas", width=64px, height=64px></canvas><br>
|
||||
<center><h3 onclick="clear_canvas()">Reset</h3></center>
|
||||
</div>
|
||||
</body>
|
||||
<script src="/static/script.js"></script>
|
||||
</html>
|
||||
65
training/generate_shapes.py
Normal file
65
training/generate_shapes.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
IMG_SIZE = 64
|
||||
NUM_IMAGES = 250 # Number of images per class
|
||||
OUTPUT_DIR = 'data'
|
||||
|
||||
# Ensure folders exist
|
||||
shapes = ['circle', 'square', 'rectangle', 'triangle']
|
||||
for shape in shapes:
|
||||
os.makedirs(os.path.join(OUTPUT_DIR, shape), exist_ok=True)
|
||||
|
||||
def draw_circle():
|
||||
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||
center = (random.randint(8, 56), random.randint(8, 56))
|
||||
radius = random.randint(4, min(60 - max(center), min(center)))
|
||||
cv2.circle(img, center, radius, (0,), 2)
|
||||
return img
|
||||
|
||||
def draw_square():
|
||||
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||
start = (random.randint(4, 56), random.randint(4, 56))
|
||||
size = random.randint(4, 60 - max(start))
|
||||
cv2.rectangle(img, start, (start[0]+size + random.randint(0, 3), start[1]+size + random.randint(0, 3)), (0,), 2)
|
||||
return img
|
||||
|
||||
def draw_rectangle():
|
||||
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||
vertical = random.randint(0, 1)
|
||||
long, short = (random.randint(4, 28), random.randint(33, 56))
|
||||
if vertical:
|
||||
start = (random.randint(4, 60 - short), random.randint(4, 60 - long))
|
||||
width, height = short, long
|
||||
else:
|
||||
start = (random.randint(4, 60 - long), random.randint(4, 60 - short))
|
||||
width, height = long, short
|
||||
cv2.rectangle(img, start, (start[0]+width, start[1]+height), (0,), 2)
|
||||
return img
|
||||
|
||||
def draw_triangle():
|
||||
img = np.ones((IMG_SIZE, IMG_SIZE), dtype=np.uint8) * 255
|
||||
pt1 = (random.randint(4, 60), random.randint(4, 60))
|
||||
pt2 = (random.randint(4, 60), random.randint(4, 60))
|
||||
pt3 = (random.randint(4, 60), random.randint(4, 60))
|
||||
points = np.array([pt1, pt2, pt3])
|
||||
cv2.drawContours(img, [points], 0, (0,), 2)
|
||||
return img
|
||||
|
||||
draw_functions = {
|
||||
'circle': draw_circle,
|
||||
'square': draw_square,
|
||||
'rectangle': draw_rectangle,
|
||||
'triangle': draw_triangle
|
||||
}
|
||||
|
||||
# ----- Generate images -----
|
||||
for shape in shapes:
|
||||
for i in range(NUM_IMAGES):
|
||||
img = draw_functions[shape]()
|
||||
filename = os.path.join(OUTPUT_DIR, shape, f"{shape}_{i}.png")
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
print("Images generated successfully!")
|
||||
BIN
training/shape_model.keras
Normal file
BIN
training/shape_model.keras
Normal file
Binary file not shown.
68
training/train.py
Normal file
68
training/train.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import tensorflow as tf
|
||||
from keras import layers, models
|
||||
from keras.preprocessing import image
|
||||
import pathlib
|
||||
import numpy as np
|
||||
|
||||
IMG_SIZE = 64
|
||||
BATCH_SIZE = 8
|
||||
NUM_CLASSES = 4
|
||||
EPOCHS = 10
|
||||
|
||||
data_dir = pathlib.Path("data")
|
||||
|
||||
train_ds = tf.keras.utils.image_dataset_from_directory(
|
||||
data_dir,
|
||||
labels='inferred',
|
||||
label_mode='categorical',
|
||||
color_mode='grayscale',
|
||||
batch_size=BATCH_SIZE,
|
||||
image_size=(IMG_SIZE, IMG_SIZE),
|
||||
validation_split=0.2,
|
||||
subset="training",
|
||||
seed=123
|
||||
)
|
||||
|
||||
val_ds = tf.keras.utils.image_dataset_from_directory(
|
||||
data_dir,
|
||||
labels='inferred',
|
||||
label_mode='categorical',
|
||||
color_mode='grayscale',
|
||||
batch_size=BATCH_SIZE,
|
||||
image_size=(IMG_SIZE, IMG_SIZE),
|
||||
validation_split=0.2,
|
||||
subset="validation",
|
||||
seed=123
|
||||
)
|
||||
|
||||
AUTOTUNE = tf.data.AUTOTUNE
|
||||
train_ds = train_ds.cache().shuffle(100).prefetch(buffer_size=AUTOTUNE)
|
||||
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
|
||||
|
||||
model = models.Sequential([
|
||||
layers.Rescaling(1/255, input_shape=(IMG_SIZE, IMG_SIZE, 1)),
|
||||
layers.Conv2D(32, (3,3), activation='relu'),
|
||||
layers.MaxPooling2D(2,2),
|
||||
|
||||
layers.Conv2D(64, (3,3), activation='relu'),
|
||||
layers.MaxPooling2D(2,2),
|
||||
|
||||
layers.Flatten(),
|
||||
layers.Dense(64, activation='relu'),
|
||||
layers.Dense(NUM_CLASSES, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(
|
||||
optimizer='adam',
|
||||
loss='categorical_crossentropy',
|
||||
metrics=['accuracy']
|
||||
)
|
||||
|
||||
model.fit(
|
||||
train_ds,
|
||||
validation_data=val_ds,
|
||||
epochs=EPOCHS
|
||||
)
|
||||
|
||||
model.save("shape_model.keras")
|
||||
print("Saved model")
|
||||
Reference in New Issue
Block a user