Skip to content

Commit cca1f6d

Browse files
authored
Merge pull request kenkoooo#770 from magurotuna/language-count-client-sqlx
[replacement with sqlx] LanguageCountClient
2 parents b6ce0bb + 80c71a2 commit cca1f6d

6 files changed

Lines changed: 232 additions & 0 deletions

File tree

atcoder-problems-backend/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

atcoder-problems-backend/sql-client/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ serde = { version = "1.0", features = ["derive"] }
1414
uuid = { version = "0.8", features = ["serde", "v4"] }
1515
anyhow = "1.0.32"
1616
async-std = { version = "1.6", features = ["attributes"] }
17+
regex = "1"
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
use crate::models::{Submission, UserLanguageCount};
2+
use crate::{PgPool, MAX_INSERT_ROWS};
3+
use anyhow::Result;
4+
use async_trait::async_trait;
5+
use regex::Regex;
6+
use sqlx::postgres::PgRow;
7+
use sqlx::Row;
8+
use std::collections::{BTreeMap, BTreeSet};
9+
10+
#[async_trait]
11+
pub trait LanguageCountClient {
12+
async fn update_language_count(&self, submissions: &[Submission]) -> Result<()>;
13+
async fn load_language_count(&self) -> Result<Vec<UserLanguageCount>>;
14+
}
15+
16+
#[async_trait]
17+
impl LanguageCountClient for PgPool {
18+
async fn update_language_count(&self, submissions: &[Submission]) -> Result<()> {
19+
let language_count = submissions
20+
.iter()
21+
.map(|s| {
22+
(
23+
s.user_id.as_str(),
24+
s.problem_id.as_str(),
25+
s.language.as_str(),
26+
)
27+
})
28+
.fold(
29+
BTreeMap::new(),
30+
|mut map, (user_id, problem_id, language)| {
31+
let simplified_language = simplify_language(&language);
32+
map.entry((user_id, simplified_language))
33+
.or_insert_with(BTreeSet::new)
34+
.insert(problem_id);
35+
map
36+
},
37+
)
38+
.into_iter()
39+
.map(|((user_id, language), set)| (user_id, language, set.len() as i32))
40+
.collect::<Vec<_>>();
41+
42+
for chunk in language_count.chunks(MAX_INSERT_ROWS) {
43+
let (user_ids, languages, counts) = chunk.iter().fold(
44+
(vec![], vec![], vec![]),
45+
|(mut user_ids, mut languages, mut counts), cur| {
46+
user_ids.push(cur.0);
47+
languages.push(cur.1.as_str());
48+
counts.push(cur.2);
49+
(user_ids, languages, counts)
50+
},
51+
);
52+
53+
sqlx::query(
54+
r"
55+
INSERT INTO language_count (user_id, simplified_language, problem_count)
56+
VALUES (
57+
UNNEST($1::VARCHAR(255)[]),
58+
UNNEST($2::VARCHAR(255)[]),
59+
UNNEST($3::INTEGER[])
60+
)
61+
ON CONFLICT (user_id, simplified_language)
62+
DO UPDATE SET problem_count = EXCLUDED.problem_count
63+
",
64+
)
65+
.bind(user_ids)
66+
.bind(languages)
67+
.bind(counts)
68+
.execute(self)
69+
.await?;
70+
}
71+
Ok(())
72+
}
73+
74+
async fn load_language_count(&self) -> Result<Vec<UserLanguageCount>> {
75+
let count = sqlx::query(
76+
r"
77+
SELECT
78+
user_id,
79+
simplified_language,
80+
problem_count
81+
FROM language_count
82+
ORDER BY user_id
83+
",
84+
)
85+
.try_map(|row: PgRow| {
86+
let user_id: String = row.try_get("user_id")?;
87+
let simplified_language: String = row.try_get("simplified_language")?;
88+
let problem_count: i32 = row.try_get("problem_count")?;
89+
Ok(UserLanguageCount {
90+
user_id,
91+
simplified_language,
92+
problem_count,
93+
})
94+
})
95+
.fetch_all(self)
96+
.await?;
97+
Ok(count)
98+
}
99+
}
100+
101+
fn simplify_language(lang: &str) -> String {
102+
let re = Regex::new(r"\d*\s*\(.*\)").unwrap();
103+
if lang.starts_with("Perl6") {
104+
"Perl6".to_string()
105+
} else {
106+
re.replace(lang, "").to_string()
107+
}
108+
}
109+
110+
#[cfg(test)]
111+
mod tests {
112+
use super::*;
113+
114+
#[test]
115+
fn test_simplify_language() {
116+
assert_eq!(simplify_language("language1"), "language1");
117+
assert_eq!(simplify_language("Perl (5)"), "Perl");
118+
assert_eq!(simplify_language("Perl6"), "Perl6");
119+
assert_eq!(simplify_language("Fortran(GNU Fortran 9.2.1)"), "Fortran");
120+
assert_eq!(simplify_language("Ada2012 (GNAT 9.2.1)"), "Ada");
121+
assert_eq!(simplify_language("PyPy2 (7.3.0)"), "PyPy");
122+
assert_eq!(simplify_language("Haxe (4.0.3); js"), "Haxe; js");
123+
}
124+
}

atcoder-problems-backend/sql-client/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::time::Duration;
44
pub mod accepted_count;
55
pub mod contest_problem;
66
pub mod internal;
7+
pub mod language_count;
78
pub mod models;
89

910
pub type PgPool = sqlx::postgres::PgPool;

atcoder-problems-backend/sql-client/src/models.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,17 @@ pub struct Submission {
1414
pub execution_time: Option<i32>,
1515
}
1616

17+
#[derive(Debug, Eq, PartialEq, Serialize)]
18+
pub struct UserLanguageCount {
19+
pub user_id: String,
20+
21+
#[serde(rename = "language")]
22+
pub simplified_language: String,
23+
24+
#[serde(rename = "count")]
25+
pub problem_count: i32,
26+
}
27+
1728
#[derive(Debug, Eq, PartialEq, Serialize)]
1829
pub struct UserProblemCount {
1930
pub user_id: String,
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
use sql_client::language_count::LanguageCountClient;
2+
use sql_client::models::{Submission, UserLanguageCount};
3+
4+
mod utils;
5+
6+
#[async_std::test]
7+
async fn test_language_count() {
8+
let pool = utils::initialize_and_connect_to_test_sql().await;
9+
let submissions = [
10+
Submission {
11+
id: 1,
12+
problem_id: "problem1".to_owned(),
13+
user_id: "user1".to_owned(),
14+
language: "language1".to_owned(),
15+
..Default::default()
16+
},
17+
Submission {
18+
id: 2,
19+
problem_id: "problem2".to_owned(),
20+
user_id: "user1".to_owned(),
21+
language: "language1".to_owned(),
22+
..Default::default()
23+
},
24+
Submission {
25+
id: 3,
26+
problem_id: "problem1".to_owned(),
27+
user_id: "user1".to_owned(),
28+
language: "language1".to_owned(),
29+
..Default::default()
30+
},
31+
Submission {
32+
id: 4,
33+
problem_id: "problem1".to_owned(),
34+
user_id: "user1".to_owned(),
35+
language: "language2".to_owned(),
36+
..Default::default()
37+
},
38+
Submission {
39+
id: 5,
40+
problem_id: "problem1".to_owned(),
41+
user_id: "user2".to_owned(),
42+
language: "language1".to_owned(),
43+
..Default::default()
44+
},
45+
Submission {
46+
id: 6,
47+
problem_id: "problem1".to_owned(),
48+
user_id: "user3".to_owned(),
49+
language: "Perl (5)".to_owned(),
50+
..Default::default()
51+
},
52+
Submission {
53+
id: 7,
54+
problem_id: "problem1".to_owned(),
55+
user_id: "user3".to_owned(),
56+
language: "Perl6".to_owned(),
57+
..Default::default()
58+
},
59+
];
60+
pool.update_language_count(&submissions).await.unwrap();
61+
62+
let language_count = pool.load_language_count().await.unwrap();
63+
assert_eq!(
64+
language_count,
65+
vec![
66+
UserLanguageCount {
67+
user_id: "user1".to_owned(),
68+
simplified_language: "language1".to_owned(),
69+
problem_count: 2
70+
},
71+
UserLanguageCount {
72+
user_id: "user1".to_owned(),
73+
simplified_language: "language2".to_owned(),
74+
problem_count: 1
75+
},
76+
UserLanguageCount {
77+
user_id: "user2".to_owned(),
78+
simplified_language: "language1".to_owned(),
79+
problem_count: 1
80+
},
81+
UserLanguageCount {
82+
user_id: "user3".to_owned(),
83+
simplified_language: "Perl".to_owned(),
84+
problem_count: 1
85+
},
86+
UserLanguageCount {
87+
user_id: "user3".to_owned(),
88+
simplified_language: "Perl6".to_owned(),
89+
problem_count: 1
90+
}
91+
]
92+
);
93+
}
94+

0 commit comments

Comments
 (0)