95 lines
2.8 KiB
Python
95 lines
2.8 KiB
Python
|
from typing import TYPE_CHECKING
|
||
|
|
||
|
from aides_spec.replacers.base import BaseReplacer
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from tree_sitter import Node
|
||
|
|
||
|
|
||
|
class LocalSourcesReplacer(BaseReplacer):
|
||
|
def process(self):
|
||
|
root_node = self.tree.root_node
|
||
|
|
||
|
self.local_files = []
|
||
|
self.prepare_func_body = None
|
||
|
|
||
|
def find_replacements(node: Node):
|
||
|
if node.type == "function_definition":
|
||
|
func_name = self._node_text(node.child_by_field_name("name"))
|
||
|
|
||
|
if func_name == "prepare":
|
||
|
self.prepare_func_body = node.child_by_field_name("body")
|
||
|
|
||
|
if node.type == "variable_assignment":
|
||
|
var_node = node.child_by_field_name("name")
|
||
|
value_node = node.child_by_field_name("value")
|
||
|
|
||
|
if var_node and value_node:
|
||
|
var_name = self._node_text(var_node)
|
||
|
if var_name == "sources":
|
||
|
self._remove_local_files(value_node)
|
||
|
|
||
|
for child in node.children:
|
||
|
find_replacements(child)
|
||
|
|
||
|
find_replacements(root_node)
|
||
|
|
||
|
copy_commands = "\n ".join(
|
||
|
f'cp "${{scriptdir}}/{file}" "${{srcdir}}"'
|
||
|
for file in self.local_files
|
||
|
)
|
||
|
|
||
|
prepare_func_content = f"""
|
||
|
{copy_commands}
|
||
|
"""
|
||
|
|
||
|
print(self.local_files)
|
||
|
|
||
|
if self.prepare_func_body is not None:
|
||
|
text = self._node_text(self.prepare_func_body)
|
||
|
closing_brace_index = text.rfind("}")
|
||
|
text = (
|
||
|
text[:closing_brace_index]
|
||
|
+ prepare_func_content
|
||
|
+ text[closing_brace_index:]
|
||
|
)
|
||
|
self.replaces.append(
|
||
|
{
|
||
|
"node": self.prepare_func_body,
|
||
|
"content": text,
|
||
|
}
|
||
|
)
|
||
|
else:
|
||
|
text = self._node_text(root_node)
|
||
|
text = f"""prepare() {{
|
||
|
{prepare_func_content}}}
|
||
|
"""
|
||
|
self.appends.append({"node": root_node, "content": text})
|
||
|
|
||
|
return self._apply_replacements()
|
||
|
|
||
|
def _remove_local_files(self, source_node: Node):
|
||
|
updated_items = []
|
||
|
for item_node in source_node.children:
|
||
|
item_text = self._node_text(item_node)
|
||
|
|
||
|
if item_text == "(" or item_text == ")":
|
||
|
continue
|
||
|
|
||
|
if "://" in item_text:
|
||
|
updated_items.append(item_text)
|
||
|
else:
|
||
|
text = item_text
|
||
|
if item_node.type == "string":
|
||
|
text = self._node_text(item_node.child(1))
|
||
|
|
||
|
self.local_files.append(text)
|
||
|
|
||
|
new_content = "(\n " + " \n".join(updated_items) + "\n)"
|
||
|
self.replaces.append(
|
||
|
{
|
||
|
"node": source_node,
|
||
|
"content": new_content,
|
||
|
}
|
||
|
)
|