PyBrain库的example之NFQ流程图分析

May 27, 2016·
夏 伟
夏 伟

有关PyBrain 库中NFQ算法的流程图分析,包括数据处理和策略的优化pipeline.

PyBrain库的example之NFQ流程图分析

如下是测试程序。主要分析doEpisode和learn两个函数。

__author__ = 'Thomas Rueckstiess, ruecksti@in.tum.de'

from pybrain.rl.environments.cartpole import CartPoleEnvironment, DiscreteBalanceTask, CartPoleRenderer
from pybrain.rl.agents import LearningAgent
from pybrain.rl.experiments import EpisodicExperiment
from pybrain.rl.learners.valuebased import NFQ, ActionValueNetwork
#,ActionValueLSTMNetwork
from pybrain.rl.explorers import BoltzmannExplorer

from numpy import array, arange, meshgrid, pi, zeros, mean
from matplotlib import pyplot as plt

# switch this to True if you want to see the cart balancing the pole (slower)
render = False  #True #

plt.ion()

env = CartPoleEnvironment()
if render:
    renderer = CartPoleRenderer()
    env.setRenderer(renderer)
    renderer.start()


# balancetask. py inside only used 2 sensors, so here can not use(4,3), just use (2,3)
# there is a debug in vesion 0.30, now, new version 0.33 had correct it!!
module = ActionValueNetwork(4,3)  #(4,3) #  0.33 had correct it
#module = ActionValueLSTMNetwork(2,3)

task = DiscreteBalanceTask(env, 100)
learner = NFQ()
learner.explorer.epsilon = 0.4

agent = LearningAgent(module, learner)
testagent = LearningAgent(module, None)
experiment = EpisodicExperiment(task, agent)


def plotPerformance(values, fig):
    plt.figure(fig.number)
    plt.clf()
    plt.plot(values, 'o-')
    plt.gcf().canvas.draw()


performance = []

if not render:
    pf_fig = plt.figure()

#while (True):
for _ in xrange(60): #60
    # one learning step after one episode of world-interaction!!!
    experiment.doEpisodes(1)
    agent.learn(2)  # 5

    # test performance (these real-world experiences are not used for training)
    if render:
        env.delay = True
    experiment.agent = testagent
    #r = mean([sum(x) for x in experiment.doEpisodes(5)])
    env.delay = False
    testagent.reset()
    experiment.agent = agent

    #performance.append(r)
    print "update step", len(performance)

    #print "reward avg", r
    print "explorer epsilon", learner.explorer.epsilon
    print "num episodes", agent.history.getNumSequences()
    print "update step", len(performance)

if not render:
    plotPerformance(performance, pf_fig)

str = raw_input("please input sth to end!")
print "you put :",str

experiment.doEpisodes(1)

agent.learn(2)

图2的注释2部分,可以参考该博文深度强化学习初探 ,但是他文中的公式应该有点问题。应该把$Q_{m+1}$改为$Q_m$,进一步参考维基百科Q-learning ,如下所示。

$$ Q_{m+1}(s_t,a_t)=Q_m(s_t,a_t)+α[r_{t+1}+γQ_m(s_{t+1},a_{t+1})−Q_m(s_t,a_t)] $$

推荐所用的画图软件process on

  • 用起来挺方便的,在线用谷歌浏览器运行,用户体验挺佳,比visio2010快多了;
  • 可以多用户协作;
  • 目前有一个缺点就是一个框里面的字体格式必须是一样的,不可以修改一个框里面部分的文字的格式。有点类似PS的思想。