0%

Python

1.4节中,我们介绍了几个数据库的安装方式,但这仅仅是用来存储数据的数据库,它们提供了存储服务,但如果想要和Python交互的话,还需要安装一些Python存储库,如MySQL需要安装PyMySQL,MongoDB需要安装PyMongo等。本节中,我们来说明一下这些存储库的安装方式。

Python

Redis是一个基于内存的高效的非关系型数据库,本节中我们来了解一下它在各个平台的安装过程。

1. 相关链接

  • 官方网站:https://redis.io
  • 官方文档:https://redis.io/documentation
  • 中文官网:http://www.redis.cn
  • GitHub:https://github.com/antirez/redis
  • 中文教程:http://www.runoob.com/redis/redis-tutorial.html
  • Redis Desktop Manager:https://redisdesktop.com
  • Redis Desktop Manager GitHub:https://github.com/uglide/RedisDesktopManager

2. Windows下的安装

在Windows下,Redis可以直接到GitHub的发行版本里面下载,具体下载地址是https://github.com/MSOpenTech/redis/releases

打开下载页面后,会发现有许多发行版本及其安装包,如图1-39所示。

图1-39 下载页面

可以下载Redis-x64-3.2.100.msi安装即可。

安装过程比较简单,直接点击Next按钮安装即可。安装完成后,Redis便会启动。

在系统服务页面里,可以观察到多了一个正在运行到Redis服务,如图1-40所示。

图1-40 系统服务页面

另外,推荐下载Redis Desktop Manager可视化管理工具,来管理Redis。这既可以到官方网站(链接为:https://redisdesktop.com/download)下载,也可以到GitHub(链接为:https://github.com/uglide/RedisDesktopManager/releases)下载最新发行版本。

安装后,直接连接本地Redis即可。

3. Linux下的安装

这里依然分为两类平台来介绍。

Ubuntu、Debian和Deepin

在Ubuntu、Debian和Deepin系统下,使用apt-get命令安装Redis:

1
sudo apt-get -y install redis-server

然后输入redis-cli进入Redis命令行模式:

1
2
3
4
5
$ redis-cli
127.0.0.1:6379> set 'name' 'Germey'
OK
127.0.0.1:6379> get 'name'
"Germey"

这样就证明Redis成功安装了,但是现在Redis还是无法远程连接的,依然需要修改配置文件,配置文件的路径为/etc/redis/redis.conf。

首先,注释这一行:

1
bind 127.0.0.1

另外,推荐给Redis设置密码,取消注释这一行:

1
requirepass foobared

foobared即当前密码,可以自行修改。

然后重启Redis服务,使用的命令如下:

1
sudo /etc/init.d/redis-server restart

现在就可以使用密码远程连接Redis了。

另外,停止和启动Redis服务的命令分别如下:

1
2
sudo /etc/init.d/redis-server stop
sudo /etc/init.d/redis-server start

CentOS和Red Hat

在CentOS和Red Hat系统中,首先添加EPEL仓库,然后更新yum源:

1
2
sudo yum install epel-release
sudo yum update

然后安装Redis数据库:

1
sudo yum -y install redis

安装好后启动Redis服务即可:

1
sudo systemctl start redis

这里同样可以使用redis-cli进入Redis命令行模式操作。

另外,为了可以使Redis能被远程连接,需要修改配置文件,路径为/etc/redis.conf。

参见上文来修改配置文件实现远程连接和密码配置。

修改完成之后保存。

然后重启Redis服务即可,命令如下:

1
sudo systemctl restart redis

4. Mac下的安装

这里推荐使用Homebrew安装,直接执行如下命令即可:

1
brew install redis

启动Redis服务的命令如下:

1
2
brew services start redis
redis-server /usr/local/etc/redis.conf

这里同样可以使用redis-cli进入Redis命令行模式。

在Mac下Redis的配置文件路径是/usr/local/etc/redis.conf,可以通过修改它来配置访问密码。

修改配置文件后,需要重启Redis服务。停止和重启Redis服务的命令分别如下:

1
2
brew services stop redis
brew services restart redis

另外,在Mac下也可以安装Redis Desktop Manager可视化管理工具来管理Redis。

Python

更新 2020/3/8

MongoDB 现在已经出到了 4.x 版本,下面的安装教程是基于 3.x 版本,可能已经过期。

关于 4.x 的安装教程,可以参考如下内容:

  • https://juejin.im/post/5d525b1af265da03b31bc2d5
  • https://www.cnblogs.com/TM0831/p/10606624.html
  • https://www.cnblogs.com/georgeleoo/p/11479409.html

以下为原文:

MongoDB是由C++语言编写的非关系型数据库,是一个基于分布式文件存储的开源数据库系统,其内容存储形式类似JSON对象,它的字段值可以包含其他文档、数组及文档数组,非常灵活。

MongoDB支持多种平台,包括Windows、Linux、Mac OS、Solaris等,在其官方网站(https://www.mongodb.com/download-center)均可找到对应的安装包。

本节中,我们来看下它的安装过程。

1. 相关链接

  • 官方网站:https://www.mongodb.com
  • 官方文档:https://docs.mongodb.com
  • GitHub:https://github.com/mongodb
  • 中文教程:http://www.runoob.com/mongodb/mongodb-tutorial.html

2. Windows下的安装

这里直接在官网(如图1-29所示)点击DOWNLOAD按钮下载msi安装包即可。

图1-29 MongoDB官网

下载完成后,双击它开始安装,指定MongoDB的安装路径,例如此处我指定的安装路径为C:\MongoDB\Server\3.4,如图1-30所示。当然,这里也可以自行选择路径。

图1-30 指定安装路径

点击Next按钮执行安装即可。

安装成功之后,进入MongoDB的安装目录,此处是C:\MongoDB\Server\3.4,在bin目录下新建同级目录data,如图1-31所示。

图1-31 新建data目录

然后进入data文件夹,新建子文件夹db来存储数据目录,如图1-32所示。

图1-32 新建db目录

之后打开命令行,进入MongoDB安装目录的bin目录下,运行MongoDB服务:

1
mongod --dbpath "C:\MongoDB\Server\3.4\data\db"

请记得将此处的路径替换成你的主机MongoDB安装路径。

运行之后,会出现一些输出信息,如图1-33所示。

图1-33 运行结果

这样我们就启动MongoDB服务了。

但是如果我们想一直使用MongoDB,就不能关闭此命令行了。如果意外关闭或重启,MongoDB服务就不能使用了。这显然不是我们想要的。所以,接下来还需将MongoDB配置成系统服务。

首先,以管理员模式运行命令行。注意,此处一定要以管理员身份运行,否则可能配置失败,如图1-34所示。

图1-34 以管理员身份运行

在“开始”菜单中搜索cmd,找到命令行,然后右击它以管理员身份运行即可。

随后新建一个日志文件,在bin目录新建logs同级目录,进入之后新建一个mongodb.log文件,用于保存MongoDB的运行日志,如图1-35所示。

图1-35 新建mongodb.log文件

在命令行下输入如下内容:

1
mongod --bind_ip 0.0.0.0 --logpath "C:\MongoDB\Server\3.4\logs\mongodb.log" --logappend --dbpath "C:\MongoDB\Server\3.4\data\db" --port 27017 --serviceName "MongoDB" --serviceDisplayName "MongoDB" --install

这里的意思是绑定IP为0.0.0.0(即任意IP均可访问),指定日志路径、数据库路径和端口,指定服务名称。需要注意的是,这里依然需要把路径替换成你的MongoDB安装路径,运行此命令后即可安装服务,运行结果如图1-36所示。图1-36 运行结果

如果没有出现错误提示,则证明MongoDB服务已经安装成功。

可以在服务管理页面查看到系统服务,如图1-37所示。

图1-37 系统服务页面

然后就可以设置它的开机启动方式了,如自动启动或手动启动等,这样我们就可以非常方便地管理MongoDB服务了。

启动服务后,在命令行下就可以利用mongo命令进入MongoDB命令交互环境了,如图1-38所示。

图1-38 命令行模式

这样,Windows下的MongoDB配置就完成了。

3. Linux下的安装

这里以MongoDB 3.4为例说明MongoDB的安装过程。

Ubuntu

首先,导入MongoDB的GPG key:

1
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 0C49F3730359A14518585931BC711F9BA15703C6

随后创建apt-get源列表,各个系统版本对应的命令分别如下。

  • Ubuntu 12.04对应的命令如下:

    1
    echo "deb [ arch=amd64 ] http://repo.mongodb.org/apt/ubuntu precise/mongodb-org/3.4 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.4.list
  • Ubuntu 14.04对应的命令如下:

    1
    echo "deb [ arch=amd64 ] http://repo.mongodb.org/apt/ubuntu trusty/mongodb-org/3.4 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.4.list
  • Ubuntu 16.04对应的命令如下:

    1
    echo "deb [ arch=amd64,arm64 ] http://repo.mongodb.org/apt/ubuntu xenial/mongodb-org/3.4 multiverse" | sudo tee /etc/apt/sources.list.d/mongodb-org-3.4.list

随后更新apt-get源:

1
sudo apt-get update

之后安装MongoDB即可:

1
sudo apt-get install -y mongodb-org

安装完成后运行MongoDB,命令如下:

1
mongod --port 27017 --dbpath /data/db

运行命令之后,MongoDB就在27017端口上运行了,数据文件会保存在/data/db路径下。

一般情况下,我们在Linux上配置MongoDB都是为了远程连接使用的,所以这里还需要配置一下MongoDB的远程连接以及用户名和密码。

接着,进入MongoDB命令行:

1
mongo --port 27017

现在我们就已经进入到MongoDB的命令行交互模式下了,在此模式下运行如下命令:

1
2
3
4
5
6
7
8
9
10
11
12
> use admin
switched to db admin
> db.createUser({user: 'admin', pwd: 'admin123', roles: [{role: 'root', db: 'admin'}]})
Successfully added user: {
"user" : "admin",
"roles" : [
{
"role" : "root",
"db" : "admin"
}
]
}

这样我们就创建了一个用户名为admin,密码为admin123的用户,赋予最高权限。

随后需要修改MongoDB的配置文件,此时执行如下命令:

1
sudo vi /etc/mongod.conf

然后修改net部分为:

1
2
3
net:
port: 27017
bindIp: 0.0.0.0

这样配置后,MongoDB可被远程访问。

另外,还需要添加如下的权限认证配置,此时直接添加如下内容到配置文件即可:

1
2
security:
authorization: enabled

配置完成之后,我们需要重新启动MongoDB服务,命令如下:

1
sudo service mongod restart

这样远程连接和权限认证就配置完成了。

CentOS和Red Hat

首先,添加MongoDB源:

1
sudo vi /etc/yum.repos.d/mongodb-org.repo

接着修改如下内容并保存:

1
2
3
4
5
6
[mongodb-org-3.4]
name=MongoDB Repository
baseurl=https://repo.mongodb.org/yum/redhat/$releasever/mongodb-org/3.4/x86_64/
gpgcheck=1
enabled=1
gpgkey=https://www.mongodb.org/static/pgp/server-3.4.asc

然后执行yum命令安装:

1
sudo yum install mongodb-org

这里启动MongoDB服务的命令如下:

1
sudo systemctl start mongod

停止和重新加载MongoDB服务的命令如下:

1
2
sudo systemctl stop mongod
sudo systemctl reload mongod

有关远程连接和认证配置,可以参考前面,方式是相同的。

更多Linux发行版的MongoDB安装方式可以参考官方文档:https://docs.mongodb.com/manual/administration/install-on-linux/

4. Mac下的安装

这里推荐使用Homebrew安装,直接执行brew命令即可:

1
brew install mongodb

然后创建一个新文件夹/data/db,用于存放MongoDB数据。

这里启动MongoDB服务的命令如下:

1
2
brew services start mongodb
sudo mongod

停止和重启MongoDB服务的命令分别是:

1
2
brew services stop mongodb
brew services restart mongodb

5. 可视化工具

这里推荐一个可视化工具RoboMongo/Robo 3T,它使用简单,功能强大,官方网站为https://robomongo.org/,三大平台都支持,下载链接为https://robomongo.org/download

另外,还有一个简单易用的可视化工具——Studio 3T,它同样具有方便的图形化管理界面,官方网站为https://studio3t.com,同样支持三大平台,下载链接为https://studio3t.com/download/

Python

MySQL是一个轻量级的关系型数据库,本节中我们来了解下它的安装方式。

1. 相关链接

  • 官方网站:https://www.mysql.com/cn
  • 下载地址:https://www.mysql.com/cn/downloads
  • 中文教程:http://www.runoob.com/mysql/mysql-tutorial.html

2. Windows下的安装

对于Windows来说,可以直接在百度软件中心搜索MySQL,下载其提供的MySQL安装包,速度还是比较快的。

当然,最安全稳妥的方式是直接到官网下载安装包进行安装,但是这样做有个缺点,那就是需要登录才可以下载,而且速度不快。

下载完成后,双击安装包即可安装,这里直接选择默认选项,点击Next按钮安装即可。这里需要记住图1-27所设置的密码。

图1-27 设置密码页面

安装完成后,我们可以在“计算机”→“管理”→“服务”页面开启和关闭MySQL服务,如图1-28所示。

图1-28 系统服务页面

如果启动了MySQL服务,就可以使用它来存储数据了。

3. Linux下的安装

下面我们仍然分平台来介绍。

Ubuntu、Debian和Deepin

在Ubuntu、Debian和Deepin系统中,我们直接使用apt-get命令即可安装MySQL:

1
2
sudo apt-get update
sudo apt-get install -y mysql-server mysql-client

在安装过程中,会提示输入用户名和密码,输入后等待片刻即可完成安装。

启动、关闭和重启MySQL服务的命令如下:

1
2
3
sudo service mysql start
sudo service mysql stop
sudo service mysql restart

CentOS和Red Hat

这里以MySQL 5.6的Yum源为例来说明(如果需要更高版本,可以另寻),安装命令如下:

1
2
3
wget http://repo.mysql.com/mysql-community-release-el7-5.noarch.rpm
sudo rpm -ivh mysql-community-release-el7-5.noarch.rpm
yum install -y mysql mysql-server

运行如上命令即可完成安装,初始密码为空。接下来,需要启动MySQL服务。

启动MySQL服务的命令如下:

1
sudo systemctl start mysqld

停止、重启MySQL服务的命令如下:

1
2
sudo systemctl stop mysqld
sudo systemctl restart mysqld

上面我们完成了Linux下MySQL的安装,之后可以修改密码,此时可以执行如下命令:

1
mysql -uroot -p

输入密码后,进入MySQL命令行模式,接着输入如下命令:

1
2
3
use mysql;
UPDATE user SET Password = PASSWORD('newpass') WHERE user = 'root';
FLUSH PRIVILEGES;

其中newpass为修改的新的MySQL密码,请自行替换。

由于Linux一般会作为服务器使用,为了使MySQL可以被远程访问,我们需要修改MySQL的配置文件,配置文件的路径一般为/etc/mysql/my.cnf。

比如,使用vi进行修改的命令如下:

1
vi /etc/mysql/my.cnf

取消此行的注释如下:

1
bind-address = 127.0.0.1

此行限制了MySQL只能本地访问而不能远程访问,取消注释即可解除此限制。

修改完成后重启MySQL服务,此时MySQL就可以被远程访问了。

到此为止,在Linux下安装MySQL的过程就结束了。

4. Mac下的安装

这里推荐使用Homebrew安装,直接执行brew命令即可:

1
brew install mysql

启动、停止和重启MySQL服务的命令如下:

1
2
3
sudo mysql.server start
sudo mysql.server stop
sudo mysql.server restart

Mac一般不会作为服务器使用,如果想取消本地host绑定,那么需要修改my.cnf 文件,然后重启服务。

Python

作为数据存储的重要部分,数据库同样是必不可少的,数据库可以分为关系型数据库和非关系型数据库。

关系型数据库如SQLite、MySQL、Oracle、SQL Server、DB2等,其数据库是以表的形式存储,非关系型数据库如MongoDB、Redis,它们的存储形式是键值对,存储形式更加灵活。

本书用到的数据库主要有关系型数据库MySQL及非关系型数据库MongoDB、Redis。

本节中,我们来了解一下它们的安装方式。

Python

在爬虫过程中,难免会遇到各种各样的验证码,而大多数验证码还是图形验证码,这时候我们可以直接用OCR来识别。

1. OCR

OCR,即Optical Character Recognition,光学字符识别,是指通过扫描字符,然后通过其形状将其翻译成电子文本的过程。对于图形验证码来说,它们都是一些不规则的字符,这些字符确实是由字符稍加扭曲变换得到的内容。

例如,对于如图1-22和图1-23所示的验证码,我们可以使用OCR技术来将其转化为电子文本,然后爬虫将识别结果提交给服务器,便可以达到自动识别验证码的过程。

图1-22 验证码

图1-23 验证码

tesserocr是Python的一个OCR识别库,但其实是对tesseract做的一层Python API封装,所以它的核心是tesseract。因此,在安装tesserocr之前,我们需要先安装tesseract。

2. 相关链接

  • tesserocr GitHub:https://github.com/sirfz/tesserocr
  • tesserocr PyPI:https://pypi.python.org/pypi/tesserocr
  • tesseract下载地址:http://digi.bib.uni-mannheim.de/tesseract
  • tesseract GitHub:https://github.com/tesseract-ocr/tesseract
  • tesseract语言包:https://github.com/tesseract-ocr/tessdata
  • tesseract文档:https://github.com/tesseract-ocr/tesseract/wiki/Documentation

3. Windows下的安装

在Windows下,首先需要下载tesseract,它为tesserocr提供了支持。

进入下载页面,可以看到有各种.exe文件的下载列表,这里可以选择下载3.0版本。图1-24所示为3.05版本。

图1-24 下载页面

其中文件名中带有dev的为开发版本,不带dev的为稳定版本,可以选择下载不带dev的版本,例如可以选择下载tesseract-ocr-setup-3.05.01.exe。

下载完成后双击,此时会出现如图1-25所示的页面。

图1-25 安装页面

此时可以勾选Additional language data(download)选项来安装OCR识别支持的语言包,这样OCR便可以识别多国语言。然后一路点击Next按钮即可。

接下来,再安装tesserocr即可,此时直接使用pip安装:

1
pip3 install tesserocr pillow

4. Linux下的安装

对于Linux来说,不同系统已经有了不同的发行包了,它可能叫作tesseract-ocr或者tesseract,直接用对应的命令安装即可。

Ubuntu、Debian和Deepin

在Ubuntu、Debian和Deepin系统下,安装命令如下:

1
sudo apt-get install -y tesseract-ocr libtesseract-dev libleptonica-dev

CentOS、Red Hat

在CentOS和Red Hat系统下,安装命令如下:

1
yum install -y tesseract

在不同发行版本运行如上命令,即可完成tesseract的安装。

安装完成后,便可以调用tesseract命令了。

接着,我们查看一下其支持的语言:

1
tesseract --list-langs

运行结果示例:

1
2
3
4
List of available languages (3):
eng
osd
equ

结果显示它只支持几种语言,如果想要安装多国语言,还需要安装语言包,官方叫作tessdata(其下载链接为:https://github.com/tesseract-ocr/tessdata)。

利用Git命令将其下载下来并迁移到相关目录即可,不同版本的迁移命令如下所示。

在Ubuntu、Debian和Deepin系统下的迁移命令如下:

1
2
git clone https://github.com/tesseract-ocr/tessdata.git
sudo mv tessdata/* /usr/share/tesseract-ocr/tessdata

在CentOS和Red Hat系统下的迁移命令如下:

1
2
git clone https://github.com/tesseract-ocr/tessdata.git
sudo mv tessdata/* /usr/share/tesseract/tessdata

这样就可以将下载下来的语言包全部安装了。

这时我们重新运行列出所有语言的命令:

1
tesseract --list-langs

结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
List of available languages (107):
afr
amh
ara
asm
aze
aze_cyrl
bel
ben
bod
bos
bul
cat
ceb
ces
chi_sim
chi_tra
...

可以发现,这里列出的语言就多了很多,比如chi_sim就代表简体中文,这就证明语言包安装成功了。

接下来再安装tesserocr即可,这里直接使用pip安装:

1
pip3 install tesserocr pillow

5. Mac下的安装

在Mac下,我们首先使用Homebrew安装ImageMagick和tesseract库:

1
2
brew install imagemagick 
brew install tesseract --all-languages

接下来再安装tesserocr即可:

1
pip3 install tesserocr pillow

这样我们便完成了tesserocr的安装。

6. 验证安装

接下来,我们可以使用tesseract和tesserocr来分别进行测试。

下面我们以如图1-26所示的图片为样例进行测试。

图1-26 测试样例

该图片的链接为https://raw.githubusercontent.com/Python3WebSpider/TestTess/master/image.png,可以直接保存或下载。

首先用命令行进行测试,将图片下载下来并保存为image.png,然后用tesseract命令测试:

1
tesseract image.png result -l eng && cat result.txt

运行结果如下:

1
2
Tesseract Open Source OCR Engine v3.05.01 with Leptonica
Python3WebSpider

这里我们调用了tesseract命令,其中第一个参数为图片名称,第二个参数result为结果保存的目标文件名称,\-l指定使用的语言包,在此使用英文(eng)。然后,再用cat命令将结果输出。

运行结果便是图片的识别结果:Python3WebSpider。可以看到,这时已经成功将图片文字转为电子文本了。

然后还可以利用Python代码来测试,这里就需要借助于tesserocr库了,测试代码如下:

1
2
3
4
import tesserocr
from PIL import Image
image = Image.open('image.png')
print(tesserocr.image_to_text(image))

我们首先利用Image读取了图片文件,然后调用了tesserocrimage_to_text()方法,再将其识别结果输出。

运行结果如下:

1
Python3WebSpider

另外,我们还可以直接调用file_to_text()方法,这可以达到同样的效果:

1
2
import tesserocr
print(tesserocr.file_to_text('image.png'))

运行结果:

1
Python3WebSpider

如果成功输出结果,则证明tesseract和tesserocr都已经安装成功。

Python

pyquery同样是一个强大的网页解析工具,它提供了和jQuery类似的语法来解析HTML文档,支持CSS选择器,使用非常方便。本节中,我们就来了解一下它的安装方式。

1. 相关链接

  • GitHub:https://github.com/gawel/pyquery
  • PyPI:https://pypi.python.org/pypi/pyquery
  • 官方文档:http://pyquery.readthedocs.io

2. pip安装

这里推荐使用pip安装,命令如下:

1
pip3 install pyquery

命令执行完毕之后即可完成安装。

3. wheel安装

当然,我们也可以到PyPI(https://pypi.python.org/pypi/pyquery/#downloads)下载对应的wheel文件安装。比如如果当前版本为1.2.17,则下载的文件名称为pyquery-1.2.17-py2.py3-none-any.whl,此时下载到本地再进行pip安装即可,命令如下:

1
pip3 install pyquery-1.2.17-py2.py3-none-any.whl

4. 验证安装

安装完成之后,可以在Python命令行下测试:

1
2
$ python3
>>> import pyquery

如果没有错误报出,则证明库已经安装好了。

Python

Beautiful Soup是Python的一个HTML或XML的解析库,我们可以用它来方便地从网页中提取数据。它拥有强大的API和多样的解析方式,本节就来了解下它的安装方式。

1. 相关链接

  • 官方文档:https://www.crummy.com/software/BeautifulSoup/bs4/doc
  • 中文文档:https://www.crummy.com/software/BeautifulSoup/bs4/doc.zh
  • PyPI:https://pypi.python.org/pypi/beautifulsoup4

2. 准备工作

Beautiful Soup的HTML和XML解析器是依赖于lxml库的,所以在此之前请确保已经成功安装好了lxml库,具体的安装方式参见上节。

3. pip安装

目前,Beautiful Soup的最新版本是4.x版本,之前的版本已经停止开发了。这里推荐使用pip来安装,安装命令如下:

1
pip3 install beautifulsoup4

命令执行完毕之后即可完成安装。

4. wheel安装

当然,我们也可以从PyPI下载wheel文件安装,链接如下:https://pypi.python.org/pypi/beautifulsoup4

然后使用pip安装wheel文件即可。

5. 验证安装

安装完成之后,可以运行下面的代码验证一下:

1
2
3
from bs4 import BeautifulSoup
soup = BeautifulSoup('<p>Hello</p>', 'lxml')
print(soup.p.string)

运行结果如下:

1
Hello

如果运行结果一致,则证明安装成功。

注意,这里我们虽然安装的是beautifulsoup4这个包,但是在引入的时候却是bs4。这是因为这个包源代码本身的库文件夹名称就是bs4,所以安装完成之后,这个库文件夹就被移入到本机Python3的lib库里,所以识别到的库文件名就叫作bs4。

因此,包本身的名称和我们使用时导入的包的名称并不一定是一致的。

Python

lxml是Python的一个解析库,支持HTML和XML的解析,支持XPath解析方式,而且解析效率非常高。本节中,我们了解一下lxml的安装方式,这主要从Windows、Linux和Mac三大平台来介绍。

1. 相关链接

  • 官方网站:http://lxml.de
  • GitHub:https://github.com/lxml/lxml
  • PyPI:https://pypi.python.org/pypi/lxml

2. Windows下的安装

在Windows下,可以先尝试利用pip安装,此时直接执行如下命令即可:

1
pip3 install lxml

如果没有任何报错,则证明安装成功。

如果出现报错,比如提示缺少libxml2库等信息,可以采用wheel方式安装。

推荐直接到这里(链接为:http://www.lfd.uci.edu/~gohlke/pythonlibs/#lxml)下载对应的wheel文件,找到本地安装Python版本和系统对应的lxml版本,例如Windows 64位、Python 3.6,就选择lxml‑3.8.0‑cp36‑cp36m‑win_amd64.whl,将其下载到本地。

然后利用pip安装即可,命令如下:

1
pip3 install lxml‑3.8.0cp36cp36mwin_amd64.whl

这样我们就可以成功安装lxml了。

3. Linux下的安装

在Linux平台下安装问题不大,同样可以先尝试pip安装,命令如下:

1
pip3 install lxml

如果报错,可以尝试下面的解决方案。

CentOS、Red Hat

对于此类系统,报错主要是因为缺少必要的库。

执行如下命令安装所需的库即可:

1
2
sudo yum groupinstall -y development tools
sudo yum install -y epel-release libxslt-devel libxml2-devel openssl-devel

主要是libxslt-devel和libxml2-devel这两个库,lxml依赖它们。安装好之后,重新尝试pip安装即可。

Ubuntu、Debian和Deepin

在这些系统下,报错的原因同样可能是缺少了必要的类库,执行如下命令安装:

1
sudo apt-get install -y python3-dev build-essential libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev

安装好之后,重新尝试pip安装即可。

4. Mac下的安装

在Mac平台下,仍然可以首先尝试pip安装,命令如下:

1
pip3 install lxml

如果产生错误,可以执行如下命令将必要的类库安装:

1
xcode-select --install

之后再重新尝试pip安装,就没有问题了。

lxml是一个非常重要的库,后面的Beautiful Soup、Scrapy框架都需要用到此库,所以请一定安装成功。

5. 验证安装

安装完成之后,可以在Python命令行下测试:

1
2
$ python3
>>> import lxml

如果没有错误报出,则证明库已经安装好了。

Python

抓取网页代码之后,下一步就是从网页中提取信息。提取信息的方式有多种多样,可以使用正则来提取,但是写起来相对比较烦琐。这里还有许多强大的解析库,如lxml、Beautiful Soup、pyquery等。此外,还提供了非常强大的解析方法,如XPath解析和CSS选择器解析等,利用它们,我们可以高效便捷地从网页中提取有效信息。

本节中,我们就来介绍一下这些库的安装过程。

Python

之前介绍的Requests库是一个阻塞式HTTP请求库,当我们发出一个请求后,程序会一直等待服务器响应,直到得到响应后,程序才会进行下一步处理。其实,这个过程比较耗费资源。如果程序可以在这个等待过程中做一些其他的事情,如进行请求的调度、响应的处理等,那么爬取效率一定会大大提高。

aiohttp就是这样一个提供异步Web服务的库,从Python 3.5版本开始,Python中加入了async/await关键字,使得回调的写法更加直观和人性化。aiohttp的异步操作借助于async/await关键字的写法变得更加简洁,架构更加清晰。使用异步请求库进行数据抓取时,会大大提高效率,下面我们来看一下这个库的安装方法。

1. 相关链接

  • 官方文档:http://aiohttp.readthedocs.io/en/stable
  • GitHub:https://github.com/aio-libs/aiohttp
  • PyPI:https://pypi.python.org/pypi/aiohttp

2. pip安装

这里推荐使用pip安装,命令如下:

1
pip3 install aiohttp

另外,官方还推荐安装如下两个库:一个是字符编码检测库cchardet,另一个是加速DNS的解析库aiodns。安装命令如下:

1
pip3 install cchardet aiodns

3. 测试安装

安装完成之后,可以在Python命令行下测试:

1
2
$ python3
>>> import aiohttp

如果没有错误报出,则证明库已经安装好了。

4. 结语

我们会在后面的实例中用到这个库,比如维护一个代理池时,利用异步方式检测大量代理的运行状况,会极大地提升效率。

Python

PhantomJS是一个无界面的、可脚本编程的WebKit浏览器引擎,它原生支持多种Web标准:DOM操作、CSS选择器、JSON、Canvas以及SVG。

Selenium支持PhantomJS,这样在运行的时候就不会再弹出一个浏览器了。而且PhantomJS的运行效率也很高,还支持各种参数配置,使用非常方便。下面我们就来了解一下PhantomJS的安装过程。

1. 相关链接

  • 官方网站:http://phantomjs.org
  • 官方文档:http://phantomjs.org/quick-start.html
  • 下载地址:http://phantomjs.org/download.html
  • API接口说明:http://phantomjs.org/api/command-line.html

2. 下载PhantomJS

我们需要在官方网站下载对应的安装包,PhantomJS支持多种操作系统,比如Windows、Linux、Mac、FreeBSD等,我们可以选择对应的平台并将安装包下载下来。

下载完成后,将PhantomJS可执行文件所在的路径配置到环境变量里。比如在Windows下,将下载的文件解压之后并打开,会看到一个bin文件夹,里面包括一个可执行文件phantomjs.exe,我们需要将它直接放在配置好环境变量的路径下或者将它所在的路径配置到环境变量里。比如,我们既可以将它直接复制到Python的Scripts文件夹,也可以将它所在的bin目录加入到环境变量。

Windows下环境变量的配置可以参见1.1节,Linux及Mac环境变量的配置可以参见1.2.3节,在此不再赘述,关键在于将PhantomJS的可执行文件所在路径配置到环境变量里。

配置成功后,可以在命令行下测试一下,输入:

1
phantomjs

如果可以进入到PhantomJS的命令行,那就证明配置完成了,如图1-21所示。

图1-21 控制台

3. 验证安装

在Selenium中使用的话,我们只需要将Chrome切换为PhantomJS即可:

1
2
3
4
from selenium import webdriver
browser = webdriver.PhantomJS()
browser.get('https://www.baidu.com')
print(browser.current_url)

运行之后,我们就不会发现有浏览器弹出了,但实际上PhantomJS已经运行起来了。这里我们访问了百度,然后将当前的URL打印出来。

控制台的输出如下:

1
https://www.baidu.com/

如此一来,我们便完成了PhantomJS的配置,后面可以利用它来完成一些页面的抓取。

这里我们介绍了Selenium对应的三大主流浏览器的对接方式,后面我们会对Selenium及各个浏览器的对接方法进行更加深入的探究。

Python

上一节中,我们了解了ChromeDriver的配置方法,配置完成之后便可以用Selenium驱动Chrome浏览器来做相应网页的抓取。

那么对于Firefox来说,也可以使用同样的方式完成Selenium的对接,这时需要安装另一个驱动GeckoDriver。

本节中,我们来介绍一下GeckoDriver的安装过程。

1. 相关链接

  • GitHub:https://github.com/mozilla/geckodriver
  • 下载地址:https://github.com/mozilla/geckodriver/releases

2. 准备工作

在这之前请确保已经正确安装好了Firefox浏览器并可以正常运行,安装过程不再赘述。

3. 下载GeckoDriver

我们可以在GitHub上找到GeckoDriver的发行版本,当前最新版本为0.18,下载页面如图1-18所示。图1-18 GeckoDriver下载页面

这里可以在不同的平台上下载,如Windows、Mac、Linux、ARM等平台,我们可以根据自己的系统和位数选择对应的驱动下载,若是Windows 64位,就下载geckodriver-v0.18.0-win64.zip。

4. 环境变量配置

在Windows下,可以直接将geckodriver.exe文件拖到Python的Scripts目录下,如图1-19所示。

图1-19 将geckodriver.exe文件拖到Python Scripts目录

此外,也可以单独将其所在路径配置到环境变量,具体的配置方法请参1.1节。

在Linux和Mac下,需要将可执行文件配置到环境变量或将文件移动到属于环境变量的目录里。

例如,要移动文件到/usr/bin目录。首先在命令行模式下进入其所在路径,然后将其移动到/usr/bin:

1
sudo mv geckodriver /usr/bin

当然,也可以将GeckoDriver配置到$PATH。首先,可以将可执行文件放到某一目录,目录可以任意选择,例如将当前可执行文件放在/usr/local/geckodriver目录下。接下来可以修改~/.profile文件,命令如下:

1
vi ~/.profile

然后添加如下一句配置:

1
export PATH="$PATH:/usr/local/geckodriver"

保存后执行如下命令即可完成配置:

1
source ~/.profile

5. 验证安装

配置完成后,就可以在命令行下直接执行geckodriver命令测试:

1
geckodriver

这时如果控制台有类似图1-20所示的输出,则证明GeckoDriver的环境变量配置好了。

图1-20 控制台输出

随后执行如下Python代码,在程序中测试一下:

1
2
from selenium import webdriver
browser = webdriver.Firefox()

运行之后,若弹出一个空白的Firefox浏览器,则证明所有的配置都没有问题;如果没有弹出,请检查之前的每一步配置。

如果没有问题,接下来就可以利用Firefox配合Selenium来做网页抓取了。

6. 结语

现在我们就可以使用Chrome或Firefox进行网页抓取了,但是这样可能有个不方便之处:因为程序运行过程中需要一直开着浏览器,在爬取网页的过程中浏览器可能一直动来动去。目前最新的Chrome浏览器版本已经支持无界面模式了,但如果版本较旧的话,就不支持。所以这里还有另一种选择,那就是安装一个无界面浏览器PhantomJS,此时抓取过程会在后台运行,不会再有窗口出现。在下一节中,我们就来了解一下PhantomJS的相关安装方法。

Python

前面我们成功安装好了Selenium库,但是它是一个自动化测试工具,需要浏览器来配合使用,本节中我们就介绍一下Chrome浏览器及ChromeDriver驱动的配置。

首先,下载Chrome浏览器,方法有很多,在此不再赘述。

随后安装ChromeDriver。因为只有安装ChromeDriver,才能驱动Chrome浏览器完成相应的操作。下面我们来介绍下怎样安装ChromeDriver。

1. 相关链接

  • 官方网站:https://sites.google.com/a/chromium.org/chromedriver
  • 下载地址:https://chromedriver.storage.googleapis.com/index.html

2. 准备工作

在这之前请确保已经正确安装好了Chrome浏览器并可以正常运行,安装过程不再赘述。

3. 查看版本

点击Chrome菜单“帮助”→“关于Google Chrome”,即可查看Chrome的版本号,如图1-14所示。

图1-14 Chrome版本号

这里我的Chrome版本是58.0。

请记住Chrome版本号,因为选择ChromeDriver版本时需要用到。

4. 下载ChromeDriver

打开ChromeDriver的官方网站,可以看到最新版本为2.31,其支持的Chrome浏览器版本为58~60,官网页面如图1-15所示。

更新:现在 2020 年,Chrome 版本已经更新到 80+,请以最新的 ChromeDriver 为准!https://chromedriver.chromium.org/downloads

图1-15 官网页面

如果你的Chrome版本号是58~60,那么可以选择此版本下载。

如果你的Chrome版本号不在此范围,可以继续查看之前的ChromeDriver版本。每个版本都有相应的支持Chrome版本的介绍,请找好自己的Chrome浏览器版本对应的ChromeDriver版本再下载,否则可能无法正常工作。

找好对应的版本号后,随后到ChromeDriver镜像站下载对应的安装包即可:https://chromedriver.storage.googleapis.com/index.html。在不同平台下,可以下载不同的安装包。

5. 环境变量配置

下载完成后,将ChromeDriver的可执行文件配置到环境变量下。

在Windows下,建议直接将chromedriver.exe文件拖到Python的Scripts目录下,如图1-16所示。

图1-16 Python Scripts目录

此外,也可以单独将其所在路径配置到环境变量,具体的配置方法请参见1.1节。

在Linux和Mac下,需要将可执行文件配置到环境变量或将文件移动到属于环境变量的目录里。

例如,要移动文件到/usr/bin目录。首先,需要在命令行模式下进入其所在路径,然后将其移动到/usr/bin:

1
sudo mv chromedriver /usr/bin

当然,也可以将ChromeDriver配置到$PATH。首先,可以将可执行文件放到某一目录,目录可以任意选择,例如将当前可执行文件放在/usr/local/chromedriver目录下,接下来可以修改~/.profile文件,相关命令如下:

1
export PATH="$PATH:/usr/local/chromedriver"

保存后执行如下命令:

1
source ~/.profile

即可完成环境变量的添加。

6. 验证安装

配置完成后,就可以在命令行下直接执行chromedriver命令了:

1
chromedriver

如果输入控制台有类似图1-17所示的输出,则证明ChromeDriver的环境变量配置好了。

图1-17 控制台输出

随后再在程序中测试,执行如下Python代码:

1
2
from selenium import webdriver
browser = webdriver.Chrome()

运行之后,如果弹出一个空白的Chrome浏览器,则证明所有的配置都没有问题。如果没有弹出,请检查之前的每一步配置。

如果弹出后闪退,则可能是ChromeDriver版本和Chrome版本不兼容,请更换ChromeDriver版本。

如果没有问题,接下来就可以利用Chrome来做网页抓取了。

Selenium是一个自动化测试工具,利用它我们可以驱动浏览器执行特定的动作,如点击、下拉等操作。对于一些JavaScript渲染的页面来说,这种抓取方式非常有效。下面我们来看看Selenium的安装过程。

1. 相关链接

  • 官方网站:http://www.seleniumhq.org
  • GitHub:https://github.com/SeleniumHQ/selenium/tree/master/py
  • PyPI:https://pypi.python.org/pypi/selenium
  • 官方文档:http://selenium-python.readthedocs.io
  • 中文文档:http://selenium-python-zh.readthedocs.io

2. pip安装

这里推荐直接使用pip安装,执行如下命令即可:

1
pip3 install selenium

3. wheel安装

此外,也可以到PyPI下载对应的wheel文件进行安装(下载地址:https://pypi.python.org/pypi/selenium/#downloads),如最新版本为3.4.3,则下载selenium-3.4.3-py2.py3-none-any.whl即可。

然后进入wheel文件目录,使用pip安装:

1
pip3 install selenium-3.4.3-py2.py3-none-any.whl

4. 验证安装

进入Python命令行交互模式,导入Selenium包,如果没有报错,则证明安装成功:

1
2
$ python3
>>> import selenium

但这样做还不够,因为我们还需要用浏览器(如Chrome、Firefox等)来配合Selenium工作。

后面我们会介绍Chrome、Firefox、PhantomJS三种浏览器的配置方式。有了浏览器,我们才可以配合Selenium进行页面的抓取。

Python

由于Requests属于第三方库,也就是Python默认不会自带这个库,所以需要我们手动安装。下面我们首先看一下它的安装过程。

1. 相关链接

  • GitHub:https://github.com/requests/requests
  • PyPI:https://pypi.python.org/pypi/requests
  • 官方文档:http://www.python-requests.org
  • 中文文档:http://docs.python-requests.org/zh_CN/latest

2. pip安装

无论是Windows、Linux还是Mac,都可以通过pip这个包管理工具来安装。

在命令行界面中运行如下命令,即可完成Requests库的安装:

1
pip3 install requests

这是最简单的安装方式,推荐使用这种方法安装。

3. wheel安装

wheel是Python的一种安装包,其后缀为.whl,在网速较差的情况下可以选择下载wheel文件再安装,然后直接用pip3命令加文件名安装即可。

不过在这之前需要先安装wheel库,安装命令如下:

1
pip3 install wheel

然后到PyPI上下载对应的wheel文件,如最新版本为2.17.3,则打开https://pypi.python.org/pypi/requests/2.17.3#downloads,下载requests-2.17.3-py2.py3-none-any.whl到本地。

随后在命令行界面进入wheel文件目录,利用pip安装即可:

1
pip3 install requests-2.17.3-py2.py3-none-any.whl

这样我们也可以完成Requests的安装。

4. 源码安装

如果你不想用pip来安装,或者想获取某一特定版本,可以选择下载源码安装。

此种方式需要先找到此库的源码地址,然后下载下来再用命令安装。

Requests项目的地址是:https://github.com/kennethreitz/requests

可以通过Git来下载源代码:

1
git clone git://github.com/kennethreitz/requests.git

或通过curl下载:

1
curl -OL https://github.com/kennethreitz/requests/tarball/master

下载下来之后,进入目录,执行如下命令即可安装:

1
2
cd requests
python3 setup.py install

命令执行结束后即可完成Requests的安装。由于这种安装方式比较烦琐,后面不再赘述。

5. 验证安装

为了验证库是否已经安装成功,可以在命令行模式测试一下:

1
2
$ python3
>>> import requests

首先输入python3,进入命令行模式,然后输入上述内容,如果什么错误提示也没有,就证明已经成功安装了Requests。

Python

爬虫可以简单分为几步:抓取页面、分析页面和存储数据。

在抓取页面的过程中,我们需要模拟浏览器向服务器发出请求,所以需要用到一些Python库来实现HTTP请求操作。在本书中,我们用到的第三方库有Requests、Selenium和aiohttp等。

在本节中,我们介绍一下这些请求库的安装方法。

Python

既然要用Python 3开发爬虫,那么第一步一定是安装Python 3。这里会介绍Windows、Linux和Mac三大平台下的安装过程。

1. 相关链接

  • 官方网站:http://python.org
  • 下载地址:https://www.python.org/downloads
  • 第三方库:https://pypi.python.org/pypi
  • 官方文档:https://docs.python.org/3
  • 中文教程:http://www.runoob.com/python3/python3-tutorial.html
  • Awesome Python:https://github.com/vinta/awesome-python
  • Awesome Python中文版:https://github.com/jobbole/awesome-python-cn

2. Windows下的安装

在Windows下安装Python 3的方式有两种。

  • 一种是通过Anaconda安装,它提供了Python的科学计算环境,里面自带了Python以及常用的库。如果选用了这种方式,后面的环境配置方式会更加简便。
  • 另一种是直接下载安装包安装,即标准的安装方式。

下面我们依次介绍这两种安装方式,任选其一即可。

(1) Anaconda安装

Anaconda的官方下载链接为https://www.continuum.io/downloads,选择Python 3版本的安装包下载即可,如图1-1所示。

图像说明文字

图1-1 Anaconda Windows下载页面

如果下载速度过慢,可以选择使用清华大学镜像,下载列表链接为https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/,使用说明链接为https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/

下载完成之后,直接双击安装包安装即可。安装完成之后,Python 3的环境就配置好了。

(2) 安装包安装

我们推荐直接下载安装包来安装,此时可以直接到官方网站下载Python 3的安装包:https://www.python.org/downloads/

写书时,Python的最新版本1是3.6.2,其下载链接为https://www.python.org/downloads/release/python-362/,下载页面如图1-2所示。需要说明的是,实际的Python最新版本以官网为准。

图像说明文字

图1-2 Python下载页面

  1. 若无特别说明,书中的最新版本均为作者写书时的情况,后面不再一一说明。

64位系统可以下载Windows x86-64 executable installer,32位系统可以下载Windows x86 executable installer。

下载完成之后,直接双击Python安装包,然后通过图形界面安装,接着设置Python的安装路径,完成后将Python 3和Python 3的Scripts目录配置到环境变量即可。

关于环境变量的配置,此处以Windows 10系统为例进行演示。

假如安装后的Python 3路径为C:\Python36,从资源管理器中打开该路径,如图1-3所示。

图像说明文字

图1-3 Python安装目录

将该路径复制下来。

随后,右击“计算机”,从中选择“属性”,此时将打开系统属性窗口,如图1-4所示。

图像说明文字

图1-4 系统属性

点击左侧的“高级系统设置”,即可看到在弹出的对话框下方看到“环境变量”按钮,如图1-5所示。

图像说明文字

图1-5 高级系统设置

点击“环境变量”按钮,找到系统变量下的Path变量,随后点击“编辑”按钮,如图1-6所示。

图像说明文字

图1-6 环境变量

随后点击“新建”,新建一个条目,将刚才复制的C:\Python36复制进去。这里需要说明的是,此处的路径就是你的Python 3安装目录,请自行替换。然后,再把C:\Python36\Scripts路径复制进去,如图1-7所示。

图像说明文字

图1-7 编辑环境变量

最后,点击“确定”按钮即可完成环境变量的配置。

配置好环境变量后,我们就可以在命令行中直接执行环境变量路径下的可执行文件了,如pythonpip等命令。

(3) 添加别名

上面这两种安装方式任选其一即可完成安装,但如果之前安装过Python 2的话,可能会导致版本冲突问题,比如在命令行下输入python就不知道是调用的Python 2还是Python 3了。为了解决这个问题,建议将安装目录中的python.exe复制一份,命名为python3.exe,这样便可以调用python3命令了。实际上,它和python命令是完全一致的,这样只是为了可以更好地区分Python版本。当然,如果没有安装过Python 2的话,也建议添加此别名,添加完毕之后的效果如图1-8所示。

图像说明文字

图1-8 添加别名

对于pip来说,安装包中自带了pip3.exe可执行文件,我们也可以直接使用pip3命令,无需额外配置。

(4) 测试验证

安装完成后,可以通过命令行测试一下安装是否成功。在“开始”菜单中搜索cmd,找到命令提示符,此时就进入命令行模式了。输入python,测试一下能否成功调用Python。如果添加了别名的话,可以输入python3测试,这里输入的是python3,测试结果如图1-9所示。

图像说明文字

图1-9 测试验证页面

输出结果类似如下:

1
2
3
4
5
6
7
8
$ python3
Python 3.6.1 (v3.6.1:69c0db5, Mar 21 2017, 17:54:52) [MSC v.1900 32 bit (Intel)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> print('Hello World')
Hello World
>>> exit()
$ pip3 -V
pip 9.0.1 from c:\python36\lib\site-packages (python 3.6)

如果出现了类似上面的提示,则证明Python 3和pip 3均安装成功;如果提示命令不存在,那么请检查下环境变量的配置情况。

3. Linux下的安装

Linux下的安装方式有多种:命令安装、源码安装和Anaconda安装。

使用源码安装需要自行编译,时间较长。推荐使用系统自带的命令或Anaconda安装,简单、高效。这里分别讲解这3种安装方式。

(1) 命令行安装

不同的Linux发行版本的安装方式又有不同,在此分别予以介绍。

CentOS、Red Hat

如果是CentOS或Red Hat版本,则使用yum命令安装即可。

下面列出了Python 3.5和Python 3.4两个版本的安装方法,可以自行选择。

Python 3.5版本:

1
2
3
sudo yum install -y https://centos7.iuscommunity.org/ius-release.rpm
sudo yum update
sudo yum install -y python35u python35u-libs python35u-devel python35u-pip

执行完毕后,便可以成功安装Python 3.5及pip 3了。

Python 3.4版本:

1
2
3
4
5
sudo yum groupinstall -y development tools
sudo yum install -y epel-release python34-devel libxslt-devel libxml2-devel openssl-devel
sudo yum install -y python34
sudo yum install -y python34-setuptools
sudo easy_install-3.4 pip

执行完毕后,便可以成功安装Python 3.4及pip 3了。

Ubuntu、Debian和Deepin

首先安装Python 3,这里使用apt-get安装即可。在安装前,还需安装一些基础库,相关命令如下:

1
2
sudo apt-get install -y python3-dev build-essential libssl-dev libffi-dev libxml2 libxml2-dev libxslt1-dev zlib1g-dev libcurl4-openssl-dev
sudo apt-get install -y python3

执行完上述命令后,就可以成功安装Python 3了。

然后还需要安装pip 3,这里仍然使用apt-get安装即可,相关命令如下:

1
sudo apt-get install -y python3-pip

执行完毕后,便可以成功安装Python 3及pip 3了。

(2) 源码安装

如果命令行的安装方式有问题,还可以下载Python 3源码进行安装。

源码下载地址为https://www.python.org/ftp/python/,可以自行选用想要的版本进行安装。这里以Python 3.6.2为例进行说明,安装路径设置为/usr/local/python3。

首先,创建安装目录,相关命令如下:

1
sudo mkdir /usr/local/python3

随后下载安装包并解压进入,相关命令如下:

1
2
3
wget --no-check-certificate https://www.python.org/ftp/python/3.6.2/Python-3.6.2.tgz
tar -xzvf Python-3.6.2.tgz
cd Python-3.6.2

接下来,编译安装。所需的时间可能较长,请耐心等待,命令如下:

1
2
3
sudo ./configure --prefix=/usr/local/python3
sudo make
sudo make install

安装完成之后,创建Python 3链接,相关命令如下:

1
sudo ln -s /usr/local/python3/bin/python3 /usr/bin/python3

随后下载pip安装包并安装,命令如下:

1
2
3
4
wget --no-check-certificate https://github.com/pypa/pip/archive/9.0.1.tar.gz
tar -xzvf 9.0.1.tar.gz
cd pip-9.0.1
python3 setup.py install

安装完成后再创建pip 3链接,相关命令如下:

1
sudo ln -s /usr/local/python3/bin/pip /usr/bin/pip3

这样就成功安装好了Python 3及pip 3。

(3) Anaconda安装

Anaconda同样支持Linux,其官方下载链接为https://www.continuum.io/downloads,选择Python 3版本的安装包下载即可,如图1-10所示。

图像说明文字

图1-10 Anaconda Linux下载页面

如果下载速度过慢,同样可以使用清华镜像,具体可参考Windows部分的介绍,在此不再赘述。

(4) 测试验证

在命令行界面下测试Python 3和pip 3是否安装成功:

1
2
3
4
5
6
$ python3
Python 3.5.2 (default, Nov 17 2016, 17:05:23)
Type "help", "copyright", "credits" or "license" for more information.
>>> exit()
$ pip3 -V
pip 8.1.1 from /usr/lib/python3/dist-packages (python 3.5)

若出现类似上面的提示,则证明Python 3和pip 3安装成功。

4. Mac下的安装

在Mac下同样有多种安装方式,如Homebrew、安装包安装、Anaconda安装等,这里推荐使用Homebrew安装。

(1) Homebrew安装

Homebrew是Mac平台下强大的包管理工具,其官方网站是https://brew.sh/

执行如下命令,即可安装Homebrew:

1
ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)"

安装完成后,便可以使用brew命令安装Python 3和pip 3了:

1
brew install python3

命令执行完成后,我们发现Python 3和pip 3均已成功安装。

(2) 安装包安装

可以到官方网站下载Python 3安装包。链接为https://www.python.org/downloads/,页面如图1-2所示。

在Mac平台下,可以选择下载Mac OS X 64-bit/32-bit installer,下载完成后,打开安装包按照提示安装即可。

(3) Anaconda安装

Anaconda同样支持Mac,其官方下载链接为:https://www.continuum.io/downloads,选择Python 3版本的安装包下载即可,如图1-11所示。

图像说明文字

图1-11 Anaconda Mac下载页面

如果下载速度过慢,同样可以使用清华镜像,具体可参考Windows部分的介绍,在此不再赘述。

(4) 测试验证

打开终端,在命令行界面中测试Python 3和pip 3是否成功安装,如图1-12所示。

图像说明文字

图1-12 测试验证页面

若出现上面的提示,则证明Python 3和pip 3安装成功。

本节中,我们介绍了3大平台Windows、Linux和Mac下Python 3的安装方式。安装完成后,我们便可以开启Python爬虫的征程了。

Python

工欲善其事,必先利其器!

编写和运行程序之前我们必须要先把开发环境配置好,只有配置好了环境并且有了更方便的开发工具我们才能更加高效地用程序实现相应的功能,然而很多情况下我们可能在最开始就卡在环境配置上,如果这个过程花费了太多时间,想必学习的兴趣就下降了大半,所以本章专门来对本书中所有的环境配置做一下说明。

本章是本书使用的所有库及工具的安装过程讲解,为了使书的条理更加清晰,本书将环境配置的过程统一合并为一章,本章不必逐节阅读,可以在需要的时候进行查阅。

文中在介绍安装过程的时候会尽量兼顾各个平台,另外会将一些安装常见的错误指出,以便快速高效地搭建好编程环境。

Python

2022 年 Python3 网络爬虫教程

大家好,我是崔庆才,由于爬虫技术不断迭代升级,一些旧的教程已经过时、案例已经过期,最前沿的爬虫技术比如异步、JavaScript 逆向、安卓逆向、智能解析、WebAssembly、大规模分布式、Kubernetes 等技术层出不穷,我最近新出了一套最新最全面的 Python3 网络爬虫系列教程。

博主自荐:截止 2022 年,可以将最前沿最全面的爬虫技术都涵盖的教程,如异步、JavaScript 逆向、安卓逆向、智能解析、WebAssembly、大规模分布式、Kubernetes 等,市面上目前就这一套了。

最新教程对旧的爬虫技术内容进行了全面更新,搭建了全新的案例平台进行全面讲解,保证案例稳定有效不过期。

教程请移步:

【2022 版】Python3 网络爬虫学习教程

2018 年 Python3 爬虫系列教程

以下为 2018 年版 Python3 网络爬虫系列教程

本内容来自于《Python3 网络爬虫开发实战》一书。 书籍购买地址: https://item.jd.com/12333540.html

本书通过多个实战案例详细介绍了 Python3 网络爬虫的知识,本书由图灵教育-人民邮电出版社出版发行,版权所有,禁止转载。

Python

前几天,大才发了一个自己写的框架,介绍地址在这里, GIT地址在这里

今天在阿里云上试用了一下,在这里做一个简单的说明。

1、配置环境

阿里云的版本是2.7.5,所以用pyenv新安装了一个3.6.4的环境,安装后使用pyenv global 3.6.4即可使用3.6.4的环境,我个人比较喜欢这样,切换自如,互不影响。 如下图: 接下来按照大才的文章,pip install gerapy即可,这一步没有遇到什么问题。有问题的同学可以向大才提issue。

2. 开启服务

首先去阿里云的后台设置安全组 ,我的是这样: 然后到命令窗口对8000和6800端口放行即可。 接着执行

gerapy init cd gerapy gerapy migrate **    # 注意下一步 **  **gerapy runserver  0.0.0.0:8000 【如果你是在本地,执行 gerapy runserver即可,如果你是在阿里云上,你就要改成前面这样来执行】**

现在在浏览器里访问:ip:8000应该就可以看到主界面了 里面的各个的含义见大才的文章。

3.创建项目

在gerapy下的projects里面新建一个scrapy爬虫,在这里我搞的是最简单的:

scrapy startproject gerapy_test cd gerapy_test scrapy genspider baidu www.baidu.com

这样就是一个最简单的爬虫了,修改一个settings.py中的ROBOTSTXT_OBEY=False, 然后修改一个spiders下面的baidu.py, 这里随意,我这里设置的是输出返回的 response.url

4.安装scrapyd

pip install scrapyd

安装好以后,命令行执行

scrapyd

然后浏览器中打开 ip:6800,如果你没有修改配置,应该这里会打不开,clients那里配置的时候,也应该会显示为error,就像这样: 后来找了一下原因发现scrapyd默认打开的也是127.0.0.1 所以这个时候就要改一下配置,具体可以参考这里, 我是这么修改:

vim ~/.scrapyd.conf [scrapyd] bind_address = 0.0.0.0

在刷新一下,就会看到前面error变成了normal

5. 打包,部署,调度

这几步大才的文章里都有详细说明,打包完,部署,在进入clients的调度界面,点击run按钮即可跑爬虫了 可以看到输出的结果了。

6.结语

建议大家可以试着用一下,很方便,我这里只是很简单的使用了一下。

Python

本节我们来尝试使用 TensorFlow 搭建一个双向 LSTM (Bi-LSTM) 深度学习模型来处理序列标注问题,主要目的是学习 Bi-LSTM 的用法。

Bi-LSTM

我们知道 RNN 是可以学习到文本上下文之间的联系的,输入是上文,输出是下文,但这样的结果是模型可以根据上文推出下文,而如果输入下文,想要推出上文就没有那么简单了,为了弥补这个缺陷,我们可以让模型从两个方向来学习,这就构成了双向 RNN。在某些任务中,双向 RNN 的表现比单向 RNN 要好,本文要实现的文本分词就是其中之一。不过本文使用的模型不是简单的双向 RNN,而是 RNN 的变种 — LSTM。 如图所示为 Bi-LSTM 的基本原理,输入层的数据会经过向前和向后两个方向推算,最后输出的隐含状态再进行 concat,再作为下一层的输入,原理其实和 LSTM 是类似的,就是多了双向计算和 concat 过程。

数据处理

本文的训练和测试数据使用的是已经做好序列标注的中文文本数据。序列标注,就是给一个汉语句子作为输入,以“BEMS”组成的序列串作为输出,然后再进行切词,进而得到输入句子的划分。其中,B 代表该字是词语中的起始字,M 代表是词语中的中间字,E 代表是词语中的结束字,S 则代表是单字成词。 这里的原始数据样例如下:

1
/b/e/s/s/b/e/s/s/s/b/m/e

这里一个字对应一个标注,我们首先需要对数据进行预处理,预处理的流程如下:

  • 将句子切分
  • 将句子的的标点符号去掉
  • 将每个字及对应的标注切分
  • 去掉长度为 0 的无效句子

首先我们将句子切分开来并去掉标点符号,代码实现如下:

1
2
3
4
5
6
7
8
# Read origin data
text = open('data/data.txt', encoding='utf-8').read()
# Get split sentences
sentences = re.split('[,。!?、‘’“”]/[bems]', text)
# Filter sentences whose length is 0
sentences = list(filter(lambda x: x.strip(), sentences))
# Strip sentences
sentences = list(map(lambda x: x.strip(), sentences))

这样我们就可以将句子切分开来并做好了清洗,接下来我们还需要把每个句子中的字及标注转为 Numpy 数组,便于下一步制作词表和数据集,代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
import re
# To numpy array
words, labels = [], []
print('Start creating words and labels...')
for sentence in sentences:
groups = re.findall('(.)/(.)', sentence)
arrays = np.asarray(groups)
words.append(arrays[:, 0])
labels.append(arrays[:, 1])
print('Words Length', len(words), 'Labels Length', len(labels))
print('Words Example', words[0])
print('Labels Example', labels[0])

这里我们利用正则 re 库的 findall() 方法将字及标注分开,并分别添加到 words 和 labels 数组中,运行效果如下:

1
2
3
Words Length 321533 Labels Length 321533
Words Example ['人' '们' '常' '说' '生' '活' '是' '一' '部' '教' '科' '书']
Labels Example ['b' 'e' 's' 's' 'b' 'e' 's' 's' 's' 'b' 'm' 'e']

接下来我们有了这些数据就要开始制作词表了,词表制作起来无非就是输入词表和输出词表的不重复的正逆对应,制作词表的目的就是将输入的文字或标注转为 index,同时还能反向根据 index 获取对应的文字或标注,所以我们这里需要制作 word2id、id2word、tag2id、id2tag 四个字典。 为了解决 OOV 问题,我们还需要将无效字符也进行标注,这里我们统一取 0。制作时我们借助于 pandas 库的 Series 进行了去重和转换,另外还限制了每一句的最大长度,这里设置为 32,如果大于32,则截断,否则进行 padding,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from itertools import chain
import pandas as pd
import numpy as np
# Merge all words
all_words = list(chain(*words))
# All words to Series
all_words_sr = pd.Series(all_words)
# Get value count, index changed to set
all_words_counts = all_words_sr.value_counts()
# Get words set
all_words_set = all_words_counts.index
# Get words ids
all_words_ids = range(1, len(all_words_set) + 1)

# Dict to transform
word2id = pd.Series(all_words_ids, index=all_words_set)
id2word = pd.Series(all_words_set, index=all_words_ids)

# Tag set and ids
tags_set = ['x', 's', 'b', 'm', 'e']
tags_ids = range(len(tags_set))

# Dict to transform
tag2id = pd.Series(tags_ids, index=tags_set)
id2tag = pd.Series(tags_set, index=tag2id)

max_length = 32

def x_transform(words):
ids = list(word2id[words])
if len(ids) >= max_length:
ids = ids[:max_length]
ids.extend([0] * (max_length - len(ids)))
return ids

def y_transform(tags):
ids = list(tag2id[tags])
if len(ids) >= max_length:
ids = ids[:max_length]
ids.extend([0] * (max_length - len(ids)))
return ids

print('Starting transform...')
data_x = list(map(lambda x: x_transform(x), words))
data_y = list(map(lambda y: y_transform(y), labels))
data_x = np.asarray(data_x)
data_y = np.asarray(data_y)

这样我们就完成了 word2id、id2word、tag2id、id2tag 四个字典的制作,并制作好了 Numpy 数组类型的 data_x 和 data_y,这里 data_x 和 data_y 单句示例如下:

1
2
Data X Example: [8, 43, 320, 88, 36, 198, 7, 2, 41, 163, 124, 245, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Data Y Example: [2, 4, 1, 1, 2, 4, 1, 1, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

可以看到数据的 x 部分,原始文字和标注结果都转化成了词表中的 index,同时不够 32 个字符就以 0 补全。 接下来我们将其保存成 pickle 文件,以备训练和测试使用:

1
2
3
4
5
6
7
8
9
print('Starting pickle to file...')
with open(join(path, 'data.pkl'), 'wb') as f:
pickle.dump(data_x, f)
pickle.dump(data_y, f)
pickle.dump(word2id, f)
pickle.dump(id2word, f)
pickle.dump(tag2id, f)
pickle.dump(id2tag, f)
print('Pickle finished')

好,现在数据预处理部分就完成了。

构造模型

接下来我们就需要利用 pickle 文件中的数据来构建模型了,首先进行 pickle 文件的读取,然后将数据分为训练集、开发集、测试集,详细流程不再赘述,赋值为如下变量:

1
2
3
4
# Load data
data_x, data_y, word2id, id2word, tag2id, id2tag = load_data()
# Split data
train_x, train_y, dev_x, dev_y, test_x, test_y = get_data(data_x, data_y)

接下来我们使用 TensorFlow 自带的 Dataset 数据结构构造输入输出,利用 Dataset 我们可以构造一个 iterator 迭代器,每调用一次 get_next() 方法,我们就可以得到一个 batch,这里 Dataset 的初始化我们使用 from_tensor_slices() 方法,然后调用其 batch() 方法来初始化每个数据集的 batch_size,接着初始化同一个 iterator,并绑定到三个数据集上声明为三个 initializer,这样每调用 initializer,就会将 iterator 切换到对应的数据集上,代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Train and dev dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
train_dataset = train_dataset.batch(FLAGS.train_batch_size)

dev_dataset = tf.data.Dataset.from_tensor_slices((dev_x, dev_y))
dev_dataset = dev_dataset.batch(FLAGS.dev_batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.batch(FLAGS.test_batch_size)

# A reinitializable iterator
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

train_initializer = iterator.make_initializer(train_dataset)
dev_initializer = iterator.make_initializer(dev_dataset)
test_initializer = iterator.make_initializer(test_dataset)

有了 Dataset 的 iterator,我们只需要调用一次 get_next() 方法即可得到 x 和 y_label 了,就不需要使用 placeholder 来声明了,代码如下:

1
2
3
# Input Layer
with tf.variable_scope('inputs'):
x, y_label = iterator.get_next()

接下来我们需要实现 embedding 层,调用 TensorFlow 的 embedding_lookup 即可实现,这里没有使用 Pre Train 的 embedding,代码实现如下:

1
2
3
4
# Embedding Layer
with tf.variable_scope('embedding'):
embedding = tf.Variable(tf.random_normal([vocab_size, FLAGS.embedding_size]), dtype=tf.float32)
inputs = tf.nn.embedding_lookup(embedding, x)

接下来我们就需要实现双向 LSTM 了,这里我们要构造一个 2 层的 Bi-LSTM 网络,实现的时候我们首先需要声明 LSTM Cell 的列表,然后调用 stack_bidirectional_rnn() 方法即可:

1
2
3
4
cell_fw = [lstm_cell(FLAGS.num_units, keep_prob) for _ in range(FLAGS.num_layer)]
cell_bw = [lstm_cell(FLAGS.num_units, keep_prob) for _ in range(FLAGS.num_layer)]
inputs = tf.unstack(inputs, FLAGS.time_step, axis=1)
output, _, _ = tf.contrib.rnn.stack_bidirectional_rnn(cell_fw, cell_bw, inputs=inputs, dtype=tf.float32)

这个方法内部是首先对每一层的 LSTM 进行正反向计算,然后对输出隐层进行 concat,然后输入下一层再进行计算,这里值得注意的地方是,我们不能把 LSTM Cell 提前组合成 MultiRNNCell 再调用 bidirectional_dynamic_rnn() 进行计算,这样相当于只有最后一层才进行 concat,是错误的。 现在我们得到的 output 就是 Bi-LSTM 的最后输出结果了。 接下来我们需要对输出结果进行一下 stack() 操作转化为一个 Tensor,然后将其 reshape() 一下,转化为 [-1, num_units * 2] 的 shape:

1
2
output = tf.stack(output, axis=1)
output = tf.reshape(output, [-1, FLAGS.num_units * 2])

这样我们再经过一层全连接网络将维度进行转换:

1
2
3
4
5
6
7
# Output Layer
with tf.variable_scope('outputs'):
w = weight([FLAGS.num_units * 2, FLAGS.category_num])
b = bias([FLAGS.category_num])
y = tf.matmul(output, w) + b
y_predict = tf.cast(tf.argmax(y, axis=1), tf.int32)
print('Output Y', y_predict)

这样得到的最后的 y_predict 即为预测结果,shape 为 [batch_size],即每一句都得到了一个最可能的结果标注。 接下来我们需要计算一下准确率和 Loss,准确率其实就是比较 y_predict 和 y_label 的相似度,Loss 即为二者交叉熵:

1
2
3
4
5
6
7
8
9
# Reshape y_label
y_label_reshape = tf.cast(tf.reshape(y_label, [-1]), tf.int32)
# Prediction
correct_prediction = tf.equal(y_predict, y_label_reshape)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Loss
cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_label_reshape, logits=tf.cast(y, tf.float32)))
# Train
train = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(cross_entropy, global_step=global_step)

这里计算交叉熵使用的是 sparse_softmax_cross_entropy_with_logits() 方法,Optimizer 使用的是 Adam。 最后指定训练过程和测试过程即可,训练过程如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for epoch in range(FLAGS.epoch_num):
tf.train.global_step(sess, global_step_tensor=global_step)
# Train
sess.run(train_initializer)
for step in range(int(train_steps)):
smrs, loss, acc, gstep, _ = sess.run([summaries, cross_entropy, accuracy, global_step, train], feed_dict={keep_prob: FLAGS.keep_prob})
# Print log
if step % FLAGS.steps_per_print == 0:
print('Global Step', gstep, 'Step', step, 'Train Loss', loss, 'Accuracy', acc)

if epoch % FLAGS.epochs_per_dev == 0:
# Dev
sess.run(dev_initializer)
for step in range(int(dev_steps)):
if step % FLAGS.steps_per_print == 0:
print('Dev Accuracy', sess.run(accuracy, feed_dict={keep_prob: 1}), 'Step', step)

这里训练时首先调用了 train_initializer,将 iterator 指向训练数据,这样每调用一次 get_next(),x 和 y_label 就会被赋值为训练数据的一个 batch,接下来打印输出了 Loss,Accuracy 等内容。另外对于开发集来说,每次进行验证的时候也需要重新调用 dev_initializer,这样 iterator 会再次指向开发集,这样每调用一次 get_next(),x 和 y_label 就会被赋值为开发集的一个 batch,然后进行验证。 对于测试来说,我们可以计算其准确率,然后将测试的结果输出出来,代码实现如下:

1
2
3
4
5
6
7
8
9
10
sess.run(test_initializer)
for step in range(int(test_steps)):
x_results, y_predict_results, acc = sess.run([x, y_predict, accuracy], feed_dict={keep_prob: 1})
print('Test step', step, 'Accuracy', acc)
y_predict_results = np.reshape(y_predict_results, x_results.shape)
for i in range(len(x_results)):
x_result, y_predict_result = list(filter(lambda x: x, x_results[i])), list(
filter(lambda x: x, y_predict_results[i]))
x_text, y_predict_text = ''.join(id2word[x_result].values), ''.join(id2tag[y_predict_result].values)
print(x_text, y_predict_text)

这里打印输出了当前测试的准确率,然后得到了测试结果,然后再结合词表将测试的真正结果打印出来即可。

运行结果

在训练过程中,我们需要构建模型图,然后调用训练部分的代码进行训练,输出结果类似如下:

1
2
3
4
5
6
7
8
9
Global Step 0 Step 0 Train Loss 1.67181 Accuracy 0.1475
Global Step 100 Step 100 Train Loss 0.210423 Accuracy 0.928125
Global Step 200 Step 200 Train Loss 0.208561 Accuracy 0.920625
Global Step 300 Step 300 Train Loss 0.185281 Accuracy 0.939375
Global Step 400 Step 400 Train Loss 0.186069 Accuracy 0.938125
Global Step 500 Step 500 Train Loss 0.165667 Accuracy 0.94375
Global Step 600 Step 600 Train Loss 0.201692 Accuracy 0.9275
Global Step 700 Step 700 Train Loss 0.13299 Accuracy 0.954375
...

随着训练的进行,准确率可以达到 96% 左右。 在测试阶段,输出了当前模型的准确率及真实测试输出结果,输出结果类似如下:

1
2
3
4
Test step 0 Accuracy 0.946125
据新华社北京7月9日电连日来 sbmebebmmesbes
董新辉为自己此生不能侍奉母亲而难过 bmesbebebebmmesbe
...

可见测试准确率在 95% 左右,对于测试数据,此处还输出了每句话的序列标注结果,如第一行结果中,“据”字对应的标注就是 s,代表单字成词,“新”字对应的标注是 b,代表词的起始,“华”字对应标注是 m,代表词的中间,“社”字对应的标注是 e,代表结束,这样 “据”、“新华社” 就可以被分成两个词了,可见还是有一定效果的。

结语

本节通过搭建一个 Bi-LSTM 网络实现了序列标注,并可实现分词,准确率可达到 95% 左右,但是最主要的还是学习 Bi-LSTM 的用法,本实例代码较多,部分代码已经省略,完整代码见:https://github.com/AIDeepLearning/BiLSTMWordBreaker

参考来源

  • TensorFlow入门 双端 LSTM 实现序列标注
  • 基于双向LSTM的seq2seq字标注
  • TensorFlow全新的数据读取方式:Dataset API入门教程
  • TensorFlow Importing Data

Python

Ansible简介

Ansible是由Python开发的一个运维工具,因为工作需要接触到Ansible,经常会集成一些东西到Ansible,所以对Ansible的了解越来越多。 那Ansible到底是什么呢?在我的理解中,原来需要登录到服务器上,然后执行一堆命令才能完成一些操作。而Ansible就是来代替我们去执行那些命令。并且可以通过Ansible控制多台机器,在机器上进行任务的编排和执行,在Ansible中称为playbook。 那Ansible是如何做到的呢?简单点说,就是Ansible将我们要执行的命令生成一个脚本,然后通过sftp将脚本上传到要执行命令的服务器上,然后在通过ssh协议,执行这个脚本并将执行结果返回。 那Ansible具体是怎么做到的呢?下面从模块和插件来看一下Ansible是如何完成一个模块的执行 PS:下面的分析都是在对Ansible有一些具体使用经验之后,通过阅读源代码进一步得出的执行结论,所以希望在看本文时,是建立在对Ansible有一定了解的基础上,最起码对于Ansible的一些概念有了解,例如inventory,module,playbooks等

Ansible模块

模块是Ansible执行的最小单位,可以是由Python编写,也可以是Shell编写,也可以是由其他语言编写。模块中定义了具体的操作步骤以及实际使用过程中所需要的参数 执行的脚本就是根据模块生成一个可执行的脚本。 那Ansible是怎么样将这个脚本上传到服务器上,然后执行获取结果的呢?

Ansible插件

connection插件

连接插件,根据指定的ssh参数连接指定的服务器,并切提供实际执行命令的接口

shell插件

命令插件,根据sh类型,来生成用于connection时要执行的命令

strategy插件

执行策略插件,默认情况下是线性插件,就是一个任务接着一个任务的向下执行,此插件将任务丢到执行器去执行。

action插件

动作插件,实质就是任务模块的所有动作,如果ansible的模块没有特别编写的action插件,默认情况下是normal或者async(这两个根据模块是否async来选择),normal和async中定义的就是模块的执行步骤。例如,本地创建临时文件,上传临时文件,执行脚本,删除脚本等等,如果想在所有的模块中增加一些特殊步骤,可以通过增加action插件的方式来扩展。

Ansible执行模块流程

  1. ansible命令实质是通过ansible/cli/adhoc.py来运行,同时会收集参数信息
    1. 设置Play信息,然后通过TaskQueueManager进行run,
    2. TaskQueueManager需要Inventory(节点仓库),variable_manager(收集变量),options(命令行中指定的参数),stdout_callback(回调函数)
  2. 在task_queue_manager.py中找到run中
    1. 初始化时会设置队列
    2. 会根据options,,variable_manager,passwords等信息设置成一个PlayContext信息(playbooks/playcontext.py)
    3. 设置插件(plugins)信息callback_loader(回调), strategy_loader(执行策略), module_loader(任务模块)
    4. 通过strategy_loader(strategy插件)的run(默认的strategy类型是linear,线性执行),去按照顺序执行所有的任务(执行一个模块,可能会执行多个任务)
    5. 在strategy_loader插件run之后,会判断action类型。如果是meta类型的话会单独执行(不是具体的ansible模块时),而其他模块时,会加载到队列_queue_task
    6. 在队列中会调用WorkerProcess去处理,在workerproces实际的run之后,会使用TaskExecutor进行执行
    7. 在TaskExecutor中会设置connection插件,并且根据task的类型(模块。或是include等)获取action插件,就是对应的模块,如果模块有自定义的执行,则会执行自定义的action,如果没有的会使用normal或者async,这个是根据是否是任务的async属性来决定
    8. 在Action插件中定义着执行的顺序,及具体操作,例如生成临时目录,生成临时脚本,所以要在统一的模式下,集成一些额外的处理时,可以重写Action的方法
    9. 通过Connection插件来执行Action的各个操作步骤

扩展Ansible实例

执行节点Python环境扩展

实际需求中,我们扩展的一些Ansible模块需要使用三方库,但每个节点中安装这些库有些不易于管理。ansible执行模块的实质就是在节点的python环境下执行生成的脚本,所以我们采取的方案是,指定节点上的Python环境,将局域网内一个python环境作为nfs共享。通过扩展Action插件,增加节点上挂载nfs,待执行结束后再将节点上的nfs卸载。具体实施步骤如下: 扩展代码:

重写ActionBase的execute_module方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# execute_module

from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

import json
import pipes

from ansible.compat.six import text_type, iteritems

from ansible import constants as C
from ansible.errors import AnsibleError
from ansible.release import __version__

try:
from __main__ import display
except ImportError:
from ansible.utils.display import Display
display = Display()


class MagicStackBase(object):

def _mount_nfs(self, ansible_nfs_src, ansible_nfs_dest):
cmd = ['mount',ansible_nfs_src, ansible_nfs_dest]
cmd = [pipes.quote(c) for c in cmd]
cmd = ' '.join(cmd)
result = self._low_level_execute_command(cmd=cmd, sudoable=True)
return result

def _umount_nfs(self, ansible_nfs_dest):
cmd = ['umount', ansible_nfs_dest]
cmd = [pipes.quote(c) for c in cmd]
cmd = ' '.join(cmd)
result = self._low_level_execute_command(cmd=cmd, sudoable=True)
return result

def _execute_module(self, module_name=None, module_args=None, tmp=None, task_vars=None, persist_files=False, delete_remote_tmp=True):
'''
Transfer and run a module along with its arguments.
'''

# display.v(task_vars)

if task_vars is None:
task_vars = dict()

# if a module name was not specified for this execution, use
# the action from the task
if module_name is None:
module_name = self._task.action
if module_args is None:
module_args = self._task.args

# set check mode in the module arguments, if required
if self._play_context.check_mode:
if not self._supports_check_mode:
raise AnsibleError("check mode is not supported for this operation")
module_args['_ansible_check_mode'] = True
else:
module_args['_ansible_check_mode'] = False

# Get the connection user for permission checks
remote_user = task_vars.get('ansible_ssh_user') or self._play_context.remote_user

# set no log in the module arguments, if required
module_args['_ansible_no_log'] = self._play_context.no_log or C.DEFAULT_NO_TARGET_SYSLOG

# set debug in the module arguments, if required
module_args['_ansible_debug'] = C.DEFAULT_DEBUG

# let module know we are in diff mode
module_args['_ansible_diff'] = self._play_context.diff

# let module know our verbosity
module_args['_ansible_verbosity'] = display.verbosity

# give the module information about the ansible version
module_args['_ansible_version'] = __version__

# set the syslog facility to be used in the module
module_args['_ansible_syslog_facility'] = task_vars.get('ansible_syslog_facility', C.DEFAULT_SYSLOG_FACILITY)

# let module know about filesystems that selinux treats specially
module_args['_ansible_selinux_special_fs'] = C.DEFAULT_SELINUX_SPECIAL_FS

(module_style, shebang, module_data) = self._configure_module(module_name=module_name, module_args=module_args, task_vars=task_vars)
if not shebang:
raise AnsibleError("module (%s) is missing interpreter line" % module_name)

# get nfs info for mount python packages
ansible_nfs_src = task_vars.get("ansible_nfs_src", None)
ansible_nfs_dest = task_vars.get("ansible_nfs_dest", None)

# a remote tmp path may be necessary and not already created
remote_module_path = None
args_file_path = None
if not tmp and self._late_needs_tmp_path(tmp, module_style):
tmp = self._make_tmp_path(remote_user)

if tmp:
remote_module_filename = self._connection._shell.get_remote_filename(module_name)
remote_module_path = self._connection._shell.join_path(tmp, remote_module_filename)
if module_style in ['old', 'non_native_want_json']:
# we'll also need a temp file to hold our module arguments
args_file_path = self._connection._shell.join_path(tmp, 'args')

if remote_module_path or module_style != 'new':
display.debug("transferring module to remote")
self._transfer_data(remote_module_path, module_data)
if module_style == 'old':
# we need to dump the module args to a k=v string in a file on
# the remote system, which can be read and parsed by the module
args_data = ""
for k,v in iteritems(module_args):
args_data += '%s=%s ' % (k, pipes.quote(text_type(v)))
self._transfer_data(args_file_path, args_data)
elif module_style == 'non_native_want_json':
self._transfer_data(args_file_path, json.dumps(module_args))
display.debug("done transferring module to remote")

environment_string = self._compute_environment_string()

remote_files = None

if args_file_path:
remote_files = tmp, remote_module_path, args_file_path
elif remote_module_path:
remote_files = tmp, remote_module_path

# Fix permissions of the tmp path and tmp files. This should be
# called after all files have been transferred.
if remote_files:
self._fixup_perms2(remote_files, remote_user)


# mount nfs
if ansible_nfs_src and ansible_nfs_dest:
result = self._mount_nfs(ansible_nfs_src, ansible_nfs_dest)
if result['rc'] != 0:
raise AnsibleError("mount nfs failed!!! {0}".format(result['stderr']))

cmd = ""
in_data = None

if self._connection.has_pipelining and self._play_context.pipelining and not C.DEFAULT_KEEP_REMOTE_FILES and module_style == 'new':
in_data = module_data
else:
if remote_module_path:
cmd = remote_module_path

rm_tmp = None
if tmp and "tmp" in tmp and not C.DEFAULT_KEEP_REMOTE_FILES and not persist_files and delete_remote_tmp:
if not self._play_context.become or self._play_context.become_user == 'root':
# not sudoing or sudoing to root, so can cleanup files in the same step
rm_tmp = tmp

cmd = self._connection._shell.build_module_command(environment_string, shebang, cmd, arg_path=args_file_path, rm_tmp=rm_tmp)
cmd = cmd.strip()
sudoable = True
if module_name == "accelerate":
# always run the accelerate module as the user
# specified in the play, not the sudo_user
sudoable = False


res = self._low_level_execute_command(cmd, sudoable=sudoable, in_data=in_data)

# umount nfs
if ansible_nfs_src and ansible_nfs_dest:
result = self._umount_nfs(ansible_nfs_dest)
if result['rc'] != 0:
raise AnsibleError("umount nfs failed!!! {0}".format(result['stderr']))

if tmp and "tmp" in tmp and not C.DEFAULT_KEEP_REMOTE_FILES and not persist_files and delete_remote_tmp:
if self._play_context.become and self._play_context.become_user != 'root':
# not sudoing to root, so maybe can't delete files as that other user
# have to clean up temp files as original user in a second step
tmp_rm_cmd = self._connection._shell.remove(tmp, recurse=True)
tmp_rm_res = self._low_level_execute_command(tmp_rm_cmd, sudoable=False)
tmp_rm_data = self._parse_returned_data(tmp_rm_res)
if tmp_rm_data.get('rc', 0) != 0:
display.warning('Error deleting remote temporary files (rc: {0}, stderr: {1})'.format(tmp_rm_res.get('rc'), tmp_rm_res.get('stderr', 'No error string available.')))

# parse the main result
data = self._parse_returned_data(res)

# pre-split stdout into lines, if stdout is in the data and there
# isn't already a stdout_lines value there
if 'stdout' in data and 'stdout_lines' not in data:
data['stdout_lines'] = data.get('stdout', u'').splitlines()

display.debug("done with _execute_module (%s, %s)" % (module_name, module_args))
return data

集成到normal.py和async.py中,记住要将这两个插件在ansible.cfg中进行配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

from ansible.plugins.action import ActionBase
from ansible.utils.vars import merge_hash

from common.ansible_plugins import MagicStackBase


class ActionModule(MagicStackBase, ActionBase):

def run(self, tmp=None, task_vars=None):
if task_vars is None:
task_vars = dict()

results = super(ActionModule, self).run(tmp, task_vars)
# remove as modules might hide due to nolog
del results['invocation']['module_args']
results = merge_hash(results, self._execute_module(tmp=tmp, task_vars=task_vars))
# Remove special fields from the result, which can only be set
# internally by the executor engine. We do this only here in
# the 'normal' action, as other action plugins may set this.
#
# We don't want modules to determine that running the module fires
# notify handlers. That's for the playbook to decide.
for field in ('_ansible_notify',):
if field in results:
results.pop(field)

return results
  • 配置ansible.cfg,将扩展的插件指定为ansible需要的action插件
  • 重写插件方法,重点是execute_module
  • 执行命令中需要指定Python环境,将需要的参数添加进去nfs挂载和卸载的参数
1
ansible 51 -m mysql_db -a "state=dump name=all target=/tmp/test.sql" -i hosts -u root -v -e "ansible_nfs_src=172.16.30.170:/web/proxy_env/lib64/python2.7/site-packages ansible_nfs_dest=/root/.pyenv/versions/2.7.10/lib/python2.7/site-packages ansible_python_interpreter=/root/.pyenv/versions/2.7.10/bin/python"

Python

背景

用 Python 做过爬虫的小伙伴可能接触过 Scrapy,GitHub:https://github.com/scrapy/scrapy。Scrapy 的确是一个非常强大的爬虫框架,爬取效率高,扩展性好,基本上是使用 Python 开发爬虫的必备利器。如果使用 Scrapy 做爬虫,那么在爬取时,我们当然完全可以使用自己的主机来完成爬取,但当爬取量非常大的时候,我们肯定不能在自己的机器上来运行爬虫了,一个好的方法就是将 Scrapy 部署到远程服务器上来执行。 所以,这时候就出现了另一个库 Scrapyd,GitHub:https://github.com/scrapy/scrapyd,有了它我们只需要在远程服务器上安装一个 Scrapyd,启动这个服务,就可以将我们写的 Scrapy 项目部署到远程主机上了,Scrapyd 还提供了各种操作 API,可以自由地控制 Scrapy 项目的运行,API 文档:http://scrapyd.readthedocs.io/en/stable/api.html,例如我们将 Scrapyd 安装在 IP 为 88.88.88.88 的服务器上,然后将 Scrapy 项目部署上去,这时候我们通过请求 API 就可以来控制 Scrapy 项目的运行了,命令如下:

1
curl http://88.88.88.88:6800/schedule.json -d project=myproject -d spider=somespider

这样就相当于启动了 myproject 项目的 somespider 爬虫,而不用我们再用命令行方式去启动爬虫,同时 Scrapyd 还提供了查看爬虫状态、取消爬虫任务、添加爬虫版本、删除爬虫版本等等的一系列 API,所以说,有了 Scrapyd,我们可以通过 API 来控制爬虫的运行,摆脱了命令行的依赖。 另外爬虫部署还是个麻烦事,因为我们需要将爬虫代码上传到远程服务器上,这个过程涉及到打包和上传两个过程,在 Scrapyd 中其实提供了这个部署的 API,叫做 addversion,但是它接受的内容是 egg 包文件,所以说要用这个接口,我们必须要把我们的 Scrapy 项目打包成 egg 文件,然后再利用文件上传的方式请求这个 addversion 接口才可以完成上传,这个过程又比较繁琐了,所以又出现了一个工具叫做 Scrapyd-Client,GitHub:https://github.com/scrapy/scrapyd-client,利用它的 scrapyd-deploy 命令我们便可以完成打包和上传的两个功能,可谓是又方便了一步。 这样我们就已经解决了部署的问题,回过头来,如果我们要想实时查看服务器上 Scrapy 的运行状态,那该怎么办呢?像刚才说的,当然是请求 Scrapyd 的 API 了,如果我们想用 Python 程序来控制一下呢?我们还要用 requests 库一次次地请求这些 API ?这就太麻烦了吧,所以为了解决这个需求,Scrapyd-API 又出现了,GitHub:https://github.com/djm/python-scrapyd-api,有了它我们可以只用简单的 Python 代码就可以实现 Scrapy 项目的监控和运行:

1
2
3
from scrapyd_api import ScrapydAPI
scrapyd = ScrapydAPI('http://88.888.88.88:6800')
scrapyd.list_jobs('project_name')

这样它的返回结果就是各个 Scrapy 项目的运行情况。 例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
{
'pending': [
],
'running': [
{
'id': u'14a65...b27ce',
'spider': u'spider_name',
'start_time': u'2018-01-17 22:45:31.975358'
},
],
'finished': [
{
'id': '34c23...b21ba',
'spider': 'spider_name',
'start_time': '2018-01-11 22:45:31.975358',
'end_time': '2018-01-17 14:01:18.209680'
}
]
}

这样我们就可以看到 Scrapy 爬虫的运行状态了。 所以,有了它们,我们可以完成的是:

  • 通过 Scrapyd 完成 Scrapy 项目的部署
  • 通过 Scrapyd 提供的 API 来控制 Scrapy 项目的启动及状态监控
  • 通过 Scrapyd-Client 来简化 Scrapy 项目的部署
  • 通过 Scrapyd-API 来通过 Python 控制 Scrapy 项目

是不是方便多了? 可是?真的达到最方便了吗?肯定没有!如果这一切的一切,从 Scrapy 的部署、启动到监控、日志查看,我们只需要鼠标键盘点几下就可以完成,那岂不是美滋滋?更或者说,连 Scrapy 代码都可以帮你自动生成,那岂不是爽爆了? 有需求就有动力,没错,Gerapy 就是为此而生的,GitHub:https://github.com/Gerapy/Gerapy。 本节我们就来简单了解一下 Gerapy 分布式爬虫管理框架的使用方法。

安装

Gerapy 是一款分布式爬虫管理框架,支持 Python 3,基于 Scrapy、Scrapyd、Scrapyd-Client、Scrapy-Redis、Scrapyd-API、Scrapy-Splash、Jinjia2、Django、Vue.js 开发,Gerapy 可以帮助我们:

  • 更方便地控制爬虫运行
  • 更直观地查看爬虫状态
  • 更实时地查看爬取结果
  • 更简单地实现项目部署
  • 更统一地实现主机管理
  • 更轻松地编写爬虫代码

安装非常简单,只需要运行 pip3 命令即可:

1
$ pip3 install gerapy

安装完成之后我们就可以使用 gerapy 命令了,输入 gerapy 便可以获取它的基本使用方法:

1
2
3
4
5
6
7
$ gerapy
Usage:
gerapy init [--folder=<folder>]
gerapy migrate
gerapy createsuperuser
gerapy runserver [<host:port>]
gerapy makemigrations

如果出现上述结果,就证明 Gerapy 安装成功了。

初始化

接下来我们来开始使用 Gerapy,首先利用如下命令进行一下初始化,在任意路径下均可执行如下命令:

1
$ gerapy init

执行完毕之后,本地便会生成一个名字为 gerapy 的文件夹,接着进入该文件夹,可以看到有一个 projects 文件夹,我们后面会用到。 紧接着执行数据库初始化命令:

1
2
cd gerapy
gerapy migrate

这样它就会在 gerapy 目录下生成一个 SQLite 数据库,同时建立数据库表。 接着我们只需要再运行命令启动服务就好了:

1
gerapy runserver

这样我们就可以看到 Gerapy 已经在 8000 端口上运行了。 全部的操作流程截图如下: 接下来我们在浏览器中打开 http://localhost:8000/,就可以看到 Gerapy 的主界面了: 这里显示了主机、项目的状态,当然由于我们没有添加主机,所以所有的数目都是 0。 如果我们可以正常访问这个页面,那就证明 Gerapy 初始化都成功了。

主机管理

接下来我们可以点击左侧 Clients 选项卡,即主机管理页面,添加我们的 Scrapyd 远程服务,点击右上角的创建按钮即可添加我们需要管理的 Scrapyd 服务: 需要添加 IP、端口,以及名称,点击创建即可完成添加,点击返回即可看到当前添加的 Scrapyd 服务列表,样例如下所示: 这样我们可以在状态一栏看到各个 Scrapyd 服务是否可用,同时可以一目了然当前所有 Scrapyd 服务列表,另外我们还可以自由地进行编辑和删除。

项目管理

Gerapy 的核心功能当然是项目管理,在这里我们可以自由地配置、编辑、部署我们的 Scrapy 项目,点击左侧的 Projects ,即项目管理选项,我们可以看到如下空白的页面: 假设现在我们有一个 Scrapy 项目,如果我们想要进行管理和部署,还记得初始化过程中提到的 projects 文件夹吗?这时我们只需要将项目拖动到刚才 gerapy 运行目录的 projects 文件夹下,例如我这里写好了一个 Scrapy 项目,名字叫做 zhihusite,这时把它拖动到 projects 文件夹下: 这时刷新页面,我们便可以看到 Gerapy 检测到了这个项目,同时它是不可配置、没有打包的: 这时我们可以点击部署按钮进行打包和部署,在右下角我们可以输入打包时的描述信息,类似于 Git 的 commit 信息,然后点击打包按钮,即可发现 Gerapy 会提示打包成功,同时在左侧显示打包的结果和打包名称: 打包成功之后,我们便可以进行部署了,我们可以选择需要部署的主机,点击后方的部署按钮进行部署,同时也可以批量选择主机进行部署,示例如下: 可以发现此方法相比 Scrapyd-Client 的命令行式部署,简直不能方便更多。

监控任务

部署完毕之后就可以回到主机管理页面进行任务调度了,任选一台主机,点击调度按钮即可进入任务管理页面,此页面可以查看当前 Scrapyd 服务的所有项目、所有爬虫及运行状态: 我们可以通过点击新任务、停止等按钮来实现任务的启动和停止等操作,同时也可以通过展开任务条目查看日志详情: 另外我们还可以随时点击停止按钮来取消 Scrapy 任务的运行。 这样我们就可以在此页面方便地管理每个 Scrapyd 服务上的 每个 Scrapy 项目的运行了。

项目编辑

同时 Gerapy 还支持项目编辑功能,有了它我们不再需要 IDE 即可完成项目的编写,我们点击项目的编辑按钮即可进入到编辑页面,如图所示: 这样即使 Gerapy 部署在远程的服务器上,我们不方便用 IDE 打开,也不喜欢用 Vim 等编辑软件,我们可以借助于本功能方便地完成代码的编写。

代码生成

上述的项目主要针对的是我们已经写好的 Scrapy 项目,我们可以借助于 Gerapy 方便地完成编辑、部署、控制、监测等功能,而且这些项目的一些逻辑、配置都是已经写死在代码里面的,如果要修改的话,需要直接修改代码,即这些项目都是不可配置的。 在 Scrapy 中,其实提供了一个可配置化的爬虫 CrawlSpider,它可以利用一些规则来完成爬取规则和解析规则的配置,这样可配置化程度就非常高,这样我们只需要维护爬取规则、提取逻辑就可以了。如果要新增一个爬虫,我们只需要写好对应的规则即可,这类爬虫就叫做可配置化爬虫。 Gerapy 可以做到:我们写好爬虫规则,它帮我们自动生成 Scrapy 项目代码。 我们可以点击项目页面的右上角的创建按钮,增加一个可配置化爬虫,接着我们便可以在此处添加提取实体、爬取规则、抽取规则了,例如这里的解析器,我们可以配置解析成为哪个实体,每个字段使用怎样的解析方式,如 XPath 或 CSS 解析器、直接获取属性、直接添加值等多重方式,另外还可以指定处理器进行数据清洗,或直接指定正则表达式进行解析等等,通过这些流程我们可以做到任何字段的解析。 再比如爬取规则,我们可以指定从哪个链接开始爬取,允许爬取的域名是什么,该链接提取哪些跟进的链接,用什么解析方法来处理等等配置。通过这些配置,我们可以完成爬取规则的设置。 最后点击生成按钮即可完成代码的生成。 生成的代码示例结果如图所示,可见其结构和 Scrapy 代码是完全一致的。 生成代码之后,我们只需要像上述流程一样,把项目进行部署、启动就好了,不需要我们写任何一行代码,即可完成爬虫的编写、部署、控制、监测。

结语

以上便是 Gerapy 分布式爬虫管理框架的基本用法,如需了解更多,可以访问其 GitHub:https://github.com/Gerapy/Gerapy。 如果觉得此框架有不足的地方,欢迎提 Issue,也欢迎发 Pull Request 来贡献代码,如果觉得 Gerapy 有所帮助,还望赐予一个 Star!非常感谢!

Python

本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间。

初始化

首先我们可以先初始化一些变量,如学习率、节点单元数、RNN 层数等:

1
2
3
4
5
6
7
8
9
10
11
learning_rate = 1e-3
num_units = 256
num_layer = 3
input_size = 28
time_step = 28
total_steps = 2000
category_num = 10
steps_per_validate = 100
steps_per_test = 500
batch_size = tf.placeholder(tf.int32, [])
keep_prob = tf.placeholder(tf.float32, [])

然后还需要声明一下 MNIST 数据生成器:

1
2
3
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

接下来常规声明一下输入的数据,输入数据用 x 表示,标注数据用 y_label 表示:

1
2
x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])

这里输入的 x 维度是 [None, 784],代表 batch_size 不确定,输入维度 784,y_label 同理。 接下来我们需要对输入的 x 进行 reshape 操作,因为我们需要将一张图分为多个 time_step 来输入,这样才能构建一个 RNN 序列,所以这里直接将 time_step 设成 28,这样一来 input_size 就变为了 28,batch_size 不变,所以reshape 的结果是一个三维的矩阵:

1
x_shape = tf.reshape(x, [-1, time_step, input_size])

RNN 层

接下来我们需要构建一个 RNN 模型了,这里我们使用的 RNN Cell 是 LSTMCell,而且要搭建一个三层的 RNN,所以这里还需要用到 MultiRNNCell,它的输入参数是 LSTMCell 的列表。 所以我们可以先声明一个方法用于创建 LSTMCell,方法如下:

1
2
3
def cell(num_units):
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=num_units)
return DropoutWrapper(cell, output_keep_prob=keep_prob)

这里还加入了 Dropout,来减少训练过程中的过拟合。 接下来我们再利用它来构建多层的 RNN:

1
cells = tf.nn.rnn_cell.MultiRNNCell([cell(num_units) for _ in range(num_layer)])

注意这里使用了 for 循环,每循环一次新生成一个 LSTMCell,而不是直接使用乘法来扩展列表,因为这样会导致 LSTMCell 是同一个对象,导致构建完 MultiRNNCell 之后出现维度不匹配的问题。 接下来我们需要声明一个初始状态:

1
h0 = cells.zero_state(batch_size, dtype=tf.float32)

然后接下来调用 dynamic_rnn() 方法即可完成模型的构建了:

1
output, hs = tf.nn.dynamic_rnn(cells, inputs=x_shape, initial_state=h0)

这里 inputs 的输入就是 x 做了 reshape 之后的结果,初始状态通过 initial_state 传入,其返回结果有两个,一个 output 是所有 time_step 的输出结果,赋值为 output,它是三维的,第一维长度等于 batch_size,第二维长度等于 time_step,第三维长度等于 num_units。另一个 hs 是隐含状态,是元组形式,长度即 RNN 的层数 3,每一个元素都包含了 c 和 h,即 LSTM 的两个隐含状态。 这样的话 output 的最终结果可以取最后一个 time_step 的结果,所以可以使用:

1
output = output[:, -1, :]

或者直接取隐藏状态最后一层的 h 也是相同的:

1
h = hs[-1].h

在此模型中,二者是等价的。但注意如果用于文本处理,可能由于文本长度不一,而 padding,导致二者不同。

输出层

接下来我们再做一次线性变换和 Softmax 输出结果即可:

1
2
3
4
5
6
# Output Layer
w = tf.Variable(tf.truncated_normal([num_units, category_num], stddev=0.1), dtype=tf.float32)
b = tf.Variable(tf.constant(0.1, shape=[category_num]), dtype=tf.float32)
y = tf.matmul(output, w) + b
# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)

这里的 Loss 直接调用了 softmax_cross_entropy_with_logits 先计算了 Softmax,然后计算了交叉熵。

训练和评估

最后再定义训练和评估的流程即可,在训练过程中每隔一定的 step 就输出 Train Accuracy 和 Test Accuracy:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Train
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)

# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

# Train
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps + 1):
batch_x, batch_y = mnist.train.next_batch(100)
sess.run(train, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5, batch_size: batch_x.shape[0]})
# Train Accuracy
if step % steps_per_validate == 0:
print('Train', step, sess.run(accuracy, feed_dict={x: batch_x, y_label: batch_y, keep_prob: 0.5,
batch_size: batch_x.shape[0]}))
# Test Accuracy
if step % steps_per_test == 0:
test_x, test_y = mnist.test.images, mnist.test.labels
print('Test', step,
sess.run(accuracy, feed_dict={x: test_x, y_label: test_y, keep_prob: 1, batch_size: test_x.shape[0]}))

运行

直接运行之后,只训练了几轮就可以达到 98% 的准确率:

1
2
3
4
5
6
7
8
9
10
11
Train 0 0.27
Test 0 0.2223
Train 100 0.87
Train 200 0.91
Train 300 0.94
Train 400 0.94
Train 500 0.99
Test 500 0.9595
Train 600 0.95
Train 700 0.97
Train 800 0.98

可以看出来 LSTM 在做 MNIST 字符分类的任务上还是比较有效的。

本节代码

本节代码地址为:https://github.com/AIDeepLearning/LSTMClassification

Python

本文介绍下 RNN 及几种变种的结构和对应的 TensorFlow 源码实现,另外通过简单的实例来实现 TensorFlow RNN 相关类的调用。

RNN

RNN,循环神经网络,Recurrent Neural Networks。人们思考问题往往不是从零开始的,比如阅读时我们对每个词的理解都会依赖于前面看到的一些信息,而不是把前面看的内容全部抛弃再去理解某处的信息。应用到深度学习上面,如果我们想要学习去理解一些依赖上文的信息,RNN 便可以做到,它有一个循环的操作,可以使其可以保留之前学习到的内容。 RNN 的结构如下: 在上图网络结构中,对于矩形块 A 的那部分,通过输入xt(t时刻的特征向量),它会输出一个结果ht(t时刻的状态或者输出)。网络中的循环结构使得某个时刻的状态能够传到下一个时刻。 这些循环的结构让 RNNs 看起来有些难以理解,但我们可以把 RNNs 看成是一个普通的网络做了多次复制后叠加在一起组成的,每一网络会把它的输出传递到下一个网络中。我们可以把 RNNs 在时间步上进行展开,就得到下图这样: 所以最基本的 RNN Cell 输入就是 xt,它还会输出一个隐含内容传递到下一个 Cell,同时还会生成一个结果 ht,其最基本的结构如如下: 仅仅是输入的 xt 和隐藏状态进行 concat,然后经过线性变换后经过一个 tanh 激活函数便输出了,另外隐含内容和输出结果是相同的内容。 我们来分析一下 TensorFlow 里面 RNN Cell 的实现。 TensorFlow 实现 RNN Cell 的位置在 python/ops/rnncellimpl.py,首先其实现了一个 RNNCell 类,继承了 Layer 类,其内部有三个比较重要的方法,state_size()、output_size()、__call() 方法,其中 state_size() 和 output_size() 方法设置为类属性,可以当做属性来调用,实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
@property
def state_size(self):
"""size(s) of state(s) used by this cell.
It can be represented by an Integer, a TensorShape or a tuple of Integers
or TensorShapes.
"""
raise NotImplementedError("Abstract method")

@property
def output_size(self):
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError("Abstract method")

分别代表 Cell 的状态和输出维度,和 Cell 中的神经元数量有关,但这里两个方法都没有实现,意思是说我们必须要实现一个子类继承 RNNCell 类并实现这两个方法。 另外对于 call() 方法,实际上就是当初始化的对象直接被调用的时候触发的方法,实现如下:

1
2
3
4
5
6
7
8
9
def __call__(self, inputs, state, scope=None):
if scope is not None:
with vs.variable_scope(scope,
custom_getter=self._rnn_get_variable) as scope:
return super(RNNCell, self).__call__(inputs, state, scope=scope)
else:
with vs.variable_scope(vs.get_variable_scope(),
custom_getter=self._rnn_get_variable):
return super(RNNCell, self).__call__(inputs, state)

实际上是调用了父类 Layer 的 call() 方法,但父类中 call() 方法中又调用了 call() 方法,而 Layer 类的 call() 方法的实现如下:

1
2
def call(self, inputs, **kwargs):
return inputs

父类的 call() 方法实现非常简单,所以要实现其真正的功能,只需要在继承 RNNCell 类的子类中实现 call() 方法即可。 接下来我们看下 RNN Cell 的最基本的实现,叫做 BasicRNNCell,其代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class BasicRNNCell(RNNCell):
"""The most basic RNN cell.
Args:
num_units: int, The number of units in the RNN cell.
activation: Nonlinearity to use. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""

def __init__(self, num_units, activation=None, reuse=None):
super(BasicRNNCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._activation = activation or math_ops.tanh
self._linear = None

@property
def state_size(self):
return self._num_units

@property
def output_size(self):
return self._num_units

def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
if self._linear is None:
self._linear = _Linear([inputs, state], self._num_units, True)

output = self._activation(self._linear([inputs, state]))
return output, output

可以看到在初始化的时候,最终要的一个参数是 numunits,意思就是这个 Cell 中神经元的个数,另外还有一个参数 activation 即默认使用的激活函数,默认使用的 tanh,reuse 代表该 Cell 是否可以被重新使用。 在 statesize()、output_size() 方法里,其返回的内容都是 num_units,即神经元的个数,接下来 call() 方法中,传入的参数为 inputs 和 state,即输入的 x 和 上一次的隐含状态,首先实例化了一个 _Linear 类,这个类实际上就是做线性变换的类,将二者传递过来,然后直接调用,就实现了 w * [inputs, state] + b 的线性变换,其中 _Linear 类的 __call() 方法实现如下:

1
2
3
4
5
6
7
8
9
10
def __call__(self, args):
if not self._is_sequence:
args = [args]
if len(args) == 1:
res = math_ops.matmul(args[0], self._weights)
else:
res = math_ops.matmul(array_ops.concat(args, 1), self._weights)
if self._build_bias:
res = nn_ops.bias_add(res, self._biases)
return res

很明显这里传递了 [inputs, state] 作为 call() 方法的 args,会执行 concat() 和 matmul() 方法,然后接着再执行 bias_add() 方法,这样就实现了线性变换。 最后回到 BasicRNNCell 的 call() 方法中,在 _linear() 方法外面又包括了一层 _activation() 方法,即对线性变换应用一次 tanh 激活函数处理,作为输出结果。 最后返回的结果是 output 和 output,第一个代表 output,第二个代表隐状态,其值也等于 output。 我们用一个实例来感受一下:

1
2
3
4
5
6
7
8
9
import tensorflow as tf

cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print(cell.state_size)
inputs = tf.placeholder(tf.float32, shape=[32, 100])
h0 = cell.zero_state(32, tf.float32)
output, h1 = cell(inputs=inputs, state=h0)
print(output, output.shape)
print(h1, h1.shape)

这里我们首先初始化了一个神经元个数为 128 的 BasicRNNCell 类,然后构造了一个 shape 为 [32, 100] 的变量作为 inputs,其代表 batch_size 为 32, 维度为 100,随后初始化了初始隐藏状态,调用了 zero_state() 方法,然后直接调用 cell,实际上是最终调用了其 call() 方法,最后得到 output 和 h1,打印输出结果:

1
2
3
128
Tensor("basic_rnn_cell/Tanh:0", shape=(32, 128), dtype=float32) (32, 128)
Tensor("basic_rnn_cell/Tanh:0", shape=(32, 128), dtype=float32) (32, 128)

可以看到,当输入变量维度为 100 的时候,经过一个 128 神经元 Cell 之后,输出维度变成了 128,其输出 shape 变成了 [32, 128],且此时输出结果和隐藏状态是相同的。

LSTM

RNNs 的出现,主要是因为它们能够把以前的信息联系到现在,从而解决现在的问题。比如,利用前面的信息,能够帮助我们理解当前的内容。 有时候,我们在处理当前任务的时候,只需要看一下比较近的一些信息。比如在一个语言模型中,我们要通过上文来预测一下个词可能是什么,那么当我们看到 “the clouds are in the?”时,不需要更多的信息,我们就能够自然而然的想到下一个词应该是“sky”。在这样的情况下,我们所要预测的内容和相关信息之间的间隔很小,这种情况下 RNNs 就能够利用过去的信息, 很容易实现: 但是如果我们想依赖前文距离非常远的信息时,普通的 RNN 就非常难以做到了,随着间隔信息的增大,RNN 难以对其做关联: 但是 LSTM 可以用来解决这个问题。 LSTM,Long Short Term Memory Networks,是 RNN 的一个变种,经试验它可以用来解决更多问题,并取得了非常好的效果。 LSTM Cell 的结构如下: LSTMs 最关键的地方在于 Cell 的状态 和 结构图上面的那条横穿的水平线。 Cell 状态的传输就像一条传送带,向量从整个 Cell 中穿过,只是做了少量的线性操作。这种结构能够很轻松地实现信息从整个 Cell 中穿过而不做改变。 若只有上面的那条水平线是没办法实现添加或者删除信息的,信息的操作是是通过一种叫做门的结构来实现的。 这里我们可以把门分为三个:遗忘门(Forget Gate)、传入门(Input Gate)、输出门(Output Gate)。

遗忘门(Forget Gate)

首先是 LSTM 要决定让那些信息继续通过这个 Cell,这是通过 Forget Gate 的 sigmoid 神经层来实现的。它的输入是ht−1和xt,输出是一个数值都在 0,1 之间的向量,表示让 Ct−1 的各部分信息通过的比重。 0 表示“不让任何信息通过”, 1 表示“让所有信息通过”。

传入门(Input Gate)

下一步是决定让多少新的信息加入到 Cell 中来,一个叫做 Input Gate 的 sigmoid 层决定哪些信息需要更新,一个 New Input 通过 tanh 生成一个向量,也就是备选的用来更新的内容,Ct~ 。在下一步,我们把这两部分联合起来,对 Cell 的状态进行一个更新。 在经过 Forget Gate 和 Input Gate 处理后,我们就可以对输入的 Ct-1 做更新了,即把Ct−1 更新为 Ct,首先我们把旧的状态 Ct−1 和 ft 相乘, 把一些不想保留的信息忘掉。然后加上 it∗Ct~,这部分信息就是我们要添加的新内容,这样就可以完成对 Ct-1 的更新。

输出门 (Output Gate)

最后我们需要来决定输出什么值,输出主要是依赖于 Cell 的状态 Ct,但是又不仅仅依赖于 Ct,而是需要经过一个过滤的处理。首先,我们还是使用一个 sigmoid 层来决定 Ct 中的哪部分信息会被输出。然后我们把 Ct 通过一个 tanh 激活函数处理,然后把其输出和 sigmoid 计算出来的权重相乘,这样就得到了最后输出的结果。 到了最后,其输出结果有三个内容,其中输出结果就是最上面的箭头代指的内容,即最终计算的结果,隐层包括两部分内容,一个是 Ct,一个是最下方的 ht,我们可以将其合并为一个变量来表示。 接下来我们来看下 LSTMCell 的 TensorFlow 代码实现。 首先它的类是 BasicLSTMCell 类,继承了 RNNCell 类,其初始化方法 init() 实现如下:

1
2
3
4
5
6
7
8
9
10
11
def __init__(self, num_units, forget_bias=1.0,
state_is_tuple=True, activation=None, reuse=None):
super(BasicLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
self._linear = None

这里必须传入的参数仍然是 num_units,即神经元的个数,然后 forget_bias 是初始化 Forget Gate 的偏置大小,state_is_tuple 指的是输出状态类型是元组类型,activation 代表默认激活函数,reuse 代表是否可以被重复使用。 接下来看下 state_size() 方法和 output_size() 方法,实现如下:

1
2
3
4
5
6
7
8
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units)
if self._state_is_tuple else 2 * self._num_units)

@property
def output_size(self):
return self._num_units

这里 state_size() 方法变了,因为输出的 state 需要将 Ct 和隐含状态合并,所以它需要包含两部分的内容,如果传入的参数 state_is_tuple 为 True 的话,状态会被表示成一个元组,否则会是 num_units 乘以 2 的数字,默认是元组形式。output_size() 方法则保持不变。 对于 call() 方法,其实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def call(self, inputs, state):
"""Long short-term memory cell (LSTM).

Args:
inputs: `2-D` tensor with shape `[batch_size x input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
`[batch_size x self.state_size]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
`[batch_size x 2 * self.state_size]`.

Returns:
A pair containing the new hidden state, and the new state (either a
`LSTMStateTuple` or a concatenated state, depending on
`state_is_tuple`).
"""
sigmoid = math_ops.sigmoid
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

if self._linear is None:
self._linear = _Linear([inputs, h], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)

if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state

首先为了获取 c, h,需要将其从 state 中分离开来,如果传入的 state 是元组的话可以直接分解,否则需要调用 split() 方法来分解:

1
2
3
4
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

接下来定义了几个门的实现:

1
i, j, f, o = array_ops.split(value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

放到一起来用 Linear 计算然后分成了 4 份,分别代表 Input Gate、New Input、Forget Gate、Output Gate,用 i、j、f、o 来表示,这时候四个变量都经过了线性变换,乘以权重并做了偏置操作。 接下来就是更新 Ct-1 为 Ct 和得到隐含状态输出了,都是遵循 LSTM 内部的公式实现:

1
2
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)

这里值得注意的是还多加了一个 _forget_bias 变量,即设置了初始化偏置,以免初始输出为 0 的问题。 最后将 new_c 和 new_h 进行合并,如果要输出元组,那么就合并为元组,否则二者进行 concat 操作,返回的结果是 new_h、new_state,前者即 Cell 的输出结果,后者代表隐含状态:

1
2
3
4
5
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state

我们再用一个实例来感受一下 BasicLSTMCell 的用法:

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf

cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
print(cell.state_size)
inputs = tf.placeholder(tf.float32, shape=(32, 100))
h0 = cell.zero_state(32, tf.float32)
output, h1 = cell(inputs=inputs, state=h0)
print(h1)
print(h1.h, h1.h.shape)
print(h1.c, h1.c.shape)
print(output, output.shape)

这里我们首先初始化了一个神经元个数为 128 的 BasicRNNCell 类,然后构造了一个 shape 为 [32, 100] 的变量作为 inputs,其代表 batch_size 为 32, 维度为 100,随后初始化了初始隐藏状态,调用了 zero_state() 方法,然后直接调用 cell,实际上是最终调用了其 call() 方法,最后得到 output 和 h1,此时 h1 是一个元组,它还可以分离成 h 和 c,分别打印其对象和维度,结果如下:

1
2
3
4
5
LSTMStateTuple(c=128, h=128)
LSTMStateTuple(c=<tf.Tensor 'add_1:0' shape=(32, 128) dtype=float32>, h=<tf.Tensor 'mul_2:0' shape=(32, 128) dtype=float32>)
Tensor("mul_2:0", shape=(32, 128), dtype=float32) (32, 128)
Tensor("add_1:0", shape=(32, 128), dtype=float32) (32, 128)
Tensor("mul_2:0", shape=(32, 128), dtype=float32) (32, 128)

可以看到其维度都是 [32, 128],而且 h1.h 和 output 是相同的。 另外 LSTM 有许多变种,其中一个比较有名的就是 Gers & Schmidhuber (2000) 提出的,它在原来的基础上行添加了 Peephole Connections,使得遗忘门可以受 Ct-1 的影响。 另外还有一个变种就是将 Forget Gate 和 Input Gate 二者联合起来,做到要么遗忘老的输入新的,要么保留老的不输入新的。 但接下来还有一个更常用的变种,俺就是 GRU,它是由 Cho, et al. (2014) 提出的,在提出的同时他还提出了 Seq2Seq 模型,为 Generation Model 做好了铺垫。

GRU

GRU,Gated Recurrent Unit,在 GRU 中,只有两个门:重置门(Reset Gate)和更新门(Update Gate)。同时在这个结构中,把 Ct 和隐藏状态进行了合并,整体结构比标准的 LSTM 结构要简单,而且这个结构后来也非常流行。 接下来我们看下 TensorFlow 中 GRUCell 的实现,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
class GRUCell(RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).

Args:
num_units: int, The number of units in the GRU cell.
activation: Nonlinearity to use. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
kernel_initializer: (optional) The initializer to use for the weight and
projection matrices.
bias_initializer: (optional) The initializer to use for the bias.
"""

def __init__(self,
num_units,
activation=None,
reuse=None,
kernel_initializer=None,
bias_initializer=None):
super(GRUCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._activation = activation or math_ops.tanh
self._kernel_initializer = kernel_initializer
self._bias_initializer = bias_initializer
self._gate_linear = None
self._candidate_linear = None

@property
def state_size(self):
return self._num_units

@property
def output_size(self):
return self._num_units

def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
if self._gate_linear is None:
bias_ones = self._bias_initializer
if self._bias_initializer is None:
bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype)
with vs.variable_scope("gates"): # Reset gate and update gate.
self._gate_linear = _Linear(
[inputs, state],
2 * self._num_units,
True,
bias_initializer=bias_ones,
kernel_initializer=self._kernel_initializer)

value = math_ops.sigmoid(self._gate_linear([inputs, state]))
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

r_state = r * state
if self._candidate_linear is None:
with vs.variable_scope("candidate"):
self._candidate_linear = _Linear(
[inputs, r_state],
self._num_units,
True,
bias_initializer=self._bias_initializer,
kernel_initializer=self._kernel_initializer)
c = self._activation(self._candidate_linear([inputs, r_state]))
new_h = u * state + (1 - u) * c
return new_h, new_h

在 state_size()、output_size() 方法里,其返回的内容都是 num_units,即神经元的个数。 接下来 call() 方法中,因为 Reset Gate rt 和 Update Gate zt 分别用变量 r、u 表示,它们需要先对 ht-1 即 state 和 xt 做合并,然后再实现线性变换,再调用 sigmod 函数得到:

1
2
value = math_ops.sigmoid(self._gate_linear([inputs, state]))
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

然后需要求解 ht~,首先用 rt 和 ht-1 即 state 相乘:

1
r_state = r * state

然后将其放到线性函数里面,在调用 tanh 激活函数即可:

1
c = self._activation(self._candidate_linear([inputs, r_state]))

最后计算隐含状态和输出结果,二者一致:

1
2
new_h = u * state + (1 - u) * c
return new_h, new_h

这样即可返回得到输出结果和隐藏状态。 我们用一个实例感受一下:

1
2
3
4
5
6
7
8
9
import tensorflow as tf

cell = tf.nn.rnn_cell.GRUCell(num_units=128)
print(cell.state_size)
inputs = tf.placeholder(tf.float32, shape=[32, 100])
h0 = cell.zero_state(32, tf.float32)
output, h1 = cell(inputs=inputs, state=h0)
print(output, output.shape)
print(h1, h1.shape)

运行结果:

1
2
3
128
Tensor("gru_cell/add:0", shape=(32, 128), dtype=float32) (32, 128)
Tensor("gru_cell/add:0", shape=(32, 128), dtype=float32) (32, 128)

这个结果和 BasicRNNCell 并无二致,但 GRUCell 内部的结构使模型的效果更加优化,一般我们也会选取 GRUCell 来代替原生的 BasicRNNCell。

结语

以上便是对 RNN 及一些变种的说明及代码原理分析和实例用法,此部分掌握之后对 Dynamic RNN、多层 RNN 及 RNN Cell 的改写会有很大帮助,需要好好掌握。

Python

上一节使用了最简单的网络来处理了 MNIST 数据集,但只有 92% 的正确率,接下来我们使用卷积神经网络来实现更高的正确率。

权重初始化

在上一节初始化 w 和 b 的时候,我们使用了置零初始化。但在卷积神经网络中,我们需要在初始化的时候权重加入少量噪声来打破对称性和避免零梯度,偏置项直接使用一个较小的正数来避免节点输出恒为零的问题。 所以权重我们可以使用截尾正态分布函数 truncated_normal() 来生成初始化张量,我们可以给它指定均值或标准差,均值默认是 0, 标准差默认是 1,例如我们生成一个 [10] 的张量,代码如下:

1
2
3
4
import tensorflow as tf
initial = tf.truncated_normal([10], stddev=0.1)
with tf.Session() as sess:
print(sess.run(initial))

结果如下:

1
2
[-0.13058113  0.03201858 -0.19349943 -0.06061752 -0.10267895 -0.11079147
0.1881365 -0.01057311 -0.02797078 0.01180232]

另外 constant() 方法是用于生成常量的方法,例如生成一个 [10] 的常量张量,代码如下:

1
2
3
4
import tensorflow as tf
initial = tf.constant(0.2, shape=[10])
with tf.Session() as sess:
print(sess.run(initial))

结果如下:

1
[ 0.2  0.2  0.2  0.2  0.2  0.2  0.2  0.2  0.2  0.2]

所以这里我们可以将这两个方法封装成一个函数来尝试:

1
2
3
4
5
6
7
def weight(shape, stddev=0.1, mean=0):
initial = tf.truncated_normal(shape=shape, mean=mean, stddev=stddev)
return tf.Variable(initial)

def bias(shape, value):
initial = tf.constant(value=value, shape=shape)
return tf.Variable(initial)

卷积

这次我们需要使用卷积神经网络来处理图片,所以这里的核心部分就是卷积和池化了,首先我们来了解一下卷积和池化。 卷积常用的方法为 conv2d() ,它的 API 如下:

1
tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

这个方法是 TensorFlow 实现卷积常用的方法,也是搭建卷积神经网络的核心方法,参数介绍如下:

  • input,指需要做卷积的输入图像,它要求是一个 Tensor,具有 [batch_size, in_height, in_width, in_channels] 这样的 shape,具体含义是 [batch_size 的图片数量, 图片高度, 图片宽度, 输入图像通道数],注意这是一个 4 维的 Tensor,要求类型为 float32 和 float64 其中之一。
  • filter,相当于 CNN 中的卷积核,它要求是一个 Tensor,具有 [filter_height, filter_width, in_channels, out_channels] 这样的shape,具体含义是 [卷积核的高度,卷积核的宽度,输入图像通道数,输出通道数(即卷积核个数)],要求类型与参数 input 相同,有一个地方需要注意,第三维 in_channels,就是参数 input 的第四维。
  • strides,卷积时在图像每一维的步长,这是一个一维的向量,长度 4,具有 [stride_batch_size, stride_in_height, stride_in_width, stride_in_channels] 这样的 shape,第一个元素代表在一个样本的特征图上移动,第二三个元素代表在特征图上的高、宽上移动,第四个元素代表在通道上移动。
  • padding,string 类型的量,只能是 SAME、VALID 其中之一,这个值决定了不同的卷积方式。
  • use_cudnn_on_gpu,布尔类型,是否使用 cudnn 加速,默认为true。

返回的结果是 [batch_size, out_height, out_width, out_channels] 维度的结果。 我们这里拿一张 3x3 的图片,单通道(通道为1)的图片,拿一个 1x1 的卷积核进行卷积:

1
2
3
4
input = tf.Variable(tf.random_normal([1, 3, 3, 1]))
filter = tf.Variable(tf.random_normal([1, 1, 1, 1]))
op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
print(op.shape)

结果如下:

1
(1, 3, 3, 1)

很清晰,一张图片,拿一个 1x1 的核去做卷积,得到的结果输出是 3x3 的,输出通道为 1,batch_size 照旧。 再将卷积核扩大,用一个 3x3 的卷积核:

1
2
3
4
input = tf.Variable(tf.random_normal([1, 3, 3, 1]))
filter = tf.Variable(tf.random_normal([3, 3, 1, 1]))
op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
print(op.shape)

结果如下:

1
(1, 1, 1, 1)

最后输出的是一个 1x1 的值。 将图片扩大为 7x7,卷积核仍然使用 3x3:

1
2
3
4
input = tf.Variable(tf.random_normal([1, 7, 7, 1]))
filter = tf.Variable(tf.random_normal([3, 3, 1, 1]))
op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
print(op.shape)

结果如下:

1
(1, 5, 5, 1)

最后输出的是一个 5x5 的值。 这时如果我们把 padding 模式改为 SAME,表示卷积核可以停留在图像边缘:

1
2
3
4
input = tf.Variable(tf.random_normal([1, 7, 7, 1]))
filter = tf.Variable(tf.random_normal([3, 3, 1, 1]))
op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
print(op.shape)

结果如下:

1
(1, 7, 7, 1)

则输出的内容和原图像是相同的。 这时如果更改 batch_size 和 out_channels,比如 batch_size 修改为 3,out_channels 修改为 6:

1
2
3
4
input = tf.Variable(tf.random_normal([3, 7, 7, 1]))
filter = tf.Variable(tf.random_normal([3, 3, 1, 6]))
op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME')
print(op.shape)

结果如下:

1
(3, 7, 7, 6)

输出结果的 batch_size 和 out_channels 会随之变化。 当 strides 的步长不为 1 的时候,我们将 stride_in_height 和 stride_in_width 修改为 2,相当于每次移动两步:

1
2
3
4
input = tf.Variable(tf.random_normal([3, 7, 7, 1]))
filter = tf.Variable(tf.random_normal([3, 3, 1, 6]))
op = tf.nn.conv2d(input, filter, strides=[1, 2, 2, 1], padding='VALID')
print(op.shape)

结果如下:

1
(3, 3, 3, 6)

最后我们用一个例子来感受一下:

1
2
3
4
5
6
7
8
9
import tensorflow as tf

input = tf.Variable(tf.random_normal([2, 4, 4, 5]))
filter = tf.Variable(tf.random_normal([2, 2, 5, 2]))
op = tf.nn.conv2d(input, filter, strides=[1, 1, 1, 1], padding='VALID')
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
print(op.shape)
print(sess.run(op))

这里 input、filter 通过指定 shape 的方式调用 random_normal() 方法进行随机初始化,input 的维度为 [2, 4, 4, 5],即 batch_size 为 2,图片是 4x4,输入通道数为 5,卷积核大小为 2x2,输入通道 5,输出通道 2,步长为 1,padding 方式选用 VALID,最后输出得到输出的 shape 和结果。 结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
(2, 3, 3, 2)
[[[[ 2.05039382 -8.82934952]
[ -9.77668381 3.63882256]
[ -4.46390772 -5.91670704]]

[[ 8.41201782 -6.72245312]
[ -1.47592044 13.03628349]
[ 5.44015312 2.46059227]]

[[ -3.18967772 1.24733043]
[-10.1108532 -6.44734669]
[ 1.99426246 2.91549349]]]


[[[ -1.66685319 0.32011557]
[ -5.66163826 -0.37670898]
[ -0.74658942 1.31723833]]

[[ -5.85412216 -0.29930949]
[ -0.75974303 -1.84006214]
[ -2.05475235 4.9572196 ]]

[[ -4.09344864 1.39405775]
[ -1.28887582 -2.82365012]
[ 4.87360907 10.8071022 ]]]]

可以看到 input 维度为 [2, 4, 4, 5],filter 维度为 [2, 2, 5, 2] 时,生成的结果维度为 [2, 3, 3, 2]。

池化

池化层往往在卷积层后面,通过池化来降低卷积层输出的特征向量,同时改善结果。 在这里介绍一个常用的最大值池化 max_pool() 方法,其 API 如下:

1
tf.nn.max_pool(value, ksize, strides, padding, name=None)

是CNN当中的最大值池化操作,其实用法和卷积很类似。 参数介绍如下:

  • value,需要池化的输入,一般池化层接在卷积层后面,所以输入通常是 feature map,依然是 [batch_size, height, width, channels] 这样的shape。
  • ksize,池化窗口的大小,取一个四维向量,一般是 [batch_size, height, width, channels],因为我们不想在 batch 和 channels 上做池化,所以这两个维度设为了1。
  • strides,和卷积类似,窗口在每一个维度上滑动的步长,一般也是 [stride_batch_size, stride_height, stride_width, stride_channels]。
  • padding,和卷积类似,可以取 VALID、SAME,返回一个 Tensor,类型不变,shape 仍然是 [batch_size, height, width, channels] 这种形式。

在这里输入为 [3, 7, 7, 2],池化窗口设置为 [1, 2, 2, 1],步长为 [1, 1, 1, 1],padding 模式设置为 VALID。

1
2
3
input = tf.Variable(tf.random_normal([3, 7, 7, 2]))
op = tf.nn.max_pool(input, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='VALID')
print(op.shape)

结果如下:

1
(3, 6, 6, 2)

类似的原理,我们可以得到这样的的结果。

卷积和池化

所以了解了以上卷积和池化方法的用法,我们可以定义如下两个工具方法:

1
2
3
4
5
def conv2d(input, filter, strides=[1, 1, 1, 1], padding='SAME'):
return tf.nn.conv2d(input, filter, strides=strides, padding=padding)

def max_pool(input, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME'):
return tf.nn.max_pool(input, ksize=ksize, strides=strides, padding=padding)

这两个方法分别实现了卷积和池化,并设置了默认步长和核大小。 接下来就让我们开始神经网络的构建吧。

初始化

首先我们需要初始化一些数据,包括输入的 x 和对一个的标注 y_label:

1
2
x = tf.placeholder(tf.float32, shape=[None, 784])
y_label = tf.placeholder(tf.float32, shape=[None, 10])

第一层卷积

现在我们可以开始实现第一层了。它由一个卷积接一个 max pooling 完成。卷积在每个 5x5 的 patch 中算出 32 个特征。卷积的权重张量形状是 [5, 5, 1, 32],前两个维度是 patch 的大小,接着是输入的通道数目,最后是输出的通道数目,而对于每一个输出通道都有一个对应的偏置量,我们首先初始化 w 和 b

1
2
w_conv1 = weight([5, 5, 1, 32])
b_conv1 = bias([32])

为了用这一层,我们把 x 变成一个四维向量,其第 2、3 维对应图片的宽、高,最后一维代表图片的颜色通道数,因为是灰度图所以这里的通道数为 1,如果是彩色图,则为 3。 随后我们需要对图片做 reshape 操作,将其

1
x_reshape = tf.reshape(x, [-1, 28, 28, 1])

我们把 x_reshape 和权值向量进行卷积,加上偏置项,然后应用 ReLU 激活函数,最后进行 max pooling:

1
2
h_conv1 = tf.nn.relu(conv2d(x_reshape, w_conv1) + b_conv1)
h_pool1 = max_pool(h_conv1)

第二层卷积

现在我们已经实现了一层卷积,为了构建一个更深的网络,我们再继续增加一层卷积,将通道数变成 64,所以这里的初始化权重和偏置为:

1
2
w_conv2 = weight([5, 5, 32, 64])
b_conv2 = bias([64])

随后我们把上一层池化结果 h_pool1 和权值向量进行卷积,加上偏置项,然后应用 ReLU 激活函数,最后进行 max pooling:

1
2
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 = max_pool(h_conv2)

密集连接层

现在,图片尺寸减小到7x7,我们再加入一个有 1024 个神经元的全连接层,用于处理整个图片。我们把池化层输出的张量 reshape 成一些向量,乘上权重矩阵,加上偏置,然后对其使用 ReLU。

1
2
3
4
w_fc1 = weight([7 * 7 * 64, 1024])
b_fc1 = bias([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout

为了减少过拟合,我们在输出层之前加入 dropout。我们用一个 placeholder 来代表一个神经元的输出在 dropout 中保持不变的概率。这样我们可以在训练过程中启用 dropout,在测试过程中关闭 dropout。 TensorFlow 的 tf.nn.dropout 操作除了可以屏蔽神经元的输出外,还会自动处理神经元输出值的 scale,所以用 dropout 的时候可以不用考虑 scale。

1
2
keep_prob = tf.placeholder(tf.float32)
h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob=keep_prob)

输出层

最后,我们添加一个 Softmax 输出层,这里我们需要将 1024 维转为 10 维,所以需要声明一个 [1024, 10] 的权重和 [10] 的偏置:

1
2
3
w_fc2 = weight([1024, 10])
b_fc1 = bias([10])
y = tf.nn.softmax(tf.matmul(h_fc1_dropout, w_fc2) + b_fc1)

训练和评估模型

为了进行训练和评估,我们使用与之前简单的单层 Softmax 神经网络模型几乎相同的一套代码,只是我们会用更加复杂的 Adam 优化器来做梯度最速下降,在 feed_dict 中加入额外的参数 keep_prob 来控制 dropout 比例,然后每 100次 迭代输出一次日志:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Loss
cross_entropy = -tf.reduce_sum(y_label * tf.log(y))
train = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# Prediction
correct_prediction = tf.equal(tf.argmax(y_label, axis=1), tf.argmax(y, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Train
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps + 1):
batch = mnist.train.next_batch(batch_size)
sess.run(train, feed_dict={x: batch[0], y_label: batch[1], keep_prob: dropout_keep_prob})
# Train accuracy
if step % steps_per_test == 0:
print('Training Accuracy', step,
sess.run(accuracy, feed_dict={x: batch[0], y_label: batch[1], keep_prob: 1}))

# Final Test
print('Test Accuracy', sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels, keep_prob: 1}))

运行

以上代码,在最终测试集上的准确率大概是99.2%。 运行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
Training Accuracy 0 0.05
Training Accuracy 100 0.7
Training Accuracy 200 0.85
Training Accuracy 300 0.9
Training Accuracy 400 0.93
Training Accuracy 500 0.91
Training Accuracy 600 0.94
Training Accuracy 700 0.95
Training Accuracy 800 0.95
Training Accuracy 900 0.95
Training Accuracy 1000 0.97
Training Accuracy 1100 0.95
Training Accuracy 1200 0.96
Training Accuracy 1300 0.99
Training Accuracy 1400 0.98
Training Accuracy 1500 0.95
Training Accuracy 1600 0.97
Training Accuracy 1700 1.0
Training Accuracy 1800 0.95
Training Accuracy 1900 0.95
Training Accuracy 2000 0.95
Training Accuracy 2100 0.96
Training Accuracy 2200 0.96
Training Accuracy 2300 0.98
Training Accuracy 2400 0.97
Training Accuracy 2500 0.96
Training Accuracy 2600 0.99
Training Accuracy 2700 0.96
Training Accuracy 2800 0.98
Training Accuracy 2900 0.95
Training Accuracy 3000 0.99

结语

本节我们实现了卷积神经网络来处理图像相关问题,将准确率大大提高,可见神经网络是非常强大的。

本节代码

本节代码地址为:https://github.com/AIDeepLearning/MNIST

Python

我们本节要用 MNIST 数据集训练一个可以识别数据的深度学习模型来帮助识别手写数字。

MNIST

MNIST 是一个入门级计算机视觉数据集,包含了很多手写数字图片,如图所示: 数据集中包含了图片和对应的标注,在 TensorFlow 中提供了这个数据集,我们可以用如下方法进行导入:

1
2
3
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
print(mnist)

输出结果如下:

1
2
3
4
5
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Datasets(train=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x101707ef0>, validation=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x1016ae4a8>, test=<tensorflow.contrib.learn.python.learn.datasets.mnist.DataSet object at 0x1016f9358>)

在这里程序会首先下载 MNIST 数据集,然后解压并保存到刚刚制定好的 MNIST_data 文件夹中,然后输出数据集对象。 数据集中包含了 55000 行的训练数据集(mnist.train)、5000 行验证集(mnist.validation)和 10000 行的测试数据集(mnist.test),文件如下所示: 正如前面提到的一样,每一个 MNIST 数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签。我们把这些图片设为 xs,把这些标签设为 ys。训练数据集和测试数据集都包含 xs 和 ys,比如训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels,每张图片是 28 x 28 像素,即 784 个像素点,我们可以把它展开形成一个向量,即长度为 784 的向量。 所以训练集我们可以转化为 [55000, 784] 的向量,第一维就是训练集中包含的图片个数,第二维是图片的像素点表示的向量。

Softmax

Softmax 可以看成是一个激励(activation)函数或者链接(link)函数,把我们定义的线性函数的输出转换成我们想要的格式,也就是关于 10 个数字类的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被 Softmax 函数转换成为一个概率值。Softmax 函数可以定义为: 展开等式右边的子式,可以得到: 比如判断一张图片中的动物是什么,可能的结果有三种,猫、狗、鸡,假如我们可以经过计算得出它们分别的得分为 3.2、5.1、-1.7,Softmax 的过程首先会对各个值进行次幂计算,分别为 24.5、164.0、0.18,然后计算各个次幂结果占总次幂结果的比重,这样就可以得到 0.13、0.87、0.00 这三个数值,所以这样我们就可以实现差别的放缩,即好的更好、差的更差。 如果要进一步求损失值可以进一步求对数然后取负值,这样 Softmax 后的值如果值越接近 1,那么得到的值越小,即损失越小,如果越远离 1,那么得到的值越大。

实现回归模型

首先导入 TensorFlow,命令如下:

1
import tensorflow as tf

接下来我们指定一个输入,在这里输入即为样本数据,如果是训练集那么则是 55000 x 784 的矩阵,如果是验证集则为 5000 x 784 的矩阵,如果是测试集则是 10000 x 784 的矩阵,所以它的行数是不确定的,但是列数是确定的。 所以可以先声明一个 placeholder 对象:

1
x = tf.placeholder(tf.float32, [None, 784])

这里第一个参数指定了矩阵中每个数据的类型,第二个参数指定了数据的维度。 接下来我们需要构建第一层网络,表达式如下: 这里实际上是对输入的 x 乘以 w 权重,然后加上一个偏置项作为输出,而这两个变量实际是在训练的过程中动态调优的,所以我们需要指定它们的类型为 Variable,代码如下:

1
2
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

接下来需要实现的就是上图所述的公式了,我们再进一步调用 Softmax 进行计算,得到 y:

1
y = tf.nn.softmax(tf.matmul(x, w) + b)

通过上面几行代码我们就已经把模型构建完毕了,结构非常简单。

损失函数

为了训练我们的模型,我们首先需要定义一个指标来评估这个模型是好的。其实,在机器学习,我们通常定义指标来表示一个模型是坏的,这个指标称为成本(cost)或损失(loss),然后尽量最小化这个指标。但是这两种方式是相同的。 一个非常常见的,非常漂亮的成本函数是“交叉熵”(cross-entropy)。交叉熵产生于信息论里面的信息压缩编码技术,但是它后来演变成为从博弈论到机器学习等其他领域里的重要技术手段。它的定义如下: y 是我们预测的概率分布, y_label 是实际的分布,比较粗糙的理解是,交叉熵是用来衡量我们的预测用于描述真相的低效性。 我们可以首先定义 y_label,它的表达式是:

1
y_label = tf.placeholder(tf.float32, [None, 10])

接下来我们需要计算它们的交叉熵,代码如下:

1
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label * tf.log(y), reduction_indices=[1]))

首先用 reduce_sum() 方法针对每一个维度进行求和,reduction_indices 是指定沿哪些维度进行求和。 然后调用 reduce_mean() 则求平均值,将一个向量中的所有元素求算平均值。 这样我们最后只需要优化这个交叉熵就好了。 所以这样我们再定义一个优化方法:

1
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

这里使用了 GradientDescentOptimizer,在这里,我们要求 TensorFlow 用梯度下降算法(gradient descent algorithm)以 0.5 的学习速率最小化交叉熵。梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow 只需将每个变量一点点地往使成本不断降低的方向移动即可。

运行模型

定义好了以上内容之后,相当于我们已经构建好了一个计算图,即设置好了模型,我们把它放到 Session 里面运行即可:

1
2
3
4
5
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps + 1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train, feed_dict={x: batch_x, y_label: batch_y})

该循环的每个步骤中,我们都会随机抓取训练数据中的 batch_size 个批处理数据点,然后我们用这些数据点作为参数替换之前的占位符来运行 train。 这里需要一些变量的定义:

1
2
batch_size = 100
total_steps = 5000

测试模型

那么我们的模型性能如何呢? 首先让我们找出那些预测正确的标签。tf.argmax() 是一个非常有用的函数,它能给出某个 Tensor 对象在某一维上的其数据最大值所在的索引值。由于标签向量是由 0,1 组成,因此最大值 1 所在的索引位置就是类别标签,比如 tf.argmax(y, 1) 返回的是模型对于任一输入 x 预测到的标签值,而 tf.argmax(y_label, 1) 代表正确的标签,我们可以用 tf.equal() 方法来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。

1
correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))

这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,[True, False, True, True] 会变成 [1, 0, 1, 1] ,取平均值后得到 0.75。

1
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

最后,我们计算所学习到的模型在测试数据集上面的正确率,定义如下:

1
2
3
steps_per_test = 100
if step % steps_per_test == 0:
print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))

这个最终结果值应该大约是92%。 这样我们就通过完成了训练和测试阶段,实现了一个基本的训练模型,后面我们会继续优化模型来达到更好的效果。 运行结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
0 0.453
100 0.8915
200 0.9026
300 0.9081
400 0.9109
500 0.9108
600 0.9175
700 0.9137
800 0.9158
900 0.9176
1000 0.9167
1100 0.9186
1200 0.9206
1300 0.9161
1400 0.9218
1500 0.9179
1600 0.916
1700 0.9196
1800 0.9222
1900 0.921
2000 0.9223
2100 0.9214
2200 0.9191
2300 0.9228
2400 0.9228
2500 0.9218
2600 0.9197
2700 0.9225
2800 0.9238
2900 0.9219
3000 0.9224
3100 0.9184
3200 0.9253
3300 0.9216
3400 0.9218
3500 0.9212
3600 0.9225
3700 0.9224
3800 0.9225
3900 0.9226
4000 0.9201
4100 0.9138
4200 0.9184
4300 0.9222
4400 0.92
4500 0.924
4600 0.9234
4700 0.9219
4800 0.923
4900 0.9254
5000 0.9218

结语

本节通过一个 MNIST 数据集来简单体验了一下真实数据的训练和预测过程,但是准确率还不够高,后面我们会学习用卷积的方式来进行模型训练,准确率会更高。

本节代码

本节代码地址为:https://github.com/AIDeepLearning/MNIST

Python

本篇内容基于 Python3 TensorFlow 1.4 版本。

本节内容

本节通过最简单的示例 —— 平面拟合来说明 TensorFlow 的基本用法。

构造数据

TensorFlow 的引入方式是:

1
import tensorflow as tf

接下来我们构造一些随机的三维数据,然后用 TensorFlow 找到平面去拟合它,首先我们用 Numpy 生成随机三维点,其中变量 x 代表三维点的 (x, y) 坐标,是一个 2x100 的矩阵,即 100 个 (x, y),然后变量 y 代表三位点的 z 坐标,我们用 Numpy 来生成这些随机的点:

1
2
3
4
5
6
import numpy as np
x_data = np.float32(np.random.rand(2, 100))
y_data = np.dot([0.300, 0.200], x_data) + 0.400

print(x_data)
print(y_data)

这里利用 Numpy 的 random 模块的 rand() 方法生成了 2x100 的随机矩阵,这样就生成了 100 个 (x, y) 坐标,然后用了一个 dot() 方法算了矩阵乘法,用了一个长度为 2 的向量跟此矩阵相乘,得到一个长度为 100 的向量,然后再加上一个常量,得到 z 坐标,输出结果样例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
[[ 0.97232962  0.08897641  0.54844421  0.5877986   0.5121088   0.64716059
0.22353953 0.18406206 0.16782761 0.97569454 0.65686035 0.75569868
0.35698661 0.43332314 0.41185728 0.24801297 0.50098598 0.12025958
0.40650111 0.51486945 0.19292323 0.03679928 0.56501174 0.5321334
0.71044683 0.00318134 0.76611853 0.42602748 0.33002195 0.04414672
0.73208278 0.62182301 0.49471655 0.8116194 0.86148429 0.48835048
0.69902027 0.14901569 0.18737803 0.66826463 0.43462989 0.35768151
0.79315376 0.0400687 0.76952982 0.12236254 0.61519378 0.92795062
0.84952474 0.16663995 0.13729768 0.50603199 0.38752931 0.39529857
0.29228279 0.09773371 0.43220878 0.2603009 0.14576958 0.21881725
0.64888018 0.41048348 0.27641159 0.61700606 0.49728736 0.75936913
0.04028837 0.88986284 0.84112513 0.34227493 0.69162005 0.89058989
0.39744586 0.85080278 0.37685293 0.80529863 0.31220895 0.50500977
0.95800418 0.43696108 0.04143282 0.05169986 0.33503434 0.1671818
0.10234453 0.31241918 0.23630807 0.37890589 0.63020509 0.78184551
0.87924582 0.99288088 0.30762389 0.43499199 0.53140771 0.43461791
0.23833922 0.08681628 0.74615192 0.25835371]
[ 0.8174957 0.26717573 0.23811154 0.02851068 0.9627012 0.36802396
0.50543582 0.29964805 0.44869211 0.23191817 0.77344608 0.36636299
0.56170034 0.37465382 0.00471885 0.19509546 0.49715847 0.15201907
0.5642485 0.70218688 0.6031307 0.4705168 0.98698962 0.865367
0.36558965 0.72073907 0.83386165 0.29963031 0.72276717 0.98171854
0.30932376 0.52615297 0.35522953 0.13186514 0.73437029 0.03887378
0.1208882 0.67004597 0.83422536 0.17487818 0.71460873 0.51926661
0.55297899 0.78169805 0.77547258 0.92139858 0.25020468 0.70916855
0.68722379 0.75378138 0.30182058 0.91982585 0.93160367 0.81539184
0.87977934 0.07394848 0.1004181 0.48765802 0.73601437 0.59894943
0.34601998 0.69065076 0.6768015 0.98533565 0.83803362 0.47194552
0.84103006 0.84892255 0.04474261 0.02038293 0.50802571 0.15178065
0.86116213 0.51097614 0.44155359 0.67713588 0.66439205 0.67885226
0.4243969 0.35731083 0.07878648 0.53950399 0.84162414 0.24412845
0.61285144 0.00316137 0.67407191 0.83218956 0.94473189 0.09813353
0.16728765 0.95433819 0.1416636 0.4220584 0.35413414 0.55999744
0.94829601 0.62568033 0.89808714 0.07021013]]
[ 0.85519803 0.48012807 0.61215557 0.58204171 0.74617288 0.66775297
0.56814902 0.51514823 0.5400867 0.739092 0.75174732 0.6999822
0.61943605 0.60492771 0.52450095 0.51342299 0.64972749 0.46648169
0.63480003 0.69489821 0.57850311 0.50514314 0.76690145 0.73271342
0.68625198 0.54510222 0.79660789 0.58773431 0.64356002 0.60958773
0.68148959 0.6917775 0.61946087 0.66985885 0.80531934 0.5542799
0.63388372 0.5787139 0.62305848 0.63545502 0.67331071 0.61115777
0.74854193 0.56836022 0.78595346 0.62098848 0.63459907 0.8202189
0.79230218 0.60074826 0.50155342 0.73577477 0.70257953 0.68166794
0.6636407 0.44410981 0.54974625 0.57562188 0.59093375 0.58543506
0.66386805 0.6612752 0.61828378 0.78216895 0.71679293 0.72219985
0.58029252 0.83674336 0.66128606 0.50675907 0.70909116 0.6975331
0.69146618 0.75743606 0.6013666 0.77701676 0.6265411 0.68727338
0.77228063 0.60255049 0.42818714 0.52341076 0.66883513 0.49898023
0.55327365 0.49435803 0.6057068 0.68010968 0.77800791 0.65418036
0.69723127 0.8887319 0.52061989 0.61490928 0.63024914 0.64238486
0.66116097 0.55118095 0.80346301 0.49154814]

这样我们就得到了一些三维的点。

构造模型

随后我们用 TensorFlow 来根据这些数据拟合一个平面,拟合的过程实际上就是寻找 (x, y) 和 z 的关系,即变量 x_data 和变量 y_data 的关系,而它们之间的关系刚才我们用了线性变换表示出来了,即 z = w * (x, y) + b,所以拟合的过程实际上就是找 w 和 b 的过程,所以这里我们就首先像设变量一样来设两个变量 w 和 b,代码如下:

1
2
3
4
5
x = tf.placeholder(tf.float32, [2, 100])
y_label = tf.placeholder(tf.float32, [100])
b = tf.Variable(tf.zeros([1]))
w = tf.Variable(tf.random_uniform([2], -1.0, 1.0))
y = tf.matmul(tf.reshape(w, [1, 2]), x) + b

在创建模型的时候,我们首先可以将现有的变量来表示出来,用 placeholder() 方法声明即可,一会我们在运行的时候传递给它真实的数据就好,第一个参数是数据类型,第二个参数是形状,因为 x_data 是 2x100 的矩阵,所以这里形状定义为 [2, 100],而 y_data 是长度为 100 的向量,所以这里形状定义为 [100],当然此处使用元组定义也可以,不过要写成 (100, )。 随后我们用 Variable 初始化了 TensorFlow 中的变量,b 初始化为一个常量,w 是一个随机初始化的 1x2 的向量,范围在 -1 和 1 之间,然后 y 再用 w、x、b 表示出来,其中 matmul() 方法就是 TensorFlow 中提供的矩阵乘法,类似 Numpy 的 dot() 方法。不过不同的是 matmul() 不支持向量和矩阵相乘,即不能 BroadCast,所以在这里做乘法前需要先调用 reshape() 一下转成 1x2 的标准矩阵,最后将结果表示为 y。 这样我们就构造出来了一个线性模型。 这里的 y 是我们模型中输出的值,而真实的数据却是我们输入的 y_data,即 y_label。

损失函数

要拟合这个平面的话,我们需要减小 y_label 和 y 的差距就好了,这个差距越小越好。 所以接下来我们可以定义一个损失函数,来代表模型实际输出值和真实值之间的差距,我们的目的就是来减小这个损失,代码实现如下:

1
loss = tf.reduce_mean(tf.square(y - y_label))

这里调用了 square() 方法,传入 y_label 和 y 的差来求得平方和,然后使用 reduce_mean() 方法得到这个值的平均值,这就是现在模型的损失值,我们的目的就是减小这个损失值,所以接下来我们使用梯度下降的方法来减小这个损失值即可,定义如下代码:

1
2
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

这里定义了 GradientDescentOptimizer 优化,即使用梯度下降的方法来减小这个损失值,我们训练模型就是来模拟这个过程。

运行模型

最后我们将模型运行起来即可,运行时必须声明一个 Session 对象,然后初始化所有的变量,然后执行一步步的训练即可,实现如下:

1
2
3
4
5
6
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(201):
sess.run(train, feed_dict={x: x_data, y: y_data})
if step % 10 == 0:
print(step, sess.run(w), sess.run(b))

这里定义了 200 次循环,每一次循环都会执行一次梯度下降优化,每次循环都调用一次 run() 方法,传入的变量就是刚才定义个 train 对象,feed_dict 就把 placeholder 类型的变量赋值即可。随着训练的进行,损失会越来越小,w 和 b 也会被慢慢调整为拟合的值。 在这里每 10 次 循环我们都打印输出一下拟合的 w 和 b 的值,结果如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
0 [ 0.31494665  0.33602586] [ 0.84270978]
10 [ 0.19601417 0.17301694] [ 0.47917289]
20 [ 0.23550016 0.18053198] [ 0.44838765]
30 [ 0.26029009 0.18700737] [ 0.43032286]
40 [ 0.27547371 0.19152154] [ 0.41897511]
50 [ 0.28481475 0.19454622] [ 0.41185945]
60 [ 0.29058149 0.19652548] [ 0.40740564]
70 [ 0.2941508 0.19780098] [ 0.40462157]
80 [ 0.29636407 0.1986146 ] [ 0.40288284]
90 [ 0.29773837 0.19913 ] [ 0.40179768]
100 [ 0.29859257 0.19945487] [ 0.40112072]
110 [ 0.29912385 0.199659 ] [ 0.40069857]
120 [ 0.29945445 0.19978693] [ 0.40043539]
130 [ 0.29966027 0.19986697] [ 0.40027133]
140 [ 0.29978839 0.19991697] [ 0.40016907]
150 [ 0.29986817 0.19994824] [ 0.40010536]
160 [ 0.29991791 0.1999677 ] [ 0.40006563]
170 [ 0.29994887 0.19997987] [ 0.40004089]
180 [ 0.29996812 0.19998746] [ 0.40002549]
190 [ 0.29998016 0.19999218] [ 0.40001586]
200 [ 0.29998764 0.19999513] [ 0.40000987]

可以看到,随着训练的进行,w 和 b 也慢慢接近真实的值,拟合越来越精确,接近正确的值。

结语

以上便是通过一个最简单的平面拟合的案例来说明了一下 TensorFlow 的用法,是不是很简单?

代码

本节代码地址:https://github.com/AIDeepLearning/TensorFlowBasis

Linux

部署公司生产环境的Splash集群无奈节点太多 差点被搞死·· 还好我有运维神器Ansible,一次编撰终生可用啊!而且这玩意儿 等幂特性 扩容回滚 So Easy!! 闲话少说开搞!

安装Ansible:

看官方文档去:http://www.ansible.com.cn/index.html 好像这个主控端不支持Windows? 大家虚拟机装个Ubuntu吧。

闲话少扯直接上干货:

整体目录如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
study@study:~/文档/ansible-examples$ tree Splash_Load_balancing_cluster
Splash_Load_balancing_cluster
├── group_vars
│   └── all
├── roles
│   ├── common
│   │   ├── files
│   │   │   ├── CentOS-Base.repo
│   │   │   ├── docker-ce.repo
│   │   │   ├── epel.repo
│   │   │   ├── ntp.conf
│   │   │   └── RPM-GPG-KEY-EPEL-7
│   │   ├── tasks
│   │   │   └── main.yml
│   │   └── templates
│   ├── docker
│   │   ├── handlers
│   │   │   └── main.yml
│   │   ├── tasks
│   │   │   └── main.yml
│   │   └── templates
│   │   └── daemon.json.j2
│   ├── haproxy
│   │   ├── handlers
│   │   │   └── main.yml
│   │   ├── tasks
│   │   │   └── main.yml
│   │   └── templates
│   │   └── haproxy.cfg.j2
│   └── splash
│   ├── files
│   │   ├── filters
│   │   │   └── default.txt
│   │   ├── js-profiles
│   │   ├── lua_modules
│   │   └── proxy-profiles
│   │   └── proxy.ini
│   └── tasks
│   └── main.yml
├── site.retry
└── site.yml

Group_vars: 里面定义全局使用的变量 Roles: 存放所有的规则目录 Roles/common :所有服务器初始化配置部署 Roles/common/filters :需要使用的文件或者文件夹 Roles/common/task:部署任务(main.yml为入口必须要有) Roles/common/templates :配置模板(jinja2模板语法 用于可变更的配置文件,可获取定义在Group_vars中的变量) Roles/Docker :Docker的安装配置 Roles/HAproxy : HAproxy的负载均衡配置 Roles/Splash : Splash的镜像拉取配置部署以及启动 site.yml : 启动入口

使用方法:

在你的Inventory文件定义好主机分组:

必须包括HaProxy、和Docker两个分组如下:

1
2
3
4
5
6
7
study@study:~/文档/ansible-examples$ cat /etc/ansible/inventory/splash 
[docker]
1.1.1.1
[haproxy]
10.253.20.25

[splash_ports]

主控端新建SSH秘钥并发布到你你需要配置的所有主机!!!!(一定要注意如果本机当前工作用户在远程主机不存在额时候,需要指定remote_user这个参数):

1
2
3
4
5
study@study:~/文档/ansible-examples$ cat /etc/ansible/ansible.cfg 
[defaults]
inventory= /etc/ansible/inventory/

remote_user=root

好了开始执行:

1
study@study:~/文档/ansible-examples/Splash_Load_balancing_cluster$ ansible-playbook site.yml

效果就像这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
PLAY [all] **********************************************************************************************************************************************************************************

TASK [Gathering Facts] **********************************************************************************************************************************************************************
ok: [10.1.4.101]
ok: [10.1.4.100]

TASK [common : Copy the CentOS repository definition] ***************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Copy the Docker repository definition] ***************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Create the repository for EPEL] **********************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Create the GPG key for EPEL] *************************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Firewalld service stop] ******************************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Chronyd service stop] ********************************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Install Ansible Base package] ************************************************************************************************************************************************
ok: [10.1.4.100] => (item=['libselinux-python', 'libsemanage-python', 'ntp'])
ok: [10.1.4.101] => (item=['libselinux-python', 'libsemanage-python', 'ntp'])

TASK [common : Configure SELinux to disable] ************************************************************************************************************************************************
[WARNING]: SELinux state change will take effect next reboot

ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Change TimeZone] *************************************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : Copy NTP conf] ***************************************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

TASK [common : NTP Start] *******************************************************************************************************************************************************************
ok: [10.1.4.100]
ok: [10.1.4.101]

PLAY [docker] *******************************************************************************************************************************************************************************

TASK [Gathering Facts] **********************************************************************************************************************************************************************
ok: [10.1.4.101]

TASK [docker : Install Docker package] ******************************************************************************************************************************************************
ok: [10.1.4.101] => (item=['yum-utils', 'device-mapper-persistent-data', 'lvm2', 'docker-ce'])

TASK [docker : Start Docker] ****************************************************************************************************************************************************************
ok: [10.1.4.101]

TASK [docker : Create Docker Speed Configuration file] **************************************************************************************************************************************
ok: [10.1.4.101]

TASK [docker : Restart Docker] **************************************************************************************************************************************************************
changed: [10.1.4.101]

TASK [splash : pull splash] *****************************************************************************************************************************************************************
changed: [10.1.4.101]

TASK [splash : Copy Splash module] **********************************************************************************************************************************************************
ok: [10.1.4.101] => (item=filters)
ok: [10.1.4.101] => (item=js-profiles)
ok: [10.1.4.101] => (item=lua_modules)
ok: [10.1.4.101] => (item=proxy-profiles)

静静等着跑完 就可以愉快的使用啦 ! 需要增加节点的话直接把IP加载Docker分组下 重新执行一遍就可以了! 需要注意如果SSH非默认的22端口还需要指定你的端口号!怎么指定 看看文档去 以上完毕!!! 完整的看这儿:https://github.com/thsheep/ansible-examples

Python

2019年01月04日16:32:17 更新了新的Chrome镜像 将Python版本升级到了3.7 Note: 推荐使用结尾提供的Docker镜像进行二次打包运行代码 各位小伙伴儿的采集日常是不是被JavaScript的各种点击事件折腾的欲仙欲死啊?好不容易找到个Selenium+Chrome可以解决问题! 但是另一个▄█▀█●的事实摆在面前,服务器都特么没有GUI啊·· 好吧!咱们要知难而上!决不能被这个点小困难打倒······· 然而摆在面前的事实是···· 他丫的各种装不上啊!坑爹啊! 那么我来拯救你们于水火之间了! 服务器如下:

1
2
3
4
5
6
7
8
9
10
11
[root@spider01 ~]# hostnamectl 
Static hostname: spider01
Icon name: computer-vm
Chassis: vm
Machine ID: 1c4029c4e7fd42498e25bb75101f85b6
Boot ID: f5a67454b94b454fae3d75ef1ccab69f
Virtualization: kvm
Operating System: CentOS Linux 7 (Core)
CPE OS Name: cpe:/o:centos:centos:7
Kernel: Linux 3.10.0-514.6.2.el7.x86_64
Architecture: x86-64

安装Chromeium:

1
2
3
4
## 安装yum源
[root@spider01 ~]# sudo yum install -y epel-release
## 安装Chrome
[root@spider01 ~]# yum install -y chromium

去这个地方:https://sites.google.com/a/chromium.org/chromedriver/downloads 下载ChromeDriver驱动放在/usr/bin/目录下: 完成结果如下:

1
2
3
[root@spider01 ~]# ll /usr/bin/ | grep chrom
-rwxrwxrwx. 1 root root 7500280 1129 17:32 chromedriver
lrwxrwxrwx. 1 root root 47 1130 09:35 chromium-browser -> /usr/lib64/chromium-browser/chromium-browser.sh

安装XVFB:

1
2
[root@spider01 ~]# yum install Xvfb -y
[root@spider01 ~]# yum install xorg-x11-fonts* -y

新建在/usr/bin/ 一个名叫 xvfb-chromium 的文件写入以下内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
[root@spider01 ~]# cat /usr/bin/xvfb-chromium 
#!/bin/bash

_kill_procs() {
kill -TERM $chromium
wait $chromium
kill -TERM $xvfb
}

# Setup a trap to catch SIGTERM and relay it to child processes
trap _kill_procs SIGTERM

XVFB_WHD=${XVFB_WHD:-1280x720x16}

# Start Xvfb
Xvfb :99 -ac -screen 0 $XVFB_WHD -nolisten tcp &
xvfb=$!

export DISPLAY=:99

chromium --no-sandbox --disable-gpu$@ &
chromium=$!

wait $chromium
wait $xvfb

更改软连接:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
## 更改Chrome启动的软连接
[root@spider01 ~]# ln -s /usr/lib64/chromium-browser/chromium-browser.sh /usr/bin/chromium


[root@spider01 ~]# rm -rf /usr/bin/chromium-browser

[root@spider01 ~]# ln -s /usr/bin/xvfb-chromium /usr/bin/chromium-browser

[root@spider01 ~]# ln -s /usr/bin/xvfb-chromium /usr/bin/google-chrome

[root@spider01 ~]# ll /usr/bin/ | grep chrom*
-rwxrwxrwx. 1 root root 7500280 1129 17:32 chromedriver
lrwxrwxrwx. 1 root root 47 1130 09:47 chromium -> /usr/lib64/chromium-browser/chromium-browser.sh
lrwxrwxrwx. 1 root root 22 1130 09:48 chromium-browser -> /usr/bin/xvfb-chromium
-rwxr-xr-x. 1 root root 73848 127 2016 chronyc
lrwxrwxrwx. 1 root root 22 1130 09:48 google-chrome -> /usr/bin/xvfb-chromium
-rwxrwxrwx. 1 root root 387 1129 18:16 xvfb-chromium

来瞅瞅能不能用哦:

1
2
3
4
5
6
7
8
9
10
11
>>> from selenium import webdriver
>>> options = webdriver.ChromeOptions()
>>> options.add_argument('--headless')
>>> options.add_argument('--no-sandbox')
>>> options.add_argument('--disable-extensions')
>>> options.add_argument('--disable-gpu')
>>> driver = webdriver.Chrome(options=options)
>>> driver.get("http://www.baidu.com")
>>> driver.find_element_by_xpath("./*//input[@id='kw']").send_keys("哎哟卧槽")
>>> driver.find_element_by_xpath("./*//input[@id='su']").click()
>>> driver.page_source

No problem!!!! 好了部署完了!当然Docker这么火贼适合懒人了!来来 看这儿 Docker版的 妥妥滴!

1
docker pull thsheep/python:3.7-debian-chrome

做好了Python3.7和Chrome集成 需要自己使用Dockerfile来重新打包安装你需要的Python包。

Note: 请按照以下方式初始化Webdriver!!!!!!!!

1
2
3
4
5
6
7
8
9
from selenium import webdriver

options = webdriver.ChromeOptions()
options.add_argument('--headless')
options.add_argument('--no-sandbox')
options.add_argument('--disable-extensions')
options.add_argument('--disable-gpu')

driver = webdriver.Chrome(options=options)

否则会出现无法初始化Webdriver的情况

顺便一提!!!!这个玩意儿从事Web测试工作的小伙伴可以用!!!!!!!!

下面是Dockerfile文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
FROM python:3.7-stretch

ENV DBUS_SESSION_BUS_ADDRESS=/dev/null

#============================================
# Google Chrome
#============================================
RUN wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add - && \
echo "deb http://dl.google.com/linux/chrome/deb/ stable main" >> /etc/apt/sources.list.d/google-chrome.list && \
apt-get update -qqy && \
apt-get -qqy install google-chrome-stable unzip&& \
rm /etc/apt/sources.list.d/google-chrome.list && \
rm -rf /var/lib/apt/lists/* /var/cache/apt/*

#==================
# Chrome driver
# CHROME_DRIVER_VERSION 需要根据上面安装的Chrome版本来设置(最好设置成最新版本)
# http://chromedriver.chromium.org/downloads 版本号在这页面上查看
#==================
ARG CHROME_DRIVER_VERSION=2.45
RUN wget -O /tmp/chromedriver.zip https://chromedriver.storage.googleapis.com/$CHROME_DRIVER_VERSION/chromedriver_linux64.zip && \
rm -rf /opt/selenium/chromedriver && \
unzip /tmp/chromedriver.zip -d /opt/selenium && \
rm /tmp/chromedriver.zip && \
mv /opt/selenium/chromedriver /opt/selenium/chromedriver-$CHROME_DRIVER_VERSION && \
chmod 755 /opt/selenium/chromedriver-$CHROME_DRIVER_VERSION && \
ln -fs /opt/selenium/chromedriver-$CHROME_DRIVER_VERSION /usr/bin/chromedriver

RUN google-chrome-stable --version

Python

微博登录限制了错误次数···加上Cookie大批账号被封需要从Cookie池中 剔除被封的账号··· 需要使用代理··· 无赖百度了大半天都是特么的啥玩意儿???结果换成了 Google手到擒来 分分钟解决(那么问题来了?百度除了卖假药还会干啥?) Selenium+Chrome认证代理不能通过options处理。只能换个方法使用扩展解决 原文地址:https://stackoverflow.com/questions/29983106/how-can-i-set-proxy-with-authentication-in-selenium-chrome-web-driver-using-pyth#answer-30953780 (Stack Overflow 这是个好地方啊) 走你!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding: utf-8 -*-
# @Time : 2017/11/15 9:50
# @Author : 哎哟卧槽
# @Site :
# @File : pubilc.py
# @Software: PyCharm

import string
import zipfile

def create_proxyauth_extension(proxy_host, proxy_port,
proxy_username, proxy_password,
scheme='http', plugin_path=None):
"""代理认证插件

args:
proxy_host (str): 你的代理地址或者域名(str类型)
proxy_port (int): 代理端口号(int类型)
proxy_username (str):用户名(字符串)
proxy_password (str): 密码 (字符串)
kwargs:
scheme (str): 代理方式 默认http
plugin_path (str): 扩展的绝对路径

return str -> plugin_path
"""


if plugin_path is None:
plugin_path = 'vimm_chrome_proxyauth_plugin.zip'

manifest_json = """
{
"version": "1.0.0",
"manifest_version": 2,
"name": "Chrome Proxy",
"permissions": [
"proxy",
"tabs",
"unlimitedStorage",
"storage",
"<all_urls>",
"webRequest",
"webRequestBlocking"
],
"background": {
"scripts": ["background.js"]
},
"minimum_chrome_version":"22.0.0"
}
"""

background_js = string.Template(
"""
var config = {
mode: "fixed_servers",
rules: {
singleProxy: {
scheme: "${scheme}",
host: "${host}",
port: parseInt(${port})
},
bypassList: ["foobar.com"]
}
};

chrome.proxy.settings.set({value: config, scope: "regular"}, function() {});

function callbackFn(details) {
return {
authCredentials: {
username: "${username}",
password: "${password}"
}
};
}

chrome.webRequest.onAuthRequired.addListener(
callbackFn,
{urls: ["<all_urls>"]},
['blocking']
);
"""
).substitute(
host=proxy_host,
port=proxy_port,
username=proxy_username,
password=proxy_password,
scheme=scheme,
)
with zipfile.ZipFile(plugin_path, 'w') as zp:
zp.writestr("manifest.json", manifest_json)
zp.writestr("background.js", background_js)

return plugin_path

使用方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from selenium import webdriver
from common.pubilc import create_proxyauth_extension

proxyauth_plugin_path = create_proxyauth_extension(
proxy_host="XXXXX.com",
proxy_port=9020,
proxy_username="XXXXXXX",
proxy_password="XXXXXXX"
)


co = webdriver.ChromeOptions()
# co.add_argument("--start-maximized")
co.add_extension(proxyauth_plugin_path)


driver = webdriver.Chrome(executable_path="C:\chromedriver.exe", chrome_options=co)
driver.get("http://ip138.com/")
print(driver.page_source)

无认证代理:

1
2
3
4
5
options = webdriver.ChromeOptions()
options.add_argument('--proxy-server=http://ip:port')
driver = webdriver.Chrome(executable_path="C:\chromedriver.exe", chrome_options=0ptions)
driver.get("http://ip138.com/")
print(driver.page_source)
以上完毕 So Easy

Java

PS:此文章为小白提供,大佬请绕道!!!! 首先特别感谢大才哥给我提供这个平台,未来我希望把java这个版块的内容补全。 今天要讲的是数据类型,最最最基础的内容~ java标识符、数据类型、关键字 开始我们先看下如何注释java代码。 标识符:类名,方法名,变量。 有三种方式分别为 //表示注释一行代码 / 表示注释一行或者多行代码 (从上面到下面都是注释的代码) / 下面还有一种注释方式叫做文档注释。 / 通常这样表示 */ 文档注释一般写在代码开头用来简述你所做程序的具体内容,在这之前我们首先看一下javadoc命令,我先编写一个简答的代码: package com.briup.chap02; / @author Twinkle @version 1.0 It’s a text file / public class PrimitiveType{ public static void main(String[] args){ byte b = 123; byte b1 = 300; } } 我们javadoc -d 生成目录 编译文件 编译成功后,我们打开刚刚生成doc里打开index.html看一下,大概是这样的: 类概要 类: Student 说明: It’s a text file 这样我们就可以看出文档注释的意义了,他可以显示在你编译出来文档的说明里,但有人会发现为啥我们编写出来的author没有出来呀? 因为他的最前面有一个@,我们需要编写的时候把它加上去才能显示出来,现在我们来试一下, —javadoc -d bin/doc-author -version src/PrimitiveType.java,这样作者和版本信息就出来了 一.类名 这边我们要记住一些代码的基本格式: 类名的写法:Student(前面首字母要大写) 方法和变量的写法:genderItem(前面单词小写,后面单词开头要大写) 常量写法:MAX_PAGE(常量大写,中间一般加下划线) 二.关键字 关键字其实就是电脑里面已经定义好的有特殊意义的标识符,像int,for,double什么的都是关键字。具体意思请百度一下~ 三.数据类型 数据类型是这篇文章的重点,我们来看下这些基本的数据类型 类型 二进制位 例 范围 byte 8位 11111111~01111111 -2^7~2^7-1 short 16位 16个二进制代码 -2^15~2^15-1 int 32位 32个二进制代码 -2^31~2^31-1 long 64位 64个二进制代码 -2^63~2^63-1 浮点型: float 32位 32个二进制代码 double 64位 64个二进制代码 布尔型: boolean 只有false和true两种类型。 具体解释一下为什么会有这么多类型呢?而且二进制位为什么还不一样? 类型多的原因是因为有些数值本身就很小,传递给大的数据类型的话,虽然可以进去,但是有些二进制位就空闲了,占用了多余的内存却没有什么作用,所以才会有这么多的类型。 我们知道编程最终的目的是我们把代码传递给硬件,通过硬件来工作,但是呢,硬件只识别二进制代码,所以java会有一个把它的代码转化为二进制代码的过渡,上面的二进制位就是二进制码的数目,我们要想看他的范围有多大,可以这样算,二进制的第一位为标志符,通俗一点讲就是正负号,后面还有n位的话它的范围就是-|2^n|~|2^n-1| 如果我们定义的类型超出这个范围的话(也就是盆子里已经装满了东西如果再加),java就会报错,超出指定的范围,所以当我们定义数据类型的时候要搞清楚各数据类型的范围。 还有一个特殊的数据类型:char (‘字符’) char的具体位数要结合unicode编码。问题又来了,unicode编码又是什么鬼!unicode编码是一个字符集,里面包含了中,日,韩,三种文字,我们可以通过char的方法来打印出字符:char(‘u\unicode编码’),unicode表具体百度一下哈~ 数据类型转换: 显式转换:也就是强制转换 隐式转换:由JVM虚拟机自行转换 数据类型的强制转换:int a = (强制转换类型)b 转换规则:从存储范围大的类型到存储范围小的类型。 具体规则为:double→float→long→int→short(char)→byte byte b =10; byte a = (int) b; 如果我们把int类型的b转换给byte类型的a的话,会出现溢出现象,所以会报错。 所以正确强制转换的方式为~~: byte b = 10; int(或者更大的类型) a =(int) b; java基本的数据类型就讲到这里啦~ *--可能发布的内容有点混乱,我会尽快把前面的补齐~有疑问的话可以到大才哥的群里找我哈~

未分类

对于Scrapy处理Ajax 处理方式当然是同家兄弟Splash比较靠谱! 但是Splash有个很坑爹的毛病就是负载承受相对较小·· 一不留神就GG了·········· 然后也就没有然后了~~! 所以准备给Splash做一个负载均衡;后端放一大堆的Splash这样总不会GG了吧。 就算其中一个GG了还有其它的可替代不是? 废话不多少开整·· 环境是基于: CentOS 7.3 Docker 17.06.2-ce Splash 3.0 HAproxy 1.7.9 (CentOS大家可以将yum切换为阿里云的yum源 Docker同理)

阿里yum源: http://mirrors.aliyun.com/help/centos 照葫芦画瓢做一遍(你是CentOS7啊!!!!不要选成其他版本了)

注意以下只需要在你需要运行splash的机器上安装即可

阿里Docker源:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# step 1: 安装必要的一些系统工具

sudo yum install -y yum-utils device-mapper-persistent-data lvm2

# Step 2: 添加软件源信息

sudo yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo

# Step 3: 更新并安装 Docker-CE

sudo yum makecache fast
sudo yum -y install docker-ce

# Step 4: 开启Docker服务

sudo service docker start

安装Docker加速器:

1
curl -sSL https://get.daocloud.io/daotools/set_mirror.sh | sh -s http://8050f360.m.daocloud.io

重启Docker:

1
systemctl restart docker

这样可以极快的速度拉取镜像。 获取splash最新的docker镜像:

1
docker pull scrapinghub/splash:master

关闭所有机器防火墙firewalld(网络安全的环境关闭,不安全的环境请放行端口,自行百度):

1
2
3
systemctl disable firewalld

systemctl stop firewalld

创建Splash配置文件目录:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 存放过滤规则文件的目录

[root@localhost ~]# mkdir filters

# 存放JavaScript文件目录

[root@localhost ~]# mkdir js-profiles

# 存放lua模块的目录

[root@localhost ~]# mkdir lua_modules

# 存放代理文件的目录

[root@localhost ~]# mkdir proxy-profiles

# 创建完成如下:

[root@localhost ~]# pwd
/root
[root@localhost ~]# ll
total 4
drwxr-xr-x. 2 root root 25 Sep 26 03:00 filters
drwxr-xr-x. 2 root root 6 Sep 25 21:08 js-profiles
drwxr-xr-x. 2 root root 6 Sep 25 21:08 lua_modules
drwxr-xr-x. 2 root root 32 Sep 25 21:08 proxy-profiles
[root@localhost ~]#

启动Splash:

1
docker run -d -p 8050:8050 --memory=5.0G --restart=always  --name splash       -v /root/proxy-profiles:/etc/splash/proxy-profiles       -v /root/js-profiles:/etc/splash/js-profiles       -v /root/lua_modules:/etc/splash/lua_modules       -v /root/filters:/etc/splash/filters       scrapinghub/splash:master --maxrss 4500

docker run 启动一个容器 -d 后台启动 -p 8050:8050 将容器的8050端口和物理机的8050端口绑定(可以从8050端口访问容器服务应用) —memory=5.0G 容器最大使用内存为5.0GB,超出这个限制会被主进程杀死(使用free -mg 查看并酌情设置你的内存使用) —restart=always 容器退出后无条件重启(满了5GB被杀死,然后重启 释放内存) —name splash 容器的名字叫splash(可以忽略) -v ** 三个-v参数是将宿主机的目录挂载进容器,便于容器能够直接访问挂载目录中的内容 scrapinghub/splash:master 用于启动容器的镜像 —maxrss 4500 Splash最大内存使用为4500MB

查看容器是否启动:

1
2
3
4
[root@localhost ~]# docker ps -a
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
1b34f7933095 scrapinghub/splash:master "python3 /app/bin/..." 4 hours ago Up 4 hours 5023/tcp, 0.0.0.0:8050->8050/tcp splash
[root@localhost ~]#

访问Splash是否正常工作:

请注意:以上操作只需要在你需要运行splash的机器上安装即可

安装HAproxy实现负载均衡:

安装zlib-devel(HAproxy使用gzip功能):

1
yum install zlib-devel -y

安装HAproxy:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# 个人喜好 源码放在这个目录
[root@localhost examples]# cd /usr/local/src/

# 安装wget
[root@localhost src]#yum install wget -y

# 下载HAproxy安装包
[root@localhost src]# wget http://www.haproxy.org/download/1.7/src/haproxy-1.7.9.tar.gz

# 解压
[root@localhost src]# tar -zxvf haproxy-1.7.9.tar.gz

# 进入目录
[root@localhost src]# cd haproxy-1.7.9

# 编译
[root@localhost src]# make TARGET=linux2628 PREFIX=/usr/local/haproxy-1.7.9 USE_ZLIB=yes

# 安装
[root@localhost src]# make install

# 拷贝启动文件到目录
[root@localhost src]# cp /usr/local/sbin/haproxy /usr/sbin/

# 测试版本
[root@localhost src]# haproxy -v

# 拷贝启动文件到启动目录
[root@localhost src]# cp examples/haproxy.init /etc/init.d/haproxy

# 赋予可执行权限
[root@localhost src]# chmod 755 /etc/init.d/haproxy

# 创建配置文件目录
[root@localhost src]# mkdir /etc/haproxy

# 创建数据目录
[root@localhost src]# mkdir /var/lib/haproxy

# 创建运行文件目录
[root@localhost src]# mkdir /var/run/haproxy

# 设置日志
[root@localhost src]# vim /etc/rsyslog.conf
# 第15行 $ModLoad imudp #打开注释
# 第16行 $UDPServerRun 514 #打开注释
# 第74行 local3.* /var/log/haproxy.log #local3的路径

# 创建日志文件
[root@localhost src]# touch /var/log/haproxy.log

# 设置权限
[root@localhost src]# chown -R haproxy.haproxy /var/log/haproxy.log

# 启动日志服务
[root@localhost src]# systemctl restart rsyslog.service

配置HAproxy Conf:

1
[root@localhost src]# vim /etc/haproxy/haproxy.cfg

写入以下内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# HAProxy 1.7 config for Splash. It assumes Splash instances are executed
# on the same machine and connected to HAProxy using Docker links.
global
# raise it if necessary
maxconn 512
# required for stats page
stats socket /tmp/haproxy

userlist users
user user insecure-password userpass

defaults
log global
mode http

# remove requests from a queue when clients disconnect;
# see https://cbonte.github.io/haproxy-dconv/1.7/configuration.html#4.2-option%20abortonclose
option abortonclose

# gzip can save quite a lot of traffic with json, html or base64 data
# compression algo gzip
compression type text/html text/plain application/json

# increase these values if you want to
# allow longer request queues in HAProxy
timeout connect 3600s
timeout client 3600s
timeout server 3600s


# visit 0.0.0.0:8036 to see HAProxy stats page
listen stats
bind *:8036
mode http
stats enable
stats hide-version
stats show-legends
stats show-desc Splash Cluster
stats uri /
stats refresh 10s
stats realm Haproxy\ Statistics
stats auth admin:adminpass


# Splash Cluster configuration
# 代理服务器监听全局的8050端口
frontend http-in
bind *:8050
# 如果你需要开启Splash的访问认证
# 则注释default_backend splash-cluster
# 并放开其余default_backend splash-cluster 之上的其余注释
# 账号密码为user userpass
# acl auth_ok http_auth(users)
# http-request auth realm Splash if !auth_ok
# http-request allow if auth_ok
# http-request deny

# acl staticfiles path_beg /_harviewer/
# acl misc path / /info /_debug /debug

# use_backend splash-cluster if auth_ok !staticfiles !misc
# use_backend splash-misc if auth_ok staticfiles
# use_backend splash-misc if auth_ok misc
default_backend splash-cluster


backend splash-cluster
option httpchk GET /
balance leastconn

# try another instance when connection is dropped
retries 2
option redispatch
# 将下面IP地址替换为你自己的Splash服务IP和端口
# 按照以下格式一次增加其余的Splash服务器
server splash-0 10.10.1.41:8050 check maxconn 5 inter 2s fall 10 observe layer4
server splash-1 10.10.1.42:8050 check maxconn 5 inter 2s fall 10 observe layer4
server splash-2 10.10.1.32:8050 check maxconn 5 inter 2s fall 10 observe layer4

backend splash-misc
balance roundrobin
# 将下面IP地址替换为你自己的Splash服务IP和端口
# 按照以下格式一次增加其余的Splash服务器
server splash-0 10.10.1.41:8050 check fall 15
server splash-1 10.10.1.42:8050 check fall 15
server splash-2 10.10.1.32:8050 check fall 15

启动HAproxy:

1
2
3
4
5
6
7
8
# 启动HAproxy
[root@localhost src]# /etc/init.d/haproxy start
Restarting haproxy (via systemctl): [ OK ]

# 如果出现错误则使用:
[root@localhost examples]# systemctl status haproxy.service

# 查看报错

查看HAproxy状态: 用户名和密码为: admin adminpass

查看HAproxy负载是否生效:

完美!!!收工!! 注意:HAproxy这台服务器没有安装Splash服务,是负载到其余安装有Splash的服务器上提供的服务器哦!

Python

大家好,我还是小四毛,不是崔老师!!!!崔老师在隔壁,哈哈哈。

写了一个从网上抓取代理IP,然后构建代理IP池的脚本,放在了这里:https://github.com/xiaosimao/IP_POOL

以后应该还会有很多的改动, 欢迎有兴趣的同学star,以便及时可以收到改动的通知。

目前是从以下几个网站获取IP:66ip,xicidaili,data5u,proxydb。

具体的使用方法文档在readme.md 中, 欢迎交流。

如果你需要从别的网站获得, 那么可以在配置文件中进行相关的配置即可, 如果实在不想自己写,也可以提issue给我,我会看情况更新到代码中。

一般情况下,只要配置一下配置项就可以从新的网站获取IP,最大限度减少写代码的时间。

免费的ip,质量不敢保证,目前测试的目标网站为百度和https://httpbin.org/get, 还是获得了一些通过测试的IP,下面是截图。

Net

HTTP 2xx 范围内的状态码表明了“客户端发送的请求已经被服务器接受并且被成功处理了”。 HTTP/1.1 200 OK 是 HTTP 请求成功后的标准响应,当你在浏览器中打开某个网站后,你通常会得到一个 200 状态码。HTTP/1.1 206 状态码表示的是“客户端通过发送范围请求头Range抓取到了资源的部分数据” 这种请求通常用来:

  • 学习http头和状态
  • 解决网路问题
  • 解决大文件下载问题
  • 解决CDN和原始HTTP服务器问题
  • 使用工具例如lftp,wget,telnet测试断电续传
  • 测试将一个大文件分割成多个部分同时下载

测试

查看服务器是否支持 HTTP 206:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
curl -I https://raw.githubusercontent.com/Germey/LaravelGeetest/master/README.md
HTTP/1.1 200 OK
Content-Security-Policy: default-src 'none'; style-src 'unsafe-inline'
Strict-Transport-Security: max-age=31536000
X-Content-Type-Options: nosniff
X-Frame-Options: deny
X-XSS-Protection: 1; mode=block
ETag: "b29f4639b76cd7f94a4b36b05be6c85acfe478f1"
Content-Type: text/plain; charset=utf-8
Cache-Control: max-age=300
X-Geo-Block-List:
X-GitHub-Request-Id: 850A:16D2:30128BA:3341504:59BBC946
Content-Length: 8709
Accept-Ranges: bytes
Date: Fri, 15 Sep 2017 12:36:31 GMT
Via: 1.1 varnish
Connection: keep-alive
X-Served-By: cache-nrt6123-NRT
X-Cache: HIT
X-Cache-Hits: 1
X-Timer: S1505478991.145862,VS0,VE1
Vary: Authorization,Accept-Encoding
Access-Control-Allow-Origin: *
X-Fastly-Request-ID: ee23d80d2ba507ec0a70c880a075df0d2671aa4d
Expires: Fri, 15 Sep 2017 12:41:31 GMT
Source-Age: 8

其中有两个我们比较关注的请求头: Accept-Ranges: bytes:该响应头表明服务器支持 Range 请求,以及服务器所支持的单位是字节。同时服务器支持断点续传,以及支持同时下载文件的多个部分,也就是说下载工具可以利用范围请求加速下载该文件。Accept-Ranges: none 响应头表示服务器不支持范围请求。 Content-Length: 8709 :Content-Length 响应头表明了响应实体的大小,也就是真实的图片文件的大小是 8709 字节 (8.7K)。

发送

利用 CURL 可以指定请求范围。 获取前 500 字节:

1
curl --header "Range: bytes=0-500" https://raw.githubusercontent.com/Germey/LaravelGeetest/master/README.md

后 500 字节:

1
curl --header "Range: bytes=-500" https://raw.githubusercontent.com/Germey/LaravelGeetest/master/README.md

从 500 - 1000 字节:

1
curl --header "Range: bytes=500-1000" https://raw.githubusercontent.com/Germey/LaravelGeetest/master/README.md

从 500 - 末尾字节:

1
curl --header "Range: bytes=500-" https://raw.githubusercontent.com/Germey/LaravelGeetest/master/README.md

开启

大部分web服务器都原生支持字节范围请求. Apache 2.x用户可以在httpd.conf中尝试 mod_headers:

1
Header set Accept-Ranges bytes

Python

大家好,我是四毛, 不是崔老师。

恩,今天的内容很短, 主要都写在了README.md里面。

写了一个将爬虫基本步骤都封装起来的小框架,地址在https://github.com/xiaosimao/AiSpider, 欢迎Star。

写的很基础,很简单,大道至简(对,其实就是不会写)。

最近也在学一些设计模式的东西。

欢迎有兴趣的同学共同研究,readme.md中有我的微信(加的时候麻烦注明一下来自静觅),提出存在的问题和你的想法,这样大家可以共同讨论,共同进步。

BE A SPIDERMAN。

Python

Neo4j是一个世界领先的开源图形数据库,由 Java 编写。图形数据库也就意味着它的数据并非保存在表或集合中,而是保存为节点以及节点之间的关系。 Neo4j 的数据由下面几部分构成:

  • 节点
  • 属性

Neo4j 除了顶点(Node)和边(Relationship),还有一种重要的部分——属性。无论是顶点还是边,都可以有任意多的属性。属性的存放类似于一个 HashMap,Key 为一个字符串,而 Value 必须是基本类型或者是基本类型数组。 在Neo4j中,节点以及边都能够包含保存值的属性,此外:

  • 可以为节点设置零或多个标签(例如 Author 或 Book)
  • 每个关系都对应一种类型(例如 WROTE 或 FRIEND_OF)
  • 关系总是从一个节点指向另一个节点(但可以在不考虑指向性的情况下进行查询)

具体介绍可以参考:https://www.w3cschool.cn/neo4j

Neo4j的特点

  • 它拥有简单的查询语言 Neo4j CQL
  • 它遵循属性图数据模型
  • 它通过使用 Apache Lucence 支持索引
  • 它支持 UNIQUE 约束
  • 它包含一个用于执行 CQL 命令的 UI:Neo4j 数据浏览器
  • 它支持完整的 ACID(原子性,一致性,隔离性和持久性)规则
  • 它采用原生图形库与本地 GPE(图形处理引擎)
  • 它支持查询的数据导出到 Json 和 XLS 格式
  • 它提供了 REST API,可以被任何编程语言(如 Java,Spring,Scala 等)访问
  • 它提供了可以通过任何 UI MVC 框架(如 Node JS )访问的 Java 脚本
  • 它支持两种 Java API:Cypher API 和 Native Java API 来开发 Java 应用程序

Neo4j安装

可以到官网直接下载安装包安装即可,链接:https://neo4j.com/download/

Neo4j CQL命令

Neo4j 的 CQL 是非常重要的命令,类似于 SQL 语句,具体的用法可以参考:https://www.w3cschool.cn/neo4j/neo4j_cql_introduction.html

Py2Neo用法

Py2Neo 是用来对接 Neo4j 的 Python 库,接下来对其详细介绍。

相关链接

  • 官方文档:http://py2neo.org/v3/index.html
  • GitHub:https://github.com/technige/py2neo

安装方法

使用 Pip 安装即可:

1
pip3 install py2neo

Node & Relationship

Neo4j 里面最重要的两个数据结构就是节点和关系,即 Node 和 Relationship,可以通过 Node 或 Relationship 对象创建,实例如下:

1
2
3
4
5
6
from py2neo import Node, Relationship

a = Node('Person', name='Alice')
b = Node('Person', name='Bob')
r = Relationship(a, 'KNOWS', b)
print(a, b, r)

运行结果:

1
(alice:Person {name:"Alice"}) (bob:Person {name:"Bob"}) (alice)-[:KNOWS]->(bob)

这样我们就成功创建了两个 Node 和两个 Node 之间的 Relationship。 Node 和 Relationship 都继承了 PropertyDict 类,它可以赋值很多属性,类似于字典的形式,例如可以通过如下方式对 Node 或 Relationship 进行属性赋值,接着上面的代码,实例如下:

1
2
3
4
a['age'] = 20
b['age'] = 21
r['time'] = '2017/08/31'
print(a, b, r)

运行结果:

1
(alice:Person {age:20,name:"Alice"}) (bob:Person {age:21,name:"Bob"}) (alice)-[:KNOWS {time:"2017/08/31"}]->(bob)

可见通过类似字典的操作方法就可以成功实现属性赋值。 另外还可以通过 setdefault() 方法赋值默认属性,例如:

1
2
a.setdefault('location', '北京')
print(a)

运行结果:

1
(alice:Person {age:20,location:"北京",name:"Alice"})

可见没有给 a 对象赋值 location 属性,现在就会使用默认属性。 但如果赋值了 location 属性,则它会覆盖默认属性,例如:

1
2
3
a['location'] = '上海'
a.setdefault('location', '北京')
print(a)

运行结果:

1
(alice:Person {age:20,location:"上海",name:"Alice"})

另外也可以使用 update() 方法对属性批量更新,接着上面的例子实例如下:

1
2
3
4
5
6
data = {
'name': 'Amy',
'age': 21
}
a.update(data)
print(a)

运行结果:

1
(alice:Person {age:21,location:"上海",name:"Amy"})

可以看到这里更新了 a 对象的 name 和 age 属性,没有更新 location 属性,则 name 和 age 属性会更新,location 属性则会保留。

Subgraph

Subgraph,子图,是 Node 和 Relationship 的集合,最简单的构造子图的方式是通过关系运算符,实例如下:

1
2
3
4
5
6
7
from py2neo import Node, Relationship

a = Node('Person', name='Alice')
b = Node('Person', name='Bob')
r = Relationship(a, 'KNOWS', b)
s = a | b | r
print(s)

运行结果:

1
({(alice:Person {name:"Alice"}), (bob:Person {name:"Bob"})}, {(alice)-[:KNOWS]->(bob)})

这样就组成了一个 Subgraph。 另外还可以通过 nodes() 和 relationships() 方法获取所有的 Node 和 Relationship,实例如下:

1
2
print(s.nodes())
print(s.relationships())

运行结果:

1
2
frozenset({(alice:Person {name:"Alice"}), (bob:Person {name:"Bob"})})
frozenset({(alice)-[:KNOWS]->(bob)})

可以看到结果是 frozenset 类型。 另外还可以利用 & 取 Subgraph 的交集,例如:

1
2
3
s1 = a | b | r
s2 = a | b
print(s1 & s2)

运行结果:

1
({(alice:Person {name:"Alice"}), (bob:Person {name:"Bob"})}, {})

可以看到结果是二者的交集。 另外我们还可以分别利用 keys()、labels()、nodes()、relationships()、types() 分别获取 Subgraph 的 Key、Label、Node、Relationship、Relationship Type,实例如下:

1
2
3
4
5
6
s = a | b | r
print(s.keys())
print(s.labels())
print(s.nodes())
print(s.relationships())
print(s.types())

运行结果:

1
2
3
4
5
frozenset({'name'})
frozenset({'Person'})
frozenset({(alice:Person {name:"Alice"}), (bob:Person {name:"Bob"})})
frozenset({(alice)-[:KNOWS]->(bob)})
frozenset({'KNOWS'})

另外还可以用 order() 或 size() 方法来获取 Subgraph 的 Node 数量和 Relationship 数量,实例如下:

1
2
3
4
from py2neo import Node, Relationship, size, order
s = a | b | r
print(order(s))
print(size(s))

运行结果:

1
2
2
1

Walkable

Walkable 是增加了遍历信息的 Subgraph,我们通过 + 号便可以构建一个 Walkable 对象,例如:

1
2
3
4
5
6
7
8
9
from py2neo import Node, Relationship

a = Node('Person', name='Alice')
b = Node('Person', name='Bob')
c = Node('Person', name='Mike')
ab = Relationship(a, "KNOWS", b)
ac = Relationship(a, "KNOWS", c)
w = ab + Relationship(b, "LIKES", c) + ac
print(w)

运行结果:

1
(alice)-[:KNOWS]->(bob)-[:LIKES]->(mike)<-[:KNOWS]-(alice)

这样我们就形成了一个 Walkable 对象。 另外我们可以调用 walk() 方法实现遍历,实例如下:

1
2
3
4
from py2neo import walk

for item in walk(w):
print(item)

运行结果:

1
2
3
4
5
6
7
(alice:Person {name:"Alice"})
(alice)-[:KNOWS]->(bob)
(bob:Person {name:"Bob"})
(bob)-[:LIKES]->(mike)
(mike:Person {name:"Mike"})
(alice)-[:KNOWS]->(mike)
(alice:Person {name:"Alice"})

可以看到它从 a 这个 Node 开始遍历,然后到 b,再到 c,最后重新回到 a。 另外还可以利用 start_node()、end_node()、nodes()、relationships() 方法来获取起始 Node、终止 Node、所有 Node 和 Relationship,例如:

1
2
3
4
print(w.start_node())
print(w.end_node())
print(w.nodes())
print(w.relationships())

运行结果:

1
2
3
4
(alice:Person {name:"Alice"})
(alice:Person {name:"Alice"})
((alice:Person {name:"Alice"}), (bob:Person {name:"Bob"}), (mike:Person {name:"Mike"}), (alice:Person {name:"Alice"}))
((alice)-[:KNOWS]->(bob), (bob)-[:LIKES]->(mike), (alice)-[:KNOWS]->(mike))

可以看到本例中起始和终止 Node 都是同一个,这和 walk() 方法得到的结果是一致的。

Graph

在 database 模块中包含了和 Neo4j 数据交互的 API,最重要的当属 Graph,它代表了 Neo4j 的图数据库,同时 Graph 也提供了许多方法来操作 Neo4j 数据库。 Graph 在初始化的时候需要传入连接的 URI,初始化参数有 bolt、secure、host、http_port、https_port、bolt_port、user、password,详情说明可以参考:http://py2neo.org/v3/database.html#py2neo.database.Graph。 初始化的实例如下:

1
2
3
4
from py2neo import Graph
graph_1 = Graph()
graph_2 = Graph(host="localhost")
graph_3 = Graph("http://localhost:7474/db/data/")

另外我们还可以利用 create() 方法传入 Subgraph 对象来将关系图添加到数据库中,实例如下:

1
2
3
4
5
6
7
8
from py2neo import Node, Relationship, Graph

a = Node('Person', name='Alice')
b = Node('Person', name='Bob')
r = Relationship(a, 'KNOWS', b)
s = a | b | r
graph = Graph(password='123456')
graph.create(s)

这里必须确保 Neo4j 正常运行,其密码为 123456,这里调用 create() 方法即可完成图的创建,结果如下: 另外我们也可以单独添加单个 Node 或 Relationship,实例如下:

1
2
3
4
5
6
7
8
from py2neo import Graph, Node, Relationship

graph = Graph(password='123456')
a = Node('Person', name='Alice')
graph.create(a)
b = Node('Person', name='Bob')
ab = Relationship(a, 'KNOWS', b)
graph.create(ab)

运行结果如下: 另外还可以利用 data() 方法来获取查询结果:

1
2
3
4
5
from py2neo import Graph

graph = Graph(password='123456')
data = graph.data('MATCH (p:Person) return p')
print(data)

运行结果:

1
[{'p': (e0d0f96:Person {name:"Alice"})}, {'p': (cfe57d0:Person {name:"Bob"})}]

这里是通过 CQL 语句实现的查询,输出结果即 CQL 语句的返回结果,是列表形式。 另外输出结果还可以直接转化为 DataFrame 对象,实例如下:

1
2
3
4
5
6
from py2neo import Graph
from pandas import DataFrame
graph = Graph(password='123456')
data = graph.data('MATCH (p:Person) return p')
df = DataFrame(data)
print(df)

运行结果:

1
2
3
                   p
0 {'name': 'Alice'}
1 {'name': 'Bob'}

另外可以使用 find_one() 或 find() 方法进行 Node 的查找,可以利用 match() 或 match_one() 方法对 Relationship 进行查找:

1
2
3
4
5
6
7
from py2neo import Graph

graph = Graph(password='123456')
node = graph.find_one(label='Person')
print(node)
relationship = graph.match_one(rel_type='KNOWS')
print(relationship)

运行结果:

1
2
(c7402c7:Person {age:21,name:"Alice"})
(c7402c7)-[:KNOWS]->(e2c42fc)

如果想要更新 Node 的某个属性可以使用 push() 方法,例如:

1
2
3
4
5
6
7
8
from py2neo import Graph, Node

graph = Graph(password='123456')
a = Node('Person', name='Alice')
node = graph.find_one(label='Person')
node['age'] = 21
graph.push(node)
print(graph.find_one(label='Person'))

运行结果:

1
(a90a763:Person {age:21,name:"Alice"})

如果想要删除某个 Node 可以使用 delete() 方法,例如:

1
2
3
4
5
6
7
from py2neo import Graph

graph = Graph(password='123456')
node = graph.find_one(label='Person')
relationship = graph.match_one(rel_type='KNOWS')
graph.delete(relationship)
graph.delete(node)

在删除 Node 时必须先删除其对应的 Relationship,否则无法删除 Node。 另外我们也可以通过 run() 方法直接执行 CQL 语句,例如:

1
2
3
4
5
from py2neo import Graph

graph = Graph(password='123456')
data = graph.run('MATCH (p:Person) RETURN p LIMIT 5')
print(list(data))

运行结果:

1
[('p': (b6f61ff:Person {age:20,name:"Alice"})), ('p': (cc238b1:Person {age:20,name:"Alice"})), ('p': (b09e672:Person {age:20,name:"Alice"}))]

NodeSelector

Graph 有时候用起来不太方便,比如如果要根据多个条件进行 Node 的查询是做不到的,在这里更方便的查询方法是利用 NodeSelector,我们首先新建如下的 Node 和 Relationship,实例如下:

1
2
3
4
5
6
7
8
9
10
11
from py2neo import Graph, Node, Relationship

graph = Graph(password='123456')
a = Node('Person', name='Alice', age=21, location='广州')
b = Node('Person', name='Bob', age=22, location='上海')
c = Node('Person', name='Mike', age=21, location='北京')
r1 = Relationship(a, 'KNOWS', b)
r2 = Relationship(b, 'KNOWS', c)
graph.create(a)
graph.create(r1)
graph.create(r2)

运行结果: 在这里我们用 NodeSelector 来筛选 age 为 21 的 Person Node,实例如下:

1
2
3
4
5
6
from py2neo import Graph, NodeSelector

graph = Graph(password='123456')
selector = NodeSelector(graph)
persons = selector.select('Person', age=21)
print(list(persons))

运行结果:

1
[(d195b2e:Person {age:21,location:"广州",name:"Alice"}), (eefe475:Person {age:21,location:"北京",name:"Mike"})]

另外也可以使用 where() 进行更复杂的查询,例如查找 name 是 A 开头的 Person Node,实例如下:

1
2
3
4
5
6
from py2neo import Graph, NodeSelector

graph = Graph(password='123456')
selector = NodeSelector(graph)
persons = selector.select('Person').where('_.name =~ "A.*"')
print(list(persons))

运行结果:

1
[(bcd8072:Person {age:21,location:"广州",name:"Alice"})]

在这里用了正则表达式匹配查询。 另外也可以使用 order_by() 进行排序:

1
2
3
4
5
6
from py2neo import Graph, NodeSelector

graph = Graph(password='123456')
selector = NodeSelector(graph)
persons = selector.select('Person').order_by('_.age')
print(list(persons))

运行结果:

1
[(e3fc3d7:Person {age:21,location:"广州",name:"Alice"}), (da0179d:Person {age:21,location:"北京",name:"Mike"}), (cafa16e:Person {age:22,location:"上海",name:"Bob"})]

前面返回的都是列表,如果要查询单个节点的话,可以使用 first() 方法,实例如下:

1
2
3
4
5
6
from py2neo import Graph, NodeSelector

graph = Graph(password='123456')
selector = NodeSelector(graph)
person = selector.select('Person').where('_.name =~ "A.*"').first()
print(person)

运行结果:

1
(ea81c04:Person {age:21,location:"广州",name:"Alice"})

更详细的内容可以查看:http://py2neo.org/v3/database.html#cypher-utilities

OGM

OGM 类似于 ORM,意为 Object Graph Mapping,这样可以实现一个对象和 Node 的关联,例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from py2neo.ogm import GraphObject, Property, RelatedTo, RelatedFrom


class Movie(GraphObject):
__primarykey__ = 'title'

title = Property()
released = Property()
actors = RelatedFrom('Person', 'ACTED_IN')
directors = RelatedFrom('Person', 'DIRECTED')
producers = RelatedFrom('Person', 'PRODUCED')

class Person(GraphObject):
__primarykey__ = 'name'

name = Property()
born = Property()
acted_in = RelatedTo('Movie')
directed = RelatedTo('Movie')
produced = RelatedTo('Movie')

我们可以用它来结合 Graph 查询,例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from py2neo import Graph
from py2neo.ogm import GraphObject, Property

graph = Graph(password='123456')


class Person(GraphObject):
__primarykey__ = 'name'

name = Property()
age = Property()
location = Property()

person = Person.select(graph).where(age=21).first()
print(person)
print(person.name)
print(person.age)

运行结果:

1
2
3
<Person name='Alice'>
Alice
21

这样我们就成功实现了对象和 Node 的映射。 我们可以用它动态改变 Node 的属性,例如修改某个 Node 的 age 属性,实例如下:

1
2
3
4
5
person = Person.select(graph).where(age=21).first()
print(person.__ogm__.node)
person.age = 22
print(person.__ogm__.node)
graph.push(person)

运行结果:

1
2
(ccf5640:Person {age:21,location:"北京",name:"Mike"})
(ccf5640:Person {age:22,location:"北京",name:"Mike"})

另外我们也可以通过映射关系进行 Relationship 的调整,例如通过 Relationship 添加一个关联 Node,实例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from py2neo import Graph
from py2neo.ogm import GraphObject, Property, RelatedTo

graph = Graph(password='123456')

class Person(GraphObject):
__primarykey__ = 'name'

name = Property()
age = Property()
location = Property()
knows = RelatedTo('Person', 'KNOWS')

person = Person.select(graph).where(age=21).first()
print(list(person.knows))
new_person = Person()
new_person.name = 'Durant'
new_person.age = 28
person.knows.add(new_person)
print(list(person.knows))

运行结果:

1
2
[<Person name='Bob'>]
[<Person name='Bob'>, <Person name='Durant'>]

这样我们就完成了 Node 和 Relationship 的添加,同时由于设置了 primarykey 为 name,所以不会重复添加。 但是注意此时数据库并没有更新,只是对象更新了,如果要更新到数据库中还需要调用 Graph 对象的 push() 或 pull() 方法,添加如下代码即可:

1
graph.push(person)

也可以通过 remove() 方法移除某个关联 Node,实例如下:

1
2
3
4
5
person = Person.select(graph).where(name='Alice').first()
target = Person.select(graph).where(name='Durant').first()
person.knows.remove(target)
graph.push(person)
graph.delete(target)

这里 target 是 name 为 Durant 的 Node,代码运行完毕后即可删除关联 Relationship 和删除 Node。 以上便是 OGM 的用法,查询修改非常方便,推荐使用此方法进行 Node 和 Relationship 的修改。 更多内容可以查看:http://py2neo.org/v3/ogm.html#module-py2neo.ogm

结语

以上便是对 Neo4j 的相关介绍。

Python

基本步骤: 1、训练素材分类: 我是参考官方的目录结构: 每个目录中放对应的文本,一个txt文件一篇对应的文章:就像下面这样 需要注意的是所有素材比例请保持在相同的比例(根据训练结果酌情调整、不可比例过于悬殊、容易造成过拟合(通俗点就是大部分文章都给你分到素材最多的那个类别去了)) 废话不多说直接上代码吧(测试代码的丑得一逼;将就着看看吧) 需要一个小工具: pip install chinese-tokenizer 这是训练器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import re
import jieba
import json
from io import BytesIO
from chinese_tokenizer.tokenizer import Tokenizer
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.externals import joblib

jie_ba_tokenizer = Tokenizer().jie_ba_tokenizer

# 加载数据集
training_data = load_files('./data', encoding='utf-8')
# x_train txt内容 y_train 是类别(正 负 中 )
x_train, _, y_train, _ = train_test_split(training_data.data, training_data.target)
print('开始建模.....')
with open('training_data.target', 'w', encoding='utf-8') as f:
f.write(json.dumps(training_data.target_names))
# tokenizer参数是用来对文本进行分词的函数(就是上面我们结巴分词)
count_vect = CountVectorizer(tokenizer=jieba_tokenizer)

tfidf_transformer = TfidfTransformer()
X_train_counts = count_vect.fit_transform(x_train)

X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)
print('正在训练分类器.....')
# 多项式贝叶斯分类器训练
clf = MultinomialNB().fit(X_train_tfidf, y_train)
# 保存分类器(好在其它程序中使用)
joblib.dump(clf, 'model.pkl')
# 保存矢量化(坑在这儿!!需要使用和训练器相同的 矢量器 不然会报错!!!!!! 提示 ValueError dimension mismatch··)
joblib.dump(count_vect, 'count_vect')
print("分类器的相关信息:")
print(clf)

下面是是使用训练好的分类器分类文章: 需要分类的文章放在predict_data目录中:照样是一篇文章一个txt文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# -*- coding: utf-8 -*-
# @Time : 2017/8/23 18:02
# @Author : 哎哟卧槽
# @Site :
# @File : 贝叶斯分类器.py
# @Software: PyCharm

import re
import jieba
import json
from sklearn.datasets import load_files
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.externals import joblib


# 加载分类器
clf = joblib.load('model.pkl')

count_vect = joblib.load('count_vect')
testing_data = load_files('./predict_data', encoding='utf-8')
target_names = json.loads(open('training_data.target', 'r', encoding='utf-8').read())
# # 字符串处理
tfidf_transformer = TfidfTransformer()

X_new_counts = count_vect.transform(testing_data.data)
X_new_tfidf = tfidf_transformer.fit_transform(X_new_counts)
# 进行预测
predicted = clf.predict(X_new_tfidf)
for title, category in zip(testing_data.filenames, predicted):
print('%r => %s' % (title, target_names[category]))

这个样子将训练好的分类器在新的程序中使用时候 就不报错: ValueError dimension mismatch·· 这儿有个demo 仅供参考:GitHub地址