from fastapi import FastAPI, Depends
from sqlalchemy import select
from sqlalchemy.orm import Session
from database import SessionLocal, engine, Base
from models import User
# 启动时自动建表(开发阶段方便,生产环境用 Alembic 迁移)
Base.metadata.create_all(bind=engine)
app = FastAPI()
# ---------- 依赖注入:每个请求分配一个数据库会话 ----------
def get_db():
db = SessionLocal() # 打开会话
try:
yield db # 把会话交给路由函数使用
finally:
db.close() # 请求结束后关闭会话
# ---------- 创建用户 ----------
@app.post("/users/")
def create_user(name: str, email: str, db: Session = Depends(get_db)):
user = User(name=name, email=email)
db.add(user) # 添加到会话
db.commit() # 提交事务,写入数据库
db.refresh(user) # 刷新,拿到数据库生成的 id
return user
# ---------- 查询所有用户 ----------
@app.get("/users/")
def list_users(db: Session = Depends(get_db)):
return db.execute(select(User)).scalars().all()
# ---------- 查询单个用户 ----------
@app.get("/users/{user_id}")
def get_user(user_id: int, db: Session = Depends(get_db)):
return db.execute(select(User).where(User.id == user_id)).scalars().first()