Skip to content

Commit 6d31d9f

Browse files
committed
adding cron expression handler for scheduling scale up / down
1 parent 4ff84aa commit 6d31d9f

File tree

10 files changed

+142
-67
lines changed

10 files changed

+142
-67
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package com.redis.autoscaler;
2+
3+
import com.redis.autoscaler.documents.Rule;
4+
import com.redis.autoscaler.documents.RuleRepository;
5+
import com.redis.autoscaler.documents.TriggerType;
6+
import com.redis.autoscaler.services.RedisCloudDatabaseService;
7+
import com.redis.autoscaler.services.SchedulingService;
8+
import org.springframework.boot.CommandLineRunner;
9+
import org.springframework.stereotype.Component;
10+
11+
@Component
12+
public class ScheduleStartup implements CommandLineRunner {
13+
private final SchedulingService schedulingService;
14+
private final RuleRepository ruleRepository;
15+
private final RedisCloudDatabaseService redisCloudDatabaseService;
16+
17+
public ScheduleStartup(SchedulingService schedulingService, RuleRepository ruleRepository, RedisCloudDatabaseService redisCloudDatabaseService) {
18+
this.schedulingService = schedulingService;
19+
this.ruleRepository = ruleRepository;
20+
this.redisCloudDatabaseService = redisCloudDatabaseService;
21+
}
22+
23+
@Override
24+
public void run(String... args) throws Exception {
25+
Iterable<Rule> rules = ruleRepository.findByTriggerType(TriggerType.Scheduled);
26+
for (Rule rule : rules) {
27+
schedulingService.scheduleTask(rule.getRuleId(), rule.getTriggerValue(), () -> {
28+
try {
29+
redisCloudDatabaseService.applyRule(rule);
30+
} catch (Exception e) {
31+
throw new RuntimeException(e);
32+
}
33+
});
34+
}
35+
}
36+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package com.redis.autoscaler;
2+
3+
import org.springframework.context.annotation.Bean;
4+
import org.springframework.scheduling.TaskScheduler;
5+
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
6+
7+
public class SchedulerConfig {
8+
@Bean
9+
public TaskScheduler taskScheduler(){
10+
return new ThreadPoolTaskScheduler();
11+
}
12+
}

autoscaler/src/main/java/com/redis/autoscaler/controllers/AlertController.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ public ResponseEntity<Map<String,Task>> inboundAlert(@RequestBody String jsonBod
8181
}
8282

8383
// 3. Find Rules associated alert type and DB ID
84-
Iterable<Rule> rules = ruleRepository.findByDbIdAndRuleType(dbId, ruleType);
84+
Iterable<Rule> rules = ruleRepository.findByDbIdAndRuleTypeAndTriggerType(dbId, ruleType, TriggerType.Webhook);
8585
if(!rules.iterator().hasNext()){
8686
LOG.info("No rule found for dbId: {} and alertName: {} JSON Body: {}", dbId, ruleType, jsonBody);
8787
continue; // move onto next alert
8888
}
8989

9090
// 4. If not, run scaling
9191
Rule rule = rules.iterator().next();
92-
Optional<Task> res = redisCloudDatabaseService.applyRule(rule, dbId);
92+
Optional<Task> res = redisCloudDatabaseService.applyRule(rule);
9393
if(res.isEmpty()){
9494
LOG.info("Failed to apply rule for dbId: {} and alertName: {}", dbId, ruleType);
9595
continue;

autoscaler/src/main/java/com/redis/autoscaler/controllers/RulesController.java

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,36 @@
33
import com.redis.autoscaler.documents.RuleType;
44
import com.redis.autoscaler.documents.RuleRepository;
55
import com.redis.autoscaler.documents.Rule;
6+
import com.redis.autoscaler.documents.TriggerType;
7+
import com.redis.autoscaler.services.RedisCloudDatabaseService;
8+
import com.redis.autoscaler.services.SchedulingService;
69
import org.slf4j.Logger;
710
import org.springframework.http.HttpEntity;
811
import org.springframework.http.HttpStatus;
912
import org.springframework.http.ResponseEntity;
13+
import org.springframework.scheduling.support.CronExpression;
1014
import org.springframework.web.bind.annotation.*;
1115

16+
import java.io.IOException;
17+
import java.util.Objects;
1218
import java.util.Optional;
1319

1420
@RestController
1521
@RequestMapping("/rules")
1622
public class RulesController {
1723
private static final Logger LOG = org.slf4j.LoggerFactory.getLogger(RulesController.class);
24+
private final RedisCloudDatabaseService redisCloudDatabaseService;
1825
private final RuleRepository ruleRepository;
26+
private final SchedulingService schedulingService;
1927

20-
public RulesController(RuleRepository ruleRepository) {
28+
public RulesController(RedisCloudDatabaseService redisCloudDatabaseService, RuleRepository ruleRepository, SchedulingService schedulingService) {
29+
this.redisCloudDatabaseService = redisCloudDatabaseService;
2130
this.ruleRepository = ruleRepository;
31+
this.schedulingService = schedulingService;
2232
}
2333

2434
@PostMapping
25-
public HttpEntity<Rule> createRule(@RequestBody Rule rule) {
35+
public HttpEntity<Object> createRule(@RequestBody Rule rule) {
2636
LOG.info("Received request to create rule: {}", rule);
2737

2838
if(rule.getRuleType() == RuleType.IncreaseMemory || rule.getRuleType() == RuleType.DecreaseMemory) {
@@ -32,11 +42,30 @@ public HttpEntity<Rule> createRule(@RequestBody Rule rule) {
3242
}
3343

3444
LOG.info("Attempting to create rule: {}", rule);
35-
if(ruleRepository.findByDbIdAndRuleType(rule.getDbId(), rule.getRuleType()).iterator().hasNext()) {
45+
if(ruleRepository.findByDbIdAndRuleTypeAndTriggerType(rule.getDbId(), rule.getRuleType(), rule.getTriggerType()).iterator().hasNext()) {
3646
return ResponseEntity.status(HttpStatus.CONFLICT).build();
3747
}
3848

39-
ruleRepository.save(rule);
49+
if(rule.getTriggerType() == TriggerType.Scheduled){
50+
try {
51+
CronExpression.parse(rule.getTriggerValue());
52+
ruleRepository.save(rule);
53+
schedulingService.scheduleTask(rule.getRuleId(), rule.getTriggerValue(), () -> {
54+
try {
55+
redisCloudDatabaseService.applyRule(rule);
56+
} catch (IOException | InterruptedException e) {
57+
throw new RuntimeException(e);
58+
}
59+
});
60+
} catch (Exception e) {
61+
LOG.error("Invalid cron expression: {} {}", rule.getTriggerValue(), e.toString());
62+
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body("Invalid cron expression: " + rule.getTriggerValue());
63+
}
64+
}
65+
else{
66+
ruleRepository.save(rule);
67+
}
68+
4069
return ResponseEntity.of(Optional.of(rule));
4170
}
4271

@@ -58,6 +87,10 @@ public HttpEntity<Void> deleteRule(@PathVariable String ruleId) {
5887
return ResponseEntity.status(HttpStatus.NOT_FOUND).build();
5988
}
6089

90+
if(rule.getTriggerValue() != null && rule.getTriggerType() == TriggerType.Scheduled) {
91+
schedulingService.cancelTask(rule.getRuleId());
92+
}
93+
6194
ruleRepository.delete(rule);
6295
return ResponseEntity.status(HttpStatus.NO_CONTENT).build();
6396
}

autoscaler/src/main/java/com/redis/autoscaler/documents/Rule.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ public class Rule {
2121
@Indexed
2222
protected ScaleType scaleType;
2323

24+
@Indexed
25+
protected TriggerType triggerType;
26+
27+
protected String triggerValue;
28+
2429
protected double scaleValue;
2530
protected double scaleCeiling;
2631
protected double scaleFloor;

autoscaler/src/main/java/com/redis/autoscaler/documents/RuleRepository.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.redis.om.spring.repository.RedisDocumentRepository;
44

55
public interface RuleRepository extends RedisDocumentRepository<Rule, String> {
6-
Iterable<Rule> findByDbIdAndRuleType(String dbId, RuleType ruleType);
6+
Iterable<Rule> findByDbIdAndRuleTypeAndTriggerType(String dbId, RuleType ruleType, TriggerType triggerType);
77
Iterable<Rule> findByDbId(String dbId);
8+
Iterable<Rule> findByTriggerType(TriggerType triggerType);
89
}

autoscaler/src/main/java/com/redis/autoscaler/documents/RuleType.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
public enum RuleType {
66
IncreaseMemory("IncreaseMemory"),
77
DecreaseMemory("DecreaseMemory"),
8-
IncreaseShards("IncreaseShards"),
9-
DecreaseShards("DecreaseShards"),
108
IncreaseThroughput("IncreaseThroughput"),
119
DecreaseThroughput("DecreaseThroughput");
1210

autoscaler/src/main/java/com/redis/autoscaler/documents/TriggerType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
public enum TriggerType {
66
Webhook("webhook"),
7-
Schedule("schedule");
7+
Scheduled("scheduled");
88

99
private final String value;
1010
TriggerType(String value){

autoscaler/src/main/java/com/redis/autoscaler/services/RedisCloudDatabaseService.java

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ public RedisCloudDatabase getDatabase(String dbId) throws IOException, Interrupt
5353
return objectMapper.readValue(response.body(), RedisCloudDatabase.class);
5454
}
5555

56-
public Optional<Task> applyRule(Rule rule, String dbId) throws IOException, InterruptedException {
56+
public Optional<Task> applyRule(Rule rule) throws IOException, InterruptedException {
57+
String dbId = rule.getDbId();
5758
// Apply the rule to the database
5859
RedisCloudDatabase db = getDatabase(dbId);
5960

@@ -92,20 +93,6 @@ public Optional<Task> applyRule(Rule rule, String dbId) throws IOException, Inte
9293

9394
scaleRequest = ScaleRequest.builder().throughputMeasurement(new ThroughputMeasurement(ThroughputMeasurement.ThroughputMeasureBy.OperationsPerSecond, newThroughput)).build();
9495
}
95-
case IncreaseShards, DecreaseShards -> {
96-
if(db.getThroughputMeasurement().getBy() != ThroughputMeasurement.ThroughputMeasureBy.NumberOfShards && rule.getScaleType() != ScaleType.Deterministic){
97-
LOG.info("DB: {} ID: {} is not measured by number of shards, cannot apply shard rule: {}",db.getName(), dbId, rule.getRuleType());
98-
return Optional.empty();
99-
}
100-
101-
long newShardCount = getNewShardCount(rule, db);
102-
if(newShardCount == db.getThroughputMeasurement().getValue()){
103-
LOG.info("DB: {} ID: {} is already at the min/max shard count: {}",db.getName(), dbId, newShardCount);
104-
return Optional.empty();
105-
}
106-
107-
scaleRequest = ScaleRequest.builder().throughputMeasurement(new ThroughputMeasurement(ThroughputMeasurement.ThroughputMeasureBy.NumberOfShards, newShardCount)).build();
108-
}
10996
default -> {
11097
return Optional.empty();
11198
}
@@ -114,48 +101,6 @@ public Optional<Task> applyRule(Rule rule, String dbId) throws IOException, Inte
114101
return Optional.of(scaleDatabase(dbId, scaleRequest));
115102
}
116103

117-
private static long getNewShardCount(Rule rule, RedisCloudDatabase db){
118-
long newShards;
119-
long currentShards = db.getThroughputMeasurement().getValue();
120-
if(rule.getRuleType() == RuleType.IncreaseShards){
121-
switch (rule.getScaleType()){
122-
case Step -> {
123-
newShards = currentShards + (long)rule.getScaleValue();
124-
}
125-
case Exponential -> {
126-
newShards = (long)Math.ceil(currentShards * rule.getScaleValue());
127-
}
128-
case Deterministic -> {
129-
newShards = (long)rule.getScaleValue();
130-
}
131-
132-
default -> throw new IllegalStateException("Unexpected value: " + rule.getScaleType());
133-
}
134-
135-
newShards = Math.min(newShards, (long)rule.getScaleCeiling());
136-
} else if(rule.getRuleType() == RuleType.DecreaseShards){
137-
switch (rule.getScaleType()){
138-
case Step -> {
139-
newShards = currentShards - (long)rule.getScaleValue();
140-
}
141-
case Exponential -> {
142-
newShards = (long)Math.ceil(db.getThroughputMeasurement().getValue() * rule.getScaleValue());
143-
}
144-
case Deterministic -> {
145-
newShards = (long)rule.getScaleValue();
146-
}
147-
148-
default -> throw new IllegalStateException("Unexpected value: " + rule.getScaleType());
149-
}
150-
151-
newShards = (long)Math.max(newShards, rule.getScaleFloor());
152-
} else {
153-
throw new IllegalStateException("Unexpected value: " + rule.getRuleType());
154-
}
155-
156-
return newShards;
157-
}
158-
159104
private static long getNewThroughput(Rule rule, RedisCloudDatabase db){
160105
long newThroughput;
161106
long currentThroughput = db.getThroughputMeasurement().getValue();
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.redis.autoscaler.services;
2+
3+
4+
import org.slf4j.Logger;
5+
import org.springframework.scheduling.TaskScheduler;
6+
import org.springframework.scheduling.support.CronTrigger;
7+
import org.springframework.stereotype.Service;
8+
9+
import java.util.Map;
10+
import java.util.concurrent.ConcurrentHashMap;
11+
import java.util.concurrent.ScheduledFuture;
12+
13+
@Service
14+
public class SchedulingService{
15+
private static Logger LOG = org.slf4j.LoggerFactory.getLogger(SchedulingService.class);
16+
private final TaskScheduler taskScheduler;
17+
private final Map<String,ScheduledFuture<?>> scheduledTasks = new ConcurrentHashMap<>();
18+
19+
public SchedulingService(TaskScheduler taskScheduler) {
20+
this.taskScheduler = taskScheduler;
21+
}
22+
23+
public String scheduleTask(String id, String cronExpression, Runnable task){
24+
if(scheduledTasks.containsKey(id)){
25+
LOG.info("Task with id {} already exists", id);
26+
return "Task with id " + id + " already exists";
27+
}
28+
29+
ScheduledFuture<?> scheduledFuture = taskScheduler.schedule(task, new CronTrigger(cronExpression));
30+
scheduledTasks.put(id, scheduledFuture);
31+
return "Scheduled task with id " + id;
32+
}
33+
34+
public String cancelTask(String id){
35+
ScheduledFuture<?> scheduledFuture = scheduledTasks.get(id);
36+
if(scheduledFuture == null){
37+
LOG.info("Task with id {} does not exist", id);
38+
return "Task with id " + id + " does not exist";
39+
}
40+
41+
scheduledFuture.cancel(true);
42+
scheduledTasks.remove(id);
43+
return "Cancelled task with id " + id;
44+
}
45+
}

0 commit comments

Comments
 (0)