python并发与并行(九) ———— 用asyncio改写通过线程实现的IO

13 minute read

Published:

知道了协程的好处之后,我们可能就想把现有项目之中的代码全都改用协程来写,于是有人就担心,这样修改起来,工作量会不会比较大呢?所幸Python已经将异步执行功能很好地集成到语言里面了,所以我们很容易就能把采用线程实现的阻塞式I/O操作转化为采用协程实现的异步I/O操作。

在这里我们要补充下线程和协程的区别,以及他们在执行阻塞式IO和异步IO上的区别。

在计算机编程中,I/O(输入/输出)操作通常涉及等待外部事件完成,如磁盘读写、网络通信等。线程和协程是两种不同的并发执行单元,它们处理阻塞式I/O和异步I/O的方式不同:

  1. 线程(Thread)
    • 线程是操作系统层面的执行单元,拥有自己的栈和独立的执行路径。
    • 当线程执行阻塞式I/O操作时,它会在操作完成之前被操作系统挂起,不会执行其他任务。这意味着在等待I/O操作完成期间,线程不能做其他工作,从而导致资源的浪费。
    • 为了解决这个问题,可以使用多线程,其中一个线程等待I/O操作时,其他线程可以继续执行。但这增加了程序的复杂性,如需要同步和通信机制来避免竞态条件和死锁。
  2. 协程(Coroutine)
    • 协程是一种更轻量级的执行单元,通常由程序内部进行管理,而不是由操作系统管理。
    • 协程主要用于处理计算密集型任务中的异步操作,它们可以暂停执行并在稍后恢复,而不会阻塞整个程序或系统。
    • 在协程中,当遇到I/O操作时,协程可以主动让出控制权,允许其他协程运行。一旦I/O操作完成,原先挂起的协程可以恢复执行。这种方式称为异步I/O,因为它允许程序在等待I/O操作时继续做其他工作。
    • 协程通常与事件循环(Event Loop)一起使用,事件循环负责处理外部事件(如I/O完成)并恢复相应的协程。

为什么线程实现的是阻塞式I/O:

  • 线程在执行I/O操作时,如果操作系统的I/O模型是阻塞式的,线程将会等待I/O操作完成,无法执行其他任务。

为什么协程实现的是异步I/O:

  • 协程允许程序在等待I/O操作时,通过切换到其他协程来执行其他任务,从而实现非阻塞的行为。
  • 异步I/O库通常提供了一种机制,当I/O操作准备好(例如,数据到达或写入完成)时,可以通知程序并恢复等待的协程。

转换阻塞式I/O到异步I/O的优势:

  • 提高效率:程序可以在等待I/O操作时继续执行其他任务,提高CPU利用率。
  • 改善性能:减少线程切换的开销,因为协程切换通常比线程切换要轻量级。
  • 简化编程模型:使用协程可以简化异步编程,因为它允许使用顺序编程风格来编写逻辑,而不必担心底层的并发和同步问题。

例如,我们要写一个基于TCP的服务器程序,让它跟用户玩猜数字的游戏。用户(也就是客户端)通过lower与upper参数把取值范围告诉服务器,让服务器在这个范围里面猜测用户心中的那个整数值。服务器把自己猜的数告诉用户,如果没猜对,用户会告诉服务器这次所猜的值跟上次相比,是离正确答案更近(warmer)还是更远(colder)。

这样的客户端/服务器(client/server,C/S)系统,通常会利用线程与阻塞式的I/O来实现。这种方案要求我们先编写一个辅助类来管理发送信息与接收信息这两项操作。为了便于演示,我们采用文本信息的形式来表达所要发送和接收的命令数据:

class EOFError(Exception):
    pass

class ConnectionBase:
    def __init__(self, connection):
        self.connection = connection
        self.file = connection.makefile('rb')

    def send(self, command):
        line = command + '\n'
        data = line.encode()
        self.connection.send(data)

    def receive(self):
        line = self.file.readline()
        if not line:
            raise EOFError('Connection closed')
        return line[:-1].decode()

我们用下面这样的ConnectionBase子类来实现服务器端的逻辑。每处理一条连接,就创建这样一个Session实例,并通过实例之中的字段维护跟客户端会话时所要用到的相关状态。

服务器类的主方法叫作loop,它会循环地解析客户端所传来的命令,并根据具体内容把这条命令派发给相关的方法去处理。请注意,为了让代码简单一些,这里用到了Python 3.8引入的新功能,也就是赋值表达式。

第一种命令叫作PARAMS命令,客户端在新游戏开局时,会通过该命令把下边界(lower)与上边界(upper)告诉服务器,让它能够在这个范围里面去猜测自己心中预想的那个值。

第二种命令叫作NUMBER,表示客户端要求服务器做一次猜测。这时,我们先在next_guess函数里判断上次的猜测结果。如果上次已经猜对了,那就把保存在self.secret里的值告诉客户端。如果上次没有猜对,那么就在取值范围内随机选一个值。请注意,我们会专门用一个while循环来判断随机选出的这个值以前是否已经选过,要是选过,那就再选,直到选出以前没猜过的值为止。现在,我们将这次选中的值加入guesses列表以免将来重复猜测。最后,通过send方法把值发送给客户端。

第三种命令叫作REPORT,表示客户端接到了我们在响应NUMBER命令时所发过去的那个猜测值并且发来了报告。看到这份报告之后,服务器端就知道自己刚才猜的数值(也就是guesses列表末尾的那个值),与前一次相比,是离正确答案更近了,还是离正确答案更远了。如果恰好猜对,那就把last变量所表示的值赋给self.secret,以便在客户端下次发来NUMBER请求的时候,作为正确答案回传给它。

import random

WARMER = 'Warmer'
COLDER = 'Colder'
UNSURE = 'Unsure'
CORRECT = 'Correct'

class UnknownCommandError(Exception):
    pass

class Session(ConnectionBase):
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_state(None, None)

    def _clear_state(self, lower, upper):
        self.lower = lower
        self.upper = upper
        self.secret = None
        self.guesses = []

    def loop(self):
        while command := self.receive():
            parts = command.split(' ')
            if parts[0] == 'PARAMS':
                self.set_params(parts)
            elif parts[0] == 'NUMBER':
                self.send_number()
            elif parts[0] == 'REPORT':
                self.receive_report(parts)
            else:
                raise UnknownCommandError(command)
    
    # PARAMS命令相关
    def set_params(self, parts):
        assert len(parts) == 3
        lower = int(parts[1])
        upper = int(parts[2])
        self._clear_state(lower, upper)

    # NUMBER命令相关
    def next_guess(self):
        if self.secret is not None:
            return self.secret

        while True:
            guess = random.randint(self.lower, self.upper)
            if guess not in self.guesses:
                return guess
            
    # NUMBER命令相关
    def send_number(self):
        guess = self.next_guess()
        self.guesses.append(guess)
        self.send(format(guess))
    
    #REPORT命令相关
    def receive_report(self, parts):
        assert len(parts) == 2
        decision = parts[1]

        last = self.guesses[-1]
        if decision == CORRECT:
            self.secret = last

        print(f'Server: {last} is {decision}')

客户端的逻辑也用ConnectionBase的子类来实现,这种实例同样会保存会话时所用到的相关状态。

session方法负责开局,我们在启动猜数字游戏时,会通过这个方法把这局游戏的正确答案记录到self.secret字段里面,并把服务器端在猜测这个答案时所要遵守的下边界(lower)与上边界(upper)通过PARAMS命令发过去。为了让服务器端在这局游戏结束后,能够正确地清理状态,我们用@contextlib.contextmanager修饰session方法,这样就可以把它用在with结构里面了,这种结构会适时地触发finally块里的清理语句

然后,我们还要写这样一个方法,用来向服务器端发送NUMBER命令,要求对方做一次猜测,如果猜得不对,就要求服务器继续猜,直到猜对或者猜测次数超过count为止。

最后,还要给客户端类里面写这样一个方法,用来向服务器发送REPORT命令,告诉对方,这次猜的数与上次相比,是距离正确答案更近(WARMER)还是更远(COLDER)。如果刚好猜对,就报告CORRECT,如果是第一次猜或者这两次猜的数字距离正确答案一样近,就报告UNSURE。

import contextlib
import math

class Client(ConnectionBase):
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_state()

    def _clear_state(self):
        self.secret = None
        self.last_distance = None

    @contextlib.contextmanager
    def session(self, lower, upper, secret):
        print(f'Guess a number between {lower} and {upper}!'
              f' Shhhhh, it\'s {secret}.')
        self.secret = secret
        self.send(f'PARAMS {lower} {upper}')
        try:
            yield
        finally:
            self._clear_state()
            self.send('PARAMS 0 -1')

    def request_numbers(self, count):
        for _ in range(count):
            self.send('NUMBER')
            data = self.receive()
            yield int(data)
            if self.last_distance == 0:
                return

    def report_outcome(self, number):
        new_distance = math.fabs(number - self.secret)
        decision = UNSURE

        if new_distance == 0:
            decision = CORRECT
        elif self.last_distance is None:
            pass
        elif new_distance < self.last_distance:
            decision = WARMER
        elif new_distance > self.last_distance:
            decision = COLDER

        self.last_distance = new_distance

        self.send(f'REPORT {decision}')
        return decision

补充:

@contextlib.contextmanager 是一个装饰器,它用于创建一个上下文管理器,这通常用于实现支持 with 语句的自定义对象。上下文管理器允许你定义一段代码的执行前后分别需要执行的代码块,这在需要资源管理时非常有用,比如文件操作、获取锁、数据库事务等场景。

在你给出的 Client 类中的 session 方法上使用了 @contextlib.contextmanager 装饰器,这意味着 session 方法会返回一个上下文管理器。下面是 session 方法的工作原理:

  1. with 语句开始时,session 方法被调用,其参数 loweruppersecret 被传递进去。

  2. 方法内部首先打印一条消息,提示用户猜测一个在 lowerupper 之间的数字,并显示秘密数字(这里假设是一个游戏或者某种交互式应用的一部分)。

  3. secret 赋值给实例变量 self.secret,这可能用于后续的逻辑判断或其他用途。

  4. 通过调用 self.send() 方法发送一个包含参数范围的字符串,这可能是向服务器或其他客户端发送当前会话的参数。

  5. yield 语句暂停 session 方法的执行,并返回控制权给 with 语句块中的代码。在 with 语句块中执行的代码可以访问 session 方法的局部变量,因为 yield 之前的部分创建了一个生成器。

  6. with 语句块中的代码执行完毕后,控制权返回到 session 方法,继续执行 yield 之后的代码。

  7. finally 子句中,调用 self._clear_state() 方法来清除会话状态,这是为了确保每次会话结束后资源被正确释放,避免潜在的状态污染。

  8. 最后,再次调用 self.send() 方法发送一个参数为 0 -1 的字符串,这可能表示会话结束或重置参数。

使用 @contextlib.contextmanager 的好处是它允许你以一种非常 Pythonic 的方式编写清晰的上下文管理代码,而不需要定义一个类并实现 __enter____exit__ 方法。这种方式更加简洁,易于理解和使用。

现在开始为运行服务器做准备。编写run_server方法给服务器线程来调用,这个方法会在socket上面监听,并接受连接请求。每连接一个客户,它就启动一条线程处理该连接。

import socket
from threading import Thread

def handle_connection(connection):
    with connection:
        session = Session(connection)
        try:
            session.loop()
        except EOFError:
            pass

def run_server(address):
    with socket.socket() as listener:
        # Allow the port to be reused
        listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        listener.bind(address)
        listener.listen()
        while True:
            connection, _ = listener.accept()
            thread = Thread(target=handle_connection,
                            args=(connection,),
                            daemon=True)
            thread.start()

补充:

  1. with connection::使用 with 语句确保 connection 对象在使用后能够正确关闭。这里假设 connection 对象实现了上下文管理协议(即有 enterexit 方法。

  2. with socket.socket() as listener::创建一个 socket 对象并命名为 listener,使用 with 语句确保在结束时正确关闭。

  3. socket.SOL_SOCKET:这个常量代表“Socket Level”,是一个通用的选项级别,用于指定接下来的选项是针对套接字本身的。它是一个整数常量,通常用于 setsockopt 函数来指定选项作用的层级。当调用 setsockopt 方法时,第一个参数是 socket.SOL_SOCKET,表示接下来的选项是设置在套接字级别上的。

  4. socket.SO_REUSEADDR:这个常量代表“Socket Option Reuse Address”,是一个选项,用于控制套接字的行为,允许套接字绑定到一个已经被使用(占用)的地址和端口上。通常,当应用程序尝试绑定一个已经在使用中的端口时,系统会抛出一个错误。但是,如果设置了 SO_REUSEADDR 选项,就可以避免这个错误,允许绑定操作成功。 使用 SO_REUSEADDR 的一个常见场景是在重启服务器时,如果服务器在关闭时没有正确释放端口,操作系统通常会保持端口在一定的时间内处于“TIME_WAIT”状态,导致无法立即重新使用该端口。通过设置 SO_REUSEADDR,可以告诉操作系统允许应用程序重新绑定到这个端口上。
  5. 在 Python 的 socket 库中,除了 socket.SOL_SOCKET 这个套接字选项级别外,还有其他几种选项级别,主要用于指定不同的协议层或者用于特定类型的套接字选项。以下是一些常见的选项级别:

socket.SOL_SOCKET:套接字选项,如上所述,用于通用的套接字级选项。

socket.IPPROTO_IP:对应于 IP 协议层的选项。在使用 IPv4 套接字时,这个级别用于设置或获取 IP 层相关的选项。

socket.IPPROTO_TCP:对应于 TCP 协议层的选项。在使用 TCP 套接字时,这个级别用于设置或获取 TCP 层相关的选项,例如 TCP_NODELAY(禁用 Nagle 算法)。

socket.IPPROTO_UDP:对应于 UDP 协议层的选项。在使用 UDP 套接字时,这个级别用于设置或获取 UDP 层相关的选项。

socket.IPPROTO_IPV6:对应于 IPv6 协议层的选项。在使用 IPv6 套接字时,这个级别用于设置或获取 IPv6 相关的选项,例如 IPV6_V6ONLY(限制套接字只使用 IPv6)。

socket.IPPROTO_ICMP:对应于 ICMP 协议层的选项,通常用于设置或获取 ICMP 相关的选项。

socket.IPPROTO_RAW:对应于原始套接字的选项,原始套接字允许你发送和接收任意的原始 IP 数据报。

socket.IPPROTO_ICMPV6:对应于 ICMPv6 协议层的选项,用于设置或获取 ICMPv6 相关的选项。

客户端放在主线程里面执行。让主线程调用下面这个函数,意思是连玩两局游戏,每局最多让服务器猜5次,然后把游戏结果收集到results里面。请注意,这段代码专门使用了Python语言之中的许多特性,例如for循环、with语句、生成器、列表推导等。这样写,是想让我们稍后能够清楚地看到,把这种代码迁移到协程实现方案上面的工作量到底大不大。 

def run_client(address):
    with socket.create_connection(address) as connection:
        client = Client(connection)

        with client.session(1, 5, 3):
            results = [(x, client.report_outcome(x))
                       for x in client.request_numbers(5)]

        with client.session(10, 15, 12):
            for number in client.request_numbers(5):
                outcome = client.report_outcome(number)
                results.append((number, outcome))

    return results

全都准备好之后,我们把代码拼接起来.

def main():
    address = ('127.0.0.1', 1234)
    server_thread = Thread(
        target=run_server, args=(address,), daemon=True)
    server_thread.start()

    results = run_client(address)
    for number, outcome in results:
        print(f'Client: {number} is {outcome}')

main()

完整的代码如下:

# Example 1
class EOFError(Exception):
    pass

class ConnectionBase:
    def __init__(self, connection):
        self.connection = connection
        self.file = connection.makefile('rb')

    def send(self, command):
        line = command + '\n'
        data = line.encode()
        self.connection.send(data)

    def receive(self):
        line = self.file.readline()
        if not line:
            raise EOFError('Connection closed')
        return line[:-1].decode()


# Example 2
import random

WARMER = 'Warmer'
COLDER = 'Colder'
UNSURE = 'Unsure'
CORRECT = 'Correct'

class UnknownCommandError(Exception):
    print('unknown command')

class Session(ConnectionBase):
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_state(None, None)

    def _clear_state(self, lower, upper):
        self.lower = lower
        self.upper = upper
        self.secret = None
        self.guesses = []


# Example 3
    def loop(self):
        while command := self.receive():
            parts = command.split(' ')
            if parts[0] == 'PARAMS':
                self.set_params(parts)
            elif parts[0] == 'NUMBER':
                self.send_number()
            elif parts[0] == 'REPORT':
                self.receive_report(parts)
            else:
                raise UnknownCommandError(command)


# Example 4
    def set_params(self, parts):
        assert len(parts) == 3
        lower = int(parts[1])
        upper = int(parts[2])
        self._clear_state(lower, upper)


# Example 5
    def next_guess(self):
        if self.secret is not None:
            return self.secret

        while True:
            guess = random.randint(self.lower, self.upper)
            if guess not in self.guesses:
                return guess

    def send_number(self):
        guess = self.next_guess()
        self.guesses.append(guess)
        self.send(format(guess))


# Example 6
    def receive_report(self, parts):
        assert len(parts) == 2
        decision = parts[1]

        last = self.guesses[-1]
        if decision == CORRECT:
            self.secret = last

        print(f'Server: {last} is {decision}')


# Example 7
import contextlib
import math

class Client(ConnectionBase):
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_state()

    def _clear_state(self):
        self.secret = None
        self.last_distance = None


# Example 8
    @contextlib.contextmanager
    def session(self, lower, upper, secret):
        print(f'Guess a number between {lower} and {upper}!'
              f' Shhhhh, it\'s {secret}.')
        self.secret = secret
        self.send(f'PARAMS {lower} {upper}')
        try:
            yield
        finally:
            self._clear_state()
            self.send('PARAMS 0 -1')


# Example 9
    def request_numbers(self, count):
        for _ in range(count):
            self.send('NUMBER')
            data = self.receive()
            yield int(data)
            if self.last_distance == 0:
                return


# Example 10
    def report_outcome(self, number):
        new_distance = math.fabs(number - self.secret)
        decision = UNSURE

        if new_distance == 0:
            decision = CORRECT
        elif self.last_distance is None:
            pass
        elif new_distance < self.last_distance:
            decision = WARMER
        elif new_distance > self.last_distance:
            decision = COLDER

        self.last_distance = new_distance

        self.send(f'REPORT {decision}')
        return decision


# Example 11
import socket
from threading import Thread

def handle_connection(connection):
    with connection:
        session = Session(connection)
        try:
            session.loop()
        except EOFError:
            pass

def run_server(address):
    with socket.socket() as listener:
        # Allow the port to be reused
        listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        listener.bind(address)
        listener.listen()
        while True:
            connection, _ = listener.accept()
            thread = Thread(target=handle_connection,
                            args=(connection,),
                            daemon=True)
            thread.start()


# Example 12
def run_client(address):
    with socket.create_connection(address) as connection:
        client = Client(connection)

        with client.session(1, 5, 3):
            results = [(x, client.report_outcome(x))
                       for x in client.request_numbers(5)]

        with client.session(10, 15, 12):
            for number in client.request_numbers(5):
                outcome = client.report_outcome(number)
                results.append((number, outcome))

    return results


# Example 13
def main():
    address = ('127.0.0.1', 1234)
    server_thread = Thread(
        target=run_server, args=(address,), daemon=True)
    server_thread.start()

    results = run_client(address)
    for number, outcome in results:
        print(f'Client: {number} is {outcome}')

main()

output:

Guess a number between 1 and 5! Shhhhh, it's 3.
Server: 4 is Unsure
Server: 5 is Colder
Server: 2 is Warmer
Server: 1 is Colder
Guess a number between 10 and 15! Shhhhh, it's 12.
Server: 3 is Correct
Client: 4 is Unsure
Client: 5 is Colder
Client: 2 is Warmer
Client: 1 is Colder
Client: 3 is Correct
Client: 12 is Correct

如果用内置的asyncio模块搭配async与await关键字来实现,那么需要修改的地方,究竟有多少呢?

首先,服务器逻辑与客户端逻辑共用的那个ConnectionBase基类必须修改,这次它不能通过send与receive方法直接执行阻塞式的I/O了,而是必须把这两个方法变为协程,也就是在声明的时候加上async关键字。

class AsyncConnectionBase:
    def __init__(self, reader, writer):             # Changed
        self.reader = reader                        # Changed
        self.writer = writer                        # Changed

    async def send(self, command):
        line = command + '\n'
        data = line.encode()
        self.writer.write(data)                     # Changed
        await self.writer.drain()                   # Changed

    async def receive(self):
        line = await self.reader.readline()         # Changed
        if not line:
            raise EOFError('Connection closed')
        return line[:-1].decode()

改完之后,我们可以创建这样一个有状态的子类,用来在服务器这边维护某条连接的会话状态。这个类跟早前表示服务器逻辑的那个Session类一样,也从刚才那个基类继承,只不过基类的名字现在已经变成AsyncConnectionBase,而不是ConnectionBase。

然后,我们来修改服务器逻辑里面的主要入口点,也就是处理命令所用的loop方法。其实只需要稍微改几个地方,就能把它变为协程。

处理第一种命令(也就是PARAMS命令)的那个方法不需要改动。

处理第二种命令(也就是NUMBER命令)的send_number方法,需要加上async关键字,这样它才能变为协程。在实现代码里面,只有一个地方要改,也就是必须用异步I/O向客户端发送所猜的数值。

处理第三种命令(也就是REPORT命令)的那个方法保持不变。

class AsyncSession(AsyncConnectionBase):            # Changed
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_values(None, None)

    def _clear_values(self, lower, upper):
        self.lower = lower
        self.upper = upper
        self.secret = None
        self.guesses = []

    async def loop(self):                           # Changed
        while command := await self.receive():      # Changed
            parts = command.split(' ')
            if parts[0] == 'PARAMS':
                self.set_params(parts)
            elif parts[0] == 'NUMBER':
                await self.send_number()            # Changed
            elif parts[0] == 'REPORT':
                self.receive_report(parts)
            else:
                raise UnknownCommandError(command)
            
    def set_params(self, parts):
        assert len(parts) == 3
        lower = int(parts[1])
        upper = int(parts[2])
        self._clear_values(lower, upper)
        
    def next_guess(self):
        if self.secret is not None:
            return self.secret

        while True:
            guess = random.randint(self.lower, self.upper)
            if guess not in self.guesses:
                return guess

    async def send_number(self):                    # Changed
        guess = self.next_guess()
        self.guesses.append(guess)
        await self.send(format(guess))              # Changed
        
    def receive_report(self, parts):
        assert len(parts) == 2
        decision = parts[1]

        last = self.guesses[-1]
        if decision == CORRECT:
            self.secret = last

        print(f'Server: {last} is {decision}')

跟服务器端的逻辑类相似,客户端的逻辑类也需要继承AsyncConnectionBase。

客户端中负责向服务器发送PARAMS命令的那个方法,现在必须声明成async方法,实现代码里面有几个地方需要加上await关键字。此外,要改用contextlib这个内置模块之中的另一个辅助函数(也就是asynccontextmanager)来修饰该方法,而不能像原来那样,用contextmanager修饰。

负责向服务器发送NUMBER命令的那个方法必须加async关键字,这样才能变为协程。另外就是必须在执行send与receive操作的那两个地方分别加上await关键字。

负责向服务器发送REPORT命令的那个方法要加上async关键字,它里面的send操作要用await来执行。

class AsyncClient(AsyncConnectionBase):             # Changed
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_state()

    def _clear_state(self):
        self.secret = None
        self.last_distance = None


# Example 21
    @contextlib.asynccontextmanager                 # Changed
    async def session(self, lower, upper, secret):  # Changed
        print(f'Guess a number between {lower} and {upper}!'
              f' Shhhhh, it\'s {secret}.')
        self.secret = secret
        await self.send(f'PARAMS {lower} {upper}')  # Changed
        try:
            yield
        finally:
            self._clear_state()
            await self.send('PARAMS 0 -1')          # Changed


# Example 22
    async def request_numbers(self, count):         # Changed
        for _ in range(count):
            await self.send('NUMBER')               # Changed
            data = await self.receive()             # Changed
            yield int(data)
            if self.last_distance == 0:
                return


# Example 23
    async def report_outcome(self, number):         # Changed
        new_distance = math.fabs(number - self.secret)
        decision = UNSURE

        if new_distance == 0:
            decision = CORRECT
        elif self.last_distance is None:
            pass
        elif new_distance < self.last_distance:
            decision = WARMER
        elif new_distance > self.last_distance:
            decision = COLDER

        self.last_distance = new_distance

        await self.send(f'REPORT {decision}')       # Changed
        # Make it so the output printing is in
        # the same order as the threaded version.
        await asyncio.sleep(0.01)
        return decision

用来运行服务器的那个run_server方法,现在必须重新实现。这次通过内置的asyncio模块里面的start_server函数启动服务器。

import asyncio

async def handle_async_connection(reader, writer):
    session = AsyncSession(reader, writer)
    try:
        await session.loop()
    except EOFError:
        pass

async def run_async_server(address):
    server = await asyncio.start_server(
        handle_async_connection, *address)
    async with server:
        await server.serve_forever()

用来运行客户端并启动游戏的run_client函数,几乎每行都要改,因为它现在不能再通过阻塞式的I/O去跟socket实例交互了,而是必须改用asyncio里面提供的类似功能来实现。另外,凡是与协程交互的那些代码行都必须适当地添加async或await关键字。如果某个地方忘了写,那么程序在运行时就会出现异常.

async def run_async_client(address):
    # Wait for the server to listen before trying to connect
    await asyncio.sleep(0.1)

    streams = await asyncio.open_connection(*address)   # New
    client = AsyncClient(*streams)                      # New

    async with client.session(1, 5, 3):
        results = [(x, await client.report_outcome(x))
                   async for x in client.request_numbers(5)]

    async with client.session(10, 15, 12):
        async for number in client.request_numbers(5):
            outcome = await client.report_outcome(number)
            results.append((number, outcome))

    _, writer = streams                                 # New
    writer.close()                                      # New
    await writer.wait_closed()                          # New

    return results


把run_client改写为run_async_client的过程中,最妙的地方在于,原函数操作客户端的这套流程基本上不用调整,只要在适当的位置写上await或async关键字,就能够使用这个新的AsyncClient客户端,并调用其中的相关协程了。笔者原来说过,这个函数故意运用了Python之中的许多特性,在这里我们看到,这些特性都有对应的异步版本,所以很容易就能实现迁移。

当然,并不是所有代码都能这么容易地迁移到协程方案上面。例如,目前还没有异步版本的next与iter内置函数,所以我们必须直接在__anext__与__aiter__方法上面做await。另外,yield from也没有异步版本,所以要想把生成器组合起来,必须多写一些代码。

最后,负责把整个程序拼合起来的那个main函数也需要改成异步版本,这样我们才能从头到尾看到完整的游戏效果。笔者在这里通过asyncio.create_task函数把运行服务器的那项操作(也就是run_async_server(address))安排到事件循环里面,这样的话,等函数推进到await语句时,系统就可以让该操作与另一项操作(也就是运行客户端的那项run_async_client(address)操作)平行地执行了。这当然也是一种实现fan-out模式的方法,但它跟我们在之前的康威生命游戏里所讲的那种办法有个区别,那种办法分派的是同一种任务(也就是更新单元格的状态)并且要通过asyncio.gather来收集运行结果,而这里要分派的,则是两种不同的任务(一种是运行服务器,另一种是运行客户端)。

async def main_async():
    address = ('127.0.0.1', 4321)

    server = run_async_server(address)
    asyncio.create_task(server)

    results = await run_async_client(address)
    for number, outcome in results:
        print(f'Client: {number} is {outcome}')

asyncio.run(main_async())

这样写,能够实现出正确的运行效果,而且协程版本的代码要比原来更容易理解,因为我们不用再跟线程交互了,那些操作全都可以删掉。内置的asyncio模块提供了许多辅助函数,让我们能够用比较少的代码实现跟早前一样的服务器逻辑,而不用再像原来那样,必须编写许多例行代码来操纵socket。

完整的代码如下:

import contextlib
import random
import math

WARMER = 'Warmer'
COLDER = 'Colder'
UNSURE = 'Unsure'
CORRECT = 'Correct'
class AsyncConnectionBase:
    def __init__(self, reader, writer):             # Changed
        self.reader = reader                        # Changed
        self.writer = writer                        # Changed

    async def send(self, command):
        line = command + '\n'
        data = line.encode()
        self.writer.write(data)                     # Changed
        await self.writer.drain()                   # Changed

    async def receive(self):
        line = await self.reader.readline()         # Changed
        if not line:
            raise EOFError('Connection closed')
        return line[:-1].decode()


# Example 15
class AsyncSession(AsyncConnectionBase):            # Changed
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_values(None, None)

    def _clear_values(self, lower, upper):
        self.lower = lower
        self.upper = upper
        self.secret = None
        self.guesses = []


# Example 16
    async def loop(self):                           # Changed
        while command := await self.receive():      # Changed
            parts = command.split(' ')
            if parts[0] == 'PARAMS':
                self.set_params(parts)
            elif parts[0] == 'NUMBER':
                await self.send_number()            # Changed
            elif parts[0] == 'REPORT':
                self.receive_report(parts)
            else:
                raise UnknownCommandError(command)


# Example 17
    def set_params(self, parts):
        assert len(parts) == 3
        lower = int(parts[1])
        upper = int(parts[2])
        self._clear_values(lower, upper)


# Example 18
    def next_guess(self):
        if self.secret is not None:
            return self.secret

        while True:
            guess = random.randint(self.lower, self.upper)
            if guess not in self.guesses:
                return guess

    async def send_number(self):                    # Changed
        guess = self.next_guess()
        self.guesses.append(guess)
        await self.send(format(guess))              # Changed


# Example 19
    def receive_report(self, parts):
        assert len(parts) == 2
        decision = parts[1]

        last = self.guesses[-1]
        if decision == CORRECT:
            self.secret = last

        print(f'Server: {last} is {decision}')


# Example 20
class AsyncClient(AsyncConnectionBase):             # Changed
    def __init__(self, *args):
        super().__init__(*args)
        self._clear_state()

    def _clear_state(self):
        self.secret = None
        self.last_distance = None


# Example 21
    @contextlib.asynccontextmanager                 # Changed
    async def session(self, lower, upper, secret):  # Changed
        print(f'Guess a number between {lower} and {upper}!'
              f' Shhhhh, it\'s {secret}.')
        self.secret = secret
        await self.send(f'PARAMS {lower} {upper}')  # Changed
        try:
            yield
        finally:
            self._clear_state()
            await self.send('PARAMS 0 -1')          # Changed


# Example 22
    async def request_numbers(self, count):         # Changed
        for _ in range(count):
            await self.send('NUMBER')               # Changed
            data = await self.receive()             # Changed
            yield int(data)
            if self.last_distance == 0:
                return


# Example 23
    async def report_outcome(self, number):         # Changed
        new_distance = math.fabs(number - self.secret)
        decision = UNSURE

        if new_distance == 0:
            decision = CORRECT
        elif self.last_distance is None:
            pass
        elif new_distance < self.last_distance:
            decision = WARMER
        elif new_distance > self.last_distance:
            decision = COLDER

        self.last_distance = new_distance

        await self.send(f'REPORT {decision}')       # Changed
        # Make it so the output printing is in
        # the same order as the threaded version.
        await asyncio.sleep(0.01)
        return decision


# Example 24
import asyncio

async def handle_async_connection(reader, writer):
    session = AsyncSession(reader, writer)
    try:
        await session.loop()
    except EOFError:
        pass

async def run_async_server(address):
    server = await asyncio.start_server(
        handle_async_connection, *address)
    async with server:
        await server.serve_forever()


# Example 25
async def run_async_client(address):
    # Wait for the server to listen before trying to connect
    await asyncio.sleep(0.1)

    streams = await asyncio.open_connection(*address)   # New
    client = AsyncClient(*streams)                      # New

    async with client.session(1, 5, 3):
        results = [(x, await client.report_outcome(x))
                   async for x in client.request_numbers(5)]

    async with client.session(10, 15, 12):
        async for number in client.request_numbers(5):
            outcome = await client.report_outcome(number)
            results.append((number, outcome))

    _, writer = streams                                 # New
    writer.close()                                      # New
    await writer.wait_closed()                          # New

    return results


# Example 26
async def main_async():
    address = ('127.0.0.1', 4321)

    server = run_async_server(address)
    asyncio.create_task(server)

    results = await run_async_client(address)
    for number, outcome in results:
        print(f'Client: {number} is {outcome}')



asyncio.run(main_async())

Output:

Guess a number between 1 and 5! Shhhhh, it's 3.
Server: 2 is Unsure
Server: 5 is Colder
Server: 3 is Correct
Guess a number between 10 and 15! Shhhhh, it's 12.
Server: 12 is Correct
Client: 2 is Unsure
Client: 5 is Colder
Client: 3 is Correct
Client: 12 is Correct