React图片识别App
淘淘是只狗 人气:0先把效果图给大家放上来
个人觉得效果还行。识别不太准确是因为这个 app学习图片的时间太短(电脑太卡)。
(笔者是 window10) 安装运行环境:
npm install --global windows-build-tools
(这个时间很漫长。。。)npm install @tensorflow/tfjs-node
(这个时间很漫长。。。)
项目目录如下
train文件夹 index.js(入口文件)
const tf = require('@tensorflow/tfjs-node') const getData = require('./data') const TRAIN_DIR = '../垃圾分类/train' const OUTPUT_DIR = '../outputDir' const MOBILENET_URL = 'http://ai-sample.oss-cn-hangzhou.aliyuncs.com/pipcook/models/mobilenet/web_model/model.json' const main = async () => { // 加载数据 const { ds, classes} = await getData(TRAIN_DIR, OUTPUT_DIR) // 定义模型 const mobilenet = await tf.loadLayersModel(MOBILENET_URL) mobilenet.summary() // console.log(mobilenet.layers.map((l, i) => [l.name, i])) const model = tf.sequential() for (let i = 0; i <= 86; i += 1) { const layer = mobilenet.layers[i] layer.trainable = false model.add(layer) } model.add(tf.layers.flatten()) model.add(tf.layers.dense({ units: 10, activation: 'relu' })) model.add(tf.layers.dense({ units: classes.length, activation: 'softmax' })) // 训练模型 model.compile({ loss: 'sparseCategoricalCrossentropy', optimizer: tf.train.adam(), metrics: ['acc'] }) await model.fitDataset(ds, { epochs: 20 }) await model.save(`file://${process.cwd()}/${OUTPUT_DIR}`) } main()
data.js(处理数据)
const fs = require('fs') const tf = require('@tensorflow/tfjs-node') const img2x = (imgPath) => { const buffer = fs.readFileSync(imgPath) return tf.tidy(() => { const imgTs = tf.node.decodeImage(new Uint8Array(buffer)) const imgTsResized = tf.image.resizeBilinear(imgTs, [224, 224]) return imgTsResized.toFloat().sub(255/2).div(255/2).reshape([1, 224, 224, 3]) }) } const getData = async (trainDir, outputDir) => { const classes = fs.readdirSync(trainDir) fs.writeFileSync(`${outputDir}/classes.json`, JSON.stringify(classes)) const data = [] classes.forEach((dir, dirIndex) => { fs.readdirSync(`${trainDir}/${dir}`) .filter(n => n.match(/jpg$/)) .slice(0, 10) .forEach(filename => { console.log('读取', dir, filename) const imgPath = `${trainDir}/${dir}/${filename}` data.push({ imgPath, dirIndex }) }) }) tf.util.shuffle(data) const ds = tf.data.generator(function* () { const count = data.length const batchSize = 32 for (let start = 0; start < count; start += batchSize) { const end = Math.min(start + batchSize, count) yield tf.tidy(() => { const inputs = [] const labels = [] for (let j = start; j < end; j += 1) { const { imgPath, dirIndex } = data[j] const x = img2x(imgPath) inputs.push(x) labels.push(dirIndex) } const xs = tf.concat(inputs) const ys = tf.tensor(labels) return { xs, ys } }) } }) return { ds, classes } } module.exports = getData
安装一些运行项目需要的插件
app 文件夹
import React, { PureComponent } from 'react' import { Button, Progress, Spin, Empty } from 'antd' import 'antd/dist/antd.css' import * as tf from '@tensorflow/tfjs' import { file2img, img2x } from './utils' import intro from './intro' const DATA_URL = 'http://127.0.0.1:8080/' class App extends PureComponent { state = {} async componentDidMount() { this.model = await tf.loadLayersModel(DATA_URL + '/model.json') // this.model.summary() this.CLASSES = await fetch(DATA_URL + '/classes.json').then(res => res.json()) } predict = async (file) => { const img = await file2img(file) this.setState({ imgSrc: img.src, isLoading: true }) setTimeout(() => { const pred = tf.tidy(() => { const x = img2x(img) return this.model.predict(x) }) const results = pred.arraySync()[0] .map((score, i) => ({score, label: this.CLASSES[i]})) .sort((a, b) => b.score - a.score) this.setState({ results, isLoading: false }) }, 0) } renderResult = (item) => { const finalScore = Math.round(item.score * 100) return ( <tr key={item.label}> <td style={{ width: 80, padding: '5px 0' }}>{item.label}</td> <td> <Progress percent={finalScore} status={finalScore === 100 ? 'success' : 'normal'} /> </td> </tr> ) } render() { const { imgSrc, results, isLoading } = this.state const finalItem = results && {...results[0], ...intro[results[0].label]} return ( <div style={{padding: 20}}> <span style={{ color: '#cccccc', textAlign: 'center', fontSize: 12, display: 'block' }} >识别可能不准确</span> <Button type="primary" size="large" style={{width: '100%'}} onClick={() => this.upload.click()} > 选择图片识别 </Button> <input type="file" onChange={e => this.predict(e.target.files[0])} ref={el => {this.upload = el}} style={{ display: 'none' }} /> { !results && !imgSrc && <Empty style={{ marginTop: 40 }} /> } {imgSrc && <div style={{ marginTop: 20, textAlign: 'center' }}> <img src={imgSrc} style={{ maxWidth: '100%' }} /> </div>} {finalItem && <div style={{marginTop: 20}}>识别结果: </div>} {finalItem && <div style={{display: 'flex', alignItems: 'flex-start', marginTop: 20}}> <img src={finalItem.icon} width={120} /> <div> <h2 style={{color: finalItem.color}}> {finalItem.label} </h2> <div style={{color: finalItem.color}}> {finalItem.intro} </div> </div> </div>} { isLoading && <Spin size="large" style={{display: 'flex', justifyContent: 'center', alignItems: 'center', marginTop: 40 }} /> } {results && <div style={{ marginTop: 20 }}> <table style={{width: '100%'}}> <tbody> <tr> <td>类别</td> <td>匹配度</td> </tr> {results.map(this.renderResult)} </tbody> </table> </div>} </div> ) } } export default App
index.html
<!DOCTYPE html> <html> <head> <title>垃圾分类</title> <meta name="viewport" content="width=device-width, inital-scale=1"> </head> <body> <div id="app"></div> <script src="./index.js"></script> </body> </html>
index.js
import React from 'react' import ReactDOM from 'react-dom' import App from './App' ReactDOM.render(<App />, document.querySelector('#app'))
intro.js
export default { '可回收物': { icon: 'https://lajifenleiapp.com/static/svg/1_3F6BA8.svg', color: '#3f6ba8', intro: '是指在日常生活中或者为日常生活提供服务的活动中产生的,已经失去原有全部或者部分使用价值,回收后经过再加工可以成为生产原料或者经过整理可以再利用的物品,包括废纸类、塑料类、玻璃类、金属类、织物类等。' }, '有害垃圾': { icon: 'https://lajifenleiapp.com/static/svg/2v_B43953.svg', color: '#b43953', intro: '是指生活垃圾中对人体健康或者自然环境造成直接或者潜在危害的物质,包括废充电电池、废扣式电池、废灯管、弃置药品、废杀虫剂(容器)、废油漆(容器)、废日用化学品、废水银产品、废旧电器以及电子产品等。' }, '厨余垃圾': { icon: 'https://lajifenleiapp.com/static/svg/3v_48925B.svg', color: '#48925b', intro: '是指居民日常生活中产生的有机易腐垃圾,包括菜叶、剩菜、剩饭、果皮、蛋壳、茶渣、骨头等。' }, '其他垃圾': { icon: 'https://lajifenleiapp.com/static/svg/4_89918B.svg', color: '#89918b', intro: '是指除可回收物、有害垃圾和厨余垃圾之外的,混杂、污染、难分类的其他生活垃圾。' } }
utils.js
import * as tf from '@tensorflow/tfjs' export const file2img = async (f) => { return new Promise(reslove => { const reader = new FileReader() reader.readAsDataURL(f) reader.onload = (e) => { const img = document.createElement('img') img.src = e.target.result img.width = 224 img.height = 224 img.onload = () => { reslove(img) } } }) } export function img2x(imgEl) { return tf.tidy(() => { return tf.browser.fromPixels(imgEl) .toFloat().sub(255/2).div(255/2) .reshape([1, 224, 224, 3]) }) }
运行项目代码之前,我们需要先在 train 目录下运行,node index.js,生成 model.json 以供识别系统使用。之后需要在根目录下运行 hs outputDir --cors, 使得生成的 model.json 运行在 http 环境下,之后才可以运行 npm start ,不然项目是会报错的。
主要的代码就是上面这些。前面笔者也说了。自己对这方面完全不懂,所以也无法解说其中的代码。各位感兴趣就自己研究一下。代码地址奉上。
gitee.com/suiboyu/gar…
总结
加载全部内容