]>
Commit | Line | Data |
---|---|---|
53e6db90 DC |
1 | """Tests for elpy.rpc.""" |
2 | ||
3 | import json | |
4 | import unittest | |
5 | import sys | |
6 | ||
7 | from elpy import rpc | |
8 | from elpy.tests.compat import StringIO | |
9 | ||
10 | ||
11 | class TestFault(unittest.TestCase): | |
12 | def test_should_have_code_and_data(self): | |
13 | fault = rpc.Fault("Hello", code=250, data="Fnord") | |
14 | self.assertEqual(str(fault), "Hello") | |
15 | self.assertEqual(fault.code, 250) | |
16 | self.assertEqual(fault.data, "Fnord") | |
17 | ||
18 | def test_should_have_defaults_for_code_and_data(self): | |
19 | fault = rpc.Fault("Hello") | |
20 | self.assertEqual(str(fault), "Hello") | |
21 | self.assertEqual(fault.code, 500) | |
22 | self.assertIsNone(fault.data) | |
23 | ||
24 | ||
25 | class TestJSONRPCServer(unittest.TestCase): | |
26 | def setUp(self): | |
27 | self.stdin = StringIO() | |
28 | self.stdout = StringIO() | |
29 | self.rpc = rpc.JSONRPCServer(self.stdin, self.stdout) | |
30 | ||
31 | def write(self, s): | |
32 | self.stdin.seek(0) | |
33 | self.stdin.truncate() | |
34 | self.stdout.seek(0) | |
35 | self.stdout.truncate() | |
36 | self.stdin.write(s) | |
37 | self.stdin.seek(0) | |
38 | ||
39 | def read(self): | |
40 | value = self.stdout.getvalue() | |
41 | self.stdin.seek(0) | |
42 | self.stdin.truncate() | |
43 | self.stdout.seek(0) | |
44 | self.stdout.truncate() | |
45 | return value | |
46 | ||
47 | ||
48 | class TestInit(TestJSONRPCServer): | |
49 | def test_should_use_arguments(self): | |
50 | self.assertEqual(self.rpc.stdin, self.stdin) | |
51 | self.assertEqual(self.rpc.stdout, self.stdout) | |
52 | ||
53 | def test_should_default_to_sys(self): | |
54 | testrpc = rpc.JSONRPCServer() | |
55 | self.assertEqual(sys.stdin, testrpc.stdin) | |
56 | self.assertEqual(sys.stdout, testrpc.stdout) | |
57 | ||
58 | ||
59 | class TestReadJson(TestJSONRPCServer): | |
60 | def test_should_read_json(self): | |
61 | objlist = [{'foo': 'bar'}, | |
62 | {'baz': 'qux', 'fnord': 'argl\nbargl'}, | |
63 | "beep\r\nbeep\r\nbeep"] | |
64 | self.write("".join([(json.dumps(obj) + "\n") | |
65 | for obj in objlist])) | |
66 | for obj in objlist: | |
67 | self.assertEqual(self.rpc.read_json(), | |
68 | obj) | |
69 | ||
70 | def test_should_raise_eof_on_eof(self): | |
71 | self.assertRaises(EOFError, self.rpc.read_json) | |
72 | ||
73 | def test_should_fail_on_malformed_json(self): | |
74 | self.write("malformed json\n") | |
75 | self.assertRaises(ValueError, | |
76 | self.rpc.read_json) | |
77 | ||
78 | ||
79 | class TestWriteJson(TestJSONRPCServer): | |
80 | def test_should_write_json_line(self): | |
81 | objlist = [{'foo': 'bar'}, | |
82 | {'baz': 'qux', 'fnord': 'argl\nbargl'}, | |
83 | ] | |
84 | for obj in objlist: | |
85 | self.rpc.write_json(**obj) | |
86 | self.assertEqual(json.loads(self.read()), | |
87 | obj) | |
88 | ||
89 | ||
90 | class TestHandleRequest(TestJSONRPCServer): | |
91 | def test_should_fail_if_json_does_not_contain_a_method(self): | |
92 | self.write(json.dumps(dict(params=[], | |
93 | id=23))) | |
94 | self.assertRaises(ValueError, | |
95 | self.rpc.handle_request) | |
96 | ||
97 | def test_should_call_right_method(self): | |
98 | self.write(json.dumps(dict(method='foo', | |
99 | params=[1, 2, 3], | |
100 | id=23))) | |
101 | self.rpc.rpc_foo = lambda *params: params | |
102 | self.rpc.handle_request() | |
103 | self.assertEqual(json.loads(self.read()), | |
104 | dict(id=23, | |
105 | result=[1, 2, 3])) | |
106 | ||
107 | def test_should_pass_defaults_for_missing_parameters(self): | |
108 | def test_method(*params): | |
109 | self.args = params | |
110 | ||
111 | self.write(json.dumps(dict(method='foo'))) | |
112 | self.rpc.rpc_foo = test_method | |
113 | self.rpc.handle_request() | |
114 | self.assertEqual(self.args, ()) | |
115 | self.assertEqual(self.read(), "") | |
116 | ||
117 | def test_should_return_error_for_missing_method(self): | |
118 | self.write(json.dumps(dict(method='foo', | |
119 | id=23))) | |
120 | self.rpc.handle_request() | |
121 | result = json.loads(self.read()) | |
122 | ||
123 | self.assertEqual(result["id"], 23) | |
124 | self.assertEqual(result["error"]["message"], | |
125 | "Unknown method foo") | |
126 | ||
127 | def test_should_return_error_for_exception_in_method(self): | |
128 | def test_method(): | |
129 | raise ValueError("An error was raised") | |
130 | ||
131 | self.write(json.dumps(dict(method='foo', | |
132 | id=23))) | |
133 | self.rpc.rpc_foo = test_method | |
134 | ||
135 | self.rpc.handle_request() | |
136 | result = json.loads(self.read()) | |
137 | ||
138 | self.assertEqual(result["id"], 23) | |
139 | self.assertEqual(result["error"]["message"], "An error was raised") | |
140 | self.assertIn("traceback", result["error"]["data"]) | |
141 | ||
142 | def test_should_not_include_traceback_for_faults(self): | |
143 | def test_method(): | |
144 | raise rpc.Fault("This is a fault") | |
145 | ||
146 | self.write(json.dumps(dict(method="foo", | |
147 | id=23))) | |
148 | self.rpc.rpc_foo = test_method | |
149 | ||
150 | self.rpc.handle_request() | |
151 | result = json.loads(self.read()) | |
152 | ||
153 | self.assertEqual(result["id"], 23) | |
154 | self.assertEqual(result["error"]["message"], "This is a fault") | |
155 | self.assertNotIn("traceback", result["error"]) | |
156 | ||
157 | def test_should_add_data_for_faults(self): | |
158 | def test_method(): | |
159 | raise rpc.Fault("St. Andreas' Fault", | |
160 | code=12345, data="Yippieh") | |
161 | ||
162 | self.write(json.dumps(dict(method="foo", id=23))) | |
163 | self.rpc.rpc_foo = test_method | |
164 | ||
165 | self.rpc.handle_request() | |
166 | result = json.loads(self.read()) | |
167 | ||
168 | self.assertEqual(result["error"]["data"], "Yippieh") | |
169 | ||
170 | def test_should_call_handle_for_unknown_method(self): | |
171 | def test_handle(method_name, args): | |
172 | return "It works" | |
173 | self.write(json.dumps(dict(method="doesnotexist", | |
174 | id=23))) | |
175 | self.rpc.handle = test_handle | |
176 | self.rpc.handle_request() | |
177 | self.assertEqual(json.loads(self.read()), | |
178 | dict(id=23, | |
179 | result="It works")) | |
180 | ||
181 | ||
182 | class TestServeForever(TestJSONRPCServer): | |
183 | def handle_request(self): | |
184 | self.hr_called += 1 | |
185 | if self.hr_called > 10: | |
186 | raise self.error() | |
187 | ||
188 | def setUp(self): | |
189 | super(TestServeForever, self).setUp() | |
190 | self.hr_called = 0 | |
191 | self.error = KeyboardInterrupt | |
192 | self.rpc.handle_request = self.handle_request | |
193 | ||
194 | def test_should_call_handle_request_repeatedly(self): | |
195 | self.rpc.serve_forever() | |
196 | self.assertEqual(self.hr_called, 11) | |
197 | ||
198 | def test_should_return_on_some_errors(self): | |
199 | self.error = KeyboardInterrupt | |
200 | self.rpc.serve_forever() | |
201 | self.error = EOFError | |
202 | self.rpc.serve_forever() | |
203 | self.error = SystemExit | |
204 | self.rpc.serve_forever() | |
205 | ||
206 | def test_should_fail_on_most_errors(self): | |
207 | self.error = RuntimeError | |
208 | self.assertRaises(RuntimeError, | |
209 | self.rpc.serve_forever) |