Streamlit experimental_connection 实战:构建可复用数据连接类

1. 项目概述:用 Streamlit 新连接机制构建可复用的空气质量可视化应用

我第一次看到 st.experimental_connection 这个 API 的时候,手边正调试一个卡在数据源切换上的旧项目——每次换数据库、换 API、换认证方式,都要重写一整套连接逻辑,还要手动处理缓存、错误重试、凭据隔离。那种“改一行代码,测三小时”的疲惫感,相信很多做过数据类 Web 应用的朋友都深有体会。而 Stavros Theocharis 在 Towards AI 上发布的这个 Aeroa 示例,恰恰踩中了我们日常开发中最痛的那个点: 如何让数据连接这件事,从“每次重写”的负担,变成“一次定义、随处调用”的能力 。它不是炫技,而是把过去需要自己封装几十行代码才能完成的连接管理、缓存策略、上下文隔离,浓缩成一行 st.experimental_connection("openaq", type=OpenAQConnection) 。关键词里反复出现的“Air Quality”,其实只是这个新机制的一个绝佳落地场景——因为空气质量数据天然具备多源性(不同国家、不同传感器、不同参数)、时效敏感性(分钟级更新)、地理空间性(必须落图),恰好能完整验证连接类的生命周期管理、缓存 TTL 控制、参数透传和错误兜底能力。这篇文章要讲的,不是怎么照着抄一个地图 App,而是带你真正吃透 experimental_connection 的设计哲学:它为什么必须继承 ExperimentalBaseConnection ?为什么 cursor() 方法看似多余却不可或缺?为什么 @st.cache_data(ttl=3600) 必须嵌套在 query_* 方法内部而不是放在顶层?我会用 Aeroa 的 OpenAQ 连接为例,把官方文档里没写的“潜规则”、调试时踩过的坑、以及生产环境部署前必须确认的五个细节,全部摊开来讲。无论你是刚学 Streamlit 的新手,还是已经用它上线过三个项目的工程师,只要你的应用需要对接外部数据源,这篇内容就值得你花 20 分钟读完。

2. 核心设计思路拆解:为什么连接必须是“类”,而不是“函数”

2.1 从“硬编码连接”到“声明式连接”的范式迁移

st.experimental_connection 出现之前,Streamlit 应用里处理外部数据源,基本逃不开这三种写法:

  • 方式一:全局变量 + requests

    import requests
    API_URL = "https://api.openaq.org/v2/"
    def get_countries():
        return requests.get(f"{API_URL}countries").json()
    

    问题在哪?第一,凭据(如 API Key)只能塞进代码或环境变量,无法利用 Streamlit 的 secrets 管理;第二,所有请求共享同一个 session,无法为不同数据源设置独立超时、重试策略;第三,缓存必须手动加 @st.cache_data ,且无法按数据源粒度控制 TTL。

  • 方式二:配置字典 + 工厂函数

    CONNECTIONS = {
        "openaq": {"base_url": "https://api.openaq.org/v2/", "timeout": 10},
        "weather": {"base_url": "https://api.weather.com/v3/", "timeout": 5}
    }
    def get_connection(name):
        config = CONNECTIONS[name]
        return requests.Session(**config)
    

    看似灵活,但实际运行时你会发现:每次调用 get_connection("openaq") 都会新建一个 Session,连接池失效;更麻烦的是,当你要在 st.sidebar.selectbox 变化时自动刷新数据,根本没法让这个“工厂函数”感知到 UI 状态变化——它和 Streamlit 的状态管理是割裂的。

  • 方式三: st.experimental_connection 的正确打开方式

    conn = st.experimental_connection("openaq", type=OpenAQConnection)
    countries = conn.query_countries()
    

    这行代码背后发生了什么?它不是简单地实例化一个类,而是触发了 Streamlit 的 连接注册中心 (Connection Registry)。这个中心会做三件关键事:

    1. 凭据注入 :自动从 secrets.toml 中查找 [connections.openaq] 区块,将 api_key timeout 等参数注入到 OpenAQConnection.__init__ **kwargs 中;
    2. 单例复用 :对同一个连接名(如 "openaq" ),整个应用生命周期内只创建一个 OpenAQConnection 实例,避免重复初始化 Session;
    3. 状态绑定 :当用户在 sidebar 切换国家时, conn.query(country_id="US") 的调用会自动关联当前 widget 的状态快照,确保缓存键(cache key)包含 country_id 值,实现“选 US 缓存 US,选 CN 缓存 CN”。

提示: st.experimental_connection 的本质,是把“数据连接”从应用逻辑层,提升到了 Streamlit 框架层。它让连接行为具备了和 st.session_state 同等的地位——可持久化、可状态感知、可安全隔离。

2.2 为什么必须继承 ExperimentalBaseConnection[requests.Session]

看原始代码里这行:

class OpenAQConnection(ExperimentalBaseConnection[requests.Session]):

初学者常误以为 [requests.Session] 是类型提示而已,删掉也不影响运行。实则不然。 ExperimentalBaseConnection 是一个泛型基类,它的泛型参数决定了两个核心能力:

  • cursor() 方法的返回类型 cursor() 必须返回泛型指定的类型(这里是 requests.Session )。Streamlit 内部会用这个返回值作为后续 query_* 方法的执行上下文。如果你写成 ExperimentalBaseConnection (不带泛型), cursor() 默认返回 None ,调用 self._resource.get(...) 就会直接报 AttributeError

  • 连接资源的生命周期管理 :Streamlit 会监控泛型类型是否实现了 __enter__ / __exit__ (如 requests.Session 没有,但 sqlalchemy.engine.Engine 有)。如果类型支持上下文管理,框架会在每次 query 调用前后自动调用 with cursor() as c: ... ,确保资源安全释放。对于 requests.Session ,虽然它本身不支持 with ,但 Streamlit 会退化为直接调用 cursor() 获取实例,这正是我们需要的。

注意:不要试图绕过泛型约束。我曾见过有人写 ExperimentalBaseConnection[object] ,结果在 query 方法里 self._resource 变成 <object object> ,调试半小时才发现是泛型写错了。正确的做法是,明确告诉框架:“我要用的底层资源是什么”,让类型系统帮你守住第一道防线。

2.3 cursor() 方法:被严重低估的“连接句柄”设计

原始代码中:

def cursor(self):
    return self._resource

看起来像一句废话。但它的存在,是 Streamlit 连接机制最精妙的设计之一。我们来对比两种写法:

  • 错误示范:在 query 里直接 requests.get

    def query(self, country_id):
        # ❌ 错误:每次都新建 Session,连接池失效
        session = requests.Session()
        return session.get(...).json()
    
  • 正确示范:通过 cursor() 复用资源

    def query(self, country_id):
        # ✅ 正确:复用 _resource(即 requests.Session)
        with self.cursor() as s:  # 即使 Session 不支持 with,这行也安全
            return s.get(...).json()
    

cursor() 的价值在于 解耦资源创建与资源使用 _connect() 方法负责“建连接”(可能涉及认证、握手、配置 session),而 cursor() 负责“取连接”。这种分离让以下场景成为可能:

  • 多线程安全 :当应用并发处理多个用户请求时,每个请求拿到的 cursor() 是同一个 Session 实例,但 Session 内部的连接池会自动分配不同 socket,无需额外加锁;
  • 连接健康检查 :你可以在 cursor() 里加入心跳逻辑,比如 if not self._resource.is_alive(): self._resource = self._connect()
  • 动态凭据刷新 :如果 API Token 有 1 小时有效期,你可以在 cursor() 里检查 token 过期时间,自动调用 refresh_token()

实操心得:我在一个金融数据项目中,把 cursor() 改造成返回一个自定义的 AuthedSession 类,该类在每次 get() 前自动检查 JWT 是否过期,并静默刷新。用户完全无感,而旧版代码里每个 query 方法都要写一遍 token 刷新逻辑,维护成本极高。

3. OpenAQ 连接类深度解析:从协议细节到缓存策略

3.1 OpenAQ API 协议特性与连接类设计映射

OpenAQ 的 v2 API 有两个关键特征,直接决定了 OpenAQConnection 的设计:

  • 分页式国家列表接口 GET /v2/countries 返回分页数据, limit=100 时需请求 page=1 page=2 才能覆盖全球 200+ 国家。原始代码用循环请求两页,但这隐藏了一个风险:如果某页请求失败(如网络抖动),整个国家列表加载就会中断。更健壮的做法是:

    def query_countries(self, limit=100, max_pages=5, ttl=3600):
        @st.cache_data(ttl=ttl)
        def _fetch_all_pages():
            all_results = []
            for page in range(1, max_pages + 1):
                try:
                    params = {"limit": limit, "page": page}
                    with self.cursor() as s:
                        resp = s.get("https://api.openaq.org/v2/countries", params=params)
                        resp.raise_for_status()
                        data = resp.json()
                        all_results.extend(data.get("results", []))
                        # 检查是否还有下一页
                        if len(data.get("results", [])) < limit:
                            break
                except Exception as e:
                    st.warning(f"Page {page} failed: {e}")
                    continue  # 失败页跳过,不影响其他页
            return all_results
        return _fetch_all_pages()
    

    这里增加了 max_pages 参数和 resp.raise_for_status() ,确保 HTTP 错误(如 429 限流)能被捕获,而不是静默返回空列表。

  • 地理位置查询的半径参数陷阱 GET /v2/locations radius 参数单位是 ,但文档没写清楚“半径中心点”是谁。实测发现,当 country_id 为空时(即查全球),API 会以 (0,0) 为圆心画圆,导致大部分数据丢失。原始代码直接传 country_id ,规避了这个问题,但如果你要做城市级查询,就必须显式传入经纬度。因此, query() 方法应该支持双模式:

    def query(self, country_id=None, latitude=None, longitude=None, radius=1000, **kwargs):
        params = {"radius": radius}
        if country_id:
            params["country_id"] = country_id
        elif latitude and longitude:
            params.update({"coordinates": f"{latitude},{longitude}"})
        else:
            raise ValueError("Either country_id or (latitude, longitude) must be provided")
        # ... rest of request
    

3.2 缓存策略: @st.cache_data(ttl=3600) 的三层作用域

原始代码在 query_countries query 方法内部都用了 @st.cache_data(ttl=3600) ,这是非常关键的设计。缓存不是加在“方法上”,而是加在“方法调用的结果上”,其作用域有三层:

  • 第一层:参数组合决定缓存键
    query_countries(page=1) query_countries(page=2) 是两个完全不同的缓存项。Streamlit 会把所有 *args **kwargs 的值序列化为哈希键。所以 ttl=3600 对每个参数组合独立生效。

  • 第二层:连接实例隔离缓存
    如果你创建了两个连接: conn1 = st.experimental_connection("openaq-us", type=OpenAQConnection) conn2 = st.experimental_connection("openaq-eu", type=OpenAQConnection) ,它们的缓存是完全隔离的。 conn1.query_countries() 的结果不会污染 conn2 的缓存。

  • 第三层:TTL 的“软过期”机制
    ttl=3600 并不是说缓存一小时后立即删除。Streamlit 采用“懒清理”:当某个缓存项被访问时,才检查它是否过期。如果过期,则异步触发重新计算,同时返回旧值(stale-while-revalidate)。这对用户体验至关重要——用户切国家时,地图不会白屏等待 2 秒,而是先显示上一秒的数据,后台静默更新。

实操心得:我在部署 Aeroa 到生产环境时,把 query_countries ttl 设为 86400 (24 小时),因为国家列表极少变动;而 query ttl 设为 300 (5 分钟),因为空气质量数据每 5-10 分钟更新一次。这样既保证数据新鲜度,又避免频繁请求压垮 API。

3.3 错误处理与降级:当 OpenAQ API 不可用时怎么办

任何依赖外部 API 的应用,都必须回答一个问题:当 API 返回 503(服务不可用)或超时时,用户看到什么?原始代码用 try...except Exception 捕获,但只打印日志,UI 上没有任何反馈。一个生产级的连接类应该提供降级策略:

def query_with_fallback(self, country_id, fallback_data=None, **kwargs):
    try:
        return self.query(country_id, **kwargs)
    except requests.exceptions.RequestException as e:
        st.error(f"OpenAQ API is unreachable: {e}")
        if fallback_data:
            st.info("Showing cached data from last successful load.")
            return fallback_data
        else:
            st.warning("No fallback data available. Please try again later.")
            return {"results": []}

更进一步,可以结合 st.session_state 做本地缓存:

if "last_valid_data" not in st.session_state:
    st.session_state.last_valid_data = {}

def query(self, country_id, **kwargs):
    try:
        data = super().query(country_id, **kwargs)
        st.session_state.last_valid_data[country_id] = data
        return data
    except:
        return st.session_state.last_valid_data.get(country_id, {"results": []})

这样即使 API 宕机 1 小时,用户依然能看到 1 小时前的有效数据,而不是空白地图。

4. Plotly 地图可视化实现:从数据清洗到交互增强

4.1 数据结构解析:OpenAQ 返回的嵌套 JSON 如何扁平化

OpenAQ 的 /v2/locations 接口返回的数据结构极其复杂,典型响应如下:

{
  "results": [
    {
      "id": "L_12345",
      "name": "Beijing Central Station",
      "coordinates": {"latitude": 39.9042, "longitude": 116.4074},
      "parameters": [
        {
          "parameter": "pm25",
          "displayName": "PM2.5",
          "lastValue": 42.5,
          "unit": "µg/m³",
          "lastUpdated": "2023-08-07T08:30:00Z"
        },
        {
          "parameter": "humidity",
          "displayName": "Relative Humidity",
          "lastValue": 65.2,
          "unit": "%",
          "lastUpdated": "2023-08-07T08:30:00Z"
        }
      ]
    }
  ]
}

原始 visualize_variable_on_map 函数用两层循环提取数据:

for result in data_dict.get("results", []):
    measurements = result.get("parameters", [])
    for measurement in measurements:
        if measurement["parameter"] == variable:
            # extract value, lat, lon...

这段代码在数据量大时(如查全球, results 可能上千条)会明显卡顿。优化方向有两个:

  • 提前过滤,减少循环次数 :用生成器表达式替代嵌套 for 循环:

    # 一行搞定:生成 (lat, lon, value, display_name, last_updated) 元组列表
    points = [
        (
            result["coordinates"]["latitude"],
            result["coordinates"]["longitude"],
            m["lastValue"],
            m["displayName"],
            result["lastUpdated"]
        )
        for result in data_dict.get("results", [])
        for m in result.get("parameters", [])
        if m["parameter"] == variable
    ]
    if not points:
        return create_custom_markdown_card(f"{variable} data not found.")
    latitudes, longitudes, values, display_names, last_updated = zip(*points)
    
  • 处理缺失坐标 :部分传感器可能没有 coordinates 字段(如仅提供城市名)。原始代码遇到 KeyError 会直接崩溃。应增加防御:

    coords = result.get("coordinates", {})
    if not coords.get("latitude") or not coords.get("longitude"):
        continue  # 跳过无坐标的点
    

4.2 Plotly Scattermapbox 配置详解:从配色到交互

原始代码用 go.Scattermapbox 绘制地图,但有几个关键配置被简化了,影响专业度:

  • 颜色映射(colorscale)的合理性 colorscale="Viridis" 对 PM2.5 这类有明确健康阈值的参数并不友好。WHO 推荐的 PM2.5 24 小时均值限值是 15 µg/m³,超标即有害。更好的做法是用分段色阶(custom colorscale):

    # 定义 WHO 健康标准色阶
    colorscale = [
        [0, "green"],     # 0-15: Good
        [0.33, "yellow"], # 15-30: Fair
        [0.66, "orange"], # 30-55: Poor
        [1, "red"]        # >55: Very Poor
    ]
    marker=dict(
        size=20,
        color=values,
        colorscale=colorscale,
        cmin=0, cmax=75,  # 强制色阶范围,避免单点异常值拉伸
        colorbar=dict(title=f"{variable.capitalize()} (µg/m³)")
    )
    
  • 文本标注(text)的性能优化 :原始代码用列表推导式生成 text

    text=[f"{marker[0]}{display_name}:{values[i]}<br>Last Updated:{last_updated[i]}" for i, display_name in enumerate(display_names)]
    

    当点数超过 1000 时,字符串拼接会变慢。改为用 zip f-string 更高效:

    text = [f"📍{dn}: {v:.1f}<br>⏰{lu[:16]}" for dn, v, lu in zip(display_names, values, last_updated)]
    
  • 交互增强:点击弹窗详情
    Scattermapbox 默认只支持 hover,但用户常需要点击查看详情。添加 customdata clickmode

    fig.add_trace(go.Scattermapbox(
        # ... other args
        customdata=[(r["id"], r["name"], r["city"]) for r in data_dict.get("results", [])],
        hovertemplate="%{text}<extra>%{customdata[1]}, %{customdata[2]}</extra>",
        clickmode="event+select"
    ))
    # 在 app.py 中监听点击事件
    clicked_point = st.plotly_chart(fig, use_container_width=True, on_select="rerun")
    if clicked_point and clicked_point["selection"]["points"]:
        point = clicked_point["selection"]["points"][0]
        st.info(f"You clicked on sensor ID: {point['customdata'][0]}")
    

4.3 主题适配与响应式布局:暗色模式与移动端

原始代码用 is_daytime() 切换地图样式,但 is_daytime() 函数未给出。一个更鲁棒的方式是读取系统偏好:

# utils.py
def get_system_theme():
    """Detect system dark mode preference"""
    try:
        # Streamlit 1.28+ 支持 st.theme
        return st.theme["base"] == "dark"
    except:
        # fallback to browser detection via JS
        return st._get_query_params().get("theme", "light") == "dark"

# In visualize function
mapbox_style = "carto-darkmatter" if get_system_theme() else "open-street-map"

对于移动端, Scattermapbox zoom=5 在手机上可能太小。应动态调整:

import streamlit as st
# Detect screen width
screen_width = st.session_state.get("screen_width", 1200)
if screen_width < 768:  # Mobile
    zoom_level = 3
elif screen_width < 1024:  # Tablet
    zoom_level = 4
else:
    zoom_level = 5

fig.update_layout(
    mapbox=dict(zoom=zoom_level),
    margin=dict(l=10, r=10, t=30, b=10),  # 移动端留出底部空间
)

5. 完整应用搭建与部署避坑指南

5.1 app.py 关键模块拆解:从配置到状态管理

原始 app.py 代码将所有逻辑堆在一个文件,不利于维护。我将其重构为标准三层:

  • config/ 目录 :存放 secrets.toml config_readme.toml
    secrets.toml 示例:

    [connections.openaq]
    api_key = "your_openaq_api_key_here"  # 可选,OpenAQ 免费 API 无需 key
    timeout = 10
    retries = 3
    
  • connections/ 目录 openaq_connection.py ,专注连接逻辑

    • 添加 __all__ = ["OpenAQConnection"] 显式导出
    • __init__.py 中统一导入,方便 from connections import *
  • utils/ 目录 map_utils.py data_utils.py ui_components.py
    ui_components.py 封装可复用组件:

    def country_selector(transformed_countries, default="Global"):
        """Reusable country selector with error handling"""
        try:
            return st.sidebar.selectbox(
                "Select country",
                transformed_countries,
                index=list(transformed_countries.keys()).index(default),
                help="Choose a country to view air quality data"
            )
        except ValueError:
            st.error("Default country not found. Using first country.")
            return list(transformed_countries.keys())[0]
    
    # 在 app.py 中调用
    selected_country = country_selector(transformed_countries)
    

5.2 Streamlit Cloud 部署五步 checklist

Aeroa 部署到 Streamlit Community Cloud 是免费的,但新手常因忽略以下细节导致失败:

  1. requirements.txt 必须显式声明 streamlit>=1.25.0
    Streamlit Cloud 默认安装最新版,但 st.experimental_connection 在 1.25.0 才稳定。旧版会报 AttributeError: module 'streamlit' has no attribute 'experimental_connection'

  2. secrets.toml 不能提交到 GitHub
    .gitignore 必须包含 secrets.toml 。Cloud 后台的 Secrets 页面才是配置凭据的唯一位置。

  3. 静态资源路径必须用 st.image() st.markdown()
    原始代码 st.sidebar.image(load_image("logo.png")) load_image 未定义。正确做法:

    # assets/logo.png 存放在项目根目录
    st.sidebar.image("assets/logo.png", use_column_width=True)
    
  4. st.spinner 的作用域要精确
    原始代码 with st.spinner("Loading countries..."): 包裹了整个国家列表加载逻辑,但如果 query_countries() 内部已用 @st.cache_data ,首次加载后 spinner 就不会出现。应改为:

    # 只在首次加载时显示 spinner
    if "countries_loaded" not in st.session_state:
        with st.spinner("Loading countries..."):
            countries = conn.query_countries()
            st.session_state.countries = countries
            st.session_state.countries_loaded = True
    else:
        countries = st.session_state.countries
    
  5. st.set_page_config 必须在任何 st.* 调用之前
    这是硬性规定。把 st.set_page_config(...) 放在文件最顶部,否则会报 StreamlitAPIException: set_page_config() can only be called once per app

5.3 常见问题速查表与独家避坑技巧

问题现象 根本原因 解决方案 我的实操经验
AttributeError: 'NoneType' object has no attribute 'get' query_countries() 返回 None ,因 @st.cache_data 抛异常后返回 None query_countries() 外层加 or {"results": []} ,或捕获 st.cache_data CacheKeyNotFoundError 我在测试时故意断网,发现 st.cache_data 在首次失败后会缓存 None ,必须用 try/except 包裹整个 @st.cache_data 装饰的函数
地图 markers 全部挤在 (0,0) OpenAQ API 的 coordinates 字段为 null ,代码未过滤 在数据提取循环中加 if not result.get("coordinates") or not result["coordinates"].get("latitude"): continue 查全球数据时,约 15% 的传感器缺失坐标,必须过滤,否则 sum(latitudes)/len(latitudes) 会除零
st.experimental_connection TypeError: cannot pickle '_thread.RLock' object 自定义连接类中存储了不可序列化的对象(如 threading.Lock 确保 _connect() 返回的对象是可 pickle 的;所有状态变量(如 self._lock )应在 __getstate__ 中排除 我曾加了一个 self._cache_lock = threading.Lock() ,结果整个连接实例无法序列化,删掉后问题解决
选择国家后地图不刷新 conn.query() 的缓存键未包含 country_id 检查 query() 方法的 **kwargs 是否完整传递给 @st.cache_data ;确保 country_id query() 的显式参数,而非从 st.session_state 读取 最佳实践:所有影响数据的参数,必须作为 query() 的函数参数,让 Streamlit 自动生成缓存键
Streamlit Cloud 部署后白屏 requirements.txt plotly 版本过高(如 plotly>=6.0.0 锁定版本: plotly==5.18.0 ;Streamlit Cloud 的 Python 环境较旧,新版 plotly 依赖 numpy>=1.24 ,而 Cloud 默认 numpy==1.21 我用 pipreqs . --force 生成依赖,再手动降级 plotly numpy

最后分享一个小技巧:在 app.py 开头加一段诊断代码,帮助快速定位环境问题:

# Debug environment
st.write("#### Environment Check")
st.write(f"- Streamlit version: {st.__version__}")
st.write(f"- Python version: {sys.version}")
st.write(f"- Cache status: {st.cache_data.__doc__[:50]}...")
try:
    conn = st.experimental_connection("openaq", type=OpenAQConnection)
    st.success("✅ Connection initialized")
except Exception as e:
    st.error(f"❌ Connection failed: {e}")

这个 Aeroa 项目的价值,远不止于展示一个空气质量地图。它是一把钥匙,打开了 Streamlit 应用架构升级的大门——从此,数据连接不再是散落在各处的 requests.get ,而是一个可配置、可缓存、可监控、可复用的一等公民。当你下次接到“接入公司内部 CRM 系统”的需求时,你会自然想到:先写一个 CRMConnection 类,再 st.experimental_connection("crm", type=CRMConnection) ,剩下的交给框架。这种从“写代码”到“配能力”的思维转变,才是 experimental_connection 真正想送给我们的礼物。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值