构建 iOS 应用程序
在本教程中,我们将探索如何构建一个集成 ONNX 运行时的设备端训练解决方案的 iOS 应用程序。设备端训练是指直接在边缘设备上训练机器学习模型的过程,而无需依赖云服务或外部服务器。
在本教程中,我们将构建一个简单的说话人识别应用程序,学习识别说话人的声音。我们将了解如何在设备上训练模型、导出训练后的模型,以及使用训练后的模型执行推理。
以下是应用程序的外观
简介
我们将指导您完成构建 iOS 应用程序的过程,该应用程序可以使用设备端训练技术训练简单的音频分类模型。本教程展示了 迁移学习
技术,其中从一个任务的模型训练中获得的知识被用来提高模型在不同但相关任务上的性能。迁移学习不是从头开始学习过程,而是允许我们将预训练模型学习到的知识或特征迁移到新任务。
在本教程中,我们将利用 wav2vec
模型,该模型已在大型名人语音数据(如 VoxCeleb1
)上进行了训练。我们将使用预训练模型从音频数据中提取特征,并训练二元分类器来识别说话人。模型的初始层充当特征提取器,捕获音频数据的重要特征。只有模型的最后一层被训练来执行分类任务。
在本教程中,我们将
- 使用 iOS 音频 API 捕获音频数据以进行训练
- 在设备上训练模型
- 导出训练后的模型
- 使用导出的模型执行推理
目录
先决条件
要学习本教程,您应该对机器学习和 iOS 开发有基本的了解。您还应该在您的机器上安装以下内容
注意: 整个 iOS 应用程序也在
onnxruntime-training-examples
GitHub 存储库上提供。您可以克隆存储库并按照教程进行操作。
生成训练工件
-
将模型导出到 ONNX。
我们将从 HuggingFace 的预训练模型开始,并将其导出到 ONNX。
wav2vec
模型已在VoxCeleb1
上进行预训练,其中包含 1000 多个类别。对于我们的任务,我们只需要将音频分类为 2 个类别。因此,我们将模型的最后一层更改为输出 2 个类别。我们将使用transformers
库加载模型并将其导出到 ONNX。from transformers import Wav2Vec2ForSequenceClassification, AutoConfig import torch # load config from the pretrained model config = AutoConfig.from_pretrained("superb/wav2vec2-base-superb-sid") model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid") # modify last layer to output 2 classes model.classifier = torch.nn.Linear(256, 2) #export model to ONNX dummy_input = torch.randn(1, 160000, requires_grad=True) torch.onnx.export(model, dummy_input, "wav2vec.onnx",input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}})
-
定义可训练和不可训练的参数
import onnx # load the onnx model onnx_model = onnx.load("wav2vec.onnx") # Define the parameters that require gradients to be computed (trainable parameters) and # those that don't (frozen/non-trainable parameters) requires_grad = ["classifier.weight", "classifier.bias"] frozen_params = [ param.name for param in onnx_model.graph.initializer if param.name not in requires_grad ]
-
生成训练工件。
在本教程中,我们将使用
CrossEntropyLoss
损失和AdamW
优化器。有关工件生成的更多详细信息,请参见此处。由于模型还输出 logits 和隐藏状态,我们将使用
onnxblock
定义一个自定义损失函数,该函数从模型输出中提取 logits 并将其传递给CrossEntropyLoss
函数。import onnxruntime.training.onnxblock as onnxblock from onnxruntime.training import artifacts # define the loss function class CustomCELoss(onnxblock.Block): def __init__(self): super().__init__() self.celoss = onnxblock.loss.CrossEntropyLoss() def build(self, logits, *args): return self.celoss(logits) # Generate the training artifacts artifacts.generate_artifacts( onnx_model, requires_grad=requires_grad, frozen_params=frozen_params, loss=CustomCELoss(), optimizer=artifacts.OptimType.AdamW, artifact_directory="artifacts", )
就这样!训练工件已在
artifacts
目录中生成。这些工件已准备好部署到 iOS 设备进行训练。
构建 iOS 应用程序
Xcode 设置
打开 Xcode,并创建一个新项目。选择 iOS
作为平台,App
作为模板。点击“下一步”。
输入项目名称。这里,我们将项目命名为“MyVoice”,但您可以将其命名为您喜欢的任何名称。确保选择 SwiftUI
作为界面,Swift
作为语言。然后,点击“下一步”。
选择您要保存项目的位置,然后点击 创建
。
现在,我们需要将 onnxruntime-training-objc
pods 添加到项目中。我们将使用 CocoaPods
添加依赖项。如果您没有安装 CocoaPods,您可以查看 此处 的安装说明。
安装 CocoaPods 后,导航到项目目录并运行以下命令以创建 Podfile
pod init
这将在项目目录中创建一个 Podfile
。打开 Podfile
并在 use_frameworks!
行后添加以下行
pod `onnxruntime-training-objc`, `~> 1.16.0`
保存 Podfile
并运行以下命令以安装依赖项
pod install
这将在项目目录中创建一个 MyVoice.xcworkspace
文件。在 Xcode 中打开 xcworkspace
文件。这将使用 CocoaPods 依赖项在 Xcode 中打开项目。
现在,右键单击项目导航器中的“MyVoice”组,然后单击“新建组”以在项目中创建一个名为 artifacts
的新组。将上一节中生成的工件拖放到 artifacts
组中。确保选择 创建文件夹引用
和 如果需要则复制项目
选项。这将把工件添加到项目中。
接下来,右键单击“MyVoice”组,然后单击“新建组”以在项目中创建一个名为 recordings
的新组。此组将包含用于训练的录音。您可以通过运行项目根目录下的 recording_gen.py
脚本来生成录音。或者,您也可以使用除您计划用于训练的说话人之外的任何其他说话人的录音。确保录音是单声道,长度为 10 秒,.wav 格式,采样率为 16KHz。此外,请确保将录音命名为 other_0.wav
、other_1.wav
等,并将它们添加到 recordings
组中。
项目结构应如下所示
应用程序概述
该应用程序将包含两个主要的 UI 视图:TrainView
和 InferView
。TrainView
用于在设备上训练模型,InferView
用于使用训练后的模型执行推理。此外,还有 ContentView
,它是应用程序的主视图,包含导航到 TrainView
和 InferView
的按钮。
此外,我们还将创建一个 AudioRecorder
类来处理通过麦克风录制音频。它将录制 10 秒的音频,并将音频数据作为 Data
对象输出,该对象可用于训练和推理目的。
我们将有一个 Trainer
类,它将处理模型的训练和导出。
最后,我们还将创建一个 VoiceIdentifier
类,它将处理使用训练后的模型进行推理。
训练模型
首先,我们将创建一个 Trainer
类,它将处理模型的训练和导出。它将加载训练工件,在给定的音频上训练模型,并使用 ONNX 运行时设备端训练 API 导出训练后的模型。有关 API 的详细文档,请参见此处。
Trainer
类将具有以下公共方法
init()
- 初始化训练会话并加载训练工件。train(_ trainingData: [Data])
- 在给定的用户音频数据上训练模型。它将接收一个Data
对象数组,其中每个Data
对象代表用户的音频数据,并将其与一些预先录制的音频数据一起使用来训练模型。exportModelForInference()
- 导出训练后的模型以进行推理。
-
加载训练工件并初始化训练会话
要训练模型,我们首先需要加载工件,创建
ORTEnv
、ORTTrainingSession
和ORTCheckpoint
。这些对象将用于训练模型。我们将在Trainer
类的init
方法中创建这些对象。import Foundation import onnxruntime_training_objc class Trainer { private let ortEnv: ORTEnv private let trainingSession: ORTTrainingSession private let checkpoint: ORTCheckpoint enum TrainerError: Error { case Error(_ message: String) } init() throws { ortEnv = try ORTEnv(loggingLevel: ORTLoggingLevel.warning) // get path for artifacts guard let trainingModelPath = Bundle.main.path(forResource: "training_model", ofType: "onnx") else { throw TrainerError.Error("Failed to find training model file.") } guard let evalModelPath = Bundle.main.path(forResource: "eval_model",ofType: "onnx") else { throw TrainerError.Error("Failed to find eval model file.") } guard let optimizerPath = Bundle.main.path(forResource: "optimizer_model", ofType: "onnx") else { throw TrainerError.Error("Failed to find optimizer model file.") } guard let checkpointPath = Bundle.main.path(forResource: "checkpoint", ofType: nil) else { throw TrainerError.Error("Failed to find checkpoint file.") } checkpoint = try ORTCheckpoint(path: checkpointPath) trainingSession = try ORTTrainingSession(env: ortEnv, sessionOptions: ORTSessionOptions(), checkpoint: checkpoint, trainModelPath: trainingModelPath, evalModelPath: evalModelPath, optimizerModelPath: optimizerPath) } }
-
训练模型
a. 在训练模型之前,我们首先需要从我们在前面部分创建的 wav 文件中提取数据。这是一个简单的函数,它将从 wav 文件中提取数据。
private func getDataFromWavFile(fileName: String) throws -> (AVAudioBuffer, Data) { guard let fileUrl = Bundle.main.url(forResource: fileName, withExtension:"wav") else { throw TrainerError.Error("Failed to find wav file: \(fileName).") } let audioFile = try AVAudioFile(forReading: fileUrl) let format = audioFile.processingFormat let totalFrames = AVAudioFrameCount(audioFile.length) guard let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: totalFrames) else { throw TrainerError.Error("Failed to create audio buffer.") } try audioFile.read(into: buffer) guard let floatChannelData = buffer.floatChannelData else { throw TrainerError.Error("Failed to get float channel data.") } let data = Data( bytesNoCopy: floatChannelData[0], count: Int(buffer.frameLength) * MemoryLayout<Float>.size, deallocator: .none ) return (buffer, data) }
b.
TrainingSession.trainStep
函数负责训练模型。它接收输入数据和标签并返回损失。输入作为ORTValue
对象传递给 ONNX 运行时。因此,我们需要将输入的音频Data
对象和标签转换为ORTValue
。private func getORTValue(dataList: [Data]) throws -> ORTValue { let tensorData = NSMutableData() dataList.forEach {data in tensorData.append(data)} let inputShape: [NSNumber] = [dataList.count as NSNumber, dataList[0].count / MemoryLayout<Float>.stride as NSNumber] return try ORTValue( tensorData: tensorData, elementType: ORTTensorElementDataType.float, shape: inputShape ) } private func getORTValue(labels: [Int64]) throws -> ORTValue { let tensorData = NSMutableData(bytes: labels, length: labels.count * MemoryLayout<Int64>.stride) let inputShape: [NSNumber] = [labels.count as NSNumber] return try ORTValue ( tensorData: tensorData, elementType: ORTTensorElementDataType.int64, shape: inputShape ) }
c. 我们已准备好编写
trainStep
函数,该函数接收一批输入数据和标签,并在给定的批次上执行一个训练步骤。func trainStep(inputData: [Data], labels: [Int64]) throws { let inputs = [try getORTValue(dataList: inputData), try getORTValue(labels: labels)] try trainingSession.trainStep(withInputValues: inputs) // update the model params try trainingSession.optimizerStep() // reset the gradients try trainingSession.lazyResetGrad() }
d. 最后,我们拥有编写训练循环所需的一切。在这里,
kNumOtherRecordings
表示我们在之前创建的recordings
目录中有多少录音。kNumEpochs
表示我们希望在给定数据上训练模型的轮次。kUserIndex
和kOtherIndex
分别表示用户和其他录音的标签。我们还有一个
progressCallback
,它将在每个训练步骤后被调用。我们将使用此回调来更新 UI 中的进度条。private let kNumOtherRecordings: Int = 20 private let kNumEpochs: Int = 3 let kUserIndex: Int64 = 1 let kOtherIndex: Int64 = 0 func train(_ trainingData: [Data], progressCallback: @escaping (Double) -> Void) throws { let numRecordings = trainingData.count var otherRecordings = Array(0..<kNumOtherRecordings) for e in 0..<kNumEpochs { print("Epoch: \(e)") otherRecordings.shuffle() let otherData = otherRecordings.prefix(numRecordings) for i in 0..<numRecordings { let (buffer, wavFileData) = try getDataFromWavFile(fileName: "other_\(otherData[i])") try trainStep(inputData: [trainingData[i], wavFileData], labels: [kUserIndex, kOtherIndex]) print("finished training on recording \(i)") let progress = Double((e * numRecordings) + i + 1) / Double(kNumEpochs * numRecordings) progressCallback(progress) } } }
-
导出训练后的模型
我们可以使用
ORTTrainingSession
类的exportModelForInference
方法导出训练后的模型。该方法接收模型应导出的路径和模型的输出名称。在这里,我们将模型导出到应用程序的
Library
目录。导出的模型将用于推理目的。func exportModelForInference() throws { guard let libraryDirectory = FileManager.default.urls(for: .libraryDirectory, in: .userDomainMask).first else { throw TrainerError.Error("Failed to find library directory ") } let modelPath = libraryDirectory.appendingPathComponent("inference_model.onnx").path try trainingSession.exportModelForInference(withOutputPath: modelPath, graphOutputNames: ["output"]) }
您可以在此处找到 Trainer
类的完整实现。
使用训练后的模型进行推理
VoiceIdentifier
类将处理使用训练后的模型进行推理。它将加载训练后的模型并对给定的音频数据执行推理。该类将具有 evaluate(inputData: Data) -> Result<(Bool, Float), Error>
方法,该方法将接收音频数据并返回推理结果。结果将是一个 (Bool, Float)
元组,其中第一个元素表示音频是否被识别为用户,第二个元素表示预测的置信度得分。
首先,我们使用 ORTSession
对象加载训练后的模型。
class VoiceIdentifier {
private let ortEnv : ORTEnv
private let ortSession: ORTSession
private let kThresholdProbability: Float = 0.80
enum VoiceIdentifierError: Error {
case Error(_ message: String)
}
init() throws {
ortEnv = try ORTEnv(loggingLevel: ORTLoggingLevel.warning)
guard let libraryDirectory = FileManager.default.urls(for: .libraryDirectory, in: .userDomainMask).first else {
throw VoiceIdentifierError.Error("Failed to find library directory ")
}
let modelPath = libraryDirectory.appendingPathComponent("inference_model.onnx").path
if !FileManager.default.fileExists(atPath: modelPath) {
throw VoiceIdentifierError.Error("Failed to find inference model file.")
}
ortSession = try ORTSession(env: ortEnv, modelPath: modelPath, sessionOptions: nil)
}
}
接下来,我们将编写 evaluate
方法。首先,它将获取音频数据并将其转换为 ORTValue
。然后,它将使用模型执行推理。最后,它将从输出中提取 logits 并应用 softmax 以获得概率。
private func isUser(logits: [Float]) -> Float {
// apply softMax
let maxInput = logits.max() ?? 0.0
let expValues = logits.map { exp($0 - maxInput) } // Calculate e^(x - maxInput) for each element
let expSum = expValues.reduce(0, +) // Sum of all e^(x - maxInput) values
return expValues.map { $0 / expSum }[1] // Calculate the softmax probabilities
}
func evaluate(inputData: Data) -> Result<(Bool, Float), Error> {
return Result<(Bool, Float), Error> { () -> (Bool, Float) in
// convert input data to ORTValue
let inputShape: [NSNumber] = [1, inputData.count / MemoryLayout<Float>.stride as NSNumber]
let input = try ORTValue(
tensorData: NSMutableData(data: inputData),
elementType: ORTTensorElementDataType.float,
shape: inputShape)
let outputs = try ortSession.run(
withInputs: ["input": input],
outputNames: ["output"],
runOptions: nil)
guard let output = outputs["output"] else {
throw VoiceIdentifierError.Error("Failed to get model output from inference.")
}
let outputData = try output.tensorData() as Data
let probUser = outputData.withUnsafeBytes { (buffer: UnsafeRawBufferPointer) -> Float in
let floatBuffer = buffer.bindMemory(to: Float.self)
let logits = Array(UnsafeBufferPointer(start: floatBuffer.baseAddress, count: outputData.count/MemoryLayout<Float>.stride))
return isUser(logits: logits)
}
return (probUser >= kThresholdProbability, probUser)
}
}
您可以在此处找到 VoiceIdentifier
类的完整实现。
录制音频
我们将使用 AudioRecorder
类通过麦克风录制音频。它将录制 10 秒的音频,并将音频数据作为 Data
对象输出,该对象可用于训练和推理目的。我们将使用 AVFoundation
框架来访问麦克风并录制音频。将有一个公共方法 record(callback: @escaping RecordingDoneCallback)
,它将录制音频并在录制完成后使用音频数据调用回调函数。
import AVFoundation
import Foundation
private let kSampleRate: Int = 16000
private let kRecordingDuration: TimeInterval = 10
class AudioRecorder {
typealias RecordResult = Result<Data, Error>
typealias RecordingDoneCallback = (RecordResult) -> Void
enum AudioRecorderError: Error {
case Error(message: String)
}
func record(callback: @escaping RecordingDoneCallback) {
let session = AVAudioSession.sharedInstance()
session.requestRecordPermission { allowed in
do {
guard allowed else {
throw AudioRecorderError.Error(message: "Recording permission denied.")
}
try session.setCategory(.record)
try session.setActive(true)
let tempDir = FileManager.default.temporaryDirectory
let recordingUrl = tempDir.appendingPathComponent("recording.wav")
let formatSettings: [String: Any] = [
AVFormatIDKey: kAudioFormatLinearPCM,
AVSampleRateKey: kSampleRate,
AVNumberOfChannelsKey: 1,
AVLinearPCMBitDepthKey: 16,
AVLinearPCMIsBigEndianKey: false,
AVLinearPCMIsFloatKey: false,
AVEncoderAudioQualityKey: AVAudioQuality.high.rawValue,
]
let recorder = try AVAudioRecorder(url: recordingUrl, settings: formatSettings)
self.recorder = recorder
let delegate = RecorderDelegate(callback: callback)
recorder.delegate = delegate
self.recorderDelegate = delegate
guard recorder.record(forDuration: kRecordingDuration) else {
throw AudioRecorderError.Error(message: "Failed to record.")
}
// control should resume in recorder.delegate.audioRecorderDidFinishRecording()
} catch {
callback(.failure(error))
}
}
}
private var recorderDelegate: RecorderDelegate?
private var recorder: AVAudioRecorder?
private class RecorderDelegate: NSObject, AVAudioRecorderDelegate {
private let callback: RecordingDoneCallback
init(callback: @escaping RecordingDoneCallback) {
self.callback = callback
}
func audioRecorderDidFinishRecording(
_ recorder: AVAudioRecorder,
successfully flag: Bool
) {
let recordResult = RecordResult { () -> Data in
guard flag else {
throw AudioRecorderError.Error(message: "Recording was unsuccessful.")
}
let recordingUrl = recorder.url
let recordingFile = try AVAudioFile(forReading: recordingUrl)
guard
let format = AVAudioFormat(
commonFormat: .pcmFormatFloat32,
sampleRate: recordingFile.fileFormat.sampleRate,
channels: 1,
interleaved: false)
else {
throw AudioRecorderError.Error(message: "Failed to create audio format.")
}
guard
let recordingBuffer = AVAudioPCMBuffer(
pcmFormat: format,
frameCapacity: AVAudioFrameCount(recordingFile.length))
else {
throw AudioRecorderError.Error(message: "Failed to create audio buffer.")
}
try recordingFile.read(into: recordingBuffer)
guard let recordingFloatChannelData = recordingBuffer.floatChannelData else {
throw AudioRecorderError.Error(message: "Failed to get float channel data.")
}
return Data(bytes: recordingFloatChannelData[0], count: Int(recordingBuffer.frameLength) * MemoryLayout<Float>.size)
}
callback(recordResult)
}
func audioRecorderEncodeErrorDidOccur(
_ recorder: AVAudioRecorder,
error: Error?
) {
if let error = error {
callback(.failure(error))
} else {
callback(.failure(AudioRecorderError.Error(message: "Encoding was unsuccessful.")))
}
}
}
}
训练视图
TrainView
将用于在用户的声音上训练模型。首先,它将提示用户录制 kNumRecordings
次他们的声音。然后,它将在用户的声音和一些预先录制的其他说话人声音的录音上训练模型。最后,它将导出训练后的模型以进行推理。
import SwiftUI
struct TrainView: View {
enum ViewState {
case recordingTrainingData, trainingInProgress, trainingComplete
}
private static let sentences = [
"In the embrace of nature's beauty, I find peace and tranquility. The gentle rustling of leaves soothes my soul, and the soft sunlight kisses my skin. As I breathe in the fresh air, I am reminded of the interconnectedness of all living things, and I feel a sense of oneness with the world around me.",
"Under the starlit sky, I gaze in wonder at the vastness of the universe. Each twinkle represents a story yet untold, a dream yet to be realized. With every new dawn, I am filled with hope and excitement for the opportunities that lie ahead. I embrace each day as a chance to grow, to learn, and to create beautiful memories.",
"A warm hug from a loved one is a precious gift that warms my heart. In that tender embrace, I feel a sense of belonging and security. Laughter and tears shared with dear friends create a bond that withstands the test of time. These connections enrich my life and remind me of the power of human relationships.",
"Life's journey is like a beautiful melody, with each note representing a unique experience. As I take each step, I harmonize with the rhythm of existence. Challenges may come my way, but I face them with resilience and determination, knowing they are opportunities for growth and self-discovery.",
"With every page turned in a book, I open the door to new worlds and ideas. The written words carry the wisdom of countless souls, and I am humbled by the knowledge they offer. In stories, I find a mirror to my own experiences and a beacon of hope for a better tomorrow.",
"Life's trials may bend me, but they will not break me. Through adversity, I discover the strength within my heart. Each obstacle is a chance to learn, to evolve, and to emerge as a better version of myself. I am grateful for every lesson, for they shape me into the person I am meant to be.",
"The sky above is an ever-changing canvas of colors and clouds. In its vastness, I realize how small I am in the grand scheme of things, and yet, I know my actions can ripple through the universe. As I walk this Earth, I seek to leave behind a positive impact and a legacy of love and compassion.",
"In the stillness of meditation, I connect with the depth of my soul. The external noise fades away, and I hear the whispers of my inner wisdom. With each breath, I release tension and embrace serenity. Meditation is my sanctuary, a place where I can find clarity and renewed energy.",
"Kindness is a chain reaction that spreads like wildfire. A simple act of compassion can brighten someone's day and inspire them to pay it forward. Together, we can create a wave of goodness that knows no boundaries, reaching even the farthest corners of the world.",
"As the sun rises on a new day, I am filled with gratitude for the gift of life. Every moment is a chance to make a difference, to love deeply, and to embrace joy. I welcome the adventures that await me and eagerly embrace the mysteries yet to be uncovered."
]
private let kNumRecordings = 5
private let audioRecorder = AudioRecorder()
private let trainer = try! Trainer()
@State private var trainingData: [Data] = []
@State private var viewState: ViewState = .recordingTrainingData
@State private var readyToRecord: Bool = true
@State private var trainingProgress: Double = 0.0
private func recordVoice() {
audioRecorder.record { recordResult in
switch recordResult {
case .success(let recordingData):
trainingData.append(recordingData)
print("Successfully completed Recording")
case .failure(let error):
print("Error: \(error)")
}
readyToRecord = true
if trainingData.count == kNumRecordings {
viewState = .trainingInProgress
trainAndExportModel()
}
}
}
private func updateProgressBar(progress: Double) {
DispatchQueue.main.async {
trainingProgress = progress
}
}
private func trainAndExportModel() {
Task {
do {
try trainer.train(trainingData, progressCallback: updateProgressBar)
try trainer.exportModelForInference()
DispatchQueue.main.async {
viewState = .trainingComplete
print("Training is complete")
}
} catch {
DispatchQueue.main.async {
viewState = .trainingComplete
print("Training Failed: \(error)")
}
}
}
}
var body: some View {
VStack {
switch viewState {
case .recordingTrainingData:
Text("\(trainingData.count + 1) of \(kNumRecordings)")
.font(.caption)
.foregroundColor(.secondary)
.padding()
ProgressView(value: Double(trainingData.count),
total: Double(kNumRecordings))
.progressViewStyle(LinearProgressViewStyle(tint: .purple))
.frame(height: 10)
.cornerRadius(5)
Spacer()
Text(TrainView.sentences[trainingData.count % TrainView.sentences.count])
.font(.body)
.padding()
.multilineTextAlignment(.center)
.fontDesign(.monospaced)
Spacer()
ZStack(alignment: .center) {
Image(systemName: "mic.fill")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 100, height: 100)
.foregroundColor( readyToRecord ? .gray: .red)
.transition(.scale)
.animation(.easeIn, value: 1)
}
Spacer()
Button(action: {
readyToRecord = false
recordVoice()
}) {
Text(readyToRecord ? "Record" : "Recording ...")
.font(.title)
.padding()
.background(readyToRecord ? .green : .gray)
.foregroundColor(.white)
.cornerRadius(10)
}.disabled(!readyToRecord)
case .trainingInProgress:
VStack {
Spacer()
ProgressView(value: trainingProgress,
total: 1.0,
label: {Text("Training")},
currentValueLabel: {Text(String(format: "%.0f%%", trainingProgress * 100))})
.padding()
Spacer()
}
case .trainingComplete:
Spacer()
Text("Training successfully finished!")
.font(.title)
.padding()
.multilineTextAlignment(.center)
.fontDesign(.monospaced)
Spacer()
NavigationLink(destination: InferView()) {
Text("Infer")
.font(.title)
.padding()
.background(.purple)
.foregroundColor(.white)
.cornerRadius(10)
}
.padding(.leading, 20)
}
Spacer()
}
.padding()
.navigationTitle("Train")
}
}
struct TrainView_Previews: PreviewProvider {
static var previews: some View {
TrainView()
}
}
您可以在此处找到 TrainView
的完整实现。
推理视图
最后,我们将创建 InferView
,它将用于使用训练后的模型执行推理。它将提示用户录制他们的声音并使用训练后的模型执行推理。然后,它将显示推理结果。
import SwiftUI
struct InferView: View {
enum InferResult {
case user, other, notSet
}
private let audioRecorder = AudioRecorder()
@State private var voiceIdentifier: VoiceIdentifier? = nil
@State private var readyToRecord: Bool = true
@State private var inferResult: InferResult = InferResult.notSet
@State private var probUser: Float = 0.0
@State private var showAlert = false
@State private var alertMessage = ""
private func recordVoice() {
audioRecorder.record { recordResult in
let recognizeResult = recordResult.flatMap { recordingData in
return voiceIdentifier!.evaluate(inputData: recordingData)
}
endRecord(recognizeResult)
}
}
private func endRecord(_ result: Result<(Bool, Float), Error>) {
DispatchQueue.main.async {
switch result {
case .success(let (isMatch, confidence)):
print("Your Voice with confidence: \(isMatch), \(confidence)")
inferResult = isMatch ? .user : .other
probUser = confidence
case .failure(let error):
print("Error: \(error)")
}
readyToRecord = true
}
}
var body: some View {
VStack {
Spacer()
ZStack(alignment: .center) {
Image(systemName: "mic.fill")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 100, height: 100)
.foregroundColor( readyToRecord ? .gray: .red)
.transition(.scale)
.animation(.easeInOut, value: 1)
}
Spacer()
Button(action: {
readyToRecord = false
recordVoice()
}) {
Text(readyToRecord ? "Record" : "Recording ...")
.font(.title)
.padding()
.background(readyToRecord ? .green : .gray)
.foregroundColor(.white)
.cornerRadius(10)
}.disabled(voiceIdentifier == nil || !readyToRecord)
.opacity(voiceIdentifier == nil ? 0.5: 1.0)
if inferResult != .notSet {
Spacer()
ZStack (alignment: .center) {
Image(systemName: inferResult == .user ? "person.crop.circle.fill.badge.checkmark": "person.crop.circle.fill.badge.xmark")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 100, height: 100)
.foregroundColor(inferResult == .user ? .green : .red)
.animation(.easeInOut, value: 2)
}
Text("Probability of User : \(String(format: "%.2f", probUser*100.0))%")
.multilineTextAlignment(.center)
.fontDesign(.monospaced)
}
Spacer()
}
.padding()
.navigationTitle("Infer")
.onAppear {
do {
voiceIdentifier = try VoiceIdentifier()
} catch {
alertMessage = "Error initializing inference session, make sure that training is completed: \(error)"
showAlert = true
}
}
.alert(isPresented: $showAlert) {
Alert(title: Text("Error"), message: Text(alertMessage), dismissButton: .default(Text("OK")))
}
}
}
struct InferView_Previews: PreviewProvider {
static var previews: some View {
InferView()
}
}
您可以在此处找到 InferView
的完整实现。
内容视图
最后,我们将更新默认的 ContentView
,使其包含导航到 TrainView
和 InferView
的按钮。
import SwiftUI
struct ContentView: View {
var body: some View {
NavigationView {
VStack {
Text("My Voice")
.font(.largeTitle)
.padding(.top, 50)
Spacer()
ZStack(alignment: .center) {
Image(systemName: "waveform.circle.fill")
.resizable()
.aspectRatio(contentMode: .fit)
.frame(width: 100, height: 100)
.foregroundColor(.purple)
}
Spacer()
HStack {
NavigationLink(destination: TrainView()) {
Text("Train")
.font(.title)
.padding()
.background(Color.purple)
.foregroundColor(.white)
.cornerRadius(10)
}
.padding(.trailing, 20)
NavigationLink(destination: InferView()) {
Text("Infer")
.font(.title)
.padding()
.background(.purple)
.foregroundColor(.white)
.cornerRadius(10)
}
.padding(.leading, 20)
}
Spacer()
}
.padding()
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}
您可以在此处找到 ContentView
的完整实现。
运行 iOS 应用程序
现在,我们准备运行应用程序。您可以在模拟器或设备上运行应用程序。您可以在此处找到有关在模拟器和设备上运行应用程序的更多信息。
a. 现在,当您运行应用程序时,您应该看到以下屏幕
b. 接下来,点击 Train
按钮导航到 TrainView
。TrainView
将提示您录制您的声音。您需要录制您的声音 kNumRecordings
次。
c. 一旦所有录音完成,应用程序将根据给定的数据训练模型。您将看到进度条指示训练进度。
d. 训练完成后,您将看到以下屏幕
e. 现在,点击 Infer
按钮导航到 InferView
。InferView
将提示您录制您的声音。录制完成后,它将使用训练后的模型执行推理并显示推理结果。
就是这样!希望它正确识别了您的声音。
结论
恭喜!您已成功构建了一个 iOS 应用程序,该应用程序可以使用设备端训练技术训练简单的音频分类模型。您现在可以使用该应用程序在您自己的声音上训练模型,并使用训练后的模型执行推理。该应用程序也可在 GitHub 上获得,地址为 onnxruntime-training-examples