当前位置: 首页 > news >正文

使用TensorFlow高级别的API进行编程

  这里涉及到的高级别API主要是使用Estimator类来编写机器学习的程序,此外你还需要用到一些数据导入的知识。

为什么使用Estimator

  Estimator类是定义在tf.estimator.Estimator中的,你可以使用其中已经有的Estimator,叫做预创建的Estimator,也可以自定义Estimator。Estimator已经封装了训练(train),评估(evaluate),预测(predict),导出以供使用等方法。

  此外,Estimator会为我们提供诸如图构建、创建session等管道工作,不用我们再做这些重复的工作。它还提供了安全的分布式训练循环。相比于低级的API,我们可以把大部分的时间和精力放在处理数据、训练模型、调整参数上面,而不是创建张量、构建图、使用session运行张量上面。

使用Estimator的步骤

1:需要编写一个数据输入的函数input_fn

  input_fn是输入函数,这个函数的作用在于对数据进行预处理,并且在模型train,predict,evaluate的时候给模型送进去数据。所以input_fn主要作用的时机在模型训练、预测和评估的时候,在模型定义的时候不需要传入输入函数,而是传入一个预定义的特征列。可以使用系统自带的函数,可以编写自定义的输入函数。

使用系统自带的数据输入函数:

  系统自带的输入函数为tf.estimator.inputs.numpy_input_fn,它的输入参数如下:

def numpy_input_fn(x,
                   y=None,
                   batch_size=128,
                   num_epochs=1,
                   shuffle=None,
                   queue_capacity=1000,
                   num_threads=1)

  x为numpy数组或者numpy数组的字典,当为numpy数组的时候,这个数组被当做单一的特征对待。

  一个例子如下,这个例子是tf.estimator Quickstart tutorial中的一段代码:

import numpy as np

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32)

train_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": np.array(training_set.data)},
    y=np.array(training_set.target),
    num_epochs=None,
    shuffle=True)

classifier.train(input_fn=train_input_fn, steps=2000)
 

自定义导入数据的函数:

  要自定义导入函数,要知道tensorflow中关于数据的概念,以及知道自定义的函数应该返回的值,下面我将梳理一下这里面的概念:

自定义函数的基本框架以及返回值

def my_input_fn():

    # 在这里进行数据的预处理...

    # ...返回两个值 1) 一个由特征列和包含特征的Tensors组成的映射(字典) 2) 一个包含labels的Tensor
    return feature_cols, labels

  自定义函数需要返回两个值,一个值是feature_cols,是一个字典,其中字典的key为特征的列名称,字典的value为包含特征值的Tensor对象。labels是一个包含标签值的Tensor对象。

tf.data.API对于数据的两个抽象:

  使用tf.data.API来构建数据输入的管道,帮助我们导入数据,无论是图像,文本还是分布式的数据,都可以用它来完成。

  一个抽象的概念是tf.data.Dataset,一个Dataset是一个数据集,它是由一系列的元素组成的,每个元素的类型都是相同的。其中每个元素包含一个或者多个Tensor对象。我们可以以两种方式来创建Dataset对象,一种方式是创建它的来源,比如使用Dataset.from_tensor_slices(),可以使用张量来创建Dataset对象,另外一种方式是运用转换的方式,可以将一个Dataset来变成另外一个Dataset,比如Dataset.batch()。

  另外一个抽象的概念是tf.data.Iterator,它代表的是迭代器。表示的是如何从数据集里面取出元素,最简答的迭代器是单次迭代器,Dataset.make_one_shot_iterator()可以创建单次迭代器。创建迭代器以后,可以使用Iterator.get_next()来获取下一个元素。

其它的创建数据集的方法:

  Dataset.from_tensor()创建一个Dataset,并将传入的Tensor当做一个元素。 Dataset.from_tensor_slices()会创建一个Dataset,并且将传入的Tensor在第0维上面切面,分成一些列的元素。还可以使用TFRecordDataset来获得磁盘上面TFRecord格式的数据。

其它的创建迭代器的方法:

  除了dataset.make_one_shot_iterator()这种单次迭代器以外,你还可以创建可初始化、可重新初始化、可馈送迭代器。

导入数据集的基本的工作机制:

1:创建Dataset对象 –> 2:将Dataset进行转化 –> 3:创建迭代器 –> 4:用迭代器返回下一个元素。

  下面用一个例子来说明一下:

from tensorflow.python.data import Dataset
import numpy as np
def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):
    """自定义的输入函数

    Args:
      features: 使用pandas中的DataFrame对象来表示的features
      targets: 使用pandas的taFrame对象表示的targets
      batch_size: 批次的大小
      shuffle: 是否将数据进行重新打乱
      num_epochs: 需要重复的epochs的数量,一个epochs代表一个训练周期. None = repeat indefinitely
    Returns:
      下一批次数据的元组 (features, labels)
    """

    # 将pandas对象转换为字典,其中字典的值为numpy的数组
    features = {key:np.array(value) for key,value in dict(features).items()}

    # 创建一个Dataset,并且设置好批次和重复的次数
    ds = Dataset.from_tensor_slices((features,targets)) # warning: 2GB limit
    ds = ds.batch(batch_size).repeat(num_epochs)

    # 是否进行数据扰动
    if shuffle:
        ds = ds.shuffle(10000)

    # 返回下个批次的数据
    features, labels = ds.make_one_shot_iterator().get_next()
    return features, labels

  上面自定义了数据导入的函数,使用Dataset.from_tensor_slices()来创建Dataset。然后使用batch、repeat、shuffle进行转换。 接着创建迭代器,并且获得下一个元素。

 

 

2:定义特征列

  使用tf.feature_column来标识特征名称、类型和任何输入预处理。

  特征列在原始数据和模型之间起到了连接的作用。在编写模型的时候需要预先确定输入数据的特征列。

  比如包含经度和维度两个特征的特征列,它们都是数值类型,这个特征列在模型定义的时候需要传入:

import tensorflow as tf
longitude = tf.feature_column.numeric_column('longitude')
latitude = tf.feature_column.numeric_column('latitude')
feature_column = [longitude, latitude]

inputs_to_model_bridge

特征列在原始数据与模型所需的数据之间架起了桥梁。

3:实例化相关的预创建的Estimator

  这个步骤就简单了,以深度学习模型为例,运用上面创建的经纬度特征列,使用10*10的隐层创建一个深度神经网络的回归模型:

hidden_units = [10, 10]
dnn_regressor = tf.estimator.DNNRegressor(
    feature_columns=feature_columns,
    hidden_units=hidden_units,
)

4:调用训练、评估或推理方法

  使用上述创建的模型进行train、evaluate、predict操作。首先需要定理训练的输入函数,将训练集的特征和标签都传进去,然后开始训练,例子如下:

training_input_fn = lambda:my_input_fn(train_df, train_target_df)
dnn_regressor.train(
    input_fn=training_input_fn,
    steps=300
)
 
 
 
 
 
 
 
 

参考:

Estimator 高级的API,介绍了创建estimator的流程

导入数据  介绍了数据集,还有迭代器的知识

Building Input Functions with tf.estimator  讲解了如何定义输入函数

特征列  详细介绍了特征列,里面有9中特征列可以学习

google机器学习速成课程的神经网络简介 ,完整的机器学习过程

 

转载于:https://www.cnblogs.com/jiaxin359/p/9092564.html

相关文章:

  • Java知识点总结(Java容器-Vector)
  • mybatis返回部分字段为空的问题
  • Confluence 6 SQL Server 数据库驱动修改
  • Python常见问题系列
  • ES6系列(二)变量的解构赋值
  • 超简单的视频对象提取程序
  • [Java并发编程实战] 共享对象之可见性
  • Java实用类库
  • MySQL常见的两种存储引擎:MyISAM与InnoDB的爱恨情仇
  • 『TensorFlow』线程控制器类变量作用域
  • Git漏洞导致攻击者可在用户电脑上运行任意代码
  • [译] 不用祖传秘方 - 写好代码的几个小技巧
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • 安装Cassandra数据库和访问客户端配置
  • CSS中background-position使用技巧
  • Google 是如何开发 Web 框架的
  • 「面试题」如何实现一个圣杯布局?
  • CSS实用技巧
  • ES6之路之模块详解
  • js面向对象
  • LeetCode刷题——29. Divide Two Integers(Part 1靠自己)
  • Netty 4.1 源代码学习:线程模型
  • Spark VS Hadoop:两大大数据分析系统深度解读
  • SpiderData 2019年2月16日 DApp数据排行榜
  • 设计模式 开闭原则
  • Hibernate主键生成策略及选择
  • 选择阿里云数据库HBase版十大理由
  • ​批处理文件中的errorlevel用法
  • ​香农与信息论三大定律
  • #### go map 底层结构 ####
  • (27)4.8 习题课
  • (C#)一个最简单的链表类
  • (cos^2 X)的定积分,求积分 ∫sin^2(x) dx
  • (day 2)JavaScript学习笔记(基础之变量、常量和注释)
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第5节(封闭类和Final方法)
  • (附源码)spring boot火车票售卖系统 毕业设计 211004
  • (一)VirtualBox安装增强功能
  • (转)shell中括号的特殊用法 linux if多条件判断
  • (转)Unity3DUnity3D在android下调试
  • (转载)(官方)UE4--图像编程----着色器开发
  • .NET 3.0 Framework已经被添加到WindowUpdate
  • .Net core 6.0 升8.0
  • .net 逐行读取大文本文件_如何使用 Java 灵活读取 Excel 内容 ?
  • .w文件怎么转成html文件,使用pandoc进行Word与Markdown文件转化
  • ::什么意思
  • [ 攻防演练演示篇 ] 利用通达OA 文件上传漏洞上传webshell获取主机权限
  • [BT]BUUCTF刷题第9天(3.27)
  • [C# 开发技巧]如何使不符合要求的元素等于离它最近的一个元素
  • [CC2642R1][VSCODE+Embedded IDE+IAR Build+Cortex-Debug] TI CC2642R1基于VsCode的开发环境
  • [DM复习]关联规则挖掘(下)
  • [Foreman]解决Unable to find internal system admin account
  • [GN] DP学习笔记板子
  • [Go WebSocket] 多房间的聊天室(五)用多个小锁代替大锁,提高效率
  • [IE技巧] 使IE8以单进程的模式运行
  • [java] 23种设计模式之责任链模式