Skip to content

Commit 860785d

Browse files
committed
Implement keyword suggestion routine
`suggestions.rs` is almost porting of implementation of [this](python/cpython#16856) and [this](python/cpython#25397). Signed-off-by: snowapril <[email protected]>
1 parent c63f973 commit 860785d

File tree

4 files changed

+170
-1
lines changed

4 files changed

+170
-1
lines changed

common/src/str.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use ascii::AsciiString;
2+
use std::cell::RefCell;
23
use std::ops::{Bound, RangeBounds};
4+
use std::thread_local;
35

46
#[cfg(not(target_arch = "wasm32"))]
57
#[allow(non_camel_case_types)]
@@ -96,6 +98,97 @@ pub const fn bytes_is_ascii(x: &str) -> bool {
9698
true
9799
}
98100

101+
fn substitution_cost(mut a: u8, mut b: u8) -> usize {
102+
if (a & 31) != (b & 31) {
103+
return 2; // MOVE COST
104+
}
105+
if a == b {
106+
return 0;
107+
}
108+
if (b'A'..=b'Z').contains(&a) {
109+
a += b'a' - b'A';
110+
}
111+
if (b'A'..=b'Z').contains(&b) {
112+
b += b'a' - b'A';
113+
}
114+
if a == b {
115+
1 // CASE COST
116+
} else {
117+
2 // MOVE COST
118+
}
119+
}
120+
121+
pub fn levenshtein_distance(a: &str, b: &str, max_cost: usize) -> usize {
122+
const MAX_STRING_SIZE: usize = 40;
123+
thread_local! {
124+
static BUFFER: RefCell<[usize; MAX_STRING_SIZE]> = RefCell::new([0usize; MAX_STRING_SIZE]);
125+
}
126+
127+
if a == b {
128+
return 0;
129+
}
130+
131+
let (mut a_bytes, mut b_bytes) = (a.as_bytes(), b.as_bytes());
132+
let (mut a_begin, mut a_end) = (0usize, a.len());
133+
let (mut b_begin, mut b_end) = (0usize, b.len());
134+
135+
while a_end > 0 && b_end > 0 && (a_bytes[a_begin] == b_bytes[b_begin]) {
136+
a_begin += 1;
137+
b_begin += 1;
138+
a_end -= 1;
139+
b_end -= 1;
140+
}
141+
while a_end > 0 && b_end > 0 && (a_bytes[a_begin + a_end - 1] == b_bytes[b_begin + b_end - 1]) {
142+
a_end -= 1;
143+
b_end -= 1;
144+
}
145+
if a_end == 0 || b_end == 0 {
146+
return (a_end + b_end) * 2;
147+
}
148+
if a_end > MAX_STRING_SIZE || b_end > MAX_STRING_SIZE {
149+
return max_cost + 1;
150+
}
151+
152+
if b_end < a_end {
153+
std::mem::swap(&mut a_bytes, &mut b_bytes);
154+
std::mem::swap(&mut a_begin, &mut b_begin);
155+
std::mem::swap(&mut a_end, &mut b_end);
156+
}
157+
158+
if (b_end - a_end) * 2 > max_cost {
159+
return max_cost + 1;
160+
}
161+
162+
BUFFER.with(|buffer| {
163+
let mut buffer = buffer.borrow_mut();
164+
for i in 0..a_end {
165+
buffer[i] = (i + 1) * 2;
166+
}
167+
168+
let mut result = 0usize;
169+
for (b_index, b_code) in b_bytes[b_begin..(b_begin + b_end)].iter().enumerate() {
170+
result = b_index * 2;
171+
let mut distance = result;
172+
let mut minimum = usize::MAX;
173+
for (a_index, a_code) in a_bytes[a_begin..(a_begin + a_end)].iter().enumerate() {
174+
let substitute = distance + substitution_cost(*b_code, *a_code);
175+
distance = buffer[a_index];
176+
let insert_delete = usize::min(result, distance) + 2;
177+
result = usize::min(insert_delete, substitute);
178+
179+
buffer[a_index] = result;
180+
if result < minimum {
181+
minimum = result;
182+
}
183+
}
184+
if minimum > max_cost {
185+
return max_cost + 1;
186+
}
187+
}
188+
result
189+
})
190+
}
191+
99192
#[macro_export]
100193
macro_rules! ascii {
101194
($x:literal) => {{

vm/src/builtins/code.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ impl PyRef<PyCode> {
256256
}
257257

258258
#[pyproperty]
259-
fn co_varnames(self, vm: &VirtualMachine) -> PyTupleRef {
259+
pub fn co_varnames(self, vm: &VirtualMachine) -> PyTupleRef {
260260
let varnames = self
261261
.code
262262
.varnames

vm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ mod sequence;
7070
mod signal;
7171
pub mod sliceable;
7272
pub mod stdlib;
73+
pub mod suggestion;
7374
pub mod types;
7475
pub mod utils;
7576
pub mod version;

vm/src/suggestion.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use crate::{
2+
builtins::{PyStr, PyStrRef},
3+
exceptions::types::PyBaseExceptionRef,
4+
sliceable::PySliceableSequence,
5+
IdProtocol, PyObjectRef, TypeProtocol, VirtualMachine,
6+
};
7+
use rustpython_common::str::levenshtein_distance;
8+
use std::iter::ExactSizeIterator;
9+
10+
const MAX_CANDIDATE_ITEMS: usize = 750;
11+
12+
fn calculate_suggestions<'a>(
13+
dir_iter: impl ExactSizeIterator<Item = &'a PyObjectRef>,
14+
name: &PyObjectRef,
15+
) -> Option<PyStrRef> {
16+
if dir_iter.len() >= MAX_CANDIDATE_ITEMS {
17+
return None;
18+
}
19+
20+
let mut suggestion: Option<&PyStrRef> = None;
21+
let mut suggestion_distance = usize::MAX;
22+
let name = name.downcast_ref::<PyStr>()?;
23+
24+
for item in dir_iter {
25+
let item_name = item.downcast_ref::<PyStr>()?;
26+
if name.as_str() == item_name.as_str() {
27+
continue;
28+
}
29+
// No more than 1/3 of the characters should need changed
30+
let max_distance = usize::min(
31+
(name.len() + item_name.len() + 3) / 3,
32+
suggestion_distance - 1,
33+
);
34+
let current_distance =
35+
levenshtein_distance(name.as_str(), item_name.as_str(), max_distance);
36+
if current_distance > max_distance {
37+
continue;
38+
}
39+
if suggestion.is_none() || current_distance < suggestion_distance {
40+
suggestion = Some(item_name);
41+
suggestion_distance = current_distance;
42+
}
43+
}
44+
suggestion.cloned()
45+
}
46+
47+
pub fn offer_suggestions(exc: &PyBaseExceptionRef, vm: &VirtualMachine) -> Option<PyStrRef> {
48+
if exc.class().is(&vm.ctx.exceptions.attribute_error) {
49+
let name = exc.as_object().clone().get_attr("name", vm).unwrap();
50+
let obj = exc.as_object().clone().get_attr("obj", vm).unwrap();
51+
52+
calculate_suggestions(vm.dir(Some(obj)).ok()?.borrow_vec().iter(), &name)
53+
} else if exc.class().is(&vm.ctx.exceptions.name_error) {
54+
let name = exc.as_object().clone().get_attr("name", vm).unwrap();
55+
let mut tb = exc.traceback().unwrap();
56+
while let Some(traceback) = tb.next.clone() {
57+
tb = traceback;
58+
}
59+
60+
let varnames = tb.frame.code.clone().co_varnames(vm);
61+
if let Some(suggestions) = calculate_suggestions(varnames.as_slice().iter(), &name) {
62+
return Some(suggestions);
63+
};
64+
65+
let globals = vm.extract_elements(tb.frame.globals.as_object()).ok()?;
66+
if let Some(suggestions) = calculate_suggestions(globals.as_slice().iter(), &name) {
67+
return Some(suggestions);
68+
};
69+
70+
let builtins = vm.extract_elements(tb.frame.builtins.as_object()).ok()?;
71+
calculate_suggestions(builtins.as_slice().iter(), &name)
72+
} else {
73+
None
74+
}
75+
}

0 commit comments

Comments
 (0)