Springboot集成阿里云通义千问(灵积模型)
我这里集成后,做成了一个工具jar包,如果有不同方式的,欢迎大家讨论,共同进步。
集成限制:
1、灵积模型有QPM(QPS)限制,每个模型不一样,需要根据每个模型适配
集成开发思路:
因有QPS限制,无法支持多任务并发执行,所以使用任务池操作,定时监听任务池中任务状态;
因系统中执行不能等待QPS释放后执行,故使用异步调用;
开发思路:
1、创建任务,提交到任务池中
2、任务监听器每10秒检查任务池中的任务执行情况:
1)任务未执行:获任务token,获取到执行任务,否则不执行
2)任务执行中:判断任务执行是否超时,如果超时,重置任务状态,重试计数加1
3)任务执行失败:执行失败回调。从任务池中清除
4)任务执行成功:从任务池中清除
3、任务执行:
1)获取任务token,如果获取到就执行,否则不执行
2)利用工具类请求灵积模型
3)判断任务执行状态:成功:执行成功回调;失败:重试计数加1,重置任务状态
4)归还token
集成编码
1、前置操作
详见阿里云灵积模型服务开发参开https://help.aliyun.com/zh/dashscope/developer-reference/acquisition-and-configuration-of-api-key?spm=a2c4g.11186623.0.0.1403193eLiHQfl
开发参考中获取到的API-KEY需要写到项目的配置文件中
2、创建灵积服务jar(aliyun-dashscope)
按照灵积模型Java jdk最佳实践的方式实现集成模型灵积模型Java jdk最佳实践https://help.aliyun.com/zh/dashscope/java-sdk-best-practices?spm=a2c4g.11186623.0.0.4da417d9T9NKfMpom文件中引入jar
<dependency><groupId>com.alibaba</groupId><artifactId>dashscope-sdk-java</artifactId><version>2.15.0</version></dependency><dependency><groupId>org.apache.commons</groupId><artifactId>commons-pool2</artifactId></dependency><dependency><groupId>com.aa.bb</groupId><artifactId>common-redis</artifactId><version>1.0.0</version></dependency>
dashscope-sdk-java : 灵积服务模型jar
commons-pool2 : 对象池工具jar
common-redis :个人项目中redis工具包(可以自己封装一个)
3、编码
1)创建config
@Data
@Configuration
@ConfigurationProperties(prefix = "aliyun.dashscope")
public class DashScopeConfig {/*** api密钥*/@Value("${aliyun.dashscope.apiKey}")private String apiKey;/*** 最大tokens数*/private int maxTokens = 800;/*** 模型*/private String model = "qwen-plus";/*** QPS*/private int qps = 15;/*** qps缓存密钥*/private String qpsRedisKey = "aliyun:dashscope:token";/*** 尝试计数*/private int tryCount = 3;/*** task间隔时间*/private int time = 10000;}
2)创建对象池工厂
public class DashScopePoolFactory extends BasePooledObjectFactory<Generation> {@Overridepublic Generation create() throws Exception {return new Generation();}@Overridepublic PooledObject<Generation> wrap(Generation generation) {return new DefaultPooledObject<>(generation);}
}
3)创建task
DashTask:任务类
@Data
@Slf4j
public class DashTask {/*** qps令牌*/private Long qpsToken;/*** 正在执行*/private boolean execute = false;/*** 成功*/private boolean success = false;/*** 尝试计数*/private int tryCount = 0;/*** 生成参数*/private GenerationParam generationParam;/*** 结果*/private Message result;/*** 成功回调*/private Consumer<DashTask> successCallback;/*** 失败回调*/private Consumer<DashTask> failCallback;public void setSuccess(boolean success) {if (success) {this.onSuccess();} else {this.onFail();}}/*** 论成功*/public void onSuccess() {this.success = true;try {if (this.successCallback != null) {this.successCallback.accept(this);}} catch (Exception ex) {log.error("dash task onSuccess error:" + ex.getMessage());}}/*** 失败*/public void onFail() {this.success = false;try {if (this.failCallback != null) {this.failCallback.accept(this);}} catch (Exception ex) {log.error("spark task onFail error:" + ex.getMessage());}}}
DashListener:任务监听类
@Slf4j
public class DashListener extends Listener {public DashListener(long interval) {super(interval, "dash-listener");}@Overridepublic void run() {log.info("灵积服务(通义千问)任务监听 start");setExecute(true);while (isExecute()) {try {DashScopeUtils.asyncTaskStart();Thread.sleep(getInterval());} catch (Exception e) {log.error("灵积服务(通义千问)任务监听 error", e);}}}
}
4)创建工具类
DashScopeUtils:灵积模型基础工具类
@Slf4j
public class DashScopeUtils {private static volatile DashScopeConfig config;private static volatile RedisService redisService;/*** 获取令牌*/public static final int GET_TOKEN_STATUS = 0;/*** 归还令牌*/public static final int BACK_TOKEN_STATUS = 1;private static CopyOnWriteArraySet<DashTask> taskList = new CopyOnWriteArraySet<DashTask>();/*** 通用池*/private static volatile GenericObjectPool<Generation> pool;/*** 创建消息** @param role 角色* @param content 所容纳之物* @return {@link Message }*/public static Message createMessage(Role role, String content) {return Message.builder().role(role.getValue()).content(content).build();}/*** 调用服务** @param param param* @return {@link GenerationResult }*/public static GenerationResult call(GenerationParam param) {try {if (param.getMaxTokens() == null) {param.setMaxTokens(getConfig().getMaxTokens());}Generation gen = getPool().borrowObject();GenerationResult call = gen.call(param);getPool().returnObject(gen);return call;} catch (Exception e) {log.error(e.getMessage(), e);throw new RuntimeException(e.getMessage());}}/*** 获取对象池** @return {@link GenericObjectPool }<{@link Generation }>*/public static GenericObjectPool<Generation> getPool() {if (pool == null) {synchronized (DashScopeUtils.class) {if (pool == null) {DashScopePoolFactory poolFactory = new DashScopePoolFactory();GenericObjectPoolConfig<Generation> config = new GenericObjectPoolConfig<>();config.setMaxTotal(64);config.setMaxIdle(64);config.setMinIdle(64);Constants.apiKey = getConfig().getApiKey();pool = new GenericObjectPool<>(poolFactory, config);}}}return pool;}/*** 获取配置** @return {@link DashScopeConfig }*/public static DashScopeConfig getConfig() {if (config == null) {synchronized (DashScopeConfig.class) {if (config == null) {config = SpringUtils.getBean(DashScopeConfig.class);}}}return config;}/*** 异步任务启动*/public static void asyncTaskStart() {instanceRedis();getConfig();// 令牌数量int current = 0;if (redisService.hasKey(config.getQpsRedisKey())) {current = Integer.parseInt(redisService.get(config.getQpsRedisKey()).toString());}if (current > 0) {String all = config.getQpsRedisKey() + ":*";int size = redisService.keys(all).size();if (size < current) {redisService.decr(config.getQpsRedisKey(), current - size);}}if (!taskList.isEmpty()) {Iterator<DashTask> iterator = taskList.iterator();while (iterator.hasNext()) {DashTask dashTask = iterator.next();if (dashTask.isExecute()) {if (!redisService.hasKey(config.getQpsRedisKey() + ":" + dashTask.getQpsToken())) {dashTask.setExecute(false);dashTask.setTryCount(dashTask.getTryCount()+1);}continue;} else if (dashTask.isSuccess()) {taskList.remove(dashTask);} else if (dashTask.getTryCount() > config.getTryCount()) {dashTask.setSuccess(false);taskList.remove(dashTask);} else if (!asyncTaskStart(dashTask)) {break;}}}}/*** 提交任务** @param dashTask 短跑任务*/public static void submitTask(DashTask dashTask) {taskList.add(dashTask);}/*** 异步任务启动** @param task 任务* @return boolean*/private static boolean asyncTaskStart(DashTask task) {if (qpsToken(GET_TOKEN_STATUS, task)) {AsyncManager.me().execute(() -> {try {task.setExecute(true);GenerationResult call = call(task.getGenerationParam());task.setResult(call.getOutput().getChoices().get(0).getMessage());task.setSuccess(true);} catch (Exception e) {task.setTryCount(task.getTryCount() + 1);}task.setExecute(false);qpsToken(BACK_TOKEN_STATUS, task);});return true;}return false;}/*** qps令牌** @param status 地位* @param task 任务* @return boolean*/private static synchronized boolean qpsToken(int status, DashTask task) {instanceRedis();getConfig();int current = 0;if (redisService.hasKey(config.getQpsRedisKey())) {current = Integer.parseInt(redisService.get(config.getQpsRedisKey()).toString());}// 获取tokenif (status == GET_TOKEN_STATUS) {if (current < config.getQps()) {Long incr = redisService.incr(config.getQpsRedisKey());task.setQpsToken(incr);redisService.set(config.getQpsRedisKey() + ":" + incr, "1", 1, TimeUnit.MINUTES);return true;} else {return false;}} else {if (current > 0) {redisService.decr(config.getQpsRedisKey());}redisService.del(config.getQpsRedisKey() + ":" + task.getQpsToken());return true;}}/*** 实例redis** @return {@link RedisService}*/private static RedisService instanceRedis() {if (redisService == null) {synchronized (DashScopeUtils.class) {if (redisService == null) {redisService = SpringUtils.getBean(RedisService.class);}if (redisService == null) {throw new RuntimeException("redisService is null");}}}return redisService;}
}
QiamwenUtils:通义千问工具类
public class QianWenUtils {/*** 单轮对话** @param content 内容* @param success 成功*/public static void call(String content, Consumer<Message> success) {Message message = DashScopeUtils.createMessage(Role.USER, content);call(Collections.singletonList(message), success);}/*** 多轮对话** @param messages 对话列表* @return {@link Message }*/public static void call(List<Message> messages, Consumer<Message> success) {try {GenerationParam param = GenerationParam.builder().model(DashScopeUtils.getConfig().getModel()).messages(messages).resultFormat(GenerationParam.ResultFormat.MESSAGE).topP(0.8).maxTokens(600).build();DashTask dashTask = new DashTask();dashTask.setGenerationParam(param);dashTask.setSuccessCallback(dash -> success.accept(dash.getResult()));DashScopeUtils.submitTask(dashTask);} catch (Exception e) {throw new RuntimeException("通义千问失败:" + e.getMessage());}}
}
5)创建runner
runner主要作用:
(1)检查配置文件是否正确配置;
(2)启动任务监听器
@Slf4j
@Component
public class DashScopeRunner {private DashListener dashListener;@PostConstructpublic void run() {DashScopeConfig config = DashScopeUtils.getConfig();if (config == null || ObjectUtil.isEmpty(config.getApiKey())) {throw new RuntimeException("灵积服务(通义千问)启动失败,请检查配置文件");} else {log.info("灵积服务(通义千问)启动");}dashListener = new DashListener(config.getTime());dashListener.start();}@PostConstructpublic void shutdown() {if (dashListener != null) {dashListener.shutdown();}}
}
4、测试
5、踩坑
1)token数量验证:每次开始执行任务池中任务状态检查时,要先检查任务token是否和实际一致,避免实际可用token数不足,导致进入死循环
2)任务池中的数据不能使用缓存(redis)
3)成功和失败回调必须是public
4)使用对象池(GenericObjectPool),借出对象,使用完成后必须归还,否则会出现无法借出的情况
5)config中QPS最好小于15,否则会出现限流情况