Skip to content

Commit da629a0

Browse files
committed
Implement keyword suggestion routine
`suggestions.rs` is almost porting of implementation of [this](python/cpython#16856) and [this](python/cpython#25397). Signed-off-by: snowapril <[email protected]>
1 parent 7f6e016 commit da629a0

File tree

4 files changed

+170
-1
lines changed

4 files changed

+170
-1
lines changed

common/src/str.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use ascii::AsciiString;
22
use once_cell::unsync::OnceCell;
3+
use std::cell::RefCell;
34
use std::fmt;
45
use std::ops::{Bound, RangeBounds};
6+
use std::thread_local;
57

68
#[cfg(not(target_arch = "wasm32"))]
79
#[allow(non_camel_case_types)]
@@ -98,6 +100,97 @@ pub const fn bytes_is_ascii(x: &str) -> bool {
98100
true
99101
}
100102

103+
fn substitution_cost(mut a: u8, mut b: u8) -> usize {
104+
if (a & 31) != (b & 31) {
105+
return 2; // MOVE COST
106+
}
107+
if a == b {
108+
return 0;
109+
}
110+
if (b'A'..=b'Z').contains(&a) {
111+
a += b'a' - b'A';
112+
}
113+
if (b'A'..=b'Z').contains(&b) {
114+
b += b'a' - b'A';
115+
}
116+
if a == b {
117+
1 // CASE COST
118+
} else {
119+
2 // MOVE COST
120+
}
121+
}
122+
123+
pub fn levenshtein_distance(a: &str, b: &str, max_cost: usize) -> usize {
124+
const MAX_STRING_SIZE: usize = 40;
125+
thread_local! {
126+
static BUFFER: RefCell<[usize; MAX_STRING_SIZE]> = RefCell::new([0usize; MAX_STRING_SIZE]);
127+
}
128+
129+
if a == b {
130+
return 0;
131+
}
132+
133+
let (mut a_bytes, mut b_bytes) = (a.as_bytes(), b.as_bytes());
134+
let (mut a_begin, mut a_end) = (0usize, a.len());
135+
let (mut b_begin, mut b_end) = (0usize, b.len());
136+
137+
while a_end > 0 && b_end > 0 && (a_bytes[a_begin] == b_bytes[b_begin]) {
138+
a_begin += 1;
139+
b_begin += 1;
140+
a_end -= 1;
141+
b_end -= 1;
142+
}
143+
while a_end > 0 && b_end > 0 && (a_bytes[a_begin + a_end - 1] == b_bytes[b_begin + b_end - 1]) {
144+
a_end -= 1;
145+
b_end -= 1;
146+
}
147+
if a_end == 0 || b_end == 0 {
148+
return (a_end + b_end) * 2;
149+
}
150+
if a_end > MAX_STRING_SIZE || b_end > MAX_STRING_SIZE {
151+
return max_cost + 1;
152+
}
153+
154+
if b_end < a_end {
155+
std::mem::swap(&mut a_bytes, &mut b_bytes);
156+
std::mem::swap(&mut a_begin, &mut b_begin);
157+
std::mem::swap(&mut a_end, &mut b_end);
158+
}
159+
160+
if (b_end - a_end) * 2 > max_cost {
161+
return max_cost + 1;
162+
}
163+
164+
BUFFER.with(|buffer| {
165+
let mut buffer = buffer.borrow_mut();
166+
for i in 0..a_end {
167+
buffer[i] = (i + 1) * 2;
168+
}
169+
170+
let mut result = 0usize;
171+
for (b_index, b_code) in b_bytes[b_begin..(b_begin + b_end)].iter().enumerate() {
172+
result = b_index * 2;
173+
let mut distance = result;
174+
let mut minimum = usize::MAX;
175+
for (a_index, a_code) in a_bytes[a_begin..(a_begin + a_end)].iter().enumerate() {
176+
let substitute = distance + substitution_cost(*b_code, *a_code);
177+
distance = buffer[a_index];
178+
let insert_delete = usize::min(result, distance) + 2;
179+
result = usize::min(insert_delete, substitute);
180+
181+
buffer[a_index] = result;
182+
if result < minimum {
183+
minimum = result;
184+
}
185+
}
186+
if minimum > max_cost {
187+
return max_cost + 1;
188+
}
189+
}
190+
result
191+
})
192+
}
193+
101194
#[macro_export]
102195
macro_rules! ascii {
103196
($x:literal) => {{

vm/src/builtins/code.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ impl PyRef<PyCode> {
256256
}
257257

258258
#[pyproperty]
259-
fn co_varnames(self, vm: &VirtualMachine) -> PyTupleRef {
259+
pub fn co_varnames(self, vm: &VirtualMachine) -> PyTupleRef {
260260
let varnames = self
261261
.code
262262
.varnames

vm/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ mod sequence;
7070
mod signal;
7171
pub mod sliceable;
7272
pub mod stdlib;
73+
pub mod suggestion;
7374
pub mod types;
7475
pub mod utils;
7576
pub mod version;

vm/src/suggestion.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use crate::{
2+
builtins::{PyStr, PyStrRef},
3+
exceptions::types::PyBaseExceptionRef,
4+
sliceable::PySliceableSequence,
5+
IdProtocol, PyObjectRef, TypeProtocol, VirtualMachine,
6+
};
7+
use rustpython_common::str::levenshtein_distance;
8+
use std::iter::ExactSizeIterator;
9+
10+
const MAX_CANDIDATE_ITEMS: usize = 750;
11+
12+
fn calculate_suggestions<'a>(
13+
dir_iter: impl ExactSizeIterator<Item = &'a PyObjectRef>,
14+
name: &PyObjectRef,
15+
) -> Option<PyStrRef> {
16+
if dir_iter.len() >= MAX_CANDIDATE_ITEMS {
17+
return None;
18+
}
19+
20+
let mut suggestion: Option<&PyStrRef> = None;
21+
let mut suggestion_distance = usize::MAX;
22+
let name = name.downcast_ref::<PyStr>()?;
23+
24+
for item in dir_iter {
25+
let item_name = item.downcast_ref::<PyStr>()?;
26+
if name.as_str() == item_name.as_str() {
27+
continue;
28+
}
29+
// No more than 1/3 of the characters should need changed
30+
let max_distance = usize::min(
31+
(name.len() + item_name.len() + 3) / 3,
32+
suggestion_distance - 1,
33+
);
34+
let current_distance =
35+
levenshtein_distance(name.as_str(), item_name.as_str(), max_distance);
36+
if current_distance > max_distance {
37+
continue;
38+
}
39+
if suggestion.is_none() || current_distance < suggestion_distance {
40+
suggestion = Some(item_name);
41+
suggestion_distance = current_distance;
42+
}
43+
}
44+
suggestion.cloned()
45+
}
46+
47+
pub fn offer_suggestions(exc: &PyBaseExceptionRef, vm: &VirtualMachine) -> Option<PyStrRef> {
48+
if exc.class().is(&vm.ctx.exceptions.attribute_error) {
49+
let name = exc.as_object().clone().get_attr("name", vm).unwrap();
50+
let obj = exc.as_object().clone().get_attr("obj", vm).unwrap();
51+
52+
calculate_suggestions(vm.dir(Some(obj)).ok()?.borrow_vec().iter(), &name)
53+
} else if exc.class().is(&vm.ctx.exceptions.name_error) {
54+
let name = exc.as_object().clone().get_attr("name", vm).unwrap();
55+
let mut tb = exc.traceback().unwrap();
56+
while let Some(traceback) = tb.next.clone() {
57+
tb = traceback;
58+
}
59+
60+
let varnames = tb.frame.code.clone().co_varnames(vm);
61+
if let Some(suggestions) = calculate_suggestions(varnames.as_slice().iter(), &name) {
62+
return Some(suggestions);
63+
};
64+
65+
let globals = vm.extract_elements(tb.frame.globals.as_object()).ok()?;
66+
if let Some(suggestions) = calculate_suggestions(globals.as_slice().iter(), &name) {
67+
return Some(suggestions);
68+
};
69+
70+
let builtins = vm.extract_elements(tb.frame.builtins.as_object()).ok()?;
71+
calculate_suggestions(builtins.as_slice().iter(), &name)
72+
} else {
73+
None
74+
}
75+
}

0 commit comments

Comments
 (0)