CreateML 使用以及在 iOS 中应用介绍

移动开发 iOS
Create ML 是苹果于2018年 WWDC 推出的生成机器学习模型的工具。它可以接收用户给定的数据,生成 iOS 开发中需要的机器学习模型(Core ML 模型)。

aPaaS Growth 团队专注在用户可感知的、宏观的 aPaaS 应用的搭建流程,及租户、应用治理等产品路径,致力于打造 aPaaS 平台流畅的 “应用交付” 流程和体验,完善应用构建相关的生态,加强应用搭建的便捷性和可靠性,提升应用的整体性能,从而助力 aPaaS 的用户增长,与基础团队一起推进 aPaaS 在企业内外部的落地与提效。

在低代码/无代码领域,例如 MS Power Platform,AWS 的 Amplify 都有类似于 AI Builder 的产品,这些产品主要让用户很低门槛训练自己的深度学习模型,CreateML 是苹果生态下的产品,工具上伴随 XCode 下发,安装了 XCode 的同学也可以打开来体验一下(得自己准备数据集)。

什么是 CreateML

图片

Create ML 是苹果于2018年 WWDC 推出的生成机器学习模型的工具。它可以接收用户给定的数据,生成 iOS 开发中需要的机器学习模型(Core ML 模型)。

iOS 开发中,机器学习模型的获取主要有以下几种:

  • 从苹果的官方主页[1]下载现成的模型。2017年有4个现成的模型,2018年有6个,2019年增加到了9个(8个图片、1个文字),今年进展到了 13,数量有限,进步速度缓慢,但是这些模型都是比较实用的,能在手机上在用户体验允许的情况下能够跑起来的。
  • 用第三方的机器学习框架生成模型,再用 Core ML Tools 转成 Core ML 模型。2017年苹果宣布支持的框架有6个,包括 Caffee、Keras。2018年宣布支持的第三方框架增加到了11个,包括了最知名的 TensorFlow、IBM Watson、MXNet。至此 Core ML 已经完全支持市面上所有主流的框架。
  • 用 Create ML 直接训练数据生成模型。2018年推出的初代 Create ML有三个特性:使用 Swift 编程进行操作、用 Playground 训练和生成模型、在 Mac OS 上完成所有工作。

今年的 Create ML 在易用性上更进一步:无需编程即可完成操作、独立成单独的 Mac OS App、支持更多的数据类型和使用场景。

CreateML 模型列表

图片

1、Image Classification:图片分类

图片

2、Object Detection:

图片

3、Style Transfer

图片

4、Hand Pose & Hand Action

图片

5、Action Classification

图片

6、Activity Classification

图片

图片

图片

7、Sound Classification

想象一下「Hey Siri」实现

图片

8、Text Classification

图片

9、Word Tagging

图片

10、Tabular Classification & Regression

图片

通过若干个维度,预测另外一个维度,例如通过性别、年龄、城市等推断你的收入级别。

11、Recommendation

例如你买了啤酒,推荐你买花生。历史上的也有一些不是基于深度学习的算法,例如 Apriori 等。

CreateML 模型尝鲜

图片

训练一个目标检测的 CreateML 模型

数据准备

有些同学可能认为觉得训练深度模型的难点在于找到适当的算法/模型、在足够强的机器下训练足够多的迭代次数。但是事实上,对于深度模型来说,最最最关键的是具有足够多的、精确的数据源,这也是 AI 行业容易形成头部效应最主要原因。假设你在做一个 AI 相关的应用,最主要需要关注的是如何拥有足够多的、精确的数据源。

下面我就与上面「尝鲜」的模型为例,讲述如何训练类似模型的。

数据格式

CreateML 目标检测的数据格式如下图:

图片

首先会有一个叫 annotions.json 的文件,这个文件会标注每个文件里有多少个目标,以及目标的 Bounding Box 的坐标是什么。

图片

例如上图对应的 Bounding Box 如下:

图片

准备足够多的数据

第一个问题是,什么才叫足够多的数据,我们可以看一些 Dataset 来参考一下:

Standford Cars Dataset: 934MB. The Cars dataset contains 16,185 images of 196 classes of cars. The data is split into 8,144 training images and 8,041 testing images。

https://www.kaggle.com/datasets/kmader/food41: Labeled food images in 101 categories from apple pies to waffles, 6GB

在上面这个例子里,原神的角色有大概 40 多个,所以我们需要准备大概百来 MB 的数据来训练作为起来,当精确度不高的时候,再增加样本的数量来增加精度。问题是我们去哪里找那么多数据呢?所以我想到的一个方法是通过脚本来合成,因为我们的问题只是定位提取图片中的角色「证件照」,我用大概 40 来角色的证件照,写了如下的脚本(colipot helped a alot ...)来生成大概 500MB 的测试训练集:

// import sharp from "sharp";

import { createCanvas, Image } from "@napi-rs/canvas";
import { promises } from "fs";
import fs from "fs";
import path from "path";
import Sharp from "sharp";

const IMAGE_GENERATED_COUNT_PER_CLASS = 5;
const MAX_NUMBER_OF_CLASSES_IN_SINGLE_IMAGE = 10;
const CANVAS_WIDTH = 1024;
const CANVAS_HEIGHT = 800;
const CONCURRENT_PROMISE_SIZE = 50;

const CanvasSize = [CANVAS_WIDTH, CANVAS_HEIGHT];

function isNotOverlap(x1: number, y1: number, width1: number, height1: number, x2: number, y2: number, width2: number, height2: number) {
    return x1 >= x2 + width2 || x1 + width1 <= x2 || y1 >= y2 + height2 || y1 + height1 <= y2;
}

const randomColorList: Record<string, string> = {
    "white": "rgb(255, 255, 255)",
    "black": "rgb(0, 0, 0)",
    "red": "rgb(255, 0, 0)",
    "green": "rgb(0, 255, 0)",
    "blue": "rgb(0, 0, 255)",
    "yellow": "rgb(255, 255, 0)",
    "cyan": "rgb(0, 255, 255)",
    "magenta": "rgb(255, 0, 255)",
    "gray": "rgb(128, 128, 128)",
    "grey": "rgb(128, 128, 128)",
    "maroon": "rgb(128, 0, 0)",
    "olive": "rgb(128, 128, 0)",
    "purple": "rgb(128, 0, 128)",
    "teal": "rgb(0, 128, 128)",
    "navy": "rgb(0, 0, 128)",
    "orange": "rgb(255, 165, 0)",
    "aliceblue": "rgb(240, 248, 255)",
    "antiquewhite": "rgb(250, 235, 215)",
    "aquamarine": "rgb(127, 255, 212)",
    "azure": "rgb(240, 255, 255)",
    "beige": "rgb(245, 245, 220)",
    "bisque": "rgb(255, 228, 196)",
    "blanchedalmond": "rgb(255, 235, 205)",
    "blueviolet": "rgb(138, 43, 226)",
    "brown": "rgb(165, 42, 42)",
    "burlywood": "rgb(222, 184, 135)",
    "cadetblue": "rgb(95, 158, 160)",
    "chartreuse": "rgb(127, 255, 0)",
    "chocolate": "rgb(210, 105, 30)",
    "coral": "rgb(255, 127, 80)",
    "cornflowerblue": "rgb(100, 149, 237)",
    "cornsilk": "rgb(255, 248, 220)",
    "crimson": "rgb(220, 20, 60)",
    "darkblue": "rgb(0, 0, 139)",
    "darkcyan": "rgb(0, 139, 139)",
    "darkgoldenrod": "rgb(184, 134, 11)",
    "darkgray": "rgb(169, 169, 169)",
    "darkgreen": "rgb(0, 100, 0)",
    "darkgrey": "rgb(169, 169, 169)",
    "darkkhaki": "rgb(189, 183, 107)",
    "darkmagenta": "rgb(139, 0, 139)",
    "darkolivegreen": "rgb(85, 107, 47)",
    "darkorange": "rgb(255, 140, 0)",
    "darkorchid": "rgb(153, 50, 204)",
    "darkred": "rgb(139, 0, 0)"
}

function generateColor(index: number = -1) {
    if (index < 0 || index > Object.keys(randomColorList).length) {
        // return random color from list
        let keys = Object.keys(randomColorList);
        let randomKey = keys[Math.floor(Math.random() * keys.length)];
        return randomColorList[randomKey];
    } else {
        // return color by index
        let keys = Object.keys(randomColorList);
        return randomColorList[keys[index]];
    }
}

function randomPlaceImagesInCanvas(canvasWidth: number, canvasHeight: number, images: number[][], overlapping: boolean = true) {
    let placedImages: number[][] = [];
    for (let image of images) {
        let [width, height] = image;
        let [x, y] = [Math.floor(Math.random() * (canvasWidth - width)), Math.floor(Math.random() * (canvasHeight - height))];
        let placed = false;
        for (let placedImage of placedImages) {
            let [placedImageX, placedImageY, placedImageWidth, placedImageHeight] = placedImage;
            if (overlapping || isNotOverlap(x, y, width, height, placedImageX, placedImageY, placedImageWidth, placedImageHeight)) {
                placed = true;
            }
        }
        placedImages.push([x, y, placed ? 1 : 0]);
    }
    return placedImages;
}

function getSizeBasedOnRatio(width: number, height: number, ratio: number) {
    return [width * ratio, height];
}

function cartesianProductOfArray(...arrays: any[][]) {
    return arrays.reduce((a, b) => a.flatMap((d: any) => b.map((e: any) => [d, e].flat())));
}

function rotateRectangleAndGetSize(width: number, height: number, angle: number) {
    let radians = angle * Math.PI / 180;
    let cos = Math.abs(Math.cos(radians));
    let sin = Math.abs(Math.sin(radians));
    let newWidth = Math.ceil(width * cos + height * sin);
    let newHeight = Math.ceil(height * cos + width * sin);
    return [newWidth, newHeight];
}

function concurrentlyExecutePromisesWithSize(promises: Promise<any>[], size: number): Promise<void> {
    let promisesToExecute = promises.slice(0, size);
    let promisesToWait = promises.slice(size);
    return Promise.all(promisesToExecute).then(() => {
        if (promisesToWait.length > 0) {
            return concurrentlyExecutePromisesWithSize(promisesToWait, size);
        }
    });
}

function generateRandomRgbColor() {
    return [Math.floor(Math.random() * 256), Math.floor(Math.random() * 256), Math.floor(Math.random() * 256)];
}

function getSizeOfImage(image: Image) {
    return [image.width, image.height];
}

async function makeSureFolderExists(path: string) {
    if (!fs.existsSync(path)) {
        await promises.mkdir(path, { recursive: true });
    }
}

// non repeatly select elements from array
async function randomSelectFromArray<T>(array: T[], count: number) {
    let copied = array.slice();
    let selected: T[] = [];
    for (let i = 0; i < count; i++) {
        let index = Math.floor(Math.random() * copied.length);
        selected.push(copied[index]);
        copied.splice(index, 1);
    }
    return selected;
}

function getFileNameFromPathWithoutPrefix(path: string) {
    return path.split("/").pop()!.split(".")[0];
}

type Annotion = {
    "image": string,
    "annotions": {
        "label": string,
        "coordinates": {
            "x": number,
            "y": number,
            "width": number,
            "height": number
        }
    }[]
}

async function generateCreateMLFormatOutput(folderPath: string, outputDir: string, imageCountPerFile: number = IMAGE_GENERATED_COUNT_PER_CLASS) {

    if (!fs.existsSync(path.join(folderPath, "real"))) {
        throw new Error("real folder does not exist");
    }

    let realFiles = fs.readdirSync(path.join(folderPath, "real")).map((file) => path.join(folderPath, "real", file));
    let confusionFiles: string[] = [];

    if (fs.existsSync(path.join(folderPath, "confusion"))) {
        confusionFiles = fs.readdirSync(path.join(folderPath, "confusion")).map((file) => path.join(folderPath, "confusion", file));
    }

    // getting files in folder
    let tasks: Promise<void>[] = [];
    let annotions: Annotion[] = [];

    for (let filePath of realFiles) {

        let className = getFileNameFromPathWithoutPrefix(filePath);

        for (let i = 0; i < imageCountPerFile; i++) {

            let annotion: Annotion = {
                "image": `${className}-${i}.jpg`,
                "annotions": []
            };

            async function __task(i: number) {

                let randomCount = Math.random() * MAX_NUMBER_OF_CLASSES_IN_SINGLE_IMAGE;
                randomCount = randomCount > realFiles.length + confusionFiles.length ? realFiles.length + confusionFiles.length : randomCount;
                let selectedFiles = await randomSelectFromArray(realFiles.concat(confusionFiles), randomCount);
                if (selectedFiles.includes(filePath)) {
                    // move filePath to the first
                    selectedFiles.splice(selectedFiles.indexOf(filePath), 1);
                    selectedFiles.unshift(filePath);
                } else {
                    selectedFiles.unshift(filePath);
                }

                console.log(`processing ${filePath} ${i}, selected ${selectedFiles.length} files`);

                let images = await Promise.all(selectedFiles.map(async (filePath) => {
                    let file = await promises.readFile(filePath);
                    let image = new Image();
                    image.src = file;
                    return image;
                }));

                console.log(`processing: ${filePath}, loaded images, start to place images in canvas`);

                let imageSizes = images.map(getSizeOfImage).map( x => {
                    let averageX = CanvasSize[0] / (images.length + 1);
                    let averageY = CanvasSize[1] / (images.length + 1);
                    return [x[0] > averageX ? averageX : x[0], x[1] > averageY ? averageY : x[1]];
                });

                let placedPoints = randomPlaceImagesInCanvas(CANVAS_WIDTH, CANVAS_HEIGHT, imageSizes, false);

                console.log(`processing: ${filePath}, placed images in canvas, start to draw images`);

                let angle = 0;
                let color = generateColor(i);

                let [canvasWidth, canvasHeight] = CanvasSize;
                const canvas = createCanvas(canvasWidth, canvasHeight);
                const ctx = canvas.getContext("2d");

                ctx.fillStyle = color;
                ctx.fillRect(0, 0, canvasWidth, canvasHeight);

                for (let j = 0; j < images.length; j++) {
                    const ctx = canvas.getContext("2d");

                    let ratio = Math.random() * 1.5 + 0.5;

                    let image = images[j];

                    let [_imageWidth, _imageHeight] = imageSizes[j];
                    let [imageWidth, imageHeight] = getSizeBasedOnRatio(_imageWidth, _imageHeight, ratio);

                    let placed = placedPoints[j][2] === 1 ? true : false;
                    if (!placed) {
                        continue;
                    }

                    let targetX = placedPoints[j][0] > imageWidth / 2 ? placedPoints[j][0] : imageWidth / 2;
                    let targetY = placedPoints[j][1] > imageHeight / 2 ? placedPoints[j][1] : imageHeight / 2;

                    let sizeAfterRotatation = rotateRectangleAndGetSize(imageWidth, imageHeight, angle);

                    console.log("final: ", [canvasWidth, canvasHeight], [imageWidth, imageHeight], [targetX, targetY], angle, ratio, color);

                    ctx.translate(targetX, targetY);
                    ctx.rotate(angle * Math.PI / 180);

                    ctx.drawImage(image, -imageWidth / 2, -imageHeight / 2, imageWidth, imageHeight);

                    ctx.rotate(-angle * Math.PI / 180);
                    ctx.translate(-targetX, -targetY);

                    // ctx.fillStyle = "green";
                    // ctx.strokeRect(targetX - sizeAfterRotatation[0] / 2, targetY - sizeAfterRotatation[1] / 2, sizeAfterRotatation[0], sizeAfterRotatation[1]);

                    annotion.annotions.push({
                        "label": getFileNameFromPathWithoutPrefix(selectedFiles[j]),
                        "coordinates": {
                            "x": targetX,
                            "y": targetY,
                            "width": sizeAfterRotatation[0],
                            "height": sizeAfterRotatation[1]
                        }
                    });
                }

                if (!annotion.annotions.length) {
                    return;
                }

                let fileName = path.join(outputDir, `${className}-${i}.jpg`);
                let pngData = await canvas.encode("jpeg");
                await promises.writeFile(fileName, pngData);

                annotions.push(annotion);
            }

            tasks.push(__task(i));

        }

    }

    await concurrentlyExecutePromisesWithSize(tasks, CONCURRENT_PROMISE_SIZE);

    await promises.writeFile(path.join(outputDir, "annotions.json"), JSON.stringify(annotions, null, 4));

}

async function generateYoloFormatOutput(folderPath: string) {
    const annotions = JSON.parse((await promises.readFile(path.join(folderPath, "annotions.json"))).toString("utf-8")) as Annotion[];

    // generate data.yml
    let classes: string[] = [];
    for (let annotion of annotions) {
        for (let label of annotion.annotions.map(a => a.label)) {
            if (!classes.includes(label)) {
                classes.push(label);
            }
        }
    }

    let dataYml = `
train: ./train/images
val: ./valid/images
test: ./test/images

nc: ${classes.length}
names: ${JSON.stringify(classes)}
`
    await promises.writeFile(path.join(folderPath, "data.yml"), dataYml);

    const weights = [0.85, 0.90, 0.95];
    const split = ["train", "valid", "test"];

    let tasks: Promise<void>[] = [];

    async function __task(annotion: Annotion) {
        const randomSeed = Math.random();
        let index = 0;
        for (let i = 0; i < weights.length; i++) {
            if (randomSeed < weights[i]) {
                index = i;
                break;
            }
        }
        let splitFolderName = split[index];
        await makeSureFolderExists(path.join(folderPath, splitFolderName));
        await makeSureFolderExists(path.join(folderPath, splitFolderName, "images"));
        await makeSureFolderExists(path.join(folderPath, splitFolderName, "labels"));

        // get info of image
        let image = await Sharp(path.join(folderPath, annotion.image)).metadata();

        // generate label files
        let line: [number, number, number, number, number][] = []
        for (let i of annotion.annotions) {
            line.push([
                classes.indexOf(i.label),
                i.coordinates.x / image.width!,
                i.coordinates.y / image.height!,
                i.coordinates.width / image.width!,
                i.coordinates.height / image.height!
            ])
        }

        await promises.rename(path.join(folderPath, annotion.image), path.join(folderPath, splitFolderName, "images", annotion.image));
        await promises.writeFile(path.join(folderPath, splitFolderName, "labels", annotion.image.replace(".jpg", ".txt")), line.map(l => l.join(" ")).join("\n"));
    }

    for (let annotion of annotions) {
        tasks.push(__task(annotion));
    }

    await concurrentlyExecutePromisesWithSize(tasks, CONCURRENT_PROMISE_SIZE);

}

(async () => {

    await generateCreateMLFormatOutput("./database", "./output");

    // await generateYoloFormatOutput("./output");

})();
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.
  • 122.
  • 123.
  • 124.
  • 125.
  • 126.
  • 127.
  • 128.
  • 129.
  • 130.
  • 131.
  • 132.
  • 133.
  • 134.
  • 135.
  • 136.
  • 137.
  • 138.
  • 139.
  • 140.
  • 141.
  • 142.
  • 143.
  • 144.
  • 145.
  • 146.
  • 147.
  • 148.
  • 149.
  • 150.
  • 151.
  • 152.
  • 153.
  • 154.
  • 155.
  • 156.
  • 157.
  • 158.
  • 159.
  • 160.
  • 161.
  • 162.
  • 163.
  • 164.
  • 165.
  • 166.
  • 167.
  • 168.
  • 169.
  • 170.
  • 171.
  • 172.
  • 173.
  • 174.
  • 175.
  • 176.
  • 177.
  • 178.
  • 179.
  • 180.
  • 181.
  • 182.
  • 183.
  • 184.
  • 185.
  • 186.
  • 187.
  • 188.
  • 189.
  • 190.
  • 191.
  • 192.
  • 193.
  • 194.
  • 195.
  • 196.
  • 197.
  • 198.
  • 199.
  • 200.
  • 201.
  • 202.
  • 203.
  • 204.
  • 205.
  • 206.
  • 207.
  • 208.
  • 209.
  • 210.
  • 211.
  • 212.
  • 213.
  • 214.
  • 215.
  • 216.
  • 217.
  • 218.
  • 219.
  • 220.
  • 221.
  • 222.
  • 223.
  • 224.
  • 225.
  • 226.
  • 227.
  • 228.
  • 229.
  • 230.
  • 231.
  • 232.
  • 233.
  • 234.
  • 235.
  • 236.
  • 237.
  • 238.
  • 239.
  • 240.
  • 241.
  • 242.
  • 243.
  • 244.
  • 245.
  • 246.
  • 247.
  • 248.
  • 249.
  • 250.
  • 251.
  • 252.
  • 253.
  • 254.
  • 255.
  • 256.
  • 257.
  • 258.
  • 259.
  • 260.
  • 261.
  • 262.
  • 263.
  • 264.
  • 265.
  • 266.
  • 267.
  • 268.
  • 269.
  • 270.
  • 271.
  • 272.
  • 273.
  • 274.
  • 275.
  • 276.
  • 277.
  • 278.
  • 279.
  • 280.
  • 281.
  • 282.
  • 283.
  • 284.
  • 285.
  • 286.
  • 287.
  • 288.
  • 289.
  • 290.
  • 291.
  • 292.
  • 293.
  • 294.
  • 295.
  • 296.
  • 297.
  • 298.
  • 299.
  • 300.
  • 301.
  • 302.
  • 303.
  • 304.
  • 305.
  • 306.
  • 307.
  • 308.
  • 309.
  • 310.
  • 311.
  • 312.
  • 313.
  • 314.
  • 315.
  • 316.
  • 317.
  • 318.
  • 319.
  • 320.
  • 321.
  • 322.
  • 323.
  • 324.
  • 325.
  • 326.
  • 327.
  • 328.
  • 329.
  • 330.
  • 331.
  • 332.
  • 333.
  • 334.
  • 335.
  • 336.
  • 337.
  • 338.
  • 339.
  • 340.
  • 341.
  • 342.
  • 343.
  • 344.
  • 345.
  • 346.
  • 347.
  • 348.
  • 349.
  • 350.
  • 351.
  • 352.
  • 353.
  • 354.
  • 355.
  • 356.
  • 357.
  • 358.
  • 359.
  • 360.
  • 361.
  • 362.
  • 363.
  • 364.
  • 365.
  • 366.
  • 367.
  • 368.
  • 369.
  • 370.
  • 371.
  • 372.
  • 373.
  • 374.
  • 375.
  • 376.
  • 377.
  • 378.
  • 379.
  • 380.
  • 381.
  • 382.
  • 383.

这个脚本的思路大概是将这 40 多张图片随意揉成各种可能的形状,然后选取若干张把它撒在画布上,画布的背景也是随机的,用来模拟足够多的场景。

顺带一说,上面 500MB 这个量级并不是一下子就定好的,而是不断试验,为了更高的准确度一步一步地提高量级。

模型训练

下一步就比较简单了,在 CreateML 上选取你的数据集,然后就可以训练了:

图片

图片

可以看到 CreateML 的 Object Detection 其实是基于 Yolo V2 的,最先进的 Yolo 版本应该是 Yolo V7,但是生态最健全的应该还是 Yolo V5。

图片

在我的 M1 Pro 机器上大概需要训练 10h+,在 Intel 的笔记本上训练时间会更长。整个过程有点像「炼蛊」了,从 500 多 MB 的文件算出一个 80MB 的文件。

模型测试

训练完之后,你可以得到上面「尝鲜」中得到模型文件,大概它拖动任意文件进去,就可以测试模型的效果了:

图片

在 iOS 中使用的模型

官方的 Demo 可以参照这个例子:

​https://developer.apple.com/documentation/vision/recognizing_objects_in_live_capture​

个人用 SwiftUI 写了一个 Demo:

//
//  ContentView.swift
//  DemoProject
/
//

import SwiftUI
import Vision

class MyVNModel: ObservableObject {
    
    static let shared: MyVNModel = MyVNModel()
        
    @Published var parsedModel: VNCoreMLModel? = .none
    var images: [UIImage]? = .none
    var observationList: [[VNObservation]]? = .none
    
    func applyModelToCgImage(image: CGImage) async throws -> [VNObservation] {
        guard let parsedModel = parsedModel else {
            throw EvaluationError.resourceNotFound("cannot find parsedModel")
        }
        
        let resp = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<[VNObservation], Error>) in
            let requestHandler = VNImageRequestHandler(cgImage: image)
            let request = VNCoreMLRequest(model: parsedModel) { request, error in
                if let _ = error {
                    return
                }
                if let results = request.results {
                    continuation.resume(returning: results)
                } else {
                    continuation.resume(throwing: EvaluationError.invalidExpression(
                        "cannot find observations in result"
                    ))
                }
            }
            #if targetEnvironment(simulator)
                request.usesCPUOnly = true
            #endif
            do {
                // Perform the text-recognition request.
                try requestHandler.perform([request])
            } catch {
                continuation.resume(throwing: error)
            }
        }
        return resp
    }
    
    init() {
        Task(priority: .background) {
            let urlPath = Bundle.main.url(forResource: "genshin2", withExtension: "mlmodelc")
            guard let urlPath = urlPath else {
                print("cannot find file genshin2.mlmodelc")
                return
            }
            
            let config = MLModelConfiguration()
            let modelResp = await withCheckedContinuation { continuation in
                MLModel.load(contentsOf: urlPath, configuration: config) { result in
                    continuation.resume(returning: result)
                }
            }

            let model = try { () -> MLModel in
                switch modelResp {
                case let .success(m):
                    return m
                case let .failure(err):
                    throw err
                }
            }()

            let parsedModel = try VNCoreMLModel(for: model)
            DispatchQueue.main.async {
                self.parsedModel = parsedModel
            }
        }
    }
    
}

struct ContentView: View {
    
    enum SheetType: Identifiable {
        case photo
        case confirm
        var id: SheetType { self }
    }
    
    @State var showSheet: SheetType? = .none
    
    @ObservedObject var viewModel: MyVNModel = MyVNModel.shared

    var body: some View {
        VStack {
            Button {
                showSheet = .photo
            } label: {
                Text("Choose Photo")
            }
        }
        .sheet(item: $showSheet) { sheetType in
            switch sheetType {
            case .photo:
                PhotoLibrary(handlePickedImage: { images in
                    
                    guard let images = images else {
                        print("no images is selected")
                        return
                    }
                    
                    var observationList: [[VNObservation]] = []
                    Task {
                        for image in images {

                            guard let cgImage = image.cgImage else {
                                throw EvaluationError.cgImageRetrievalFailure
                            }
                            
                            let result = try await viewModel.applyModelToCgImage(image: cgImage)
                            print("model applied: (result)")
                            
                            observationList.append(result)
                        }
                        
                        DispatchQueue.main.async {
                            viewModel.images = images
                            viewModel.observationList = observationList
                            self.showSheet = .confirm
                        }
                    }
                    
                }, selectionLimit: 1)
            case .confirm:
                if let images = viewModel.images, let observationList = viewModel.observationList {
                    VNObservationConfirmer(imageList: images, observations: observationList, onSubmit: { _,_  in
                        
                    })
                } else {
                    Text("No Images (viewModel.images?.count ?? 0) (viewModel.observationList?.count ?? 0)")
                }
                
            }
            
        }
        .padding()
    }
}

struct ContentView_Previews: PreviewProvider {
    static var previews: some View {
        ContentView()
    }
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.
  • 122.
  • 123.
  • 124.
  • 125.
  • 126.
  • 127.
  • 128.
  • 129.
  • 130.
  • 131.
  • 132.
  • 133.
  • 134.
  • 135.
  • 136.
  • 137.
  • 138.
  • 139.
  • 140.
  • 141.
  • 142.
  • 143.
  • 144.
  • 145.
  • 146.
  • 147.
  • 148.
  • 149.
  • 150.
  • 151.
  • 152.
  • 153.
  • 154.
  • 155.

运行效果


图片

责任编辑:武晓燕 来源: ELab团队
相关推荐

2013-04-15 09:48:40

AndroidAVD错误处理方法

2017-05-25 11:49:30

Android网络请求OkHttp

2010-04-23 09:51:12

Oracle工具

2023-07-17 06:57:16

2023-09-28 09:07:54

注解失效场景

2010-09-15 17:29:20

无线局域网

2021-12-09 09:52:36

云原生安全工具云安全

2024-06-28 07:59:34

C#编程字段

2018-07-30 08:20:39

编程语言Python集合

2011-08-19 17:44:01

2014-04-23 13:30:23

类簇iOS开发

2010-07-19 16:55:51

Telnet命令

2010-04-30 11:10:32

Oracle Sql

2011-06-15 15:16:54

Session

2023-07-10 16:00:56

AT指令建立网络连接

2011-09-02 19:12:59

IOS应用Sqlite数据库

2017-03-16 20:00:17

Kafka设计原理达观产品

2023-05-29 07:17:48

内存溢出场景

2010-03-10 11:45:15

云计算

2024-01-26 08:06:43

点赞
收藏

51CTO技术栈公众号