Konva.js is a 2D Canvas Library for the web. We'll draw an image using Konva, extract the image array, and call an API to run an ML model on the image.
2024-03-17
07m Read
By: Abhilaksh Singh Reen
Today, we'll create a website using React that makes use of konva.js to allow us to draw an image. We'll send this image to an API that's running an ML model for inference, get the prediction back, and display it to the user.
The code for the React Application can be found here and the backend code is available here.
If you would like to explore the API we'll be using in more detail, you can check out this tutorial, in which we build the API from the ground up in Express.js.
Open up a new terminal in your project's directory and run the following command to create a new react application.
npx create-react-app drawn-digit-api-inference
Move into the newly created folder and open it up in the code editor of your choice.
create-react-app provides us with a lot of boilerplate code. I will leave most files in the project untouched and only talk about the files that I create or make changes to.
In the src folder, let's create a new folder called pages, and inside it create a new file called Home.jsx. We'll add a simple React component here.
export default function Home() {
return (
<div>Home</div>
)
}
Now, open up App.js in the project's root directory, and replace the entire file with the following code.
import Home from "./pages/Home";
function App() {
return (
<div className="App">
<Home />
</div>
);
}
export default App;
Here, we just have one component, that is, the home component.
Run the application using
npm start
and head to localhost:3000 in your web browser. You should see a page that says "Home".
Let's head back to pages/Home.jsx. We'll add a Konva Stage and set up some logic that will allow us to Draw and Erase on the Stage.
import { useRef, useState } from "react";
import { Stage, Layer, Line } from "react-konva";
export default function Home() {
const stageSize = Math.min(window.innerWidth * 0.9, window.innerHeight * 0.5);
const [tool, setTool] = useState("pen");
const [lines, setLines] = useState([]);
const [strokeWidth, setStrokeWidth] = useState(15);
const stageRef = useRef(null);
const isDrawing = useRef(false);
const handleMouseDown = (e) => {
isDrawing.current = true;
const pos = e.target.getStage().getPointerPosition();
setLines([...lines, { tool, strokeWidth, points: [pos.x, pos.y] }]);
};
const handleMouseMove = (e) => {
if (!isDrawing.current) {
return;
}
const stage = e.target.getStage();
const point = stage.getPointerPosition();
let lastLine = lines[lines.length - 1];
lastLine.points = lastLine.points.concat([point.x, point.y]);
lines.splice(lines.length - 1, 1, lastLine);
setLines(lines.concat());
};
const handleMouseUp = () => {
isDrawing.current = false;
};
return (
<div
style={{
width: "100%",
display: "flex",
flexDirection: "column",
justifyContent: "flex-start",
alignItems: "center",
}}
>
<h1>Drawn Digit Prediction</h1>
<Stage
ref={stageRef}
width={stageSize}
height={stageSize}
onMouseDown={handleMouseDown}
onMousemove={handleMouseMove}
onMouseup={handleMouseUp}
style={{
border: "1px solid black",
}}
>
<Layer>
{lines.map((line, i) => (
<Line
key={i}
points={line.points}
stroke="#000000"
strokeWidth={line.strokeWidth}
tension={0.5}
lineCap="round"
lineJoin="round"
globalCompositeOperation={line.tool === "eraser" ? "destination-out" : "source-over"}
/>
))}
</Layer>
</Stage>
<div
style={{
width: stageSize,
marginTop: 10,
display: "flex",
flexDirection: "row",
justifyContent: "space-evenly",
alignItems: "center",
}}
>
<select
value={tool}
onChange={(e) => {
setTool(e.target.value);
}}
style={{
padding: "8px",
borderRadius: "4px",
border: "1px solid #ccc",
fontSize: "16px",
}}
>
<option value="pen">Pen</option>
<option value="eraser">Eraser</option>
</select>
<select
value={strokeWidth}
onChange={(e) => {
setStrokeWidth(parseInt(e.target.value));
}}
style={{
padding: "8px",
borderRadius: "4px",
border: "1px solid #ccc",
fontSize: "16px",
}}
>
<option value="1">1</option>
<option value="3">3</option>
<option value="5">5</option>
<option value="10">10</option>
<option value="15">15</option>
<option value="20">20</option>
<option value="30">30</option>
<option value="40">40</option>
<option value="50">50</option>
</select>
<button
onClick={() => setLines([])}
style={{
padding: "8px 12px",
borderRadius: "4px",
border: "1px solid #ccc",
background: "#ffffff",
color: "#333",
fontSize: "16px",
cursor: "pointer",
}}
>
Clear
</button>
</div>
<button
style={{
padding: "8px 12px",
borderRadius: "4px",
border: "1px solid #444444",
background: "#eeeeee",
color: "#333",
fontSize: "16px",
cursor: "pointer",
marginTop: 10,
marginBottom: 5,
}}
>
Predict
</button>
</div>
);
}
Let me explain what's going on here. We have a Stage with MouseDown, MouseUp, and Mousemove event listeners. When the user presses the mouse down on the stage, we set isDrawing to true and add a new line to the lines array with the current tool, the stroke width, and the current point (the point at which the mouse was clicked). Then, in the handleMouseMove event listener, we check if we are currently drawing and if so, we add the current point to the last element in the lines array i.e. we add the current point to the last line. In the handleMouseUp listener, we just simply set isDrawing to false.
Now, let's look inside the Stage tag. We have a Layer tag that stores all our lines. For each line, we have a Line tag, passing it the points and the strokeWidth. We also pass in the globalCompositeOperation based on the tool corresponding to the line, you can read more about this here.
Finally, below the stage, we have some select tags to select the tool as well as its size, a button to clear the Stage, and a button to call the API and get a prediction. Speaking of which, let's do that.
First, we create a state to store the latest prediction:
const [prediction, setPrediction] = useState(null);
Next, we work on the handlePredictButtonClick function which is the onClick listener for the Predict button.
const handlePredictButtonClick = async (e) => {
const stageImageDataUri = stageRef.current.toDataURL();
setPrediction(0);
};
From the Konva Stage, we are able to extract its Data URI, but our API expects the image to be a Blob. So, let's create a function that can handle this conversion for us.
Inside the src folder, create a new folder called utils, and inside it create a file called fileUtils.js.
function dataURIToBlob(dataURI) {
const splitDataURI = dataURI.split(",");
const byteString = splitDataURI[0].indexOf("base64") >= 0 ? atob(splitDataURI[1]) : decodeURI(splitDataURI[1]);
const mimeString = splitDataURI[0].split(":")[1].split(";")[0];
const ia = new Uint8Array(byteString.length);
for (let i = 0; i < byteString.length; i++) ia[i] = byteString.charCodeAt(i);
return new Blob([ia], { type: mimeString });
}
export { dataURIToBlob };
We can now import this function in pages/Home.jsx.
import { dataURIToBlob } from "../utils/fileUtils";
and complete the handlePredictButtonClick function.
const handlePredictButtonClick = async (e) => {
const stageImageDataUri = stageRef.current.toDataURL();
const stageImageBlob = dataURIToBlob(stageImageDataUri);
const formData = new FormData();
formData.append("file", stageImageBlob, "image.jpg");
let responseData;
try {
const response = await fetch("http://localhost:8000/api/run-inference?image_provider=konva", {
method: "POST",
body: formData,
});
if (!response.ok) {
window.alert("Server failed to predict.");
return;
}
responseData = await response.json();
} catch (error) {
window.alert("Server failed to predict.");
return;
}
setPrediction(responseData.predicted_label);
};
And finally, add the onClick listener to the Predict button and display the prediction right below it.
<button
onClick={handlePredictButtonClick}
style={{
padding: "8px 12px",
borderRadius: "4px",
border: "1px solid #444444",
background: "#eeeeee",
color: "#333",
fontSize: "16px",
cursor: "pointer",
marginTop: 10,
marginBottom: 5,
}}
>
Predict
</button>
{prediction !== null && (
<h4
style={{
margin: 0,
}}
>
Probably a {prediction}
</h4>
)}
Make sure that the backend server is running and try drawing the image and hitting the Predict button, you should get a text displaying the predicted value.
Before we wrap this up, I wanna make a couple of changes for enhanced user experience. Once we have a prediction, and the user makes a change to the image, the prediction is no longer valid. So, in the handleMouseDown function, we'll set prediction to null so that it disappears from the screen.
const handleMouseDown = (e) => {
isDrawing.current = true;
const pos = e.target.getStage().getPointerPosition();
setLines([...lines, { tool, strokeWidth, points: [pos.x, pos.y] }]);
setPrediction(null);
};
And the other change is for the Clear button. We wanna ask for confirmation from the user before clearing the Stage. I did so once by mistake while drawing the "Hello There" image above, only after which the idea entered my mind.
<button
onClick={() => {
if (!window.confirm("Are you sure you want to clear your drawing? This operation cannot be undone.")) {
return;
}
setLines([]);
}}
style={{
padding: "8px 12px",
borderRadius: "4px",
border: "1px solid #ccc",
background: "#ffffff",
color: "#333",
fontSize: "16px",
cursor: "pointer",
}}
>
Clear
</button>
Here's the complete pages/Home.jsx file.
import { useRef, useState } from "react";
import { Stage, Layer, Line } from "react-konva";
import { dataURIToBlob } from "../utils/fileUtils";
export default function Home() {
const stageSize = Math.min(window.innerWidth * 0.9, window.innerHeight * 0.5);
const [tool, setTool] = useState("pen");
const [lines, setLines] = useState([]);
const [strokeWidth, setStrokeWidth] = useState(15);
const [prediction, setPrediction] = useState(null);
const stageRef = useRef(null);
const isDrawing = useRef(false);
const handleMouseDown = (e) => {
isDrawing.current = true;
const pos = e.target.getStage().getPointerPosition();
setLines([...lines, { tool, strokeWidth, points: [pos.x, pos.y] }]);
setPrediction(null);
};
const handleMouseMove = (e) => {
if (!isDrawing.current) {
return;
}
const stage = e.target.getStage();
const point = stage.getPointerPosition();
let lastLine = lines[lines.length - 1];
lastLine.points = lastLine.points.concat([point.x, point.y]);
lines.splice(lines.length - 1, 1, lastLine);
setLines(lines.concat());
};
const handleMouseUp = () => {
isDrawing.current = false;
};
const handlePredictButtonClick = async (e) => {
const stageImageDataUri = stageRef.current.toDataURL();
const stageImageBlob = dataURIToBlob(stageImageDataUri);
const formData = new FormData();
formData.append("file", stageImageBlob, "image.jpg");
let responseData;
try {
const response = await fetch("http://localhost:8000/api/run-inference?image_provider=konva", {
method: "POST",
body: formData,
});
if (!response.ok) {
window.alert("Server failed to predict.");
return;
}
responseData = await response.json();
} catch (error) {
window.alert("Server failed to predict.");
return;
}
setPrediction(responseData.predicted_label);
};
return (
<div
style={{
width: "100%",
display: "flex",
flexDirection: "column",
justifyContent: "flex-start",
alignItems: "center",
}}
>
<h1>Drawn Digit Prediction</h1>
<Stage
ref={stageRef}
width={stageSize}
height={stageSize}
onMouseDown={handleMouseDown}
onMousemove={handleMouseMove}
onMouseup={handleMouseUp}
style={{
border: "1px solid black",
}}
>
<Layer>
{lines.map((line, i) => (
<Line
key={i}
points={line.points}
stroke="#000000"
strokeWidth={line.strokeWidth}
tension={0.5}
lineCap="round"
lineJoin="round"
globalCompositeOperation={line.tool === "eraser" ? "destination-out" : "source-over"}
/>
))}
</Layer>
</Stage>
<div
style={{
width: stageSize,
marginTop: 10,
display: "flex",
flexDirection: "row",
justifyContent: "space-evenly",
alignItems: "center",
}}
>
<select
value={tool}
onChange={(e) => {
setTool(e.target.value);
}}
style={{
padding: "8px",
borderRadius: "4px",
border: "1px solid #ccc",
fontSize: "16px",
}}
>
<option value="pen">Pen</option>
<option value="eraser">Eraser</option>
</select>
<select
value={strokeWidth}
onChange={(e) => {
setStrokeWidth(parseInt(e.target.value));
}}
style={{
padding: "8px",
borderRadius: "4px",
border: "1px solid #ccc",
fontSize: "16px",
}}
>
<option value="1">1</option>
<option value="3">3</option>
<option value="5">5</option>
<option value="10">10</option>
<option value="15">15</option>
<option value="20">20</option>
<option value="30">30</option>
<option value="40">40</option>
<option value="50">50</option>
</select>
<button
onClick={() => {
if (!window.confirm("Are you sure you want to clear your drawing? This operation cannot be undone.")) {
return;
}
setLines([]);
}}
style={{
padding: "8px 12px",
borderRadius: "4px",
border: "1px solid #ccc",
background: "#ffffff",
color: "#333",
fontSize: "16px",
cursor: "pointer",
}}
>
Clear
</button>
</div>
<button
onClick={handlePredictButtonClick}
style={{
padding: "8px 12px",
borderRadius: "4px",
border: "1px solid #444444",
background: "#eeeeee",
color: "#333",
fontSize: "16px",
cursor: "pointer",
marginTop: 10,
marginBottom: 5,
}}
>
Predict
</button>
{prediction !== null && (
<h4
style={{
margin: 0,
}}
>
Probably a {prediction}
</h4>
)}
</div>
);
}
And there we have it, from our React App to our ML API. Accessing the API from a Frontend is only a part of the entire system, the other half is the API itself. If you wish to find out more about how the API we used today was built, head to this Link.
You might be wondering, should you deploy your models on the client side (using something like onnxruntime-web) or the server side, like our current API? Well, both approaches have their own advantages and disadvantages and there are some use cases in which one is certainly preferable over the other. To know more about the merits and demerits as well as the specific use cases, you can check out this Article.
However, if you've already decided to put your model on the Frontend, it can be done using an ONNX model as described here or a complete ML framework like TensorFlow.js as I have done in this Blog Post.
See you next time :)