博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践
阅读量:5845 次
发布时间:2019-06-18

本文共 19967 字,大约阅读时间需要 66 分钟。

hot3.png

分布式TensorFlow由高性能gRPC库底层技术支持。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。

分布式原理。分布式集群 由多个服务器进程、客户端进程组成。部署方式,单机多卡、分布式(多机多卡)。多机多卡TensorFlow分布式。

单机多卡,单台服务器多块GPU。训练过程:在单机单GPU训练,数据一个批次(batch)一个批次训练。单机多GPU,一次处理多个批次数据,每个GPU处理一个批次数据计算。变量参数保存在CPU,数据由CPU分发给多个GPU,GPU计算每个批次更新梯度。CPU收集完多个GPU更新梯度,计算平均梯度,更新参数。继续计算更新梯度。处理速度取决最慢GPU速度。

分布式,训练在多个工作节点(worker)。工作节点,实现计算单元。计算服务器单卡,指服务器。计算服务器多卡,多个GPU划分多个工作节点。数据量大,超过一台机器处理能力,须用分布式。

分布式TensorFlow底层通信,gRPC(google remote procedure call)。gRPC,谷歌开源高性能、跨语言RPC框架。RPC协议,远程过程调用协议,网络从远程计算机程度请求服务。

分布式部署方式。分布式运行,多个计算单元(工作节点),后端服务器部署单工作节点、多工作节点。

单工作节点部署。每台服务器运行一个工作节点,服务器多个GPU,一个工作节点可以访问多块GPU卡。代码tf.device()指定运行操作设备。优势,单机多GPU间通信,效率高。劣势,手动代码指定设备。

多工作节点部署。一台服务器运行多个工作节点。

设置CUDA_VISIBLE_DEVICES环境变量,限制各个工作节点只可见一个GPU,启动进程添加环境变量。用tf.device()指定特定GPU。多工作节点部署优势,代码简单,提高GPU使用率。劣势,工作节点通信,需部署多个工作节点。 。

CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1

分布式架构。 。客户端(client)、服务端(server),服务端包括主节点(master)、工作节点(worker)组成。

客户端、主节点、工作节点关系。TensorFlow,客户端会话联系主节点,实际工作由工作节点实现,每个工作节点占一台设备(TensorFlow具体计算硬件抽象,CPU或GPU)。单机模式,客户端、主节点、工作节点在同一台服务器。分布模式,可不同服务器。客户端->主节点->工作节点/job:worker/task:0->/job:ps/task:0。 客户端。建立TensorFlow计算图,建立与集群交互会话层。代码包含Session()。一个客户端可同时与多个服务端相连,一具服务端也可与多个客户端相连。 服务端。运行tf.train.Server实例进程,TensroFlow执行任务集群(cluster)一部分。有主节点服务(Master service)和工作节点服务(Worker service)。运行中,一个主节点进程和数个工作节点进程,主节点进程和工作接点进程通过接口通信。单机多卡和分布式结构相同,只需要更改通信接口实现切换。 主节点服务。实现tensorflow::Session接口。通过RPC服务程序连接工作节点,与工作节点服务进程工作任务通信。TensorFlow服务端,task_index为0作业(job)。 工作节点服务。实现worker_service.proto接口,本地设备计算部分图。TensorFlow服务端,所有工作节点包含工作节点服务逻辑。每个工作节点负责管理一个或多个设备。工作节点可以是本地不同端口不同进程,或多台服务多个进程。运行TensorFlow分布式执行任务集,一个或多个作业(job)。每个作业,一个或多个相同目的任务(task)。每个任务,一个工作进程执行。作业是任务集合,集群是作业集合。 分布式机器学习框架,作业分参数作业(parameter job)和工作节点作业(worker job)。参数作业运行服务器为参数服务器(parameter server,PS),管理参数存储、更新。工作节点作业,管理无状态主要从事计算任务。模型越大,参数越多,模型参数更新超过一台机器性能,需要把参数分开到不同机器存储更新。参数服务,多台机器组成集群,类似分布式存储架构,涉及数据同步、一致性,参数存储为键值对(key-value)。分布式键值内存数据库,加参数更新操作。李沐《Parameter Server for Distributed Machine Learning》 。 参数存储更新在参数作业进行,模型计算在工作节点作业进行。TensorFlow分布式实现作业间数据传输,参数作业到工作节点作业前向传播,工作节点作业到参数作业反向传播。 任务。特定TensorFlow服务器独立进程,在作业中拥有对应序号。一个任务对应一个工作节点。集群->作业->任务->工作节点。

客户端、主节点、工作节点交互过程。单机多卡交互,客户端->会话运行->主节点->执行子图->工作节点->GPU0、GPU1。分布式交互,客户端->会话运行->主节点进程->执行子图1->工作节点进程1->GPU0、GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》 。

分布式模式。

数据并行。 。CPU负责梯度平均、参数更新,不同GPU训练模型副本(model replica)。基于训练样例子集训练,模型有独立性。 步骤:不同GPU分别定义模型网络结构。单个GPU从数据管道读取不同数据块,前向传播,计算损失,计算当前变量梯度。所有GPU输出梯度数据转移到CPU,梯度求平均操作,模型变量更新。重复,直到模型变量收敛。 数据并行,提高SGD效率。SGD mini-batch样本,切成多份,模型复制多份,在多个模型上同时计算。多个模型计算速度不一致,CPU更新变量有同步、异步两个方案。

同步更新、异步更新。分布式随机梯度下降法,模型参数分布式存储在不同参数服务上,工作节点并行训练数据,和参数服务器通信获取模型参数。 同步随机梯度下降法(Sync-SGD,同步更新、同步训练),训练时,每个节点上工作任务读入共享参数,执行并行梯度计算,同步需要等待所有工作节点把局部梯度处好,将所有共享参数合并、累加,再一次性更新到模型参数,下一批次,所有工作节点用模型更新后参数训练。优势,每个训练批次考虑所有工作节点训练情部,损失下降稳定。劣势,性能瓶颈在最慢工作节点。异楹设备,工作节点性能不同,劣势明显。 异步随机梯度下降法(Async-SGD,异步更新、异步训练),每个工作节点任务独立计算局部梯度,异步更新到模型参数,不需执行协调、等待操作。优势,性能不存在瓶颈。劣势,每个工作节点计算梯度值发磅回参数服务器有参数更新冲突,影响算法收剑速度,损失下降过程抖动较大。 同步更新、异步更新实现区别于更新参数服务器参数策略。数据量小,各节点计算能力较均衡,用同步模型。数据量大,各机器计算性能参差不齐,用异步模式。 带备份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz论文《Revisiting Distributed Synchronous SGD》 。增加工作节点,解决部分工作节点计算慢问题。工作节点总数n+n*5%,n为集群工作节点数。异步更新设定接受到n个工作节点参数直接更新参数服务器模型参数,进入下一批次模型训练。计算较慢节点训练参数直接丢弃。 同步更新、异步更新有图内模式(in-graph pattern)和图间模式(between-graph pattern),独立于图内(in-graph)、图间(between-graph)概念。 图内复制(in-grasph replication),所有操作(operation)在同一个图中,用一个客户端来生成图,把所有操作分配到集群所有参数服务器和工作节点上。国内复制和单机多卡类似,扩展到多机多卡,数据分发还是在客户端一个节点上。优势,计算节点只需要调用join()函数等待任务,客户端随时提交数据就可以训练。劣势,训练数据分发在一个节点上,要分发给不同工作节点,严重影响并发训练速度。 图间复制(between-graph replication),每一个工作节点创建一个图,训练参数保存在参数服务器,数据不分发,各个工作节点独立计算,计算完成把要更新参数告诉参数服务器,参数服务器更新参数。优势,不需要数据分发,各个工作节点都创建图和读取数据训练。劣势,工作节点既是图创建者又是计算任务执行者,某个工作节点宕机影响集群工作。大数据相关深度学习推荐使用图间模式。

模型并行。切分模型,模型不同部分执行在不同设备上,一个批次样本可以在不同设备同时执行。TensorFlow尽量让相邻计算在同一台设备上完成节省网络开销。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》 。

模型并行、数据并行,TensorFlow中,计算可以分离,参数可以分离。可以在每个设备上分配计算节点,让对应参数也在该设备上,计算参数放一起。

分布式API。 。 创建集群,每个任务(task)启动一个服务(工作节点服务或主节点服务)。任务可以分布不同机器,可以同一台机器启动多个任务,用不同GPU运行。每个任务完成工作:创建一个tf.train.ClusterSpec,对集群所有任务进行描述,描述内容对所有任务相同。创建一个tf.train.Server,创建一个服务,运行相应作业计算任务。 TensorFlow分布式开发API。tf.train.ClusterSpec({"ps":ps_hosts,"worker":worke_hosts})。创建TensorFlow集群描述信息,ps、worker为作业名称,ps_phsts、worker_hosts为作业任务所在节点地址信息。tf.train.ClusterSpec传入参数,作业和任务间关系映射,映射关系任务通过IP地址、端口号表示。

结构 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})可用任务 /job:local/task:0、/job:local/task:1。结构 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})可用任务 /job:worker/task:0、 /job:worker/task:1、 /job:worker/task:2、 /job:ps/task:0、 /job:ps/task:1

tf.train.Server(cluster,job_name,task_index)。创建服务(主节点服务或工作节点服务),运行作业计算任务,运行任务在task_index指定机器启动。

#任务0 cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})server  = tr.train.Server(cluster,job_name="local",task_index=0) #任务1 cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})server  = tr.train.Server(cluster,job_name="local",task_index=1)。

自动化管理节点、监控节点工具。集群管理工具Kubernetes。 tf.device(device_name_or_function)。设定指定设备执行张量运算,批定代码运行CPU、GPU。

#指定在task0所在机器执行Tensor操作运算 with tf.device("/job:ps/task:0"):  weights_1 = tf.Variable(…)  biases_1 = tf.Variable(…)

分布式训练代码框架。创建TensorFlow服务器集群,在该集群分布式计算数据流图。 。

import argparseimport sysimport tensorflow as tfFLAGS = Nonedef main(_):  # 第1步:命令行参数解析,获取集群信息ps_hosts、worker_hosts  # 当前节点角色信息job_name、task_index  ps_hosts = FLAGS.ps_hosts.split(",")  worker_hosts = FLAGS.worker_hosts.split(",")  # 第2步:创建当前任务节点服务器  # Create a cluster from the parameter server and worker hosts.  cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})  # Create and start a server for the local task.  server = tf.train.Server(cluster,                           job_name=FLAGS.job_name,                           task_index=FLAGS.task_index)  # 第3步:如果当前节点是参数服务器,调用server.join()无休止等待;如果是工作节点,执行第4步  if FLAGS.job_name == "ps":    server.join()  # 第4步:构建要训练模型,构建计算图  elif FLAGS.job_name == "worker":    # Assigns ops to the local worker by default.    with tf.device(tf.train.replica_device_setter(        worker_device="/job:worker/task:%d" % FLAGS.task_index,        cluster=cluster)):      # Build model...      loss = ...      global_step = tf.contrib.framework.get_or_create_global_step()      train_op = tf.train.AdagradOptimizer(0.01).minimize(          loss, global_step=global_step)    # The StopAtStepHook handles stopping after running given steps.    # 第5步管理模型训练过程    hooks=[tf.train.StopAtStepHook(last_step=1000000)]    # The MonitoredTrainingSession takes care of session initialization,    # restoring from a checkpoint, saving to a checkpoint, and closing when done    # or an error occurs.    with tf.train.MonitoredTrainingSession(master=server.target,                                           is_chief=(FLAGS.task_index == 0),                                           checkpoint_dir="/tmp/train_logs",                                           hooks=hooks) as mon_sess:      while not mon_sess.should_stop():        # Run a training step asynchronously.        # See `tf.train.SyncReplicasOptimizer` for additional details on how to        # perform *synchronous* training.        # mon_sess.run handles AbortedError in case of preempted PS.        # 训练模型        mon_sess.run(train_op)if __name__ == "__main__":  parser = argparse.ArgumentParser()  parser.register("type", "bool", lambda v: v.lower() == "true")  # Flags for defining the tf.train.ClusterSpec  parser.add_argument(      "--ps_hosts",      type=str,      default="",      help="Comma-separated list of hostname:port pairs"  )  parser.add_argument(      "--worker_hosts",      type=str,      default="",      help="Comma-separated list of hostname:port pairs"  )  parser.add_argument(      "--job_name",      type=str,      default="",      help="One of 'ps', 'worker'"  )  # Flags for defining the tf.train.Server  parser.add_argument(      "--task_index",      type=int,      default=0,      help="Index of task within the job"  )  FLAGS, unparsed = parser.parse_known_args()  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

分布式最佳实践。 。 MNIST数据集分布式训练。开设3个端口作分布式工作节点部署,2222端口参数服务器,2223端口工作节点0,2224端口工作节点1。参数服务器执行参数更新任务,工作节点0、工作节点1执行图模型训练计算任务。参数服务器/job:ps/task:0 cocalhost:2222,工作节点/job:worker/task:0 cocalhost:2223,工作节点/job:worker/task:1 cocalhost:2224。 运行代码。

python mnist_replica.py --job_name="ps" --task_index=0python mnist_replica.py --job_name="worker" --task_index=0python mnist_replica.py --job_name="worker" --task_index=1from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport mathimport sysimport tempfileimport timeimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data# 定义常量,用于创建数据流图flags = tf.app.flagsflags.DEFINE_string("data_dir", "/tmp/mnist-data",                    "Directory for storing mnist data")# 只下载数据,不做其他操作flags.DEFINE_boolean("download_only", False,                     "Only perform downloading of data; Do not proceed to "                     "session preparation, model definition or training")# task_index从0开始。0代表用来初始化变量的第一个任务flags.DEFINE_integer("task_index", None,                     "Worker task index, should be >= 0. task_index=0 is "                     "the master worker task the performs the variable "                     "initialization ")# 每台机器GPU个数,机器没有GPU为0flags.DEFINE_integer("num_gpus", 1,                     "Total number of gpus for each machine."                     "If you don't use GPU, please set it to '0'")# 同步训练模型下,设置收集工作节点数量。默认工作节点总数flags.DEFINE_integer("replicas_to_aggregate", None,                     "Number of replicas to aggregate before parameter update"                     "is applied (For sync_replicas mode only; default: "                     "num_workers)")flags.DEFINE_integer("hidden_units", 100,                     "Number of units in the hidden layer of the NN")# 训练次数flags.DEFINE_integer("train_steps", 200,                     "Number of (global) training steps to perform")flags.DEFINE_integer("batch_size", 100, "Training batch size")flags.DEFINE_float("learning_rate", 0.01, "Learning rate")# 使用同步训练、异步训练flags.DEFINE_boolean("sync_replicas", False,                     "Use the sync_replicas (synchronized replicas) mode, "                     "wherein the parameter updates from workers are aggregated "                     "before applied to avoid stale gradients")# 如果服务器已经存在,采用gRPC协议通信;如果不存在,采用进程间通信flags.DEFINE_boolean(    "existing_servers", False, "Whether servers already exists. If True, "    "will use the worker hosts via their GRPC URLs (one client process "    "per worker host). Otherwise, will create an in-process TensorFlow "    "server.")# 参数服务器主机flags.DEFINE_string("ps_hosts","localhost:2222",                    "Comma-separated list of hostname:port pairs")# 工作节点主机flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",                    "Comma-separated list of hostname:port pairs")# 本作业是工作节点还是参数服务器flags.DEFINE_string("job_name", None,"job name: worker or ps")FLAGS = flags.FLAGSIMAGE_PIXELS = 28def main(unused_argv):  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)  if FLAGS.download_only:    sys.exit(0)  if FLAGS.job_name is None or FLAGS.job_name == "":    raise ValueError("Must specify an explicit `job_name`")  if FLAGS.task_index is None or FLAGS.task_index =="":    raise ValueError("Must specify an explicit `task_index`")  print("job name = %s" % FLAGS.job_name)  print("task index = %d" % FLAGS.task_index)  #Construct the cluster and start the server  # 读取集群描述信息  ps_spec = FLAGS.ps_hosts.split(",")  worker_spec = FLAGS.worker_hosts.split(",")  # Get the number of workers.  num_workers = len(worker_spec)  # 创建TensorFlow集群描述对象  cluster = tf.train.ClusterSpec({      "ps": ps_spec,      "worker": worker_spec})  # 为本地执行任务创建TensorFlow Server对象。  if not FLAGS.existing_servers:    # Not using existing servers. Create an in-process server.    # 创建本地Sever对象,从tf.train.Server这个定义开始,每个节点开始不同    # 根据执行的命令的参数(作业名字)不同,决定这个任务是哪个任务    # 如果作业名字是ps,进程就加入这里,作为参数更新的服务,等待其他工作节点给它提交参数更新的数据    # 如果作业名字是worker,就执行后面的计算任务    server = tf.train.Server(        cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)    # 如果是参数服务器,直接启动即可。这里,进程就会阻塞在这里    # 下面的tf.train.replica_device_setter代码会将参数批定给ps_server保管    if FLAGS.job_name == "ps":      server.join()  # 处理工作节点  # 找出worker的主节点,即task_index为0的点  is_chief = (FLAGS.task_index == 0)  # 如果使用gpu  if FLAGS.num_gpus > 0:    # Avoid gpu allocation conflict: now allocate task_num -> #gpu    # for each worker in the corresponding machine    gpu = (FLAGS.task_index % FLAGS.num_gpus)    # 分配worker到指定gpu上运行    worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)  # 如果使用cpu  elif FLAGS.num_gpus == 0:    # Just allocate the CPU to worker server    # 把cpu分配给worker    cpu = 0    worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)  # The device setter will automatically place Variables ops on separate  # parameter servers (ps). The non-Variable ops will be placed on the workers.  # The ps use CPU and workers use corresponding GPU  # 用tf.train.replica_device_setter将涉及变量操作分配到参数服务器上,使用CPU。将涉及非变量操作分配到工作节点上,使用上一步worker_device值。  # 在这个with语句之下定义的参数,会自动分配到参数服务器上去定义。如果有多个参数服务器,就轮流循环分配  with tf.device(      tf.train.replica_device_setter(          worker_device=worker_device,          ps_device="/job:ps/cpu:0",          cluster=cluster)):    # 定义全局步长,默认值为0    global_step = tf.Variable(0, name="global_step", trainable=False)    # Variables of the hidden layer    # 定义隐藏层参数变量,这里是全连接神经网络隐藏层    hid_w = tf.Variable(        tf.truncated_normal(            [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],            stddev=1.0 / IMAGE_PIXELS),        name="hid_w")    hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")    # Variables of the softmax layer    # 定义Softmax 回归层参数变量    sm_w = tf.Variable(        tf.truncated_normal(            [FLAGS.hidden_units, 10],            stddev=1.0 / math.sqrt(FLAGS.hidden_units)),        name="sm_w")    sm_b = tf.Variable(tf.zeros([10]), name="sm_b")    # Ops: located on the worker specified with FLAGS.task_index    # 定义模型输入数据变量    x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])    y_ = tf.placeholder(tf.float32, [None, 10])    # 构建隐藏层    hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)    hid = tf.nn.relu(hid_lin)    # 构建损失函数和优化器    y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))    cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))    # 异步训练模式:自己计算完成梯度就去更新参数,不同副本之间不会去协调进度    opt = tf.train.AdamOptimizer(FLAGS.learning_rate)    # 同步训练模式    if FLAGS.sync_replicas:      if FLAGS.replicas_to_aggregate is None:        replicas_to_aggregate = num_workers      else:        replicas_to_aggregate = FLAGS.replicas_to_aggregate      # 使用SyncReplicasOptimizer作优化器,并且是在图间复制情况下      # 在图内复制情况下将所有梯度平均      opt = tf.train.SyncReplicasOptimizer(          opt,          replicas_to_aggregate=replicas_to_aggregate,          total_num_replicas=num_workers,          name="mnist_sync_replicas")    train_step = opt.minimize(cross_entropy, global_step=global_step)    if FLAGS.sync_replicas:      local_init_op = opt.local_step_init_op      if is_chief:        # 所有进行计算工作节点里一个主工作节点(chief)        # 主节点负责初始化参数、模型保存、概要保存        local_init_op = opt.chief_init_op      ready_for_local_init_op = opt.ready_for_local_init_op      # Initial token and chief queue runners required by the sync_replicas mode      # 同步训练模式所需初始令牌、主队列      chief_queue_runner = opt.get_chief_queue_runner()      sync_init_op = opt.get_init_tokens_op()    init_op = tf.global_variables_initializer()    train_dir = tempfile.mkdtemp()    if FLAGS.sync_replicas:      # 创建一个监管程序,用于统计训练模型过程中的信息      # lodger 是保存和加载模型路径      # 启动就会去这个logdir目录看是否有检查点文件,有的话就自动加载      # 没有就用init_op指定初始化参数      # 主工作节点(chief)负责模型参数初始化工作      # 过程中,其他工作节点等待主节眯完成初始化工作,初始化完成后,一起开始训练数据      # global_step值是所有计算节点共享的      # 在执行损失函数最小值时自动加1,通过global_step知道所有计算节点一共计算多少步      sv = tf.train.Supervisor(          is_chief=is_chief,          logdir=train_dir,          init_op=init_op,          local_init_op=local_init_op,          ready_for_local_init_op=ready_for_local_init_op,          recovery_wait_secs=1,          global_step=global_step)    else:      sv = tf.train.Supervisor(          is_chief=is_chief,          logdir=train_dir,          init_op=init_op,          recovery_wait_secs=1,          global_step=global_step)    # 创建会话,设置属性allow_soft_placement为True    # 所有操作默认使用被指定设置,如GPU    # 如果该操作函数没有GPU实现,自动使用CPU设备    sess_config = tf.ConfigProto(        allow_soft_placement=True,        log_device_placement=False,        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])    # The chief worker (task_index==0) session will prepare the session,    # while the remaining workers will wait for the preparation to complete.    # 主工作节点(chief),task_index为0节点初始化会话    # 其余工作节点等待会话被初始化后进行计算    if is_chief:      print("Worker %d: Initializing session..." % FLAGS.task_index)    else:      print("Worker %d: Waiting for session to be initialized..." %            FLAGS.task_index)    if FLAGS.existing_servers:      server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]      print("Using existing server at: %s" % server_grpc_url)      # 创建TensorFlow会话对象,用于执行TensorFlow图计算      # prepare_or_wait_for_session需要参数初始化完成且主节点准备好后,才开始训练      sess = sv.prepare_or_wait_for_session(server_grpc_url,                                            config=sess_config)    else:      sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)    print("Worker %d: Session initialization complete." % FLAGS.task_index)    if FLAGS.sync_replicas and is_chief:      # Chief worker will start the chief queue runner and call the init op.      sess.run(sync_init_op)      sv.start_queue_runners(sess, [chief_queue_runner])    # Perform training    # 执行分布式模型训练    time_begin = time.time()    print("Training begins @ %f" % time_begin)    local_step = 0    while True:      # Training feed      # 读入MNIST训练数据,默认每批次100张图片      batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)      train_feed = {x: batch_xs, y_: batch_ys}      _, step = sess.run([train_step, global_step], feed_dict=train_feed)      local_step += 1      now = time.time()      print("%f: Worker %d: training step %d done (global step: %d)" %            (now, FLAGS.task_index, local_step, step))      if step >= FLAGS.train_steps:        break    time_end = time.time()    print("Training ends @ %f" % time_end)    training_time = time_end - time_begin    print("Training elapsed time: %f s" % training_time)    # Validation feed    # 读入MNIST验证数据,计算验证的交叉熵    val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}    val_xent = sess.run(cross_entropy, feed_dict=val_feed)    print("After %d training step(s), validation cross entropy = %g" %          (FLAGS.train_steps, val_xent))if __name__ == "__main__":  tf.app.run()

参考资料: 《TensorFlow技术解析与实战》

欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi

转载于:https://my.oschina.net/u/3482787/blog/1570956

你可能感兴趣的文章
如何成为一个成功的软件工程师
查看>>
基于 Arduino 开发板,这款插座是可编程且开源的
查看>>
CM 之父被踢出局:与乔布斯经历了同样的悲惨境遇
查看>>
《淘宝店铺 大数据营销+SEO+爆款打造 一册通》一一2.7 关注单品分析,打造店铺爆款...
查看>>
《I'm a Mac:雄狮训练手册》——1.11 自动开机
查看>>
《机器学习系统设计:Python语言实现》一1.3 总结
查看>>
(九)万事俱备
查看>>
在 Linux 上配置一个 syslog 服务器
查看>>
如何利用“图计算”实现大规模实时预测分析
查看>>
一次Eclipse插件修改经历
查看>>
《Redis实战》一2.3 网页缓存
查看>>
《JavaScript核心概念及实践》——2.2 变量
查看>>
《JavaScript开发框架权威指南》——2.3 将Grunt添加到项目中
查看>>
《统计会犯错——如何避免数据分析中的统计陷阱》—第2章膨胀的真理
查看>>
15 个‘ls’命令的面试问题(一)
查看>>
暗渡陈仓:用低消耗设备进行破解和渗透测试3.4 创建一个microSD卡
查看>>
如何使用LibreOffice把DOCX,DOC,RTF,ODT转换成PDF
查看>>
《Python高手之路》——2.3 外部库
查看>>
C#实现栈和队列
查看>>
揭秘Oracle数据库truncate原理
查看>>