当前位置: 首页 > news >正文

专业做国际网站网站开发的编程软件

专业做国际网站,网站开发的编程软件,企业查询系统官网入口,如何做英文系统下载网站介绍#xff1a;上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型#xff08;仅使用训练集#xff09;#xff0c;为了保证模型可以泛化到未见过的数据上#xff0c;数据集通常被分为训练和测试两个集合#xff0c;测试集与训练集相互独立#xff0c;用以测试模… 介绍上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型仅使用训练集为了保证模型可以泛化到未见过的数据上数据集通常被分为训练和测试两个集合测试集与训练集相互独立用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的同时还引入checkpoint和早停策略以得到模型最佳权重。 相关链接https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html 训练集、验证集、测试集的使用 1.添加依赖获取训练集和测试集 添加相应的依赖同时使用MNIST数据集获取训练和测试集 import torch.utils.data as data from torchvision import datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader# 加载数据测试集trainFalse transform transforms.ToTensor() train_set datasets.MNIST(rootMNIST, downloadTrue, trainTrue, transformtransform) test_set datasets.MNIST(rootMNIST, downloadTrue, trainFalse, transformtransform)2.实现并调用test_step 在定义LightningModule中实现test_step方法在外部调用test方法 class LitAutoEncoder(pl.LightningModule):def training_step(self, batch, batch_idx):...def test_step(self, batch, batch_idx): # 测试该方法与training_step相似x, y batchx x.view(x.size(0), -1)z self.encoder(x)x_hat self.decoder(z)test_loss F.mse_loss(x_hat, x)self.log(test_loss, test_loss)# 初始化Trainer trainer Trainer()# 执行test方法 trainer.test(model, dataloadersDataLoader(test_set))3.实现并调用验证集 通常使用torch.utils.data中的方法将训练集中的一部分数据化为验证集 # 训练集中的20%数据划为验证集 train_set_size int(len(train_set) * 0.8) valid_set_size len(train_set) - train_set_size# 拆分使用data.random_split方法 seed torch.Generator().manual_seed(42) train_set, valid_set data.random_split(train_set, [train_set_size, valid_set_size], generatorseed)与测试集一样需要在定义LightningModule中实现validation_step方法在外部调用fit方法 class LitAutoEncoder(pl.LightningModule):def training_step(self, batch, batch_idx):...def validation_step(self, batch, batch_idx):x, y batchx x.view(x.size(0), -1)z self.encoder(x)x_hat self.decoder(z)val_loss F.mse_loss(x_hat, x)self.log(val_loss, val_loss)def test_step(self, batch, batch_idx):... # 调用torch.utils.data中的DataLoader对训练和测试集进行封装 train_loader DataLoader(train_set) valid_loader DataLoader(valid_set)# 在fit方法中引入valid_loader即验证集 trainer Trainer() trainer.fit(model, train_loader, valid_loader)checkpoint checkpoint有两个作用一是能得到每一次epoch后的模型权重能得到最佳表现的权重二是能够在中断或停止后继续在当前checkpoint处继续训练。在Lightning中的checkpoint包含模型的整个内部状态这与普通的PyTorch不同即使在最复杂的分布式训练环境中Lightning也可以保存恢复模型所需的一切。包含以下状态 16-bit scaling factor (若使用16精度训练)Current epochGlobal stepLightningModule’s state_dictState of all optimizersState of all learning rate schedulersState of all callbacks (for stateful callbacks)State of datamodule (for stateful datamodules)The hyperparameters (init arguments) with which the model was createdThe hyperparameters (init arguments) with which the datamodule was createdState of Loops 保存与调用方法 # 保存方法可自定义default_root_dir路径若不设置路径将会自动保存 trainer Trainer(default_root_dirsome/path/)# 调用方法 model MyLightningModule.load_from_checkpoint(/path/to/checkpoint.ckpt) model.eval() # disable randomness, dropout, etc... y_hat model(x) 调用还可以使用torch的方法 checkpoint torch.load(checkpoint, map_locationlambda storage, loc: storage) print(checkpoint[hyper_parameters]) # {learning_rate: the_value, another_parameter: the_other_value}也可以实现重现例如模型LitModel(in_dim32, out_dim10) # 使用 in_dim32, out_dim10 model LitModel.load_from_checkpoint(PATH) # 使用 in_dim128, out_dim10 model LitModel.load_from_checkpoint(PATH, in_dim128, out_dim10)Lightning和PyTorch完全兼容 checkpoint torch.load(CKPT_PATH) encoder_weights checkpoint[encoder] decoder_weights checkpoint[decoder]设置checkpoint不可见 trainer Trainer(enable_checkpointingFalse)如果想全部重新恢复 model LitModel() trainer Trainer()自动恢复所有相关参数 model, epoch, step, LR schedulers, etc… trainer.fit(model, ckpt_pathsome/path/to/my_checkpoint.ckpt)早停策略 EarlyStopping Callback 在Lightning中早停回调步骤如下 Import EarlyStopping callback. 载入EarlyStopping回调方法Log the metric you want to monitor using log() method. 加载日志方法Init the callback, and set monitor to the logged metric of your choice. 设置monitorSet the mode based on the metric needs to be monitored. 设置modePass the EarlyStopping callback to the Trainer callbacks flag. 调入EarlyStropping from lightning.pytorch.callbacks.early_stopping import EarlyStoppingclass LitModel(LightningModule):def validation_step(self, batch, batch_idx):loss ...self.log(val_loss, loss)model LitModel() trainer Trainer(callbacks[EarlyStopping(monitorval_loss, modemin)]) trainer.fit(model)# 也可以使用自定义的EarlyStopping策略 early_stop_callback EarlyStopping(monitorval_accuracy, min_delta0.00, patience3, verboseFalse, modemax) trainer Trainer(callbacks[early_stop_callback]) # EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping注意 EarlyStopping默认在一次Validation后调用但是Validation可以自定义多少次epoch后进行一次验证例如check_val_every_n_epoch and val_check_interval。 完整代码 # coding:utf-8 import torch import torch.nn as nn import torch.utils.data as data from torchvision import datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader import torch.nn.functional as F import lightning as L# -------------------------------- # Step 1: 定义模型 # -------------------------------- class LitAutoEncoder(L.LightningModule):def __init__(self):super().__init__()self.encoder nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))self.decoder nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))def training_step(self, batch, batch_idx):x, y batchx x.view(x.size(0), -1)z self.encoder(x)x_hat self.decoder(z)loss F.mse_loss(x_hat, x)self.log(train_loss, loss)return lossdef test_step(self, batch, batch_idx): # 测试该方法与training_step相似x, y batchx x.view(x.size(0), -1)z self.encoder(x)x_hat self.decoder(z)test_loss F.mse_loss(x_hat, x)self.log(test_loss, test_loss)def validation_step(self, batch, batch_idx):x, y batchx x.view(x.size(0), -1)z self.encoder(x)x_hat self.decoder(z)val_loss F.mse_loss(x_hat, x)self.log(val_loss, val_loss)def configure_optimizers(self):optimizer torch.optim.Adam(self.parameters(), lr1e-3)return optimizerdef forward(self, x):# forward 定义了一次 预测/推理 行为embedding self.encoder(x)return embedding # -------------------------------- # Step 2: 加载数据模型 # -------------------------------- transform transforms.ToTensor() train_set datasets.MNIST(rootMNIST, downloadTrue, trainTrue, transformtransform) test_set datasets.MNIST(rootMNIST, downloadTrue, trainFalse, transformtransform)# 训练集中的20%数据划为验证集 train_set_size int(len(train_set) * 0.8) valid_set_size len(train_set) - train_set_size# 拆分使用data.random_split方法 seed torch.Generator().manual_seed(42) train_set, valid_set data.random_split(train_set, [train_set_size, valid_set_size], generatorseed) train_loader DataLoader(train_set) valid_loader DataLoader(valid_set)autoencoder LitAutoEncoder() # -------------------------------- # Step 3: 训练验证测试 # -------------------------------- # 训练验证 trainer L.Trainer(default_root_dirsome/path/) # 这里自定义需要保存的路径 trainer.fit(autoencoder, train_loader, valid_loader)# 测试 trainer.test(autoencoder, dataloadersDataLoader(test_set))
http://www.zqtcl.cn/news/303197/

相关文章:

  • 如何运营垂直网站网页工具大全
  • 如何让自己做的网站可以播放歌曲做培训网站
  • 做网站的毕业设计网站没备案怎么做淘宝客
  • 百度申诉网站建设银行住房租赁代表品牌是什么
  • 网站初期推广方案虚拟服务器搭建wordpress
  • jeecms可以做网站卖吗山西网络推广专业
  • 2017 如何做网站优化育儿哪个网站做的好
  • 网站制作容易吗青岛网站建设公司报价
  • 淘宝建设网站的好处网站制作结构
  • 网站开发网站建设公司临沂网站建设找谁
  • 咋么做网站在电脑上潍坊免费模板建站
  • 苏州网站建设推广咨询平台做网站的公司图
  • 北京企业网站怎么建设免费给我推广
  • 网站制作价钱多少专业的咨询行业网站制作
  • 做百度网站每年的费用多少交换友情链接时需要注意的事项
  • 怎么在百度网站上做自己的网站百度开户渠道
  • php技术的网站建设实录方案做二手手机的网站有哪些
  • 做网站店铺装修的软件怎么做淘课网站
  • 百度一下官方网站wordpress连接代码
  • 什么网站详情页做的好仿唧唧帝笑话门户网站源码带多条采集规则 织梦搞笑图片视频模板
  • 平原网站建设费用少儿编程加盟店倒闭
  • 企业网站建设专业公司蜜淘app在那个网站做的
  • 市住房城乡建设部网站大学生课程设计网站
  • 广州大石附近做网站的公司外包服务公司是干什么的
  • 做的新网站网上搜不到做的网站百度搜索不出来的
  • 电商网站后台报价公司如何建站
  • 查网站有没有做推广企业网站建设的目标
  • 北京网站维护公司专业外贸网站建设_诚信_青岛
  • 网站自己做还是用程序制作网站一般使用的软件有哪些
  • 晨雷文化传媒网站建设济南互联网品牌设计