Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/limiter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,30 @@ export class Limiter implements LimiterStoreContract {
*
* @param key - Unique identifier for the rate limit
*/
consume(key: string | number): Promise<LimiterResponse> {
return this.#store.consume(key)
consume(key: string | number, amount?: number): Promise<LimiterResponse> {
return this.#store.consume(key, amount)
}

/**
* Increments the consumed request count for the given key.
* Unlike consume(), this method does not throw when the limit is reached.
*
* @param key - Unique identifier for the rate limit
* @param amount - Number of requests to increment (default: 1)
*/
increment(key: string | number): Promise<LimiterResponse> {
return this.#store.increment(key)
increment(key: string | number, amount?: number): Promise<LimiterResponse> {
return this.#store.increment(key, amount)
}

/**
* Decrements the consumed request count for the given key.
* Will not decrement below zero.
*
* @param key - Unique identifier for the rate limit
* @param amount - Number of requests to decrement (default: 1)
*/
decrement(key: string | number): Promise<LimiterResponse> {
return this.#store.decrement(key)
decrement(key: string | number, amount?: number): Promise<LimiterResponse> {
return this.#store.decrement(key, amount)
}

/**
Expand All @@ -95,7 +97,11 @@ export class Limiter implements LimiterStoreContract {
* }
* ```
*/
async attempt<T>(key: string | number, callback: () => T | Promise<T>): Promise<T | undefined> {
async attempt<T>(
key: string | number,
callback: () => T | Promise<T>,
amount?: number
): Promise<T | undefined> {
/**
* Return early when remaining requests are less than
* zero.
Expand All @@ -110,7 +116,7 @@ export class Limiter implements LimiterStoreContract {
}

try {
await this.consume(key)
await this.consume(key, amount)
return callback()
} catch (error) {
if (error instanceof E_TOO_MANY_REQUESTS === false) {
Expand Down Expand Up @@ -144,7 +150,8 @@ export class Limiter implements LimiterStoreContract {
*/
async penalize<T>(
key: string | number,
callback: () => T | Promise<T>
callback: () => T | Promise<T>,
amount?: number
): Promise<[null, T] | [ThrottleException, null]> {
const response = await this.get(key)

Expand All @@ -169,7 +176,7 @@ export class Limiter implements LimiterStoreContract {
* an error.
*/
if (callbackError) {
const { consumed, limit } = await this.increment(key)
const { consumed, limit } = await this.increment(key, amount)
if (consumed >= limit && this.blockDuration) {
await this.block(key, this.blockDuration)
}
Expand Down
30 changes: 23 additions & 7 deletions src/stores/bridge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,12 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
* console.log(`Remaining: ${response.remaining}`)
* ```
*/
async consume(key: string | number): Promise<LimiterResponse> {
async consume(key: string | number, amount?: number): Promise<LimiterResponse> {
const consumeAmount = amount !== undefined && amount > 0 ? amount : 1

try {
const response = await this.rateLimiter.consume(key, 1)
debug('request consumed for key %s', key)
const response = await this.rateLimiter.consume(key, consumeAmount)
debug('request consumed for key %s with amount %d', key, consumeAmount)
return this.makeLimiterResponse(response)
} catch (errorResponse: unknown) {
debug('unable to consume request for key %s, %O', key, errorResponse)
Expand All @@ -117,8 +119,13 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
* const response = await limiter.increment('user:123')
* ```
*/
async increment(key: string | number): Promise<LimiterResponse> {
const response = await this.rateLimiter.penalty(key, 1)
async increment(key: string | number, amount: number = 1): Promise<LimiterResponse> {
if (amount <= 0) {
debug('invalid increment amount "%d" provided. Falling back to 1', amount)
amount = 1
}

const response = await this.rateLimiter.penalty(key, amount)
debug('increased requests count for key %s', key)

return this.makeLimiterResponse(response)
Expand All @@ -135,7 +142,7 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
* const response = await limiter.decrement('user:123')
* ```
*/
async decrement(key: string | number): Promise<LimiterResponse> {
async decrement(key: string | number, amount: number = 1): Promise<LimiterResponse> {
const existingKey = await this.rateLimiter.get(key)

/**
Expand All @@ -145,17 +152,26 @@ export default abstract class RateLimiterBridge implements LimiterStoreContract
return this.set(key, 0, this.duration)
}

if (amount <= 0) {
debug('invalid decrement amount "%d" provided. Falling back to 1', amount)
amount = 1
}

/**
* Do not decrement beyond zero
*/
if (existingKey.consumedPoints <= 0) {
return this.makeLimiterResponse(existingKey)
}

if (amount > existingKey.consumedPoints) {
amount = existingKey.consumedPoints
}

/**
* Decrement
*/
const response = await this.rateLimiter.reward(key, 1)
const response = await this.rateLimiter.reward(key, amount)
debug('decreased requests count for key %s', key)

return this.makeLimiterResponse(response)
Expand Down
6 changes: 3 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,18 @@ export interface LimiterStoreContract {
* when all the requests have already been consumed or if
* the key is blocked.
*/
consume(key: string | number): Promise<LimiterResponse>
consume(key: string | number, amount?: number): Promise<LimiterResponse>

/**
* Increment the number of consumed requests for a given key.
* No errors are thrown when limit has reached
*/
increment(key: string | number): Promise<LimiterResponse>
increment(key: string | number, amount?: number): Promise<LimiterResponse>

/**
* Decrement the number of consumed requests for a given key.
*/
decrement(key: string | number): Promise<LimiterResponse>
decrement(key: string | number, amount?: number): Promise<LimiterResponse>

/**
* Block a given key for the given duration. The duration must be
Expand Down
148 changes: 145 additions & 3 deletions tests/limiter.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,27 @@ test.group('Limiter', () => {
*/
const consumeCall = sinon.spy(store, 'consume')
await limiter.consume('ip_localhost')
assert.isTrue(consumeCall.calledOnceWithExactly('ip_localhost'), 'consume called')
assert.isTrue(consumeCall.calledOnceWithExactly('ip_localhost', undefined), 'consume called')

/**
* increment call
*/
const incrementCall = sinon.spy(store, 'increment')
await limiter.increment('ip_localhost')
assert.isTrue(incrementCall.calledOnceWithExactly('ip_localhost'), 'increment called')
assert.isTrue(
incrementCall.calledOnceWithExactly('ip_localhost', undefined),
'increment called'
)

/**
* decrement call
*/
const decrementCall = sinon.spy(store, 'decrement')
await limiter.decrement('ip_localhost')
assert.isTrue(decrementCall.calledOnceWithExactly('ip_localhost'), 'decrement called')
assert.isTrue(
decrementCall.calledOnceWithExactly('ip_localhost', undefined),
'decrement called'
)

/**
* get call
Expand Down Expand Up @@ -104,6 +110,142 @@ test.group('Limiter', () => {
await assert.doesNotReject(() => limiter.increment('ip_localhost'))
})

test('increment requests count with negative amount should default to 1', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 5,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', -5)
assert.equal(await limiter.remaining('ip_localhost'), 4)
})

test('decrement requests count with negative amount should default to 1', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 5,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', 5)
await limiter.decrement('ip_localhost', -3)
assert.equal(await limiter.remaining('ip_localhost'), 1)
})

test('increment requests count with zero amount should default to 1', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 5,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', 0)
assert.equal(await limiter.remaining('ip_localhost'), 4)
})

test('decrement requests count with zero amount should default to 1', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 5,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', 5)
await limiter.decrement('ip_localhost', 0)
assert.equal(await limiter.remaining('ip_localhost'), 1)
})

test('increment remaining requests by amount', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 5,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', 3)
const response = await limiter.get('ip_localhost')
assert.containsSubset(response, {
consumed: 3,
remaining: 2,
limit: 5,
})
})

test('decrement consumed requests by amount', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 5,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', 4)
await limiter.decrement('ip_localhost', 2)
const response = await limiter.get('ip_localhost')
assert.containsSubset(response, {
consumed: 2,
remaining: 3,
limit: 5,
})
})

test('consume remaining requests by amount', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 5,
})

const limiter = new Limiter(store)

await limiter.consume('ip_localhost', 3)
const response = await limiter.get('ip_localhost')
assert.containsSubset(response, {
consumed: 3,
remaining: 2,
limit: 5,
})
})

test('increment requests count with a custom amount', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 10,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', 3)
assert.equal(await limiter.remaining('ip_localhost'), 7)
})

test('decrement requests count with a custom amount', async ({ assert }) => {
const redis = createRedis(['rlflx:ip_localhost']).connection()
const store = new LimiterRedisStore(redis, {
duration: '1 minute',
requests: 10,
})

const limiter = new Limiter(store)

await limiter.increment('ip_localhost', 10)
await limiter.decrement('ip_localhost', 4)
assert.equal(await limiter.remaining('ip_localhost'), 4)
})

test('do not run action when all requests have been exhausted', async ({ assert }) => {
const executionStack: string[] = []
const redis = createRedis(['rlflx:ip_localhost']).connection()
Expand Down
Loading
Loading