在浏览器中用KerasTensorflow.js进行图片分类( 三 )


3.1 对输入图像进行预处理正如我已经提到的 , 输入到mobilenet的图像大小是[224, 224] , 并且特征在[-1 , 1]之间缩放 。 在使用模型进行预测之前 , 您需要执行这两个步骤(将输入图像的大小转化为[224, 224] , 并且将特征大小转为[-1, 1]之间) 。 为此 , 我们使用preprocessImage()函数 , 该函数接受image和modelName两个参数 。
使用 tf.fromPixels() 可以很方便的将输入图片加载进来 , 使用 resizeNearestNeighbor() 调整大小 , 并使用toFloat()将图片中的所有值转化为浮点类型 。
然后 , 我们使用127.5的标量值来缩放图像张量中的值 , 127.5是图像像素范围[0, 255]的中间值 。 对于图像中的每个像素值 , 我们减去该偏移值并除以该偏移值以达到在[-1 , 1]之间缩放的效果 。 然后使用expandDims()来扩展维度 。
mobile-net.js
// 对图像做预处理 , 以使得其对mobilenet模型友好function preprocessImage(image, modelName) {// resize the input image to mobilenet's target size of (224, 224)let tensor = tf.browser.fromPixels(image).resizeNearestNeighbor([224, 224]).toFloat();// if model is not available, send the tensor with expanded dimensionsif (modelName === undefined) {return tensor.expandDims();}// if model is mobilenet, feature scale tensor image to range [-1, 1]else if (modelName === "mobilenet") {let offset = tf.scalar(127.5);return tensor.sub(offset).div(offset).expandDims();}// else throw an errorelse {alert("Unknown model name..")}}3.2 使用Tf.js模型进行预测在对图像进行预处理后 , 我为Predict按钮做了一个处理程序 。 同样 , 这也是一个async函数 , 它使用await关键字 , 直到模型进行成功的预测 。
使用Tf.js的模型(使用model.predict(tensor)方法)进行预测与Keras一样简单明了 。 为了得到预测结果 , 我们对model.predict(tensor)执行了data()方法 。
预测的结果被映射到了一个名为results的数组中 , 这个数组使用了我们在本教程开始时加载的IMAGENET_CLASSES 。 我们还使用sort()根据概率从高到低对数组进行排序 , 并使用slice()只获取前5个概率 。
mobile-net.js
// If "Predict Button" is clicked, preprocess the image and// make predictions using mobilenet$("#predict-button").click(async function () {// check if model loadedif (model == undefined) {alert("Please load the model first..")}// check if image loadedif (document.getElementById("predict-box").style.display == "none") {alert("Please load an image using 'Demo Image' or 'Upload Image' button..")}// html-image element can be given to tf.fromPixelslet image= document.getElementById("test-image");let tensor = preprocessImage(image, modelName);// make predictions on the preprocessed image tensorlet predictions = await model.predict(tensor).data();// get the model's prediction resultslet results = Array.from(predictions).map(function (p, i) {return {probability: p,className: IMAGENET_CLASSES[i]};}).sort(function (a, b) {return b.probability - a.probability;}).slice(0, 5);// display the top-1 prediction of the modeldocument.getElementById("results-box").style.display = "block";document.getElementById("prediction").innerHTML = "MobileNet prediction - " + results[0].className + "";// display top-5 predictions of the modelvar ul = document.getElementById("predict-list");ul.innerHTML = "";results.forEach(function (p) {console.log(p.className + " " + p.probability.toFixed(6));var li = document.createElement("LI");li.innerHTML = p.className + " " + p.probability.toFixed(6);ul.appendChild(li);});});至此 , 整个演示就结束了 , 你可以根据自己的情况操作起来了 。 我们现在在客户端浏览器中拥有了最先进的Keras预训练模型MobileNet的能力 , 它能够对属于ImageNet类别的图像进行预测 。
请注意 , mobilenet模型在浏览器中的加载速度非常快 , 预测速度也非常快 。
参考文档1. TensorFlow.js - Official Documentation
2. Keras - Official Documentation
3. Importing a Keras model into TensorFlow.js
4. Introduction to TensorFlow.js - Intelligence and Learning
5. TensorFlow.js: Tensors - Intelligence and Learning
6. TensorFlow.js Quick Start 7. Session 6 - TensorFlow.js - Intelligence and Learning 8. Session
7 - TensorFlow.js Color Classifier - Intelligence and Learning
9. Tensorflow.js Explained
10. Webcam Tracking with Tensorflow.js
11. Try TensorFlow.js in your browser
来源:微信公众号:跨端与全栈
作者:噶牛
出处:;mid=2247483830&idx=1&sn=6b5d08c45ca83d352f99454599c8f0b6