构建 iOS 应用程序

在本教程中,我们将探索如何构建一个集成 ONNX 运行时的设备端训练解决方案的 iOS 应用程序。设备端训练是指直接在边缘设备上训练机器学习模型的过程,而无需依赖云服务或外部服务器。

在本教程中,我们将构建一个简单的说话人识别应用程序,学习识别说话人的声音。我们将了解如何在设备上训练模型、导出训练后的模型,以及使用训练后的模型执行推理。

以下是应用程序的外观

application demo, with buttons for voice, train, and infer.

简介

我们将指导您完成构建 iOS 应用程序的过程,该应用程序可以使用设备端训练技术训练简单的音频分类模型。本教程展示了 迁移学习 技术,其中从一个任务的模型训练中获得的知识被用来提高模型在不同但相关任务上的性能。迁移学习不是从头开始学习过程,而是允许我们将预训练模型学习到的知识或特征迁移到新任务。

在本教程中,我们将利用 wav2vec 模型,该模型已在大型名人语音数据(如 VoxCeleb1)上进行了训练。我们将使用预训练模型从音频数据中提取特征,并训练二元分类器来识别说话人。模型的初始层充当特征提取器,捕获音频数据的重要特征。只有模型的最后一层被训练来执行分类任务。

在本教程中,我们将

  • 使用 iOS 音频 API 捕获音频数据以进行训练
  • 在设备上训练模型
  • 导出训练后的模型
  • 使用导出的模型执行推理

目录

先决条件

要学习本教程,您应该对机器学习和 iOS 开发有基本的了解。您还应该在您的机器上安装以下内容

注意: 整个 iOS 应用程序也在 onnxruntime-training-examples GitHub 存储库上提供。您可以克隆存储库并按照教程进行操作。

生成训练工件

  1. 将模型导出到 ONNX。

    我们将从 HuggingFace 的预训练模型开始,并将其导出到 ONNX。wav2vec 模型已在 VoxCeleb1 上进行预训练,其中包含 1000 多个类别。对于我们的任务,我们只需要将音频分类为 2 个类别。因此,我们将模型的最后一层更改为输出 2 个类别。我们将使用 transformers 库加载模型并将其导出到 ONNX。

     from transformers import Wav2Vec2ForSequenceClassification, AutoConfig
     import torch
    
     # load config from the pretrained model
     config = AutoConfig.from_pretrained("superb/wav2vec2-base-superb-sid")
     model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid")
    
     # modify last layer to output 2 classes
     model.classifier = torch.nn.Linear(256, 2)
    
     #export model to ONNX
     dummy_input = torch.randn(1, 160000, requires_grad=True)
     torch.onnx.export(model, dummy_input, "wav2vec.onnx",input_names=["input"], output_names=["output"],
                       dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
    
  2. 定义可训练和不可训练的参数

     import onnx
    
     # load the onnx model
     onnx_model = onnx.load("wav2vec.onnx")
    
     # Define the parameters that require gradients to be computed (trainable parameters) and
     # those that don't (frozen/non-trainable parameters)
     requires_grad = ["classifier.weight", "classifier.bias"]
     frozen_params = [
         param.name
         for param in onnx_model.graph.initializer
         if param.name not in requires_grad
     ]
    
  3. 生成训练工件。

    在本教程中,我们将使用 CrossEntropyLoss 损失和 AdamW 优化器。有关工件生成的更多详细信息,请参见此处

    由于模型还输出 logits 和隐藏状态,我们将使用 onnxblock 定义一个自定义损失函数,该函数从模型输出中提取 logits 并将其传递给 CrossEntropyLoss 函数。

     import onnxruntime.training.onnxblock as onnxblock
     from onnxruntime.training import artifacts
    
     # define the loss function
     class CustomCELoss(onnxblock.Block):
         def __init__(self):
             super().__init__()
             self.celoss = onnxblock.loss.CrossEntropyLoss()
    
    
         def build(self, logits, *args):
             return self.celoss(logits)
    
    
     # Generate the training artifacts
     artifacts.generate_artifacts(
         onnx_model,
         requires_grad=requires_grad,
         frozen_params=frozen_params,
         loss=CustomCELoss(),
         optimizer=artifacts.OptimType.AdamW,
         artifact_directory="artifacts",
     )
    
    

    就这样!训练工件已在 artifacts 目录中生成。这些工件已准备好部署到 iOS 设备进行训练。

构建 iOS 应用程序

Xcode 设置

打开 Xcode,并创建一个新项目。选择 iOS 作为平台,App 作为模板。点击“下一步”。

Xcode Setup New Project

输入项目名称。这里,我们将项目命名为“MyVoice”,但您可以将其命名为您喜欢的任何名称。确保选择 SwiftUI 作为界面,Swift 作为语言。然后,点击“下一步”。

Xcode Setup Project Name

选择您要保存项目的位置,然后点击 创建

现在,我们需要将 onnxruntime-training-objc pods 添加到项目中。我们将使用 CocoaPods 添加依赖项。如果您没有安装 CocoaPods,您可以查看 此处 的安装说明。

安装 CocoaPods 后,导航到项目目录并运行以下命令以创建 Podfile

pod init

这将在项目目录中创建一个 Podfile。打开 Podfile 并在 use_frameworks! 行后添加以下行

pod `onnxruntime-training-objc`, `~> 1.16.0`

保存 Podfile 并运行以下命令以安装依赖项

pod install

这将在项目目录中创建一个 MyVoice.xcworkspace 文件。在 Xcode 中打开 xcworkspace 文件。这将使用 CocoaPods 依赖项在 Xcode 中打开项目。

现在,右键单击项目导航器中的“MyVoice”组,然后单击“新建组”以在项目中创建一个名为 artifacts 的新组。将上一节中生成的工件拖放到 artifacts 组中。确保选择 创建文件夹引用如果需要则复制项目 选项。这将把工件添加到项目中。

接下来,右键单击“MyVoice”组,然后单击“新建组”以在项目中创建一个名为 recordings 的新组。此组将包含用于训练的录音。您可以通过运行项目根目录下的 recording_gen.py 脚本来生成录音。或者,您也可以使用除您计划用于训练的说话人之外的任何其他说话人的录音。确保录音是单声道,长度为 10 秒,.wav 格式,采样率为 16KHz。此外,请确保将录音命名为 other_0.wavother_1.wav 等,并将它们添加到 recordings 组中。

项目结构应如下所示

Xcode Project Structure

应用程序概述

该应用程序将包含两个主要的 UI 视图:TrainViewInferViewTrainView 用于在设备上训练模型,InferView 用于使用训练后的模型执行推理。此外,还有 ContentView,它是应用程序的主视图,包含导航到 TrainViewInferView 的按钮。

此外,我们还将创建一个 AudioRecorder 类来处理通过麦克风录制音频。它将录制 10 秒的音频,并将音频数据作为 Data 对象输出,该对象可用于训练和推理目的。

我们将有一个 Trainer 类,它将处理模型的训练和导出。

最后,我们还将创建一个 VoiceIdentifier 类,它将处理使用训练后的模型进行推理。

训练模型

首先,我们将创建一个 Trainer 类,它将处理模型的训练和导出。它将加载训练工件,在给定的音频上训练模型,并使用 ONNX 运行时设备端训练 API 导出训练后的模型。有关 API 的详细文档,请参见此处

Trainer 类将具有以下公共方法

  • init() - 初始化训练会话并加载训练工件。
  • train(_ trainingData: [Data]) - 在给定的用户音频数据上训练模型。它将接收一个 Data 对象数组,其中每个 Data 对象代表用户的音频数据,并将其与一些预先录制的音频数据一起使用来训练模型。
  • exportModelForInference() - 导出训练后的模型以进行推理。
  1. 加载训练工件并初始化训练会话

    要训练模型,我们首先需要加载工件,创建 ORTEnvORTTrainingSessionORTCheckpoint。这些对象将用于训练模型。我们将在 Trainer 类的 init 方法中创建这些对象。

     import Foundation
     import onnxruntime_training_objc
    
     class Trainer {
         private let ortEnv: ORTEnv
         private let trainingSession: ORTTrainingSession
         private let checkpoint: ORTCheckpoint
            
         enum TrainerError: Error {
             case Error(_ message: String)
         }
            
         init() throws {
             ortEnv = try ORTEnv(loggingLevel: ORTLoggingLevel.warning)
                
             // get path for artifacts
             guard let trainingModelPath = Bundle.main.path(forResource: "training_model", ofType: "onnx") else {
                 throw TrainerError.Error("Failed to find training model file.")
             }
                
             guard let evalModelPath = Bundle.main.path(forResource: "eval_model",ofType: "onnx") else {
                 throw TrainerError.Error("Failed to find eval model file.")
             }
                
             guard let optimizerPath = Bundle.main.path(forResource: "optimizer_model", ofType: "onnx") else {
                 throw TrainerError.Error("Failed to find optimizer model file.")
             }
                
             guard let checkpointPath = Bundle.main.path(forResource: "checkpoint", ofType: nil) else {
                 throw TrainerError.Error("Failed to find checkpoint file.")
             }
                
             checkpoint = try ORTCheckpoint(path: checkpointPath)
             trainingSession = try ORTTrainingSession(env: ortEnv, sessionOptions: ORTSessionOptions(), checkpoint: checkpoint, trainModelPath: trainingModelPath, evalModelPath: evalModelPath, optimizerModelPath: optimizerPath)
         }
     }
    
  2. 训练模型

    a. 在训练模型之前,我们首先需要从我们在前面部分创建的 wav 文件中提取数据。这是一个简单的函数,它将从 wav 文件中提取数据。

    private func getDataFromWavFile(fileName: String) throws -> (AVAudioBuffer, Data) {
        guard let fileUrl = Bundle.main.url(forResource: fileName, withExtension:"wav") else {
            throw TrainerError.Error("Failed to find wav file: \(fileName).")
        }
            
        let audioFile = try AVAudioFile(forReading: fileUrl)
        let format = audioFile.processingFormat
        let totalFrames = AVAudioFrameCount(audioFile.length)
    
        guard let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: totalFrames) else {
            throw TrainerError.Error("Failed to create audio buffer.")
        }
            
        try audioFile.read(into: buffer)
            
        guard let floatChannelData = buffer.floatChannelData else {
            throw TrainerError.Error("Failed to get float channel data.")
        }
            
        let data = Data(
            bytesNoCopy: floatChannelData[0],
            count: Int(buffer.frameLength) * MemoryLayout<Float>.size,
            deallocator: .none
        )
        return (buffer, data)
    }
    

    b. TrainingSession.trainStep 函数负责训练模型。它接收输入数据和标签并返回损失。输入作为 ORTValue 对象传递给 ONNX 运行时。因此,我们需要将输入的音频 Data 对象和标签转换为 ORTValue

    private func getORTValue(dataList: [Data]) throws -> ORTValue {
        let tensorData = NSMutableData()
        dataList.forEach {data in tensorData.append(data)}
        let inputShape: [NSNumber] = [dataList.count as NSNumber, dataList[0].count / MemoryLayout<Float>.stride as NSNumber]
            
        return try ORTValue(
            tensorData: tensorData, elementType: ORTTensorElementDataType.float, shape: inputShape
        )
    }
        
    private func getORTValue(labels: [Int64]) throws -> ORTValue {
        let tensorData = NSMutableData(bytes: labels, length: labels.count * MemoryLayout<Int64>.stride)
        let inputShape: [NSNumber] = [labels.count as NSNumber]
            
        return try ORTValue (
            tensorData: tensorData, elementType: ORTTensorElementDataType.int64, shape: inputShape
        )
    }
    

    c. 我们已准备好编写 trainStep 函数,该函数接收一批输入数据和标签,并在给定的批次上执行一个训练步骤。

    func trainStep(inputData: [Data], labels: [Int64]) throws  {
        let inputs = [try getORTValue(dataList: inputData), try getORTValue(labels: labels)]
        try trainingSession.trainStep(withInputValues: inputs)
            
        // update the model params
        try trainingSession.optimizerStep()
            
        // reset the gradients
        try trainingSession.lazyResetGrad()
    }
    

    d. 最后,我们拥有编写训练循环所需的一切。在这里,kNumOtherRecordings 表示我们在之前创建的 recordings 目录中有多少录音。kNumEpochs 表示我们希望在给定数据上训练模型的轮次。kUserIndexkOtherIndex 分别表示用户和其他录音的标签。

    我们还有一个 progressCallback,它将在每个训练步骤后被调用。我们将使用此回调来更新 UI 中的进度条。

    private let kNumOtherRecordings: Int = 20
    private let kNumEpochs: Int = 3
        
    let kUserIndex: Int64 = 1
    let kOtherIndex: Int64 = 0
    
    func train(_ trainingData: [Data], progressCallback: @escaping (Double) -> Void) throws {
        let numRecordings = trainingData.count
        var otherRecordings = Array(0..<kNumOtherRecordings)
        for e in 0..<kNumEpochs {
            print("Epoch: \(e)")
            otherRecordings.shuffle()
            let otherData = otherRecordings.prefix(numRecordings)
                
            for i in 0..<numRecordings {
                let (buffer, wavFileData) = try getDataFromWavFile(fileName: "other_\(otherData[i])")
                try trainStep(inputData: [trainingData[i], wavFileData], labels: [kUserIndex, kOtherIndex])
                print("finished training on recording \(i)")
                    
                let progress = Double((e * numRecordings) + i + 1) / Double(kNumEpochs * numRecordings)
                progressCallback(progress)
            }
        }
            
    }
    
  3. 导出训练后的模型

    我们可以使用 ORTTrainingSession 类的 exportModelForInference 方法导出训练后的模型。该方法接收模型应导出的路径和模型的输出名称。

    在这里,我们将模型导出到应用程序的 Library 目录。导出的模型将用于推理目的。

    func exportModelForInference() throws {
        guard let libraryDirectory = FileManager.default.urls(for: .libraryDirectory, in: .userDomainMask).first else {
            throw TrainerError.Error("Failed to find library directory ")
        }
            
        let modelPath = libraryDirectory.appendingPathComponent("inference_model.onnx").path
        try trainingSession.exportModelForInference(withOutputPath: modelPath, graphOutputNames: ["output"])
    }
    

您可以在此处找到 Trainer 类的完整实现。

使用训练后的模型进行推理

VoiceIdentifier 类将处理使用训练后的模型进行推理。它将加载训练后的模型并对给定的音频数据执行推理。该类将具有 evaluate(inputData: Data) -> Result<(Bool, Float), Error> 方法,该方法将接收音频数据并返回推理结果。结果将是一个 (Bool, Float) 元组,其中第一个元素表示音频是否被识别为用户,第二个元素表示预测的置信度得分。

首先,我们使用 ORTSession 对象加载训练后的模型。

class VoiceIdentifier {
    
    private let ortEnv : ORTEnv
    private let ortSession: ORTSession
    private let kThresholdProbability: Float = 0.80
    
    enum VoiceIdentifierError: Error {
        case Error(_ message: String)
    }
    
    init() throws {
        ortEnv = try ORTEnv(loggingLevel: ORTLoggingLevel.warning)

        guard let libraryDirectory = FileManager.default.urls(for: .libraryDirectory, in: .userDomainMask).first else {
            throw VoiceIdentifierError.Error("Failed to find library directory ")
        }
        let modelPath = libraryDirectory.appendingPathComponent("inference_model.onnx").path

        if !FileManager.default.fileExists(atPath: modelPath) {
            throw VoiceIdentifierError.Error("Failed to find inference model file.")
        }
        ortSession = try ORTSession(env: ortEnv, modelPath: modelPath, sessionOptions: nil)
    }
}

接下来,我们将编写 evaluate 方法。首先,它将获取音频数据并将其转换为 ORTValue。然后,它将使用模型执行推理。最后,它将从输出中提取 logits 并应用 softmax 以获得概率。

    private func isUser(logits: [Float]) -> Float {
        // apply softMax
        let maxInput = logits.max() ?? 0.0
        let expValues = logits.map { exp($0 - maxInput) } // Calculate e^(x - maxInput) for each element
        let expSum = expValues.reduce(0, +) // Sum of all e^(x - maxInput) values
        
        return expValues.map { $0 / expSum }[1] // Calculate the softmax probabilities
    }
    
    func evaluate(inputData: Data) -> Result<(Bool, Float), Error> {
        
        return Result<(Bool, Float), Error> { () -> (Bool, Float) in
            
            // convert input data to ORTValue
            let inputShape: [NSNumber] = [1, inputData.count / MemoryLayout<Float>.stride as NSNumber]
            
            let input = try ORTValue(
                tensorData: NSMutableData(data: inputData),
                elementType: ORTTensorElementDataType.float,
                shape: inputShape)
            
            let outputs = try ortSession.run(
                withInputs: ["input": input],
                outputNames: ["output"],
                runOptions: nil)
            
            guard let output = outputs["output"] else {
                throw VoiceIdentifierError.Error("Failed to get model output from inference.")
            }
            
            let outputData = try output.tensorData() as Data
            let probUser = outputData.withUnsafeBytes { (buffer: UnsafeRawBufferPointer) -> Float in
                let floatBuffer = buffer.bindMemory(to: Float.self)
                let logits = Array(UnsafeBufferPointer(start: floatBuffer.baseAddress, count: outputData.count/MemoryLayout<Float>.stride))
                return isUser(logits: logits)
            }
            
            return (probUser >= kThresholdProbability, probUser)
        }
    }

您可以在此处找到 VoiceIdentifier 类的完整实现。

录制音频

我们将使用 AudioRecorder 类通过麦克风录制音频。它将录制 10 秒的音频,并将音频数据作为 Data 对象输出,该对象可用于训练和推理目的。我们将使用 AVFoundation 框架来访问麦克风并录制音频。将有一个公共方法 record(callback: @escaping RecordingDoneCallback),它将录制音频并在录制完成后使用音频数据调用回调函数。

import AVFoundation
import Foundation

private let kSampleRate: Int = 16000
private let kRecordingDuration: TimeInterval = 10

class AudioRecorder {
    typealias RecordResult = Result<Data, Error>
    typealias RecordingDoneCallback = (RecordResult) -> Void
    
    enum AudioRecorderError: Error {
        case Error(message: String)
    }
    
    func record(callback: @escaping RecordingDoneCallback) {
        let session = AVAudioSession.sharedInstance()
        session.requestRecordPermission { allowed in
            do {
                guard allowed else {
                    throw AudioRecorderError.Error(message: "Recording permission denied.")
                }
                
                try session.setCategory(.record)
                try session.setActive(true)
                
                let tempDir = FileManager.default.temporaryDirectory
                
                let recordingUrl = tempDir.appendingPathComponent("recording.wav")
                
                let formatSettings: [String: Any] = [
                    AVFormatIDKey: kAudioFormatLinearPCM,
                    AVSampleRateKey: kSampleRate,
                    AVNumberOfChannelsKey: 1,
                    AVLinearPCMBitDepthKey: 16,
                    AVLinearPCMIsBigEndianKey: false,
                    AVLinearPCMIsFloatKey: false,
                    AVEncoderAudioQualityKey: AVAudioQuality.high.rawValue,
                ]
                
                let recorder = try AVAudioRecorder(url: recordingUrl, settings: formatSettings)
                self.recorder = recorder
                
                let delegate = RecorderDelegate(callback: callback)
                recorder.delegate = delegate
                self.recorderDelegate = delegate
                
                guard recorder.record(forDuration: kRecordingDuration) else {
                    throw AudioRecorderError.Error(message: "Failed to record.")
                }
                
                // control should resume in recorder.delegate.audioRecorderDidFinishRecording()
            } catch {
                callback(.failure(error))
            }
        }
    }
    
    private var recorderDelegate: RecorderDelegate?
    private var recorder: AVAudioRecorder?
    
    private class RecorderDelegate: NSObject, AVAudioRecorderDelegate {
        private let callback: RecordingDoneCallback
        
        init(callback: @escaping RecordingDoneCallback) {
            self.callback = callback
        }
        
        func audioRecorderDidFinishRecording(
            _ recorder: AVAudioRecorder,
            successfully flag: Bool
        ) {
            let recordResult = RecordResult { () -> Data in
                guard flag else {
                    throw AudioRecorderError.Error(message: "Recording was unsuccessful.")
                }
                
                let recordingUrl = recorder.url
                let recordingFile = try AVAudioFile(forReading: recordingUrl)
                
                guard
                    let format = AVAudioFormat(
                        commonFormat: .pcmFormatFloat32,
                        sampleRate: recordingFile.fileFormat.sampleRate,
                        channels: 1,
                        interleaved: false)
                else {
                    throw AudioRecorderError.Error(message: "Failed to create audio format.")
                }
                
                guard
                    let recordingBuffer = AVAudioPCMBuffer(
                        pcmFormat: format,
                        frameCapacity: AVAudioFrameCount(recordingFile.length))
                else {
                    throw AudioRecorderError.Error(message: "Failed to create audio buffer.")
                }
                
                try recordingFile.read(into: recordingBuffer)
                
                guard let recordingFloatChannelData = recordingBuffer.floatChannelData else {
                    throw AudioRecorderError.Error(message: "Failed to get float channel data.")
                }
                
                return Data(bytes: recordingFloatChannelData[0], count: Int(recordingBuffer.frameLength) * MemoryLayout<Float>.size)
               
            }
            
            callback(recordResult)
        }
        
        func audioRecorderEncodeErrorDidOccur(
            _ recorder: AVAudioRecorder,
            error: Error?
        ) {
            if let error = error {
                callback(.failure(error))
            } else {
                callback(.failure(AudioRecorderError.Error(message: "Encoding was unsuccessful.")))
            }
        }
    }
}

训练视图

TrainView 将用于在用户的声音上训练模型。首先,它将提示用户录制 kNumRecordings 次他们的声音。然后,它将在用户的声音和一些预先录制的其他说话人声音的录音上训练模型。最后,它将导出训练后的模型以进行推理。

import SwiftUI

struct TrainView: View {
    
    enum ViewState {
        case recordingTrainingData, trainingInProgress, trainingComplete
    }
    
    private static let sentences = [
        "In the embrace of nature's beauty, I find peace and tranquility. The gentle rustling of leaves soothes my soul, and the soft sunlight kisses my skin. As I breathe in the fresh air, I am reminded of the interconnectedness of all living things, and I feel a sense of oneness with the world around me.",
        "Under the starlit sky, I gaze in wonder at the vastness of the universe. Each twinkle represents a story yet untold, a dream yet to be realized. With every new dawn, I am filled with hope and excitement for the opportunities that lie ahead. I embrace each day as a chance to grow, to learn, and to create beautiful memories.",
        "A warm hug from a loved one is a precious gift that warms my heart. In that tender embrace, I feel a sense of belonging and security. Laughter and tears shared with dear friends create a bond that withstands the test of time. These connections enrich my life and remind me of the power of human relationships.",
        "Life's journey is like a beautiful melody, with each note representing a unique experience. As I take each step, I harmonize with the rhythm of existence. Challenges may come my way, but I face them with resilience and determination, knowing they are opportunities for growth and self-discovery.",
        "With every page turned in a book, I open the door to new worlds and ideas. The written words carry the wisdom of countless souls, and I am humbled by the knowledge they offer. In stories, I find a mirror to my own experiences and a beacon of hope for a better tomorrow.",
        "Life's trials may bend me, but they will not break me. Through adversity, I discover the strength within my heart. Each obstacle is a chance to learn, to evolve, and to emerge as a better version of myself. I am grateful for every lesson, for they shape me into the person I am meant to be.",
        "The sky above is an ever-changing canvas of colors and clouds. In its vastness, I realize how small I am in the grand scheme of things, and yet, I know my actions can ripple through the universe. As I walk this Earth, I seek to leave behind a positive impact and a legacy of love and compassion.",
        "In the stillness of meditation, I connect with the depth of my soul. The external noise fades away, and I hear the whispers of my inner wisdom. With each breath, I release tension and embrace serenity. Meditation is my sanctuary, a place where I can find clarity and renewed energy.",
        "Kindness is a chain reaction that spreads like wildfire. A simple act of compassion can brighten someone's day and inspire them to pay it forward. Together, we can create a wave of goodness that knows no boundaries, reaching even the farthest corners of the world.",
        "As the sun rises on a new day, I am filled with gratitude for the gift of life. Every moment is a chance to make a difference, to love deeply, and to embrace joy. I welcome the adventures that await me and eagerly embrace the mysteries yet to be uncovered."
    ]

    
    private let kNumRecordings = 5
    private let audioRecorder = AudioRecorder()
    private let trainer = try! Trainer()
    
    @State private var trainingData: [Data] = []
    
    @State private var viewState: ViewState = .recordingTrainingData
    @State private var readyToRecord: Bool = true
    @State private var trainingProgress: Double = 0.0
    
    private func recordVoice() {
        audioRecorder.record { recordResult in
           switch recordResult {
           case .success(let recordingData):
               trainingData.append(recordingData)
               print("Successfully completed Recording")
           case .failure(let error):
               print("Error: \(error)")
            }
            
            readyToRecord = true
            
            if trainingData.count == kNumRecordings  {
                viewState = .trainingInProgress
                trainAndExportModel()
            }
        }
    }
    
    private func updateProgressBar(progress: Double) {
        DispatchQueue.main.async {
            trainingProgress = progress
        }
    }
    
    private func trainAndExportModel() {
        Task {
            do {
                try trainer.train(trainingData, progressCallback: updateProgressBar)
                try trainer.exportModelForInference()
                   
                DispatchQueue.main.async {
                    viewState = .trainingComplete
                    print("Training is complete")
                }
            } catch {
                DispatchQueue.main.async {
                    viewState = .trainingComplete
                    print("Training Failed: \(error)")
                }
            }
        }
    }
    
    
    var body: some View {
        VStack {
           
            switch viewState {
            case .recordingTrainingData:
                Text("\(trainingData.count + 1) of \(kNumRecordings)")
                    .font(.caption)
                    .foregroundColor(.secondary)
                    .padding()
                
                ProgressView(value: Double(trainingData.count),
                             total: Double(kNumRecordings))
                .progressViewStyle(LinearProgressViewStyle(tint: .purple))
                .frame(height: 10)
                .cornerRadius(5)
                
                Spacer()
                
                Text(TrainView.sentences[trainingData.count % TrainView.sentences.count])
                    .font(.body)
                    .padding()
                    .multilineTextAlignment(.center)
                    .fontDesign(.monospaced)
                
                Spacer()
                
                ZStack(alignment: .center) {
                    Image(systemName: "mic.fill")
                        .resizable()
                        .aspectRatio(contentMode: .fit)
                        .frame(width: 100, height: 100)
                        .foregroundColor( readyToRecord ? .gray: .red)
                        .transition(.scale)
                        .animation(.easeIn, value: 1)
                }
                
                Spacer()
                
                Button(action: {
                    readyToRecord = false
                    recordVoice()
                }) {
                    Text(readyToRecord ? "Record" : "Recording ...")
                        .font(.title)
                        .padding()
                        .background(readyToRecord ? .green : .gray)
                        .foregroundColor(.white)
                        .cornerRadius(10)
                }.disabled(!readyToRecord)
                    
            case .trainingInProgress:
                VStack {
                    Spacer()
                    ProgressView(value: trainingProgress,
                                 total: 1.0,
                                 label: {Text("Training")},
                                 currentValueLabel: {Text(String(format: "%.0f%%", trainingProgress * 100))})
                    .padding()
                    Spacer()
                }
                    
            case .trainingComplete:
                Spacer()
                Text("Training successfully finished!")
                    .font(.title)
                    .padding()
                    .multilineTextAlignment(.center)
                    .fontDesign(.monospaced)
                
                Spacer()
                NavigationLink(destination: InferView()) {
                    Text("Infer")
                        .font(.title)
                        .padding()
                        .background(.purple)
                        .foregroundColor(.white)
                        .cornerRadius(10)
                }
                .padding(.leading, 20)
            }
            
            Spacer()
        }
        .padding()
        .navigationTitle("Train")
    }
}

struct TrainView_Previews: PreviewProvider {
    static var previews: some View {
        TrainView()
    }
}

您可以在此处找到 TrainView 的完整实现。

推理视图

最后,我们将创建 InferView,它将用于使用训练后的模型执行推理。它将提示用户录制他们的声音并使用训练后的模型执行推理。然后,它将显示推理结果。

import SwiftUI

struct InferView: View {
    
    enum InferResult {
        case user, other, notSet
    }
    
    private let audioRecorder = AudioRecorder()
    
    @State private var voiceIdentifier: VoiceIdentifier? = nil
    @State private var readyToRecord: Bool = true
    
    @State private var inferResult: InferResult = InferResult.notSet
    @State private var probUser: Float = 0.0
    
    @State private var showAlert = false
    @State private var alertMessage = ""

    private func recordVoice() {
        audioRecorder.record { recordResult in
            let recognizeResult = recordResult.flatMap { recordingData in
                return voiceIdentifier!.evaluate(inputData: recordingData)
            }
            endRecord(recognizeResult)
        }
    }
    
    private func endRecord(_ result: Result<(Bool, Float), Error>) {
        DispatchQueue.main.async {
            switch result {
            case .success(let (isMatch, confidence)):
                print("Your Voice with confidence: \(isMatch),  \(confidence)")
                inferResult = isMatch ? .user : .other
                probUser = confidence
            case .failure(let error):
                print("Error: \(error)")
            }
            readyToRecord = true
        }
    }
    
    var body: some View {
        VStack {
            Spacer()
            
            ZStack(alignment: .center) {
                Image(systemName: "mic.fill")
                    .resizable()
                    .aspectRatio(contentMode: .fit)
                    .frame(width: 100, height: 100)
                    .foregroundColor( readyToRecord ? .gray: .red)
                    .transition(.scale)
                    .animation(.easeInOut, value: 1)
            }
            
            Spacer()
            
            Button(action: {
                readyToRecord = false
                recordVoice()
            }) {
                Text(readyToRecord ? "Record" : "Recording ...")
                    .font(.title)
                    .padding()
                    .background(readyToRecord ? .green : .gray)
                    .foregroundColor(.white)
                    .cornerRadius(10)
                
            }.disabled(voiceIdentifier == nil || !readyToRecord)
                .opacity(voiceIdentifier == nil ? 0.5: 1.0)
            
            if  inferResult != .notSet {
                Spacer()
                ZStack (alignment: .center) {
                    Image(systemName: inferResult == .user ? "person.crop.circle.fill.badge.checkmark": "person.crop.circle.fill.badge.xmark")
                        .resizable()
                        .aspectRatio(contentMode: .fit)
                        .frame(width: 100, height: 100)
                        .foregroundColor(inferResult == .user ? .green : .red)
                        .animation(.easeInOut, value: 2)
                    
                }
                
                Text("Probability of User : \(String(format: "%.2f", probUser*100.0))%")
                    .multilineTextAlignment(.center)
                    .fontDesign(.monospaced)
            }
            
            Spacer()
        }
        .padding()
        .navigationTitle("Infer")
        .onAppear {
            do {
                voiceIdentifier = try  VoiceIdentifier()
                
            } catch {
                alertMessage = "Error initializing inference session, make sure that training is completed: \(error)"
                showAlert = true
            }
            
        }
        .alert(isPresented: $showAlert) {
            Alert(title: Text("Error"), message: Text(alertMessage), dismissButton: .default(Text("OK")))
        }
    }
}

struct InferView_Previews: PreviewProvider {
    static var previews: some View {
        InferView()
    }
}

您可以在此处找到 InferView 的完整实现。

内容视图

最后,我们将更新默认的 ContentView,使其包含导航到 TrainViewInferView 的按钮。

import SwiftUI

struct ContentView: View {
    var body: some View {
        NavigationView {
            VStack {
                
                Text("My Voice")
                    .font(.largeTitle)
                    .padding(.top, 50)
                
                Spacer()
                
                ZStack(alignment: .center) {
                    Image(systemName: "waveform.circle.fill")
                        .resizable()
                        .aspectRatio(contentMode: .fit)
                        .frame(width: 100, height: 100)
                        .foregroundColor(.purple)
                }
                
                Spacer()
                
                HStack {
                    NavigationLink(destination: TrainView()) {
                        Text("Train")
                            .font(.title)
                            .padding()
                            .background(Color.purple)
                            .foregroundColor(.white)
                            .cornerRadius(10)
                    }
                    .padding(.trailing, 20)
                    
                    NavigationLink(destination: InferView()) {
                        Text("Infer")
                            .font(.title)
                            .padding()
                            .background(.purple)
                            .foregroundColor(.white)
                            .cornerRadius(10)
                    }
                    .padding(.leading, 20)
                }
                
                Spacer()
            }
            .padding()
        }
    }
}

struct ContentView_Previews: PreviewProvider {
    static var previews: some View {
        ContentView()
    }
}

您可以在此处找到 ContentView 的完整实现。

运行 iOS 应用程序

现在,我们准备运行应用程序。您可以在模拟器或设备上运行应用程序。您可以在此处找到有关在模拟器和设备上运行应用程序的更多信息。

a. 现在,当您运行应用程序时,您应该看到以下屏幕

My Voice application with Train and Infer buttons

b. 接下来,点击 Train 按钮导航到 TrainViewTrainView 将提示您录制您的声音。您需要录制您的声音 kNumRecordings 次。

My Voice application with words to record

c. 一旦所有录音完成,应用程序将根据给定的数据训练模型。您将看到进度条指示训练进度。

Loading bar while the app is training

d. 训练完成后,您将看到以下屏幕

The app informs you training finished successfully!

e. 现在,点击 Infer 按钮导航到 InferViewInferView 将提示您录制您的声音。录制完成后,它将使用训练后的模型执行推理并显示推理结果。

My Voice application allows you to record and infer whether it's you or not.

就是这样!希望它正确识别了您的声音。

结论

恭喜!您已成功构建了一个 iOS 应用程序,该应用程序可以使用设备端训练技术训练简单的音频分类模型。您现在可以使用该应用程序在您自己的声音上训练模型,并使用训练后的模型执行推理。该应用程序也可在 GitHub 上获得,地址为 onnxruntime-training-examples

返回页首