1use chrono::{DateTime, Utc};
8use mas_iana::oauth::PkceCodeChallengeMethod;
9use oauth2_types::{
10 pkce::{CodeChallengeError, CodeChallengeMethodExt},
11 requests::ResponseMode,
12 scope::{OPENID, PROFILE, Scope},
13};
14use rand::{
15 RngCore,
16 distributions::{Alphanumeric, DistString},
17};
18use serde::Serialize;
19use ulid::Ulid;
20use url::Url;
21
22use super::session::Session;
23use crate::InvalidTransitionError;
24
25#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
26pub struct Pkce {
27 pub challenge_method: PkceCodeChallengeMethod,
28 pub challenge: String,
29}
30
31impl Pkce {
32 #[must_use]
34 pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
35 Pkce {
36 challenge_method,
37 challenge,
38 }
39 }
40
41 pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
47 self.challenge_method.verify(&self.challenge, verifier)
48 }
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
52pub struct AuthorizationCode {
53 pub code: String,
54 pub pkce: Option<Pkce>,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
58#[serde(tag = "stage", rename_all = "lowercase")]
59pub enum AuthorizationGrantStage {
60 #[default]
61 Pending,
62 Fulfilled {
63 session_id: Ulid,
64 fulfilled_at: DateTime<Utc>,
65 },
66 Exchanged {
67 session_id: Ulid,
68 fulfilled_at: DateTime<Utc>,
69 exchanged_at: DateTime<Utc>,
70 },
71 Cancelled {
72 cancelled_at: DateTime<Utc>,
73 },
74}
75
76impl AuthorizationGrantStage {
77 #[must_use]
78 pub fn new() -> Self {
79 Self::Pending
80 }
81
82 fn fulfill(
83 self,
84 fulfilled_at: DateTime<Utc>,
85 session: &Session,
86 ) -> Result<Self, InvalidTransitionError> {
87 match self {
88 Self::Pending => Ok(Self::Fulfilled {
89 fulfilled_at,
90 session_id: session.id,
91 }),
92 _ => Err(InvalidTransitionError),
93 }
94 }
95
96 fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
97 match self {
98 Self::Fulfilled {
99 fulfilled_at,
100 session_id,
101 } => Ok(Self::Exchanged {
102 fulfilled_at,
103 exchanged_at,
104 session_id,
105 }),
106 _ => Err(InvalidTransitionError),
107 }
108 }
109
110 fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
111 match self {
112 Self::Pending => Ok(Self::Cancelled { cancelled_at }),
113 _ => Err(InvalidTransitionError),
114 }
115 }
116
117 #[must_use]
121 pub fn is_pending(&self) -> bool {
122 matches!(self, Self::Pending)
123 }
124
125 #[must_use]
129 pub fn is_fulfilled(&self) -> bool {
130 matches!(self, Self::Fulfilled { .. })
131 }
132
133 #[must_use]
137 pub fn is_exchanged(&self) -> bool {
138 matches!(self, Self::Exchanged { .. })
139 }
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
143pub struct AuthorizationGrant {
144 pub id: Ulid,
145 #[serde(flatten)]
146 pub stage: AuthorizationGrantStage,
147 pub code: Option<AuthorizationCode>,
148 pub client_id: Ulid,
149 pub redirect_uri: Url,
150 pub scope: Scope,
151 pub state: Option<String>,
152 pub nonce: Option<String>,
153 pub response_mode: ResponseMode,
154 pub response_type_id_token: bool,
155 pub created_at: DateTime<Utc>,
156 pub login_hint: Option<String>,
157 pub locale: Option<String>,
158}
159
160impl std::ops::Deref for AuthorizationGrant {
161 type Target = AuthorizationGrantStage;
162
163 fn deref(&self) -> &Self::Target {
164 &self.stage
165 }
166}
167
168impl AuthorizationGrant {
169 pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
177 self.stage = self.stage.exchange(exchanged_at)?;
178 Ok(self)
179 }
180
181 pub fn fulfill(
189 mut self,
190 fulfilled_at: DateTime<Utc>,
191 session: &Session,
192 ) -> Result<Self, InvalidTransitionError> {
193 self.stage = self.stage.fulfill(fulfilled_at, session)?;
194 Ok(self)
195 }
196
197 pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
209 self.stage = self.stage.cancel(canceld_at)?;
210 Ok(self)
211 }
212
213 #[doc(hidden)]
214 pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
215 Self {
216 id: Ulid::from_datetime_with_source(now.into(), rng),
217 stage: AuthorizationGrantStage::Pending,
218 code: Some(AuthorizationCode {
219 code: Alphanumeric.sample_string(rng, 10),
220 pkce: None,
221 }),
222 client_id: Ulid::from_datetime_with_source(now.into(), rng),
223 redirect_uri: Url::parse("http://localhost:8080").unwrap(),
224 scope: Scope::from_iter([OPENID, PROFILE]),
225 state: Some(Alphanumeric.sample_string(rng, 10)),
226 nonce: Some(Alphanumeric.sample_string(rng, 10)),
227 response_mode: ResponseMode::Query,
228 response_type_id_token: false,
229 created_at: now,
230 login_hint: Some(String::from("mxid:@example-user:example.com")),
231 locale: Some(String::from("fr")),
232 }
233 }
234}