Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ install: runtime
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search
$(ENV_INSTALL) apisix/plugins/ai-rag/vector-search/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search

$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-cache
$(ENV_INSTALL) apisix/plugins/ai-cache/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-cache

$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-lakera-guard
$(ENV_INSTALL) apisix/plugins/ai-lakera-guard/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-lakera-guard

Expand Down
1 change: 1 addition & 0 deletions apisix/cli/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ local _M = {
"ai-rate-limiting",
"ai-proxy-multi",
"ai-proxy",
"ai-cache",
"ai-aws-content-moderation",
"ai-aliyun-content-moderation",
"ai-lakera-guard",
Expand Down
48 changes: 48 additions & 0 deletions apisix/core/json.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,19 @@ local json_encode = cjson.encode
local json_decode = cjson.decode
local cjson_null = cjson.null
local clear_tab = require("table.clear")
local require = require
local ngx = ngx
local tostring = tostring
local type = type
local pairs = pairs
local ipairs = ipairs
local getmetatable = getmetatable
local cached_tab = {}

local rapidjson
local rapidjson_null
local rapidjson_encode_opts = { sort_keys = true }


cjson.encode_escape_forward_slash(false)
cjson.decode_array_with_array_mt(true)
Expand Down Expand Up @@ -122,6 +129,47 @@ local function encode(data, force)
end
_M.encode = encode


local function to_rapidjson_value(data)
if data == cjson_null then
return rapidjson_null
end

if type(data) ~= "table" then
return data
end

if getmetatable(data) == cjson.array_mt then
local arr = {}
for i, v in ipairs(data) do
arr[i] = to_rapidjson_value(v)
end
return rapidjson.array(arr)
end

local obj = {}
for k, v in pairs(data) do
obj[k] = to_rapidjson_value(v)
end
return obj
end


--- Encode a Lua value to a canonical JSON string with sorted object keys.
-- Unlike core.json.encode, object keys are emitted in a stable (sorted) order,
-- so the same logical value always produces the same string -- suitable for
-- hashing, cache keys and signatures. cjson null / array_mt markers are
-- preserved. Backed by rapidjson, which is loaded on first use.
-- @tparam table data The value to encode.
-- @treturn string The canonically-encoded JSON string.
function _M.canonical_encode(data)
if not rapidjson then
rapidjson = require("rapidjson")
rapidjson_null = rapidjson.null
end
return rapidjson.encode(to_rapidjson_value(data), rapidjson_encode_opts)
end

local max_delay_encode_items = 16
local delay_tab_idx = 0
local delay_tab_arr = {}
Expand Down
216 changes: 216 additions & 0 deletions apisix/plugins/ai-cache.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--

local core = require("apisix.core")
local schema = require("apisix.plugins.ai-cache.schema")
local key_mod = require("apisix.plugins.ai-cache.key")
local binding = require("apisix.plugins.ai-protocols.binding")
local redis_util = require("apisix.utils.redis")

local ngx = ngx
local ngx_null = ngx.null
local ipairs = ipairs
local concat = table.concat

local CACHE_STATUS_HEADER = "X-AI-Cache-Status"
local CACHE_AGE_HEADER = "X-AI-Cache-Age"
local DEFAULT_TTL = 3600
local DEFAULT_MAX_BODY = 1048576

local _M = {
version = 0.1,
priority = 1035,
name = "ai-cache",
schema = schema,
}


function _M.check_schema(conf)
return core.schema.check(schema, conf)
end


local function release(conf, red)
local ok, err = red:set_keepalive(conf.redis_keepalive_timeout or 10000,
conf.redis_keepalive_pool or 100)
if not ok then
core.log.warn("ai-cache: failed to set redis keepalive: ", err)
end
end


local function serve_hit(conf, ctx, cached)
ctx.ai_cache_status = "HIT"
if conf.cache_headers ~= false then
core.response.set_header(CACHE_STATUS_HEADER, "HIT")
local age = ngx.time() - (cached.created_at or ngx.time())
core.response.set_header(CACHE_AGE_HEADER, age < 0 and 0 or age)
end
core.response.set_header("Content-Type", "application/json")
return core.response.exit(200, cached.body)
end
Comment thread
janiussyafiq marked this conversation as resolved.


function _M.access(conf, ctx)
if not ctx.picked_ai_instance then
local handled, code, body = binding.on_unsupported(
conf.fail_mode, _M.name, ctx,
"no ai instance picked (request did not pass through ai-proxy/ai-proxy-multi)",
500, "ai-cache must be used with the ai-proxy or ai-proxy-multi plugin")
if handled then
return code, body
end
ctx.ai_cache_status = "BYPASS"
return
end

-- Streaming responses are not cached in PR-1 (SSE replay is a later
-- increment). ai-proxy (higher priority) has already classified the
-- request, so bypass before doing any work.
if ctx.var.request_type == "ai_stream" then
ctx.ai_cache_status = "BYPASS"
return
end

if conf.bypass_on then
for _, rule in ipairs(conf.bypass_on) do
if core.request.header(ctx, rule.header) == rule.equals then
ctx.ai_cache_status = "BYPASS"
return
end
end
end

local body, err = core.request.get_json_request_body_table()
if not body then
core.log.warn("ai-cache: cannot read request body, bypassing: ", err)
ctx.ai_cache_status = "BYPASS"
return
end

ctx.ai_cache_fingerprint = key_mod.fingerprint(ctx, body)
ctx.ai_cache_key = key_mod.build(conf, ctx, ctx.ai_cache_fingerprint)

local red
red, err = redis_util.new(conf)
if not red then
-- fail-open: never let a cache-backend outage break the request.
core.log.warn("ai-cache: redis unavailable, fail-open as MISS: ", err)
ctx.ai_cache_status = "MISS"
return
end

local res
res, err = red:get(ctx.ai_cache_key)
if err then
red:close()
core.log.warn("ai-cache: redis get failed, fail-open as MISS: ", err)
ctx.ai_cache_status = "MISS"
return
end
release(conf, red)

if res ~= nil and res ~= ngx_null then
local cached = core.json.decode(res)
if cached and cached.body then
return serve_hit(conf, ctx, cached)
end
core.log.warn("ai-cache: discarding malformed cache entry for ", ctx.ai_cache_key)
end

ctx.ai_cache_status = "MISS"
end


function _M.header_filter(conf, ctx)
if ctx.ai_cache_status and conf.cache_headers ~= false then
core.response.set_header(CACHE_STATUS_HEADER, ctx.ai_cache_status)
end
end


function _M.body_filter(conf, ctx)
-- only a MISS gets written back; HIT exited in access, BYPASS opts out.
if ctx.ai_cache_status ~= "MISS" or ctx.ai_cache_oversized then
return
end
local chunk = ngx.arg[1]
if chunk and #chunk > 0 then
local buf = ctx.ai_cache_buf
if not buf then
buf = { n = 0, bytes = 0 }
ctx.ai_cache_buf = buf
end
local n = buf.n + 1
buf.n = n
buf[n] = chunk
buf.bytes = buf.bytes + #chunk
if buf.bytes > (conf.max_cache_body_size or DEFAULT_MAX_BODY) then
ctx.ai_cache_buf = nil
ctx.ai_cache_oversized = true
end
end
end


-- The response-capturing phases (body_filter / log) run in contexts where
-- cosockets are disabled, so the Redis write is deferred to a 0-delay timer
-- (timers run in a light thread where cosockets are allowed).
local function write_to_cache(premature, conf, cache_key, response_body)
if premature then
return
end
local red, err = redis_util.new(conf)
if not red then
core.log.warn("ai-cache: redis unavailable on write: ", err)
return
end
local envelope = core.json.encode({ body = response_body, created_at = ngx.time() })
local ttl = (conf.exact and conf.exact.ttl) or DEFAULT_TTL
local ok
ok, err = red:set(cache_key, envelope, "EX", ttl)
if not ok then
red:close()
core.log.warn("ai-cache: redis set failed: ", err)
return
end
release(conf, red)
Comment thread
janiussyafiq marked this conversation as resolved.
end


function _M.log(conf, ctx)
if ctx.ai_cache_status ~= "MISS" or not ctx.ai_cache_fingerprint then
return
end
if ngx.status ~= 200 then
return
end
Comment thread
janiussyafiq marked this conversation as resolved.
local buf = ctx.ai_cache_buf
if not buf or buf.bytes == 0 then
return
end
local response_body = concat(buf, "", 1, buf.n)

local cache_key = key_mod.build(conf, ctx, ctx.ai_cache_fingerprint)
local ok, err = ngx.timer.at(0, write_to_cache, conf, cache_key, response_body)
if not ok then
core.log.warn("ai-cache: failed to schedule cache write: ", err)
end
end


return _M
88 changes: 88 additions & 0 deletions apisix/plugins/ai-cache/key.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--

local core = require("apisix.core")
local protocols = require("apisix.plugins.ai-protocols")
local sha256 = require("resty.sha256")
local to_hex = require("resty.string").to_hex

local ipairs = ipairs
local pairs = pairs
local concat = table.concat

local KEY_PREFIX = "ai-cache:l1:"

local _M = {}


local function hex_digest(s)
local hash = sha256:new()
hash:update(s)
return to_hex(hash:final())
end


function _M.fingerprint(ctx, body)
local params = {}
for k, v in pairs(body) do
if k ~= "messages" and k ~= "model" and k ~= "stream" then
params[k] = v
end
end

local repr = core.json.canonical_encode({
protocol = ctx.ai_client_protocol or "",
model = ctx.var.request_llm_model or body.model or "",
messages = protocols.get_messages(body, ctx) or {},
params = params,
})
return hex_digest(repr)
end


local function scope(conf, ctx)
local ck = conf.cache_key or {}

local parts = {}
if ctx.picked_ai_instance_name then
parts[#parts + 1] = "instance=" .. ctx.picked_ai_instance_name
end
if not ck.share_across_routes then
parts[#parts + 1] = "route=" .. (ctx.var.route_id or "")
end
if ck.include_consumer then
parts[#parts + 1] = "consumer=" .. (ctx.consumer_name or "")
end
if ck.include_vars then
for _, name in ipairs(ck.include_vars) do
parts[#parts + 1] = name .. "=" .. (ctx.var[name] or "")
end
end

if #parts == 0 then
return "shared"
end
return concat(parts, ":")
end


function _M.build(conf, ctx, fingerprint)
return KEY_PREFIX .. scope(conf, ctx) .. ":" .. fingerprint
end


return _M
Loading
Loading