|  |  | @ -1,4 +1,3 @@ | 
			
		
	
		
		
			
				
					
					|  |  |  | from binascii import Error |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  | from copy import deepcopy |  |  |  | from copy import deepcopy | 
			
		
	
		
		
			
				
					
					|  |  |  | from dataclasses import dataclass |  |  |  | from dataclasses import dataclass | 
			
		
	
		
		
			
				
					
					|  |  |  | from enum import Enum, auto |  |  |  | from enum import Enum, auto | 
			
		
	
	
		
		
			
				
					|  |  | @ -90,6 +89,9 @@ class projection(ast_node): | 
			
		
	
		
		
			
				
					
					|  |  |  |         elif 'select_distinct' in node: |  |  |  |         elif 'select_distinct' in node: | 
			
		
	
		
		
			
				
					
					|  |  |  |             p = node['select_distinct'] |  |  |  |             p = node['select_distinct'] | 
			
		
	
		
		
			
				
					
					|  |  |  |             self.distinct = True |  |  |  |             self.distinct = True | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         else: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             raise NotImplementedError('AST node is not a projection node') | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |          | 
			
		
	
		
		
			
				
					
					|  |  |  |         if 'with' in node: |  |  |  |         if 'with' in node: | 
			
		
	
		
		
			
				
					
					|  |  |  |             with_table = node['with']['name'] |  |  |  |             with_table = node['with']['name'] | 
			
		
	
		
		
			
				
					
					|  |  |  |             with_table_name = tuple(with_table.keys())[0] |  |  |  |             with_table_name = tuple(with_table.keys())[0] | 
			
		
	
	
		
		
			
				
					|  |  | @ -946,10 +948,41 @@ class filter(ast_node): | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.add(filter_expr.sql) |  |  |  |         self.add(filter_expr.sql) | 
			
		
	
		
		
			
				
					
					|  |  |  |         if self.datasource is not None: |  |  |  |         if self.datasource is not None: | 
			
		
	
		
		
			
				
					
					|  |  |  |             self.datasource.join_conditions += filter_expr.join_conditions |  |  |  |             self.datasource.join_conditions += filter_expr.join_conditions | 
			
		
	
		
		
			
				
					
					|  |  |  |          |  |  |  | 
 | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | class union_all(ast_node): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     name = 'union_all' | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     first_order = name | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     sql_name = 'UNION ALL' | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def produce(self, node): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         queries = node[self.name] | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         generated_queries : List[Optional[projection]] = [None] * len(queries) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         is_standard = True | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         for i, q in enumerate(queries): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             if 'select' in q: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                 generated_queries[i] = projection(self, q) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                 is_standard &= not generated_queries[i].has_postproc | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         if is_standard: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             self.sql = f' {self.sql_name} '.join([q.sql for q in generated_queries]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         else: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |             raise NotImplementedError(f"{self.sql_name} only support standard sql for now") | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     def consume(self, node): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         super().consume(node) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |         self.context.direct_output() | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | class except_clause(union_all): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     name = 'except' | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     first_order = name | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     sql_name = 'EXCEPT' | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | class create_table(ast_node): |  |  |  | class create_table(ast_node): | 
			
		
	
		
		
			
				
					
					|  |  |  |     name = 'create_table' |  |  |  |     name = 'create_table' | 
			
		
	
		
		
			
				
					
					|  |  |  |     first_order = name |  |  |  |     first_order = name | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     allowed_subq = { | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         'select_distinct': select_distinct,  | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         'select': projection,  | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         'union_all': union_all,  | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         'except': except_clause | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     } | 
			
		
	
		
		
			
				
					
					|  |  |  |     def init(self, node): |  |  |  |     def init(self, node): | 
			
		
	
		
		
			
				
					
					|  |  |  |         node = node[self.name] |  |  |  |         node = node[self.name] | 
			
		
	
		
		
			
				
					
					|  |  |  |         if 'query' in node: |  |  |  |         if 'query' in node: | 
			
		
	
	
		
		
			
				
					|  |  | @ -957,9 +990,11 @@ class create_table(ast_node): | 
			
		
	
		
		
			
				
					
					|  |  |  |                 raise ValueError("Table name not specified") |  |  |  |                 raise ValueError("Table name not specified") | 
			
		
	
		
		
			
				
					
					|  |  |  |             projection_node = node['query'] |  |  |  |             projection_node = node['query'] | 
			
		
	
		
		
			
				
					
					|  |  |  |             projection_node['into'] = node['name'] |  |  |  |             projection_node['into'] = node['name'] | 
			
		
	
		
		
			
				
					
					|  |  |  |             proj_cls = (select_distinct  |  |  |  |             proj_cls = projection | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                         if 'select_distinct' in projection_node  |  |  |  |             for k in create_table.allowed_subq.keys(): | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                         else projection) |  |  |  |                 if k in projection_node: | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     proj_cls = create_table.allowed_subq[k] | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     break | 
			
		
	
		
		
			
				
					
					|  |  |  |             proj_cls(None, projection_node, self.context) |  |  |  |             proj_cls(None, projection_node, self.context) | 
			
		
	
		
		
			
				
					
					|  |  |  |             self.produce = lambda *_: None |  |  |  |             self.produce = lambda *_: None | 
			
		
	
		
		
			
				
					
					|  |  |  |             self.spawn = lambda *_: None |  |  |  |             self.spawn = lambda *_: None | 
			
		
	
	
		
		
			
				
					|  |  | @ -1073,31 +1108,6 @@ class delete_from(ast_node): | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.sql = f'DELETE FROM {tbl} ' |  |  |  |         self.sql = f'DELETE FROM {tbl} ' | 
			
		
	
		
		
			
				
					
					|  |  |  |         if 'where' in node: |  |  |  |         if 'where' in node: | 
			
		
	
		
		
			
				
					
					|  |  |  |             self.sql += filter(self, node['where']).sql |  |  |  |             self.sql += filter(self, node['where']).sql | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  | class union_all(ast_node): |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     name = 'union_all' |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     first_order = name |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     sql_name = 'UNION ALL' |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     def produce(self, node): |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         queries = node[self.name] |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         generated_queries : List[Optional[projection]] = [None] * len(queries) |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         is_standard = True |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         for i, q in enumerate(queries): |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |             if 'select' in q: |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |                 generated_queries[i] = projection(self, q) |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |                 is_standard &= not generated_queries[i].has_postproc |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         if is_standard: |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |             self.sql = f' {self.sql_name} '.join([q.sql for q in generated_queries]) |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         else: |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |             raise NotImplementedError(f"{self.sql_name} only support standard sql for now") |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     def consume(self, node): |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         super().consume(node) |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |         self.context.direct_output() |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  | class except_clause(union_all): |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     name = 'except' |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     first_order = name |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |     sql_name = 'EXCEPT' |  |  |  |  | 
			
		
	
		
		
			
				
					
					|  |  |  |      |  |  |  |      | 
			
		
	
		
		
			
				
					
					|  |  |  | class load(ast_node): |  |  |  | class load(ast_node): | 
			
		
	
		
		
			
				
					
					|  |  |  |     name="load" |  |  |  |     name="load" | 
			
		
	
	
		
		
			
				
					|  |  | 
 |