首页 >> 大全

FATE —— 二.2.1 Homo-NN自定义数据集

2023-12-19 大全 25 作者:考证青年

前言

FATE系统主要支持表格数据作为其标准数据格式。然而,通过使用NN模块的数据集特性,可以在神经网络中使用非表格数据,例如图像、文本、混合数据或关系数据。NN模块中的数据集模块允许定制数据集,以用于更复杂的数据场景。本教程将介绍Homo NN模块中数据集功能的使用,并提供如何自定义数据集的指导。我们将使用MNIST手写识别任务作为示例来说明这些概念。

准备MNIST数据

请从以下链接下载MNIST数据集,并将其放在项目示例/数据文件夹中:MNIST

这是MNIST数据集的简化版本,共有十个类别,根据标签分为0-9 10个文件夹。我们对数据集进行采样以减少样本数量。

MNIST数据集的来源是:

数据集代码介绍

在FATE-1.10版本中,FATE为数据集引入了一个新的基类,称为,它基于的类。此类允许用户根据其特定需求创建自定义数据集。其用法与的类类似,在使用FATE-NN进行数据读取和训练时,需要实现两个额外的接口:load()和()。

要在Homo NN中创建自定义数据集,用户需要:

对于不熟悉的数据集类的人,可以在文档中找到更多信息:数据集文档

load()

所需的第一个附加接口是load()。此接口接收文件路径,并允许用户直接从本地文件系统读取数据。提交任务时,可以通过读取器组件指定数据路径。Homo NN将使用用户指定的类,利用load()接口从指定路径读取数据,并完成数据集的加载以进行训练。有关更多信息,请参阅//nn//base.py中的源代码。

()

第二个附加接口是()。此接口应返回一个样本ID列表,该列表可以是整数或字符串,并且长度应与数据集相同。实际上,当使用Homo NN时,您可以跳过实现这个接口,因为Homo NN组件将自动为样本生成ID。

示例:实现一个简单的图像数据集

为了更好地理解数据集的定制,我们在这里实现了一个简单的图像数据集来读取MNIST图像,然后在横向场景中完成联合图像分类任务。为了方便起见,我们使用的接口来更新代码以.nn.,名为.py,当然,您可以手动将代码文件复制到目录中

: ()

from pipeline.component.nn import save_to_fate

MNIST数据集

_定义集合数据对象_定义集合

这里我们实现了数据集,并使用()保存它。

%%save_to_fate dataset mnist_dataset.py
import numpy as np
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transformsclass MNISTDataset(Dataset):def __init__(self, flatten_feature=False): # flatten feature or not super(MNISTDataset, self).__init__()self.image_folder = Noneself.ids = Noneself.flatten_feature = flatten_featuredef load(self, path):  # read data from path, and set sample ids# read using ImageFolderself.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))# filename as the image idids = []for image_name in self.image_folder.imgs:ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))self.ids = idsreturn selfdef get_sample_ids(self):  # implement the get sample id interface, simply return idsreturn self.idsdef __len__(self,):  # return the length of the datasetreturn len(self.image_folder)def __getitem__(self, idx): # get itemret = self.image_folder[idx]if self.flatten_feature:img = ret[0][0].flatten() # return flatten tensor 784-dimreturn img, ret[1] # return tensor and labelelse:return ret

在我们实现数据集之后,我们可以在本地测试它:

from federatedml.nn.dataset.mnist_dataset import MNISTDatasetds = MNISTDataset(flatten_feature=True)

# load MNIST data and check 
ds.load('/mnt/hgfs/YOLOV5/mnist/')  # 切换成自己下载上文中minist文件夹的地址
print(len(ds))
print(ds[0])
print(ds.get_sample_ids()[0])

测试数据集

在提交任务之前,可以在本地进行测试。正如我们在2.1 Homo NN 二进制分类任务中提到的,在Homo NN中,FATE默认使用。自定义数据集、模型和训练器可用于本地调试,以测试程序是否正确运行。请注意,在本地测试期间,将跳过所有联合过程,并且模型不会执行联合平均。

from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer
trainer = FedAVGTrainer(epochs=3, batch_size=256, shuffle=True, data_loader_worker=8, pin_memory=False) # set parameter

trainer.local_mode() # !! Be sure to enable local_mode to skip the federation process !!

import torch as t
from pipeline import fate_torch_hook
fate_torch_hook(t)
# our simple classification model:
model = t.nn.Sequential(t.nn.Linear(784, 32),t.nn.ReLU(),t.nn.Linear(32, 10),t.nn.Softmax(dim=1)
)trainer.set_model(model) # set model

optimizer = t.optim.Adam(model.parameters(), lr=0.01)  # optimizer
loss = t.nn.CrossEntropyLoss()  # loss function
trainer.train(train_set=ds, optimizer=optimizer, loss=loss)  # use dataset we just developed

在的train()函数中,将使用 迭代数据集。程序可以正确运行!现在我们可以提交联合任务了。

使用数据集提交任务 导入组件

import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Modelt = fate_torch_hook(t)

将数据路径绑定到名称和命名空间

这里,我们使用将路径绑定到名称和命名空间。然后,我们可以使用读取器组件将此路径传递到数据集的“加载”接口。培训师将在train()中获取此数据集,并使用 对其进行迭代。请注意,在本教程中,我们使用的是独立版本,如果您使用的是集群版本,则需要将数据与每台计算机上的相应名称和命名空间绑定。

import os
# bind data path to name & namespace
fate_project_path = os.path.abspath('../')
host = 10000
guest = 9999
arbiter = 10000
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,arbiter=arbiter)data_0 = {"name": "mnist_guest", "namespace": "experiment"}
data_1 = {"name": "mnist_host", "namespace": "experiment"}# 这里需要根据自己得版本作出调整,否则文件参数上传失败会报错
data_path_0 = fate_project_path + '/examples/data/mnist_train'
data_path_1 = fate_project_path + '/examples/data/mnist_train'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)

{'': '', '': ''}

reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)

数据集参数

使用指定数据集的模块名称,并在后面填写其参数,这些参数将传递给数据集的接口。请注意,数据集参数需要是JSON可序列化的,否则无法解析它们。

from pipeline.component.nn import DatasetParamdataset_param = DatasetParam(dataset_name='mnist_dataset', flatten_feature=True)  # specify dataset, and its init parameters

from pipeline.component.homo_nn import TrainerParam  # Interface# our simple classification model:
model = t.nn.Sequential(t.nn.Linear(784, 32),t.nn.ReLU(),t.nn.Linear(32, 10),t.nn.Softmax(dim=1)
)nn_component = HomoNN(name='nn_0',model=model, # modelloss=t.nn.CrossEntropyLoss(),  # lossoptimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizerdataset=dataset_param,  # datasettrainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),torch_seed=100 # random seed)

pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))

pipeline.compile()
pipeline.fit()

pipeline.get_component('nn_0').get_output_data()

pipeline.get_component('nn_0').get_summary()

{'': 1,

'': [3., 3.],

'': {'train': {'': [0.,

0.49586],

'': [0.37323, 0.],

'': [0., 0.]}},

'': False}

关于我们

最火推荐

小编推荐

联系我们


版权声明:本站内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 88@qq.com 举报,一经查实,本站将立刻删除。备案号:桂ICP备2021009421号
Powered By Z-BlogPHP.
复制成功
微信号:
我知道了