限流

/ JDK / 没有评论 / 60浏览

在单机中,我们常用的限流方式有Semaphore(Java包中用来限制并发量)、RateLimit(Guava中令牌桶实现,用来控制并发速率),但是在分布式系统中,就没啥用了,下面介绍基于Redis实现。

Semaphore

原理

例如,我们设置的并发量为100,我们利用string数据结构将凭据设置为100,每个线程获取时减一,说明拿到了。如果不大于0,说明获取不到了。释放凭据时,再加一。

我们很容易写出这样的代码:

public boolean tryAcquire() {
    Jedis jedis = null;
    try {
        jedis = pool.getResource();
        // 获取当前剩余的凭据数
        Long current = Long.valueOf(jedis.get(key));
        if (current > 0) {
            // 凭据数大于0,则获取成功,减一
            jedis.incr(key);
            return true;
        }
        return false;
    } catch (JedisException e) {
        LOG.error("tryAcquire error", e);
        return false;
    } finally {
        returnResource(jedis);
    }
}

感觉上这么简单,然而get(key)incr(key)不是原子性,因此不行滴

我们可以用lua脚本将getincr命令封装下,由于redis单线程,因此可以解决这个问题。

提示eval命令。

代码

/**
 * 基于redis lua实现Semaphore
 *
 * @author xuan
 * @since 1.0.0
 */
public class RedisSemaphore implements Semaphore {

    private static final Logger LOG = LoggerFactory.getLogger(Semaphore.class);

    /**
     * redis默认存储的key
     */
    private static final String DEFAULT_KEY = "rateLimit:semaphore";

    /**
     * lua执行脚本,如果大于0,则减一,返回1,代表获取成功
     */
    private static final String SCRIPT_LIMIT =
            "local key = KEYS[1] " +
            "local current = tonumber(redis.call('get', key)) " +
            "local res = 0 " +
            "if current > 0 then " +
            "   redis.call('decr', key) " +
            "   res = 1 " +
            "end " +
            "return res ";

    /**
     * Redis连接池
     */
    private final Pool<Jedis> pool;

    /**
     * redis存储的key
     */
    private final String key;

    /**
     * 凭据限制的数目
     */
    private final Long limits;

    public RedisSemaphore(Pool<Jedis> pool, Long limits) {
        this(pool, DEFAULT_KEY, limits);
    }

    public RedisSemaphore(Pool<Jedis> pool, String key, Long limits) {
        this.pool = pool;
        this.key = key;
        this.limits = limits;
        setup();
    }

    /**
     * 尝试获取凭据,获取不到凭据不等待,直接返回
     *
     * @return 获取到凭据返回true,否则返回false
     */
    @Override
    public boolean tryAcquire() {
        Jedis jedis = null;
        try {
            jedis = pool.getResource();
            Long res = (Long) jedis.eval(SCRIPT_LIMIT, Collections.singletonList(key), Collections.<String>emptyList());
            return res > 0;
        } catch (JedisException e) {
            LOG.error("tryAcquire error", e);
            return false;
        } finally {
            returnResource(jedis);
        }
    }

    /**
     * 释放获取到的凭据
     */
    @Override
    public void release() {
        Jedis jedis = null;
        try {
            jedis = pool.getResource();
            jedis.incr(key);
        } catch (JedisException e) {
            LOG.error("release error", e);
        } finally {
            returnResource(jedis);
        }
    }

    private void setup() {
        Jedis jedis = null;
        try {
            jedis = pool.getResource();
            jedis.del(key);
            jedis.incrBy(key, limits);
        } finally {
            returnResource(jedis);
        }
    }

    private void returnResource(Jedis jedis) {
        if (jedis != null) {
            jedis.close();
        }
    }

}

emm,代码比较简单,没有实现获取不到锁,然后等待那种操作。这里直接是快速失败了,让业务逻辑去处理。

RateLimit

Semaphore控制了并发量,但是没有控制速率,就是那种每秒最多给你多少张凭据,获取完了等待下一秒。。

原理

令牌桶算法

emm,实现起来还是挺复杂的。发现有位道友写的挺棒的,就按照他的代码思路实现了。

代码

/**
 * 基于redis lua实现令牌桶算法
 * 依赖于本地时间,分布式需要服务器时钟同步
 *
 * @author xuan
 * @since 1.0.0
 */
public class RedisBucket implements Bucket {

    private static final Logger LOG = LoggerFactory.getLogger(RedisBucket.class);

    /**
     * redis默认存储的key
     */
    private static final String DEFAULT_KEY = "rateLimit:bucket:";

    /**
     * lua执行脚本,如果超过限制则返回0,否则返回1
     */
    private static final String SCRIPT_LIMIT =
            "local key = KEYS[1] " +
            "local limit = tonumber(ARGV[1]) " +
            "local current = tonumber(redis.call('get', key) or '0') " +
            "local res " +
            // 如果超出限流大小
            "if current + 1 > limit then " +
            "   res = 0 " +
            // 请求数+1,并设置2秒过期
            "else " +
            "   redis.call('incrBy', key, 1) " +
            "   redis.call('expire', key, 2) " +
            "   res = 1 " +
            "end " +
            "return res ";

    /**
     * Redis连接池
     */
    private final Pool<Jedis> pool;

    /**
     * redis存储的key
     */
    private final String key;

    /**
     * 每秒凭据限制的数目
     */
    private final String permitsPerSecond;

    public RedisBucket(Pool<Jedis> pool, Integer permitsPerSecond) {
        this(pool, DEFAULT_KEY, permitsPerSecond);
    }

    public RedisBucket(Pool<Jedis> pool, String key, Integer permitsPerSecond) {
        this.pool = pool;
        this.key = key;
        this.permitsPerSecond = String.valueOf(permitsPerSecond);
    }

    @Override
    public boolean tryAcquire() {
        Jedis jedis = null;
        try {
            jedis = pool.getResource();
            // 将当前时间戳取秒数
            String key = this.key + System.currentTimeMillis() / 1000;
            Long res = (Long) jedis.eval(SCRIPT_LIMIT, Collections.singletonList(key), Collections.singletonList(permitsPerSecond));
            return res == 1;
        } catch (JedisException e) {
            LOG.error("tryAcquire error" + e.getMessage(), e);
            return false;
        } finally {
            returnResource(jedis);
        }
    }

    private void returnResource(Jedis jedis) {
        if (jedis != null) {
            jedis.close();
        }
    }
}

由于依赖于本地时间,因此分布式下需要服务器时钟同步,否则。。。

Test

下面是相关测试。

提示:代码中用了CountDownLatch是为了让子线程执行完成后,主线程再退出。

SemaphoreTest

/**
 * 测试Semaphore
 *
 * @author xuan
 * @since 1.0.0
 */
public class SemaphoreTest {

    private static final Logger LOG = LoggerFactory.getLogger(SemaphoreTest.class);

    private CountDownLatch latch = new CountDownLatch(200);

    private JedisPool pool;
    private Semaphore semaphore;

    @Before
    public void setup() {
        pool = new JedisPool();
        semaphore = new RedisSemaphore(pool, 100L);
    }

    /**
     * 测试获取凭据
     */
    @Test
    public void tryAcquire() {
        for (int i = 0; i < 200; i++) {
            new Thread(() -> {
                // 尝试获取凭据
                if (semaphore.tryAcquire()) {
                    LOG.info("{} - acquired", Thread.currentThread().getName());
                    // 获取到凭据之后,做一些什么
                    doSomething();
                    // 释放凭据
                    semaphore.release();
                } else {
                    LOG.warn("{} - not acquired", Thread.currentThread().getName());
                }
                latch.countDown();
            }).start();
        }
        try {
            latch.await();
            LOG.info("main await end");
        } catch (InterruptedException e) {
            LOG.error("main await error", e);
        }
    }

    private void doSomething() {
        try {
            Thread.sleep(50L);
        } catch (InterruptedException e) {
            LOG.error("InterruptedException", e);
        }
    }

    @After
    public void cleanup() {
        semaphore = null;
        pool.destroy();
    }

}

RateLimitTest

/**
 * 测试令牌桶
 *
 * @author xuan
 * @since 1.0.0
 */
public class RedisBucketTest {

    private static final Logger LOG = LoggerFactory.getLogger(SemaphoreTest.class);

    private CountDownLatch latch = new CountDownLatch(200);

    private JedisPool pool;
    private Bucket bucket;

    @Before
    public void setup() {
        pool = new JedisPool();
        // 每秒10个凭据
        bucket = new RedisBucket(pool, 10);
    }

    /**
     * 测试获取凭据
     */
    @Test
    public void tryAcquire() {
        for (int i = 0; i < 200; i++) {
            new Thread(() -> {
                // 尝试获取凭据
                if (bucket.tryAcquire()) {
                    LOG.info("{} - acquired", Thread.currentThread().getName());
                    // 获取到凭据之后,做一些什么
                    doSomething();
                } else {
                    LOG.warn("{} - not acquired", Thread.currentThread().getName());
                }
                latch.countDown();
            }).start();
        }
        try {
            latch.await();
            LOG.info("main await end");
        } catch (InterruptedException e) {
            LOG.error("main await error", e);
        }
    }

    private void doSomething() {
        try {
            Thread.sleep(500L);
        } catch (InterruptedException e) {
            LOG.error("InterruptedException", e);
        }
    }

    @After
    public void cleanup() {
        bucket = null;
        pool.destroy();
    }

}