分布式TensorFlow由高性能gRPC库底层技术支持。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》。
分布式原理。分布式集群 由多个服务器进程、客户端进程组成。部署方式,单机多卡、分布式(多机多卡)。多机多卡TensorFlow分布式。
单机多卡,单台服务器多块GPU。训练过程:在单机单GPU训练,数据一个批次(batch)一个批次训练。单机多GPU,一次处理多个批次数据,每个GPU处理一个批次数据计算。变量参数保存在CPU,数据由CPU分发给多个GPU,GPU计算每个批次更新梯度。CPU收集完多个GPU更新梯度,计算平均梯度,更新参数。继续计算更新梯度。处理速度取决最慢GPU速度。
分布式,训练在多个工作节点(worker)。工作节点,实现计算单元。计算服务器单卡,指服务器。计算服务器多卡,多个GPU划分多个工作节点。数据量大,超过一台机器处理能力,须用分布式。
分布式TensorFlow底层通信,gRPC(google remote procedure call)。gRPC,谷歌开源高性能、跨语言RPC框架。RPC协议,远程过程调用协议,网络从远程计算机程度请求服务。
分布式部署方式。分布式运行,多个计算单元(工作节点),后端服务器部署单工作节点、多工作节点。
单工作节点部署。每台服务器运行一个工作节点,服务器多个GPU,一个工作节点可以访问多块GPU卡。代码tf.device()指定运行操作设备。优势,单机多GPU间通信,效率高。劣势,手动代码指定设备。
多工作节点部署。一台服务器运行多个工作节点。
设置CUDA_VISIBLE_DEVICES环境变量,限制各个工作节点只可见一个GPU,启动进程添加环境变量。用tf.device()指定特定GPU。多工作节点部署优势,代码简单,提高GPU使用率。劣势,工作节点通信,需部署多个工作节点。 。
CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=0CUDA_VISIBLE_DEVICES='' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=ps --task_index=1CUDA_VISIBLE_DEVICES='0' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=0CUDA_VISIBLE_DEVICES='1' python ./distributed_supervisor.py --ps_hosts=127.0.0.1:2222,127.0.0.1:2223 --worker_hosts=127.0.0.1:2224,127.0.0.1:2225 --job_name=worker --task_index=1
分布式架构。 。客户端(client)、服务端(server),服务端包括主节点(master)、工作节点(worker)组成。
客户端、主节点、工作节点关系。TensorFlow,客户端会话联系主节点,实际工作由工作节点实现,每个工作节点占一台设备(TensorFlow具体计算硬件抽象,CPU或GPU)。单机模式,客户端、主节点、工作节点在同一台服务器。分布模式,可不同服务器。客户端->主节点->工作节点/job:worker/task:0->/job:ps/task:0。 客户端。建立TensorFlow计算图,建立与集群交互会话层。代码包含Session()。一个客户端可同时与多个服务端相连,一具服务端也可与多个客户端相连。 服务端。运行tf.train.Server实例进程,TensroFlow执行任务集群(cluster)一部分。有主节点服务(Master service)和工作节点服务(Worker service)。运行中,一个主节点进程和数个工作节点进程,主节点进程和工作接点进程通过接口通信。单机多卡和分布式结构相同,只需要更改通信接口实现切换。 主节点服务。实现tensorflow::Session接口。通过RPC服务程序连接工作节点,与工作节点服务进程工作任务通信。TensorFlow服务端,task_index为0作业(job)。 工作节点服务。实现worker_service.proto接口,本地设备计算部分图。TensorFlow服务端,所有工作节点包含工作节点服务逻辑。每个工作节点负责管理一个或多个设备。工作节点可以是本地不同端口不同进程,或多台服务多个进程。运行TensorFlow分布式执行任务集,一个或多个作业(job)。每个作业,一个或多个相同目的任务(task)。每个任务,一个工作进程执行。作业是任务集合,集群是作业集合。 分布式机器学习框架,作业分参数作业(parameter job)和工作节点作业(worker job)。参数作业运行服务器为参数服务器(parameter server,PS),管理参数存储、更新。工作节点作业,管理无状态主要从事计算任务。模型越大,参数越多,模型参数更新超过一台机器性能,需要把参数分开到不同机器存储更新。参数服务,多台机器组成集群,类似分布式存储架构,涉及数据同步、一致性,参数存储为键值对(key-value)。分布式键值内存数据库,加参数更新操作。李沐《Parameter Server for Distributed Machine Learning》 。 参数存储更新在参数作业进行,模型计算在工作节点作业进行。TensorFlow分布式实现作业间数据传输,参数作业到工作节点作业前向传播,工作节点作业到参数作业反向传播。 任务。特定TensorFlow服务器独立进程,在作业中拥有对应序号。一个任务对应一个工作节点。集群->作业->任务->工作节点。
客户端、主节点、工作节点交互过程。单机多卡交互,客户端->会话运行->主节点->执行子图->工作节点->GPU0、GPU1。分布式交互,客户端->会话运行->主节点进程->执行子图1->工作节点进程1->GPU0、GPU1。《TensorFlow:Large-Scale Machine Learning on Heterogeneous distributed Systems》 。
分布式模式。
数据并行。 。CPU负责梯度平均、参数更新,不同GPU训练模型副本(model replica)。基于训练样例子集训练,模型有独立性。 步骤:不同GPU分别定义模型网络结构。单个GPU从数据管道读取不同数据块,前向传播,计算损失,计算当前变量梯度。所有GPU输出梯度数据转移到CPU,梯度求平均操作,模型变量更新。重复,直到模型变量收敛。 数据并行,提高SGD效率。SGD mini-batch样本,切成多份,模型复制多份,在多个模型上同时计算。多个模型计算速度不一致,CPU更新变量有同步、异步两个方案。
同步更新、异步更新。分布式随机梯度下降法,模型参数分布式存储在不同参数服务上,工作节点并行训练数据,和参数服务器通信获取模型参数。 同步随机梯度下降法(Sync-SGD,同步更新、同步训练),训练时,每个节点上工作任务读入共享参数,执行并行梯度计算,同步需要等待所有工作节点把局部梯度处好,将所有共享参数合并、累加,再一次性更新到模型参数,下一批次,所有工作节点用模型更新后参数训练。优势,每个训练批次考虑所有工作节点训练情部,损失下降稳定。劣势,性能瓶颈在最慢工作节点。异楹设备,工作节点性能不同,劣势明显。 异步随机梯度下降法(Async-SGD,异步更新、异步训练),每个工作节点任务独立计算局部梯度,异步更新到模型参数,不需执行协调、等待操作。优势,性能不存在瓶颈。劣势,每个工作节点计算梯度值发磅回参数服务器有参数更新冲突,影响算法收剑速度,损失下降过程抖动较大。 同步更新、异步更新实现区别于更新参数服务器参数策略。数据量小,各节点计算能力较均衡,用同步模型。数据量大,各机器计算性能参差不齐,用异步模式。 带备份的Sync-SGD(Sync-SDG with backup)。Jianmin Chen、Xinghao Pan、Rajat Monga、Aamy Bengio、Rafal Jozefowicz论文《Revisiting Distributed Synchronous SGD》 。增加工作节点,解决部分工作节点计算慢问题。工作节点总数n+n*5%,n为集群工作节点数。异步更新设定接受到n个工作节点参数直接更新参数服务器模型参数,进入下一批次模型训练。计算较慢节点训练参数直接丢弃。 同步更新、异步更新有图内模式(in-graph pattern)和图间模式(between-graph pattern),独立于图内(in-graph)、图间(between-graph)概念。 图内复制(in-grasph replication),所有操作(operation)在同一个图中,用一个客户端来生成图,把所有操作分配到集群所有参数服务器和工作节点上。国内复制和单机多卡类似,扩展到多机多卡,数据分发还是在客户端一个节点上。优势,计算节点只需要调用join()函数等待任务,客户端随时提交数据就可以训练。劣势,训练数据分发在一个节点上,要分发给不同工作节点,严重影响并发训练速度。 图间复制(between-graph replication),每一个工作节点创建一个图,训练参数保存在参数服务器,数据不分发,各个工作节点独立计算,计算完成把要更新参数告诉参数服务器,参数服务器更新参数。优势,不需要数据分发,各个工作节点都创建图和读取数据训练。劣势,工作节点既是图创建者又是计算任务执行者,某个工作节点宕机影响集群工作。大数据相关深度学习推荐使用图间模式。
模型并行。切分模型,模型不同部分执行在不同设备上,一个批次样本可以在不同设备同时执行。TensorFlow尽量让相邻计算在同一台设备上完成节省网络开销。Martin Abadi、Ashish Agarwal、Paul Barham论文《TensorFlow:Large-Scale Machine Learning on Heterogeneous Distributed Systems》 。
模型并行、数据并行,TensorFlow中,计算可以分离,参数可以分离。可以在每个设备上分配计算节点,让对应参数也在该设备上,计算参数放一起。
分布式API。 。 创建集群,每个任务(task)启动一个服务(工作节点服务或主节点服务)。任务可以分布不同机器,可以同一台机器启动多个任务,用不同GPU运行。每个任务完成工作:创建一个tf.train.ClusterSpec,对集群所有任务进行描述,描述内容对所有任务相同。创建一个tf.train.Server,创建一个服务,运行相应作业计算任务。 TensorFlow分布式开发API。tf.train.ClusterSpec({"ps":ps_hosts,"worker":worke_hosts})。创建TensorFlow集群描述信息,ps、worker为作业名称,ps_phsts、worker_hosts为作业任务所在节点地址信息。tf.train.ClusterSpec传入参数,作业和任务间关系映射,映射关系任务通过IP地址、端口号表示。
结构 tf.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})可用任务 /job:local/task:0、/job:local/task:1。结构 tf.train.ClusterSpec({"worker":["worker0.example.com:2222","worker1.example.com:2222","worker2.example.com:2222"],"ps":["ps0.example.com:2222","ps1.example.com:2222"]})可用任务 /job:worker/task:0、 /job:worker/task:1、 /job:worker/task:2、 /job:ps/task:0、 /job:ps/task:1
tf.train.Server(cluster,job_name,task_index)。创建服务(主节点服务或工作节点服务),运行作业计算任务,运行任务在task_index指定机器启动。
#任务0 cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})server = tr.train.Server(cluster,job_name="local",task_index=0) #任务1 cluster = tr.train.ClusterSpec({"local":["localhost:2222","localhost:2223"]})server = tr.train.Server(cluster,job_name="local",task_index=1)。
自动化管理节点、监控节点工具。集群管理工具Kubernetes。 tf.device(device_name_or_function)。设定指定设备执行张量运算,批定代码运行CPU、GPU。
#指定在task0所在机器执行Tensor操作运算 with tf.device("/job:ps/task:0"): weights_1 = tf.Variable(…) biases_1 = tf.Variable(…)
分布式训练代码框架。创建TensorFlow服务器集群,在该集群分布式计算数据流图。 。
import argparseimport sysimport tensorflow as tfFLAGS = Nonedef main(_): # 第1步:命令行参数解析,获取集群信息ps_hosts、worker_hosts # 当前节点角色信息job_name、task_index ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") # 第2步:创建当前任务节点服务器 # Create a cluster from the parameter server and worker hosts. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) # Create and start a server for the local task. server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 第3步:如果当前节点是参数服务器,调用server.join()无休止等待;如果是工作节点,执行第4步 if FLAGS.job_name == "ps": server.join() # 第4步:构建要训练模型,构建计算图 elif FLAGS.job_name == "worker": # Assigns ops to the local worker by default. with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster)): # Build model... loss = ... global_step = tf.contrib.framework.get_or_create_global_step() train_op = tf.train.AdagradOptimizer(0.01).minimize( loss, global_step=global_step) # The StopAtStepHook handles stopping after running given steps. # 第5步管理模型训练过程 hooks=[tf.train.StopAtStepHook(last_step=1000000)] # The MonitoredTrainingSession takes care of session initialization, # restoring from a checkpoint, saving to a checkpoint, and closing when done # or an error occurs. with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(FLAGS.task_index == 0), checkpoint_dir="/tmp/train_logs", hooks=hooks) as mon_sess: while not mon_sess.should_stop(): # Run a training step asynchronously. # See `tf.train.SyncReplicasOptimizer` for additional details on how to # perform *synchronous* training. # mon_sess.run handles AbortedError in case of preempted PS. # 训练模型 mon_sess.run(train_op)if __name__ == "__main__": parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") # Flags for defining the tf.train.ClusterSpec parser.add_argument( "--ps_hosts", type=str, default="", help="Comma-separated list of hostname:port pairs" ) parser.add_argument( "--worker_hosts", type=str, default="", help="Comma-separated list of hostname:port pairs" ) parser.add_argument( "--job_name", type=str, default="", help="One of 'ps', 'worker'" ) # Flags for defining the tf.train.Server parser.add_argument( "--task_index", type=int, default=0, help="Index of task within the job" ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
分布式最佳实践。 。 MNIST数据集分布式训练。开设3个端口作分布式工作节点部署,2222端口参数服务器,2223端口工作节点0,2224端口工作节点1。参数服务器执行参数更新任务,工作节点0、工作节点1执行图模型训练计算任务。参数服务器/job:ps/task:0 cocalhost:2222,工作节点/job:worker/task:0 cocalhost:2223,工作节点/job:worker/task:1 cocalhost:2224。 运行代码。
python mnist_replica.py --job_name="ps" --task_index=0python mnist_replica.py --job_name="worker" --task_index=0python mnist_replica.py --job_name="worker" --task_index=1from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport mathimport sysimport tempfileimport timeimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data# 定义常量,用于创建数据流图flags = tf.app.flagsflags.DEFINE_string("data_dir", "/tmp/mnist-data", "Directory for storing mnist data")# 只下载数据,不做其他操作flags.DEFINE_boolean("download_only", False, "Only perform downloading of data; Do not proceed to " "session preparation, model definition or training")# task_index从0开始。0代表用来初始化变量的第一个任务flags.DEFINE_integer("task_index", None, "Worker task index, should be >= 0. task_index=0 is " "the master worker task the performs the variable " "initialization ")# 每台机器GPU个数,机器没有GPU为0flags.DEFINE_integer("num_gpus", 1, "Total number of gpus for each machine." "If you don't use GPU, please set it to '0'")# 同步训练模型下,设置收集工作节点数量。默认工作节点总数flags.DEFINE_integer("replicas_to_aggregate", None, "Number of replicas to aggregate before parameter update" "is applied (For sync_replicas mode only; default: " "num_workers)")flags.DEFINE_integer("hidden_units", 100, "Number of units in the hidden layer of the NN")# 训练次数flags.DEFINE_integer("train_steps", 200, "Number of (global) training steps to perform")flags.DEFINE_integer("batch_size", 100, "Training batch size")flags.DEFINE_float("learning_rate", 0.01, "Learning rate")# 使用同步训练、异步训练flags.DEFINE_boolean("sync_replicas", False, "Use the sync_replicas (synchronized replicas) mode, " "wherein the parameter updates from workers are aggregated " "before applied to avoid stale gradients")# 如果服务器已经存在,采用gRPC协议通信;如果不存在,采用进程间通信flags.DEFINE_boolean( "existing_servers", False, "Whether servers already exists. If True, " "will use the worker hosts via their GRPC URLs (one client process " "per worker host). Otherwise, will create an in-process TensorFlow " "server.")# 参数服务器主机flags.DEFINE_string("ps_hosts","localhost:2222", "Comma-separated list of hostname:port pairs")# 工作节点主机flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "Comma-separated list of hostname:port pairs")# 本作业是工作节点还是参数服务器flags.DEFINE_string("job_name", None,"job name: worker or ps")FLAGS = flags.FLAGSIMAGE_PIXELS = 28def main(unused_argv): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) if FLAGS.download_only: sys.exit(0) if FLAGS.job_name is None or FLAGS.job_name == "": raise ValueError("Must specify an explicit `job_name`") if FLAGS.task_index is None or FLAGS.task_index =="": raise ValueError("Must specify an explicit `task_index`") print("job name = %s" % FLAGS.job_name) print("task index = %d" % FLAGS.task_index) #Construct the cluster and start the server # 读取集群描述信息 ps_spec = FLAGS.ps_hosts.split(",") worker_spec = FLAGS.worker_hosts.split(",") # Get the number of workers. num_workers = len(worker_spec) # 创建TensorFlow集群描述对象 cluster = tf.train.ClusterSpec({ "ps": ps_spec, "worker": worker_spec}) # 为本地执行任务创建TensorFlow Server对象。 if not FLAGS.existing_servers: # Not using existing servers. Create an in-process server. # 创建本地Sever对象,从tf.train.Server这个定义开始,每个节点开始不同 # 根据执行的命令的参数(作业名字)不同,决定这个任务是哪个任务 # 如果作业名字是ps,进程就加入这里,作为参数更新的服务,等待其他工作节点给它提交参数更新的数据 # 如果作业名字是worker,就执行后面的计算任务 server = tf.train.Server( cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) # 如果是参数服务器,直接启动即可。这里,进程就会阻塞在这里 # 下面的tf.train.replica_device_setter代码会将参数批定给ps_server保管 if FLAGS.job_name == "ps": server.join() # 处理工作节点 # 找出worker的主节点,即task_index为0的点 is_chief = (FLAGS.task_index == 0) # 如果使用gpu if FLAGS.num_gpus > 0: # Avoid gpu allocation conflict: now allocate task_num -> #gpu # for each worker in the corresponding machine gpu = (FLAGS.task_index % FLAGS.num_gpus) # 分配worker到指定gpu上运行 worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu) # 如果使用cpu elif FLAGS.num_gpus == 0: # Just allocate the CPU to worker server # 把cpu分配给worker cpu = 0 worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu) # The device setter will automatically place Variables ops on separate # parameter servers (ps). The non-Variable ops will be placed on the workers. # The ps use CPU and workers use corresponding GPU # 用tf.train.replica_device_setter将涉及变量操作分配到参数服务器上,使用CPU。将涉及非变量操作分配到工作节点上,使用上一步worker_device值。 # 在这个with语句之下定义的参数,会自动分配到参数服务器上去定义。如果有多个参数服务器,就轮流循环分配 with tf.device( tf.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)): # 定义全局步长,默认值为0 global_step = tf.Variable(0, name="global_step", trainable=False) # Variables of the hidden layer # 定义隐藏层参数变量,这里是全连接神经网络隐藏层 hid_w = tf.Variable( tf.truncated_normal( [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b") # Variables of the softmax layer # 定义Softmax 回归层参数变量 sm_w = tf.Variable( tf.truncated_normal( [FLAGS.hidden_units, 10], stddev=1.0 / math.sqrt(FLAGS.hidden_units)), name="sm_w") sm_b = tf.Variable(tf.zeros([10]), name="sm_b") # Ops: located on the worker specified with FLAGS.task_index # 定义模型输入数据变量 x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) y_ = tf.placeholder(tf.float32, [None, 10]) # 构建隐藏层 hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) hid = tf.nn.relu(hid_lin) # 构建损失函数和优化器 y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) # 异步训练模式:自己计算完成梯度就去更新参数,不同副本之间不会去协调进度 opt = tf.train.AdamOptimizer(FLAGS.learning_rate) # 同步训练模式 if FLAGS.sync_replicas: if FLAGS.replicas_to_aggregate is None: replicas_to_aggregate = num_workers else: replicas_to_aggregate = FLAGS.replicas_to_aggregate # 使用SyncReplicasOptimizer作优化器,并且是在图间复制情况下 # 在图内复制情况下将所有梯度平均 opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=replicas_to_aggregate, total_num_replicas=num_workers, name="mnist_sync_replicas") train_step = opt.minimize(cross_entropy, global_step=global_step) if FLAGS.sync_replicas: local_init_op = opt.local_step_init_op if is_chief: # 所有进行计算工作节点里一个主工作节点(chief) # 主节点负责初始化参数、模型保存、概要保存 local_init_op = opt.chief_init_op ready_for_local_init_op = opt.ready_for_local_init_op # Initial token and chief queue runners required by the sync_replicas mode # 同步训练模式所需初始令牌、主队列 chief_queue_runner = opt.get_chief_queue_runner() sync_init_op = opt.get_init_tokens_op() init_op = tf.global_variables_initializer() train_dir = tempfile.mkdtemp() if FLAGS.sync_replicas: # 创建一个监管程序,用于统计训练模型过程中的信息 # lodger 是保存和加载模型路径 # 启动就会去这个logdir目录看是否有检查点文件,有的话就自动加载 # 没有就用init_op指定初始化参数 # 主工作节点(chief)负责模型参数初始化工作 # 过程中,其他工作节点等待主节眯完成初始化工作,初始化完成后,一起开始训练数据 # global_step值是所有计算节点共享的 # 在执行损失函数最小值时自动加1,通过global_step知道所有计算节点一共计算多少步 sv = tf.train.Supervisor( is_chief=is_chief, logdir=train_dir, init_op=init_op, local_init_op=local_init_op, ready_for_local_init_op=ready_for_local_init_op, recovery_wait_secs=1, global_step=global_step) else: sv = tf.train.Supervisor( is_chief=is_chief, logdir=train_dir, init_op=init_op, recovery_wait_secs=1, global_step=global_step) # 创建会话,设置属性allow_soft_placement为True # 所有操作默认使用被指定设置,如GPU # 如果该操作函数没有GPU实现,自动使用CPU设备 sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=False, device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index]) # The chief worker (task_index==0) session will prepare the session, # while the remaining workers will wait for the preparation to complete. # 主工作节点(chief),task_index为0节点初始化会话 # 其余工作节点等待会话被初始化后进行计算 if is_chief: print("Worker %d: Initializing session..." % FLAGS.task_index) else: print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index) if FLAGS.existing_servers: server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index] print("Using existing server at: %s" % server_grpc_url) # 创建TensorFlow会话对象,用于执行TensorFlow图计算 # prepare_or_wait_for_session需要参数初始化完成且主节点准备好后,才开始训练 sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config) else: sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) print("Worker %d: Session initialization complete." % FLAGS.task_index) if FLAGS.sync_replicas and is_chief: # Chief worker will start the chief queue runner and call the init op. sess.run(sync_init_op) sv.start_queue_runners(sess, [chief_queue_runner]) # Perform training # 执行分布式模型训练 time_begin = time.time() print("Training begins @ %f" % time_begin) local_step = 0 while True: # Training feed # 读入MNIST训练数据,默认每批次100张图片 batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) train_feed = {x: batch_xs, y_: batch_ys} _, step = sess.run([train_step, global_step], feed_dict=train_feed) local_step += 1 now = time.time() print("%f: Worker %d: training step %d done (global step: %d)" % (now, FLAGS.task_index, local_step, step)) if step >= FLAGS.train_steps: break time_end = time.time() print("Training ends @ %f" % time_end) training_time = time_end - time_begin print("Training elapsed time: %f s" % training_time) # Validation feed # 读入MNIST验证数据,计算验证的交叉熵 val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} val_xent = sess.run(cross_entropy, feed_dict=val_feed) print("After %d training step(s), validation cross entropy = %g" % (FLAGS.train_steps, val_xent))if __name__ == "__main__": tf.app.run()
参考资料: 《TensorFlow技术解析与实战》
欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi