构建 iOS 应用程序

在本教程中,我们将探讨如何构建一个结合了 ONNX Runtime 设备端训练解决方案的 iOS 应用程序。设备端训练指的是无需依赖云服务或外部服务器,直接在边缘设备上训练机器学习模型的过程。

在本教程中,我们将构建一个简单的说话人识别应用,该应用可以学习识别说话人的声音。我们将了解如何在设备上训练模型、导出训练好的模型,并使用训练好的模型执行推理。

应用程序的外观如下

application demo, with buttons for voice, train, and infer.

引言

我们将指导您完成构建 iOS 应用程序的过程,该应用程序可以使用设备端训练技术训练一个简单的音频分类模型。本教程展示了 迁移学习 技术,该技术利用在一个任务上训练模型获得的知识来提高模型在不同但相关任务上的性能。迁移学习允许我们将预训练模型学习到的知识或特征迁移到新任务,而不是从头开始学习过程。

在本教程中,我们将利用 wav2vec 模型,该模型已在大规模名人语音数据(如 VoxCeleb1)上进行预训练。我们将使用预训练模型从音频数据中提取特征,并训练一个二分类器来识别说话人。模型的初始层充当特征提取器,捕获音频数据的重要特征。只有模型的最后一层用于执行分类任务。

在本教程中,我们将

  • 使用 iOS 音频 API 捕获用于训练的音频数据
  • 在设备上训练模型
  • 导出训练好的模型
  • 使用导出的模型执行推理

目录

先决条件

要跟随本教程,您应该对机器学习和 iOS 开发有基本的了解。您的机器上还应该安装以下软件:

注意:完整的 iOS 应用程序也已发布到 onnxruntime-training-examples GitHub 仓库中。您可以克隆仓库并跟随教程操作。

生成训练 artifact

  1. 将模型导出为 ONNX。

    我们将从 HuggingFace 的预训练模型开始,并将其导出为 ONNX。wav2vec 模型已在 VoxCeleb1 上进行预训练,该数据集包含 1000 多个类别。对于我们的任务,我们只需要将音频分类到 2 个类别中。因此,我们修改模型的最后一层,使其输出 2 个类别。我们将使用 transformers 库加载模型并将其导出为 ONNX。

     from transformers import Wav2Vec2ForSequenceClassification, AutoConfig
     import torch
    
     # load config from the pretrained model
     config = AutoConfig.from_pretrained("superb/wav2vec2-base-superb-sid")
     model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid")
    
     # modify last layer to output 2 classes
     model.classifier = torch.nn.Linear(256, 2)
    
     #export model to ONNX
     dummy_input = torch.randn(1, 160000, requires_grad=True)
     torch.onnx.export(model, dummy_input, "wav2vec.onnx",input_names=["input"], output_names=["output"],
                       dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
    
  2. 定义可训练和不可训练参数

     import onnx
    
     # load the onnx model
     onnx_model = onnx.load("wav2vec.onnx")
    
     # Define the parameters that require gradients to be computed (trainable parameters) and
     # those that don't (frozen/non-trainable parameters)
     requires_grad = ["classifier.weight", "classifier.bias"]
     frozen_params = [
         param.name
         for param in onnx_model.graph.initializer
         if param.name not in requires_grad
     ]
    
  3. 生成训练 artifact。

    在本教程中,我们将使用 CrossEntropyLoss 损失函数和 AdamW 优化器。有关 artifact 生成的更多详细信息,请参阅此处

    由于模型还会输出 logits 和隐藏状态,我们将使用 onnxblock 来定义一个自定义损失函数,该函数从模型输出中提取 logits 并将其传递给 CrossEntropyLoss 函数。

     import onnxruntime.training.onnxblock as onnxblock
     from onnxruntime.training import artifacts
    
     # define the loss function
     class CustomCELoss(onnxblock.Block):
         def __init__(self):
             super().__init__()
             self.celoss = onnxblock.loss.CrossEntropyLoss()
    
    
         def build(self, logits, *args):
             return self.celoss(logits)
    
    
     # Generate the training artifacts
     artifacts.generate_artifacts(
         onnx_model,
         requires_grad=requires_grad,
         frozen_params=frozen_params,
         loss=CustomCELoss(),
         optimizer=artifacts.OptimType.AdamW,
         artifact_directory="artifacts",
     )
    
    

    就这样!训练 artifact 已经生成在 artifacts 目录中。artifact 已准备好部署到 iOS 设备上进行训练。

构建 iOS 应用程序

Xcode 设置

打开 Xcode,创建一个新项目。选择 iOS 作为平台,选择 App 作为模板。点击 Next。

Xcode Setup New Project

输入项目名称。这里我们将项目命名为“MyVoice”,但您可以随意命名。确保选择 SwiftUI 作为界面,选择 Swift 作为语言。然后,点击 Next。

Xcode Setup Project Name

选择您要保存项目的位置,然后点击 Create

现在,我们需要将 onnxruntime-training-objc pods 添加到项目中。我们将使用 CocoaPods 添加依赖。如果您尚未安装 CocoaPods,可以查看此处的安装说明。

安装 CocoaPods 后,导航到项目目录并运行以下命令创建 Podfile

pod init

这将在项目目录中创建一个 Podfile。打开 Podfile 并在 use_frameworks! 行之后添加以下行:

pod `onnxruntime-training-objc`, `~> 1.16.0`

保存 Podfile 并运行以下命令安装依赖项:

pod install

这将在项目目录中创建一个 MyVoice.xcworkspace 文件。在 Xcode 中打开 xcworkspace 文件。这将打开 Xcode 中的项目,并使 CocoaPods 依赖项可用。

现在,右键点击项目导航器中的“MyVoice”组,然后点击“New Group”创建一个名为 artifacts 的新组。将上一步生成的 artifact 拖放到 artifacts 组中。确保选择 Create folder referencesCopy items if needed 选项。这将把 artifact 添加到项目中。

接下来,右键点击“MyVoice”组并点击“New Group”创建一个名为 recordings 的新组。这个组将包含用于训练的音频录音。您可以通过在项目根目录运行 recording_gen.py 脚本来生成录音。或者,您也可以使用除您计划用于训练的说话人之外的任何其他说话人的语音录音。确保录音是单声道的,长度为 10 秒,格式为 .wav,采样率为 16KHz。此外,确保将录音命名为 other_0.wavother_1.wav 等,并将它们添加到 recordings 组。

项目结构应如下所示

Xcode Project Structure

应用程序概览

该应用程序将包含两个主要的 UI 视图:TrainViewInferViewTrainView 用于在设备上训练模型,InferView 用于使用训练好的模型执行推理。此外,还有一个 ContentView,它是应用程序的主页视图,包含导航到 TrainViewInferView 的按钮。

此外,我们还将创建一个 AudioRecorder 类来处理通过麦克风录制音频。它将录制 10 秒的音频,并将音频数据输出为 Data 对象,该对象可用于训练和推理。

我们将有一个 Trainer 类来处理模型的训练和导出。

最后,我们还将创建一个 VoiceIdentifier 类来处理使用训练好的模型进行推理。

训练模型

首先,我们将创建一个 Trainer 类,该类将处理模型的训练和导出。它将加载训练 artifact,在给定的音频上训练模型,并使用 ONNX Runtime 设备端训练 API 导出训练好的模型。有关 API 的详细文档可以在此处找到。

Trainer 将具有以下公共方法:

  • init() - 初始化训练会话并加载训练 artifact。
  • train(_ trainingData: [Data]) - 在给定的用户音频数据上训练模型。它将接收一个 Data 对象数组,其中每个 Data 对象代表用户的音频数据,并将其与一些预录制的音频数据一起用于训练模型。
  • exportModelForInference() - 导出用于推理的训练好的模型。
  1. 加载训练 artifact 并初始化训练会话

    要训练模型,我们首先需要加载 artifact,创建 ORTEnvORTTrainingSessionORTCheckpoint。这些对象将用于训练模型。我们将在 Trainer 类的 init 方法中创建这些对象。

     import Foundation
     import onnxruntime_training_objc
    
     class Trainer {
         private let ortEnv: ORTEnv
         private let trainingSession: ORTTrainingSession
         private let checkpoint: ORTCheckpoint
            
         enum TrainerError: Error {
             case Error(_ message: String)
         }
            
         init() throws {
             ortEnv = try ORTEnv(loggingLevel: ORTLoggingLevel.warning)
                
             // get path for artifacts
             guard let trainingModelPath = Bundle.main.path(forResource: "training_model", ofType: "onnx") else {
                 throw TrainerError.Error("Failed to find training model file.")
             }
                
             guard let evalModelPath = Bundle.main.path(forResource: "eval_model",ofType: "onnx") else {
                 throw TrainerError.Error("Failed to find eval model file.")
             }
                
             guard let optimizerPath = Bundle.main.path(forResource: "optimizer_model", ofType: "onnx") else {
                 throw TrainerError.Error("Failed to find optimizer model file.")
             }
                
             guard let checkpointPath = Bundle.main.path(forResource: "checkpoint", ofType: nil) else {
                 throw TrainerError.Error("Failed to find checkpoint file.")
             }
                
             checkpoint = try ORTCheckpoint(path: checkpointPath)
             trainingSession = try ORTTrainingSession(env: ortEnv, sessionOptions: ORTSessionOptions(), checkpoint: checkpoint, trainModelPath: trainingModelPath, evalModelPath: evalModelPath, optimizerModelPath: optimizerPath)
         }
     }
    
  2. 训练模型

    a. 在训练模型之前,我们首先需要从前面创建的 wav 文件中提取数据。这是一个简单的函数,用于从 wav 文件中提取数据。

    private func getDataFromWavFile(fileName: String) throws -> (AVAudioBuffer, Data) {
        guard let fileUrl = Bundle.main.url(forResource: fileName, withExtension:"wav") else {
            throw TrainerError.Error("Failed to find wav file: \(fileName).")
        }
            
        let audioFile = try AVAudioFile(forReading: fileUrl)
        let format = audioFile.processingFormat
        let totalFrames = AVAudioFrameCount(audioFile.length)
    
        guard let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: totalFrames) else {
            throw TrainerError.Error("Failed to create audio buffer.")
        }
            
        try audioFile.read(into: buffer)
            
        guard let floatChannelData = buffer.floatChannelData else {
            throw TrainerError.Error("Failed to get float channel data.")
        }
            
        let data = Data(
            bytesNoCopy: floatChannelData[0],
            count: Int(buffer.frameLength) * MemoryLayout<Float>.size,
            deallocator: .none
        )
        return (buffer, data)
    }
    

    b. TrainingSession.trainStep 函数负责训练模型。它接受输入数据和标签,并返回损失。输入数据作为 ORTValue 对象传递给 ONNX Runtime。因此,我们需要将输入音频 Data 对象和标签转换为 ORTValue

    private func getORTValue(dataList: [Data]) throws -> ORTValue {
        let tensorData = NSMutableData()
        dataList.forEach {data in tensorData.append(data)}
        let inputShape: [NSNumber] = [dataList.count as NSNumber, dataList[0].count / MemoryLayout<Float>.stride as NSNumber]
            
        return try ORTValue(
            tensorData: tensorData, elementType: ORTTensorElementDataType.float, shape: inputShape
        )
    }
        
    private func getORTValue(labels: [Int64]) throws -> ORTValue {
        let tensorData = NSMutableData(bytes: labels, length: labels.count * MemoryLayout<Int64>.stride)
        let inputShape: [NSNumber] = [labels.count as NSNumber]
            
        return try ORTValue (
            tensorData: tensorData, elementType: ORTTensorElementDataType.int64, shape: inputShape
        )
    }
    

    c. 我们已准备好编写 trainStep 函数,该函数接受一批输入数据和标签,并对给定批次执行一个训练步骤。

    func trainStep(inputData: [Data], labels: [Int64]) throws  {
        let inputs = [try getORTValue(dataList: inputData), try getORTValue(labels: labels)]
        try trainingSession.trainStep(withInputValues: inputs)
            
        // update the model params
        try trainingSession.optimizerStep()
            
        // reset the gradients
        try trainingSession.lazyResetGrad()
    }
    

    d. 最后,我们拥有编写训练循环所需的一切。kNumOtherRecordings 代表我们在前面创建的 recordings 目录中有多少条录音。kNumEpochs 代表我们希望在给定数据上训练模型的 epoch 数量。kUserIndexkOtherIndex 分别代表用户和其他录音的标签。

    我们还有一个 progressCallback,它将在每个训练步骤后被调用。我们将使用此回调函数更新 UI 中的进度条。

    private let kNumOtherRecordings: Int = 20
    private let kNumEpochs: Int = 3
        
    let kUserIndex: Int64 = 1
    let kOtherIndex: Int64 = 0
    
    func train(_ trainingData: [Data], progressCallback: @escaping (Double) -> Void) throws {
        let numRecordings = trainingData.count
        var otherRecordings = Array(0..<kNumOtherRecordings)
        for e in 0..<kNumEpochs {
            print("Epoch: \(e)")
            otherRecordings.shuffle()
            let otherData = otherRecordings.prefix(numRecordings)
                
            for i in 0..<numRecordings {
                let (buffer, wavFileData) = try getDataFromWavFile(fileName: "other_\(otherData[i])")
                try trainStep(inputData: [trainingData[i], wavFileData], labels: [kUserIndex, kOtherIndex])
                print("finished training on recording \(i)")
                    
                let progress = Double((e * numRecordings) + i + 1) / Double(kNumEpochs * numRecordings)
                progressCallback(progress)
            }
        }
            
    }
    
  3. 导出训练好的模型

    我们可以使用 ORTTrainingSession 类的 exportModelForInference 方法导出训练好的模型。该方法接受模型应导出的路径和模型的输出名称。

    在这里,我们将模型导出到应用程序的 Library 目录。导出的模型将用于推理。

    func exportModelForInference() throws {
        guard let libraryDirectory = FileManager.default.urls(for: .libraryDirectory, in: .userDomainMask).first else {
            throw TrainerError.Error("Failed to find library directory ")
        }
            
        let modelPath = libraryDirectory.appendingPathComponent("inference_model.onnx").path
        try trainingSession.exportModelForInference(withOutputPath: modelPath, graphOutputNames: ["output"])
    }
    

您可以在此处找到 Trainer 类的完整实现。

使用训练好的模型进行推理

VoiceIdentifier 类将处理使用训练好的模型进行推理。它将加载训练好的模型并在给定的音频数据上执行推理。该类将有一个 evaluate(inputData: Data) -> Result<(Bool, Float), Error> 方法,该方法将接收音频数据并返回推理结果。结果将是一个 (Bool, Float) 的元组,其中第一个元素表示音频是否被识别为用户,第二个元素表示预测的置信度分数。

首先,我们使用 ORTSession 对象加载训练好的模型。

class VoiceIdentifier {
    
    private let ortEnv : ORTEnv
    private let ortSession: ORTSession
    private let kThresholdProbability: Float = 0.80
    
    enum VoiceIdentifierError: Error {
        case Error(_ message: String)
    }
    
    init() throws {
        ortEnv = try ORTEnv(loggingLevel: ORTLoggingLevel.warning)

        guard let libraryDirectory = FileManager.default.urls(for: .libraryDirectory, in: .userDomainMask).first else {
            throw VoiceIdentifierError.Error("Failed to find library directory ")
        }
        let modelPath = libraryDirectory.appendingPathComponent("inference_model.onnx").path

        if !FileManager.default.fileExists(atPath: modelPath) {
            throw VoiceIdentifierError.Error("Failed to find inference model file.")
        }
        ortSession = try ORTSession(env: ortEnv, modelPath: modelPath, sessionOptions: nil)
    }
}

接下来,我们将编写 evaluate 方法。首先,它将获取音频数据并将其转换为 ORTValue。然后,它将使用模型执行推理。最后,它将从输出中提取 logits 并应用 softmax 以获得概率。

    private func isUser(logits: [Float]) -> Float {
        // apply softMax
        let maxInput = logits.max() ?? 0.0
        let expValues = logits.map { exp($0 - maxInput) } // Calculate e^(x - maxInput) for each element
        let expSum = expValues.reduce(0, +) // Sum of all e^(x - maxInput) values
        
        return expValues.map { $0 / expSum }[1] // Calculate the softmax probabilities
    }
    
    func evaluate(inputData: Data) -> Result<(Bool, Float), Error> {
        
        return Result<(Bool, Float), Error> { () -> (Bool, Float) in
            
            // convert input data to ORTValue
            let inputShape: [NSNumber] = [1, inputData.count / MemoryLayout<Float>.stride as NSNumber]
            
            let input = try ORTValue(
                tensorData: NSMutableData(data: inputData),
                elementType: ORTTensorElementDataType.float,
                shape: inputShape)
            
            let outputs = try ortSession.run(
                withInputs: ["input": input],
                outputNames: ["output"],
                runOptions: nil)
            
            guard let output = outputs["output"] else {
                throw VoiceIdentifierError.Error("Failed to get model output from inference.")
            }
            
            let outputData = try output.tensorData() as Data
            let probUser = outputData.withUnsafeBytes { (buffer: UnsafeRawBufferPointer) -> Float in
                let floatBuffer = buffer.bindMemory(to: Float.self)
                let logits = Array(UnsafeBufferPointer(start: floatBuffer.baseAddress, count: outputData.count/MemoryLayout<Float>.stride))
                return isUser(logits: logits)
            }
            
            return (probUser >= kThresholdProbability, probUser)
        }
    }

您可以在此处找到 VoiceIdentifier 类的完整实现。

录制音频

我们将使用 AudioRecorder 类通过麦克风录制音频。它将录制 10 秒音频,并将音频数据输出为 Data 对象,该对象可用于训练和推理。我们将使用 AVFoundation 框架访问麦克风并录制音频。将有一个公共方法 record(callback: @escaping RecordingDoneCallback),它将录制音频并在录制完成后调用回调函数并传入音频数据。

import AVFoundation
import Foundation

private let kSampleRate: Int = 16000
private let kRecordingDuration: TimeInterval = 10

class AudioRecorder {
    typealias RecordResult = Result<Data, Error>
    typealias RecordingDoneCallback = (RecordResult) -> Void
    
    enum AudioRecorderError: Error {
        case Error(message: String)
    }
    
    func record(callback: @escaping RecordingDoneCallback) {
        let session = AVAudioSession.sharedInstance()
        session.requestRecordPermission { allowed in
            do {
                guard allowed else {
                    throw AudioRecorderError.Error(message: "Recording permission denied.")
                }
                
                try session.setCategory(.record)
                try session.setActive(true)
                
                let tempDir = FileManager.default.temporaryDirectory
                
                let recordingUrl = tempDir.appendingPathComponent("recording.wav")
                
                let formatSettings: [String: Any] = [
                    AVFormatIDKey: kAudioFormatLinearPCM,
                    AVSampleRateKey: kSampleRate,
                    AVNumberOfChannelsKey: 1,
                    AVLinearPCMBitDepthKey: 16,
                    AVLinearPCMIsBigEndianKey: false,
                    AVLinearPCMIsFloatKey: false,
                    AVEncoderAudioQualityKey: AVAudioQuality.high.rawValue,
                ]
                
                let recorder = try AVAudioRecorder(url: recordingUrl, settings: formatSettings)
                self.recorder = recorder
                
                let delegate = RecorderDelegate(callback: callback)
                recorder.delegate = delegate
                self.recorderDelegate = delegate
                
                guard recorder.record(forDuration: kRecordingDuration) else {
                    throw AudioRecorderError.Error(message: "Failed to record.")
                }
                
                // control should resume in recorder.delegate.audioRecorderDidFinishRecording()
            } catch {
                callback(.failure(error))
            }
        }
    }
    
    private var recorderDelegate: RecorderDelegate?
    private var recorder: AVAudioRecorder?
    
    private class RecorderDelegate: NSObject, AVAudioRecorderDelegate {
        private let callback: RecordingDoneCallback
        
        init(callback: @escaping RecordingDoneCallback) {
            self.callback = callback
        }
        
        func audioRecorderDidFinishRecording(
            _ recorder: AVAudioRecorder,
            successfully flag: Bool
        ) {
            let recordResult = RecordResult { () -> Data in
                guard flag else {
                    throw AudioRecorderError.Error(message: "Recording was unsuccessful.")
                }
                
                let recordingUrl = recorder.url
                let recordingFile = try AVAudioFile(forReading: recordingUrl)
                
                guard
                    let format = AVAudioFormat(
                        commonFormat: .pcmFormatFloat32,
                        sampleRate: recordingFile.fileFormat.sampleRate,
                        channels: 1,
                        interleaved: false)
                else {
                    throw AudioRecorderError.Error(message: "Failed to create audio format.")
                }
                
                guard
                    let recordingBuffer = AVAudioPCMBuffer(
                        pcmFormat: format,
                        frameCapacity: AVAudioFrameCount(recordingFile.length))
                else {
                    throw AudioRecorderError.Error(message: "Failed to create audio buffer.")
                }
                
                try recordingFile.read(into: recordingBuffer)
                
                guard let recordingFloatChannelData = recordingBuffer.floatChannelData else {
                    throw AudioRecorderError.Error(message: "Failed to get float channel data.")
                }
                
                return Data(bytes: recordingFloatChannelData[0], count: Int(recordingBuffer.frameLength) * MemoryLayout<Float>.size)
               
            }
            
            callback(recordResult)
        }
        
        func audioRecorderEncodeErrorDidOccur(
            _ recorder: AVAudioRecorder,
            error: Error?
        ) {
            if let error = error {
                callback(.failure(error))
            } else {
                callback(.failure(AudioRecorderError.Error(message: "Encoding was unsuccessful.")))
            }
        }
    }
}

训练视图 (Train View)

TrainView 将用于在用户的声音上训练模型。首先,它将提示用户录制 kNumRecordings 次他们的声音。然后,它将在用户的声音和一些预先录制的其他说话人的声音上训练模型。最后,它将导出训练好的模型用于推理。

import SwiftUI

struct TrainView: View {
    
    enum ViewState {
        case recordingTrainingData, trainingInProgress, trainingComplete
    }
    
    private static let sentences = [
        "In the embrace of nature's beauty, I find peace and tranquility. The gentle rustling of leaves soothes my soul, and the soft sunlight kisses my skin. As I breathe in the fresh air, I am reminded of the interconnectedness of all living things, and I feel a sense of oneness with the world around me.",
        "Under the starlit sky, I gaze in wonder at the vastness of the universe. Each twinkle represents a story yet untold, a dream yet to be realized. With every new dawn, I am filled with hope and excitement for the opportunities that lie ahead. I embrace each day as a chance to grow, to learn, and to create beautiful memories.",
        "A warm hug from a loved one is a precious gift that warms my heart. In that tender embrace, I feel a sense of belonging and security. Laughter and tears shared with dear friends create a bond that withstands the test of time. These connections enrich my life and remind me of the power of human relationships.",
        "Life's journey is like a beautiful melody, with each note representing a unique experience. As I take each step, I harmonize with the rhythm of existence. Challenges may come my way, but I face them with resilience and determination, knowing they are opportunities for growth and self-discovery.",
        "With every page turned in a book, I open the door to new worlds and ideas. The written words carry the wisdom of countless souls, and I am humbled by the knowledge they offer. In stories, I find a mirror to my own experiences and a beacon of hope for a better tomorrow.",
        "Life's trials may bend me, but they will not break me. Through adversity, I discover the strength within my heart. Each obstacle is a chance to learn, to evolve, and to emerge as a better version of myself. I am grateful for every lesson, for they shape me into the person I am meant to be.",
        "The sky above is an ever-changing canvas of colors and clouds. In its vastness, I realize how small I am in the grand scheme of things, and yet, I know my actions can ripple through the universe. As I walk this Earth, I seek to leave behind a positive impact and a legacy of love and compassion.",
        "In the stillness of meditation, I connect with the depth of my soul. The external noise fades away, and I hear the whispers of my inner wisdom. With each breath, I release tension and embrace serenity. Meditation is my sanctuary, a place where I can find clarity and renewed energy.",
        "Kindness is a chain reaction that spreads like wildfire. A simple act of compassion can brighten someone's day and inspire them to pay it forward. Together, we can create a wave of goodness that knows no boundaries, reaching even the farthest corners of the world.",
        "As the sun rises on a new day, I am filled with gratitude for the gift of life. Every moment is a chance to make a difference, to love deeply, and to embrace joy. I welcome the adventures that await me and eagerly embrace the mysteries yet to be uncovered."
    ]

    
    private let kNumRecordings = 5
    private let audioRecorder = AudioRecorder()
    private let trainer = try! Trainer()
    
    @State private var trainingData: [Data] = []
    
    @State private var viewState: ViewState = .recordingTrainingData
    @State private var readyToRecord: Bool = true
    @State private var trainingProgress: Double = 0.0
    
    private func recordVoice() {
        audioRecorder.record { recordResult in
           switch recordResult {
           case .success(let recordingData):
               trainingData.append(recordingData)
               print("Successfully completed Recording")
           case .failure(let error):
               print("Error: \(error)")
            }
            
            readyToRecord = true
            
            if trainingData.count == kNumRecordings  {
                viewState = .trainingInProgress
                trainAndExportModel()
            }
        }
    }
    
    private func updateProgressBar(progress: Double) {
        DispatchQueue.main.async {
            trainingProgress = progress
        }
    }
    
    private func trainAndExportModel() {
        Task {
            do {
                try trainer.train(trainingData, progressCallback: updateProgressBar)
                try trainer.exportModelForInference()
                   
                DispatchQueue.main.async {
                    viewState = .trainingComplete
                    print("Training is complete")
                }
            } catch {
                DispatchQueue.main.async {
                    viewState = .trainingComplete
                    print("Training Failed: \(error)")
                }
            }
        }
    }
    
    
    var body: some View {
        VStack {
           
            switch viewState {
            case .recordingTrainingData:
                Text("\(trainingData.count + 1) of \(kNumRecordings)")
                    .font(.caption)
                    .foregroundColor(.secondary)
                    .padding()
                
                ProgressView(value: Double(trainingData.count),
                             total: Double(kNumRecordings))
                .progressViewStyle(LinearProgressViewStyle(tint: .purple))
                .frame(height: 10)
                .cornerRadius(5)
                
                Spacer()
                
                Text(TrainView.sentences[trainingData.count % TrainView.sentences.count])
                    .font(.body)
                    .padding()
                    .multilineTextAlignment(.center)
                    .fontDesign(.monospaced)
                
                Spacer()
                
                ZStack(alignment: .center) {
                    Image(systemName: "mic.fill")
                        .resizable()
                        .aspectRatio(contentMode: .fit)
                        .frame(width: 100, height: 100)
                        .foregroundColor( readyToRecord ? .gray: .red)
                        .transition(.scale)
                        .animation(.easeIn, value: 1)
                }
                
                Spacer()
                
                Button(action: {
                    readyToRecord = false
                    recordVoice()
                }) {
                    Text(readyToRecord ? "Record" : "Recording ...")
                        .font(.title)
                        .padding()
                        .background(readyToRecord ? .green : .gray)
                        .foregroundColor(.white)
                        .cornerRadius(10)
                }.disabled(!readyToRecord)
                    
            case .trainingInProgress:
                VStack {
                    Spacer()
                    ProgressView(value: trainingProgress,
                                 total: 1.0,
                                 label: {Text("Training")},
                                 currentValueLabel: {Text(String(format: "%.0f%%", trainingProgress * 100))})
                    .padding()
                    Spacer()
                }
                    
            case .trainingComplete:
                Spacer()
                Text("Training successfully finished!")
                    .font(.title)
                    .padding()
                    .multilineTextAlignment(.center)
                    .fontDesign(.monospaced)
                
                Spacer()
                NavigationLink(destination: InferView()) {
                    Text("Infer")
                        .font(.title)
                        .padding()
                        .background(.purple)
                        .foregroundColor(.white)
                        .cornerRadius(10)
                }
                .padding(.leading, 20)
            }
            
            Spacer()
        }
        .padding()
        .navigationTitle("Train")
    }
}

struct TrainView_Previews: PreviewProvider {
    static var previews: some View {
        TrainView()
    }
}

您可以在此处找到 TrainView 的完整实现。

推理视图 (Infer View)

最后,我们将创建 InferView,用于使用训练好的模型进行推理。它将提示用户录制他们的声音,并使用训练好的模型进行推理。然后,它将显示推理结果。

import SwiftUI

struct InferView: View {
    
    enum InferResult {
        case user, other, notSet
    }
    
    private let audioRecorder = AudioRecorder()
    
    @State private var voiceIdentifier: VoiceIdentifier? = nil
    @State private var readyToRecord: Bool = true
    
    @State private var inferResult: InferResult = InferResult.notSet
    @State private var probUser: Float = 0.0
    
    @State private var showAlert = false
    @State private var alertMessage = ""

    private func recordVoice() {
        audioRecorder.record { recordResult in
            let recognizeResult = recordResult.flatMap { recordingData in
                return voiceIdentifier!.evaluate(inputData: recordingData)
            }
            endRecord(recognizeResult)
        }
    }
    
    private func endRecord(_ result: Result<(Bool, Float), Error>) {
        DispatchQueue.main.async {
            switch result {
            case .success(let (isMatch, confidence)):
                print("Your Voice with confidence: \(isMatch),  \(confidence)")
                inferResult = isMatch ? .user : .other
                probUser = confidence
            case .failure(let error):
                print("Error: \(error)")
            }
            readyToRecord = true
        }
    }
    
    var body: some View {
        VStack {
            Spacer()
            
            ZStack(alignment: .center) {
                Image(systemName: "mic.fill")
                    .resizable()
                    .aspectRatio(contentMode: .fit)
                    .frame(width: 100, height: 100)
                    .foregroundColor( readyToRecord ? .gray: .red)
                    .transition(.scale)
                    .animation(.easeInOut, value: 1)
            }
            
            Spacer()
            
            Button(action: {
                readyToRecord = false
                recordVoice()
            }) {
                Text(readyToRecord ? "Record" : "Recording ...")
                    .font(.title)
                    .padding()
                    .background(readyToRecord ? .green : .gray)
                    .foregroundColor(.white)
                    .cornerRadius(10)
                
            }.disabled(voiceIdentifier == nil || !readyToRecord)
                .opacity(voiceIdentifier == nil ? 0.5: 1.0)
            
            if  inferResult != .notSet {
                Spacer()
                ZStack (alignment: .center) {
                    Image(systemName: inferResult == .user ? "person.crop.circle.fill.badge.checkmark": "person.crop.circle.fill.badge.xmark")
                        .resizable()
                        .aspectRatio(contentMode: .fit)
                        .frame(width: 100, height: 100)
                        .foregroundColor(inferResult == .user ? .green : .red)
                        .animation(.easeInOut, value: 2)
                    
                }
                
                Text("Probability of User : \(String(format: "%.2f", probUser*100.0))%")
                    .multilineTextAlignment(.center)
                    .fontDesign(.monospaced)
            }
            
            Spacer()
        }
        .padding()
        .navigationTitle("Infer")
        .onAppear {
            do {
                voiceIdentifier = try  VoiceIdentifier()
                
            } catch {
                alertMessage = "Error initializing inference session, make sure that training is completed: \(error)"
                showAlert = true
            }
            
        }
        .alert(isPresented: $showAlert) {
            Alert(title: Text("Error"), message: Text(alertMessage), dismissButton: .default(Text("OK")))
        }
    }
}

struct InferView_Previews: PreviewProvider {
    static var previews: some View {
        InferView()
    }
}

您可以在此处找到 InferView 的完整实现。

ContentView

最后,我们将更新默认的 ContentView,使其包含导航到 TrainViewInferView 的按钮。

import SwiftUI

struct ContentView: View {
    var body: some View {
        NavigationView {
            VStack {
                
                Text("My Voice")
                    .font(.largeTitle)
                    .padding(.top, 50)
                
                Spacer()
                
                ZStack(alignment: .center) {
                    Image(systemName: "waveform.circle.fill")
                        .resizable()
                        .aspectRatio(contentMode: .fit)
                        .frame(width: 100, height: 100)
                        .foregroundColor(.purple)
                }
                
                Spacer()
                
                HStack {
                    NavigationLink(destination: TrainView()) {
                        Text("Train")
                            .font(.title)
                            .padding()
                            .background(Color.purple)
                            .foregroundColor(.white)
                            .cornerRadius(10)
                    }
                    .padding(.trailing, 20)
                    
                    NavigationLink(destination: InferView()) {
                        Text("Infer")
                            .font(.title)
                            .padding()
                            .background(.purple)
                            .foregroundColor(.white)
                            .cornerRadius(10)
                    }
                    .padding(.leading, 20)
                }
                
                Spacer()
            }
            .padding()
        }
    }
}

struct ContentView_Previews: PreviewProvider {
    static var previews: some View {
        ContentView()
    }
}

您可以在此处找到 ContentView 的完整实现。

运行 iOS 应用程序

现在,我们准备好运行应用程序了。您可以在模拟器或设备上运行应用程序。您可以在此处找到有关在模拟器和设备上运行应用程序的更多信息。

a. 现在,当您运行应用程序时,您应该看到以下屏幕

My Voice application with Train and Infer buttons

b. 接下来,点击 Train 按钮导航到 TrainViewTrainView 将提示您录制您的声音。您需要录制 kNumRecordings 次您的声音。

My Voice application with words to record

c. 一旦所有录制完成,应用程序将使用给定的数据训练模型。您将看到指示训练进度的进度条。

Loading bar while the app is training

d. 训练完成后,您将看到以下屏幕

The app informs you training finished successfully!

e. 现在,点击 Infer 按钮导航到 InferViewInferView 将提示您录制您的声音。录制完成后,它将使用训练好的模型执行推理并显示推理结果。

My Voice application allows you to record and infer whether it's you or not.

就这样!希望它正确识别了您的声音。

结论

恭喜!您已成功构建了一个 iOS 应用程序,该应用程序可以使用设备端训练技术训练一个简单的音频分类模型。您现在可以使用该应用程序训练您的声音模型,并使用训练好的模型进行推理。该应用程序也已在 GitHub 上发布,地址是 onnxruntime-training-examples

返回顶部