From 35c5b40e2a0a5c806be15f7d71591dfcd44add96 Mon Sep 17 00:00:00 2001 From: nsaccente Date: 2025年7月23日 21:56:53 -0400 Subject: [PATCH 01/18] Add check for Literal type annotation in get_sqlalchemy_type to return an AutoString --- sqlmodel/main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 38c85915aa..404d1efd0d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -655,6 +655,9 @@ def get_sqlalchemy_type(field: Any) -> Any: type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) + # Checks for `Literal` type annotation + if type_ is Literal: + return AutoString # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): return sa_Enum(type_) From e562654a23be28f13d4ca2a820f267431073c937 Mon Sep 17 00:00:00 2001 From: nsaccente Date: Thu, 7 Aug 2025 17:58:19 -0400 Subject: [PATCH 02/18] Add unit test for Literal parsing --- tests/test_main.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/test_main.py b/tests/test_main.py index 60d5c40ebb..0155a06c64 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Literal import pytest from sqlalchemy.exc import IntegrityError @@ -125,3 +125,26 @@ class Hero(SQLModel, table=True): # The next statement should not raise an AttributeError assert hero_rusty_man.team assert hero_rusty_man.team.name == "Preventers" + + +def test_literal_typehints_are_treated_as_strings(clear_sqlmodel): + """Test https://github.com/fastapi/sqlmodel/issues/57""" + + class Hero(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(unique=True) + weakness: Literal["Kryptonite", "Dehydration", "Munchies"] + + + superman = Hero(name="Superman", weakness="Kryptonite") + + engine = create_engine("sqlite://", echo=True) + + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + session.add(superman) + session.commit() + session.refresh(superman) + assert superman.weakness == "Kryptonite" + assert isinstance(superman.weakness, str) From fcabf3f4204c8c861b14da50d6ad31ca78f96622 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Aug 2025 21:58:31 +0000 Subject: [PATCH 03/18] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 0155a06c64..297002de3c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Literal +from typing import List, Literal, Optional import pytest from sqlalchemy.exc import IntegrityError @@ -135,7 +135,6 @@ class Hero(SQLModel, table=True): name: str = Field(unique=True) weakness: Literal["Kryptonite", "Dehydration", "Munchies"] - superman = Hero(name="Superman", weakness="Kryptonite") engine = create_engine("sqlite://", echo=True) From 859a4af460230179e4d112f769efad116d20eb04 Mon Sep 17 00:00:00 2001 From: nsaccente Date: Thu, 7 Aug 2025 19:11:57 -0400 Subject: [PATCH 04/18] Testing git test runner with commented out code --- sqlmodel/_compat.py | 12 ++++++++---- tests/test_main.py | 10 +++++----- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 38dd501c4a..9103f8e2cd 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -10,6 +10,7 @@ Dict, ForwardRef, Generator, + Literal, Mapping, Optional, Set, @@ -22,6 +23,7 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin +from .sql.sqltypes import AutoString # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION @@ -458,10 +460,12 @@ def is_field_noneable(field: "FieldInfo") -> bool: ) return field.allow_none # type: ignore[no-any-return, attr-defined] - def get_sa_type_from_field(field: Any) -> Any: - if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: - return field.type_ - raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + # def get_sa_type_from_field(field: Any) -> Any: + # if field is Literal: + # return AutoString + # elif isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + # return field.type_ + # raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") def get_field_metadata(field: Any) -> Any: metadata = FakeMetadata() diff --git a/tests/test_main.py b/tests/test_main.py index 297002de3c..5416bfc666 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -135,15 +135,15 @@ class Hero(SQLModel, table=True): name: str = Field(unique=True) weakness: Literal["Kryptonite", "Dehydration", "Munchies"] - superman = Hero(name="Superman", weakness="Kryptonite") + superguy = Hero(name="Superguy", weakness="Kryptonite") engine = create_engine("sqlite://", echo=True) SQLModel.metadata.create_all(engine) with Session(engine) as session: - session.add(superman) + session.add(superguy) session.commit() - session.refresh(superman) - assert superman.weakness == "Kryptonite" - assert isinstance(superman.weakness, str) + session.refresh(superguy) + assert superguy.weakness == "Kryptonite" + assert isinstance(superguy.weakness, str) From e193bcdc2369245ae39b9577b218c67bddd94d2f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Aug 2025 23:12:08 +0000 Subject: [PATCH 05/18] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 9103f8e2cd..8d511020b3 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -10,7 +10,6 @@ Dict, ForwardRef, Generator, - Literal, Mapping, Optional, Set, @@ -23,7 +22,6 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin -from .sql.sqltypes import AutoString # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION From cdc863dc48a197626138fa4de04e1c7554be9171 Mon Sep 17 00:00:00 2001 From: nsaccente Date: Thu, 7 Aug 2025 19:18:23 -0400 Subject: [PATCH 06/18] Test literal patch in _compat --- sqlmodel/_compat.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 9103f8e2cd..f2473aa626 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -460,12 +460,12 @@ def is_field_noneable(field: "FieldInfo") -> bool: ) return field.allow_none # type: ignore[no-any-return, attr-defined] - # def get_sa_type_from_field(field: Any) -> Any: - # if field is Literal: - # return AutoString - # elif isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: - # return field.type_ - # raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") + def get_sa_type_from_field(field: Any) -> Any: + if get_origin(field.type_) is Literal: + return AutoString + elif isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + return field.type_ + raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") def get_field_metadata(field: Any) -> Any: metadata = FakeMetadata() From 4e8b303bb439742566f2a21a2847b75b01562c7c Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月22日 14:19:44 +0200 Subject: [PATCH 07/18] import Literal from typing or typing_extensions --- sqlmodel/_compat.py | 7 +++++++ sqlmodel/main.py | 3 ++- tests/test_main.py | 3 ++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 7f8669cac7..2b95a51ee1 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -1,3 +1,4 @@ +import sys import types from contextlib import contextmanager from contextvars import ContextVar @@ -62,6 +63,12 @@ def _is_union_type(t: Any) -> bool: return t is UnionType or t is Union +if sys.version_info>= (3, 9): + from typing import Literal +else: + from typing_extensions import Literal + + finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 404d1efd0d..b49fc4f4f1 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -52,12 +52,13 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import Literal, TypeAlias, deprecated, get_origin +from typing_extensions import TypeAlias, deprecated, get_origin from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, PYDANTIC_MINOR_VERSION, BaseConfig, + Literal, ModelField, ModelMetaclass, Representation, diff --git a/tests/test_main.py b/tests/test_main.py index 5416bfc666..5ee61f446f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,9 +1,10 @@ -from typing import List, Literal, Optional +from typing import List, Optional import pytest from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import RelationshipProperty from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select +from sqlmodel._compat import Literal def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): From c659666cdb0a98a4c5f07f23ae8dda83717839bf Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月22日 14:23:46 +0200 Subject: [PATCH 08/18] fix import of AutoString --- sqlmodel/_compat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 2b95a51ee1..48d6fab162 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -24,6 +24,8 @@ from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin +from . import AutoString + # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2]) From ce968ea67500bde5cf5f401556534f04d26af0b6 Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月22日 14:25:48 +0200 Subject: [PATCH 09/18] avoid circular import --- sqlmodel/_compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 48d6fab162..a6c97af659 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -24,7 +24,7 @@ from pydantic.fields import FieldInfo from typing_extensions import Annotated, get_args, get_origin -from . import AutoString +from .sql.sqltypes import AutoString # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION From e522d9dbd80d333216cf5e517a7e7e92ca93b7c6 Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月26日 15:27:36 +0200 Subject: [PATCH 10/18] fix case where AutoString is being converted --- sqlmodel/main.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index b49fc4f4f1..47ec456947 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -39,6 +39,7 @@ Numeric, inspect, ) +from sqlalchemy import types as sa_types from sqlalchemy import Enum as sa_Enum from sqlalchemy.orm import ( Mapped, @@ -656,6 +657,10 @@ def get_sqlalchemy_type(field: Any) -> Any: type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) + # If it's already an SQLAlchemy type (eg. AutoString), use it directly + if isinstance(type_, type) and issubclass(type_, sa_types.TypeEngine): + return type_ + # Checks for `Literal` type annotation if type_ is Literal: return AutoString From 510e4219473e4940f7e996333849b8938a4c3b84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: 2025年8月26日 13:29:25 +0000 Subject: [PATCH 11/18] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 47ec456947..f15037e3f7 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -39,8 +39,8 @@ Numeric, inspect, ) -from sqlalchemy import types as sa_types from sqlalchemy import Enum as sa_Enum +from sqlalchemy import types as sa_types from sqlalchemy.orm import ( Mapped, RelationshipProperty, From 9068bad194698f14ac559d5e40de122e6d34b9c6 Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月26日 15:46:50 +0200 Subject: [PATCH 12/18] fix enum cast --- sqlmodel/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index f15037e3f7..abb0b5dac0 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -666,7 +666,7 @@ def get_sqlalchemy_type(field: Any) -> Any: return AutoString # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): - return sa_Enum(type_) + return sa_Enum(cast(Type[Enum], type_)) if issubclass( type_, ( From 545a55aa31fd2b31ffde18ed18932c47c1956145 Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月26日 16:15:17 +0200 Subject: [PATCH 13/18] set typing-modules to avoid ruff error --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 766b055819..28cd9fadf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ disallow_untyped_defs = false disallow_untyped_calls = false [tool.ruff.lint] +typing-modules = ["sqlmodel._compat"] select = [ "E", # pycodestyle errors "W", # pycodestyle warnings From 1fc1717794e66f3462679a3e6347c21d315e584b Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月27日 12:19:17 +0200 Subject: [PATCH 14/18] simply import Literal from typing_extensions always --- pyproject.toml | 1 - sqlmodel/_compat.py | 9 +-------- sqlmodel/main.py | 3 +-- tests/test_main.py | 2 +- 4 files changed, 3 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 28cd9fadf2..766b055819 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,6 @@ disallow_untyped_defs = false disallow_untyped_calls = false [tool.ruff.lint] -typing-modules = ["sqlmodel._compat"] select = [ "E", # pycodestyle errors "W", # pycodestyle warnings diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index a6c97af659..05f33cae64 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -1,4 +1,3 @@ -import sys import types from contextlib import contextmanager from contextvars import ContextVar @@ -22,7 +21,7 @@ from pydantic import VERSION as P_VERSION from pydantic import BaseModel from pydantic.fields import FieldInfo -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, Literal, get_args, get_origin from .sql.sqltypes import AutoString @@ -65,12 +64,6 @@ def _is_union_type(t: Any) -> bool: return t is UnionType or t is Union -if sys.version_info>= (3, 9): - from typing import Literal -else: - from typing_extensions import Literal - - finish_init: ContextVar[bool] = ContextVar("finish_init", default=True) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index abb0b5dac0..42a4ac2d9f 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -53,13 +53,12 @@ from sqlalchemy.orm.instrumentation import is_instrumented from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid -from typing_extensions import TypeAlias, deprecated, get_origin +from typing_extensions import Literal, TypeAlias, deprecated, get_origin from ._compat import ( # type: ignore[attr-defined] IS_PYDANTIC_V2, PYDANTIC_MINOR_VERSION, BaseConfig, - Literal, ModelField, ModelMetaclass, Representation, diff --git a/tests/test_main.py b/tests/test_main.py index 5ee61f446f..98b9abcd67 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,7 +4,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import RelationshipProperty from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select -from sqlmodel._compat import Literal +from typing_extensions import Literal def test_should_allow_duplicate_row_if_unique_constraint_is_not_passed(clear_sqlmodel): From ab774b5d15200e51b4bc53091a6a2b507cb504e0 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: 2025年8月27日 12:23:32 +0200 Subject: [PATCH 15/18] commit suggestion by Yurii Co-authored-by: Motov Yurii <109919500+yuriimotov@users.noreply.github.com> --- sqlmodel/_compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index 05f33cae64..fb756548f9 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -462,8 +462,8 @@ def is_field_noneable(field: "FieldInfo") -> bool: def get_sa_type_from_field(field: Any) -> Any: if get_origin(field.type_) is Literal: - return AutoString - elif isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: + return Literal + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") From 7e550228d1882cfd0f0e81a599e7e6fc2f5cc549 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: 2025年8月27日 10:23:38 +0000 Subject: [PATCH 16/18] =?UTF-8?q?=F0=9F=8E=A8=20[pre-commit.ci]=20Auto=20f?= =?UTF-8?q?ormat=20from=20pre-commit.com=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sqlmodel/_compat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index fb756548f9..af90cfa823 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -23,8 +23,6 @@ from pydantic.fields import FieldInfo from typing_extensions import Annotated, Literal, get_args, get_origin -from .sql.sqltypes import AutoString - # Reassign variable to make it reexported for mypy PYDANTIC_VERSION = P_VERSION PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2]) From 1f0e0ea3cf4ecafad73138ba3f566a7f7c613a3d Mon Sep 17 00:00:00 2001 From: svlandeg Date: 2025年8月27日 12:25:34 +0200 Subject: [PATCH 17/18] cleanup --- sqlmodel/main.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 42a4ac2d9f..404d1efd0d 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -40,7 +40,6 @@ inspect, ) from sqlalchemy import Enum as sa_Enum -from sqlalchemy import types as sa_types from sqlalchemy.orm import ( Mapped, RelationshipProperty, @@ -656,16 +655,12 @@ def get_sqlalchemy_type(field: Any) -> Any: type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) - # If it's already an SQLAlchemy type (eg. AutoString), use it directly - if isinstance(type_, type) and issubclass(type_, sa_types.TypeEngine): - return type_ - # Checks for `Literal` type annotation if type_ is Literal: return AutoString # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): - return sa_Enum(cast(Type[Enum], type_)) + return sa_Enum(type_) if issubclass( type_, ( From 11cc55e5ddeac1203ec6fa28395bc772db64a1db Mon Sep 17 00:00:00 2001 From: Yurii Motov Date: 2025年9月30日 22:33:02 +0200 Subject: [PATCH 18/18] Handle special cases with `Literal` (all `int` and all `bool`) --- sqlmodel/_compat.py | 15 ++++++++++++++- sqlmodel/main.py | 3 --- tests/test_main.py | 45 +++++++++++++++++++++++++++++++++++---------- 3 files changed, 49 insertions(+), 14 deletions(-) diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index af90cfa823..6295662aec 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -208,6 +208,13 @@ def get_sa_type_from_type_annotation(annotation: Any) -> Any: # Optional unions are allowed use_type = bases[0] if bases[0] is not NoneType else bases[1] return get_sa_type_from_type_annotation(use_type) + if origin is Literal: + literal_args = get_args(annotation) + if all(isinstance(arg, bool) for arg in literal_args): # all bools + return bool + if all(isinstance(arg, int) for arg in literal_args): # all ints + return int + return str return origin def get_sa_type_from_field(field: Any) -> Any: @@ -460,7 +467,13 @@ def is_field_noneable(field: "FieldInfo") -> bool: def get_sa_type_from_field(field: Any) -> Any: if get_origin(field.type_) is Literal: - return Literal + literal_args = get_args(field.type_) + if all(isinstance(arg, bool) for arg in literal_args): # all bools + return bool + if all(isinstance(arg, int) for arg in literal_args): # all ints + return int + return str + if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON: return field.type_ raise ValueError(f"The field {field.name} has no matching SQLAlchemy type") diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 404d1efd0d..38c85915aa 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -655,9 +655,6 @@ def get_sqlalchemy_type(field: Any) -> Any: type_ = get_sa_type_from_field(field) metadata = get_field_metadata(field) - # Checks for `Literal` type annotation - if type_ is Literal: - return AutoString # Check enums first as an enum can also be a str, needed by Pydantic/FastAPI if issubclass(type_, Enum): return sa_Enum(type_) diff --git a/tests/test_main.py b/tests/test_main.py index 98b9abcd67..188bf9df66 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -128,23 +128,48 @@ class Hero(SQLModel, table=True): assert hero_rusty_man.team.name == "Preventers" -def test_literal_typehints_are_treated_as_strings(clear_sqlmodel): +def test_literal_str(clear_sqlmodel, caplog): """Test https://github.com/fastapi/sqlmodel/issues/57""" - class Hero(SQLModel, table=True): + class Model(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) - name: str = Field(unique=True) - weakness: Literal["Kryptonite", "Dehydration", "Munchies"] - - superguy = Hero(name="Superguy", weakness="Kryptonite") + all_str: Literal["a", "b", "c"] + mixed: Literal["yes", "no", 1, 0] + all_int: Literal[1, 2, 3] + int_bool: Literal[0, 1, True, False] + all_bool: Literal[True, False] + + obj = Model( + all_str="a", + mixed="yes", + all_int=1, + int_bool=True, + all_bool=False, + ) engine = create_engine("sqlite://", echo=True) SQLModel.metadata.create_all(engine) + # Check DDL + assert "all_str VARCHAR NOT NULL" in caplog.text + assert "mixed VARCHAR NOT NULL" in caplog.text + assert "all_int INTEGER NOT NULL" in caplog.text + assert "int_bool INTEGER NOT NULL" in caplog.text + assert "all_bool BOOLEAN NOT NULL" in caplog.text + + # Check query with Session(engine) as session: - session.add(superguy) + session.add(obj) session.commit() - session.refresh(superguy) - assert superguy.weakness == "Kryptonite" - assert isinstance(superguy.weakness, str) + session.refresh(obj) + assert isinstance(obj.all_str, str) + assert obj.all_str == "a" + assert isinstance(obj.mixed, str) + assert obj.mixed == "yes" + assert isinstance(obj.all_int, int) + assert obj.all_int == 1 + assert isinstance(obj.int_bool, int) + assert obj.int_bool == 1 + assert isinstance(obj.all_bool, bool) + assert obj.all_bool is False

AltStyle によって変換されたページ (->オリジナル) /