Commit 1b0a4376 authored by erodrigu's avatar erodrigu
Browse files

Moved code from main into functions

parent da30cebf
......@@ -7,15 +7,14 @@ from AutoRPE.UtilsRPE.SourceManager import preprocess_sources, parse_sources, lo
from AutoRPE.UtilsRPE.Cleaner import remove_comments
from AutoRPE.UtilsRPE.Finder import find_declaration_line
from argparse import ArgumentParser
from os import mkdir
from os.path import join, isdir
from shutil import rmtree
import re
import itertools
def get_command_line_arguments():
from argparse import ArgumentParser
help_text = """
This script takes as arguments the input source (-s/--input-source)
......@@ -64,47 +63,18 @@ def replace_subroutine_name(old_name, new_name, code_text):
return pat.sub(replacement, code_text)
def xfind_declaration_line(variable, routine_lines):
import re
import warnings
# Declaration pattern
# TODO: This might have issues with split lines
pattern = r"::.*\b%s\b" % variable.name
exception_pattern = r"&*\b%s\b" % variable.name
real_pattern = r"real.*::"
current_block = CurrentBlock(vault[variable.procedure.module.name].main())
for _index, _line in enumerate(routine_lines):
# _line = remove_comments(_line).strip()
if re.search(pattern, _line):
if current_block.procedure.name == variable.procedure.name:
return _index
else:
current_block = current_block.in_which_block(_line, vault=vault)
print("Declaration of variable %s with id %s not found in module %s" % (variable.name, variable.id, variable.procedure.module.name))
return None
def obtain_routine_source(routine, sources_path):
source_file = join(sources_path, "%s.f90" % routine.module.name)
indices = []
with open(source_file) as f:
# Routine name pattern
pattern = r"routine.*\b%s\b" % routine.name
lines = [l for l in f]
indices = []
for index, line in enumerate(lines):
line = remove_comments(line).strip()
pattern = r"routine.*\b%s\b" % routine.name
lines = [l for l in f]
# Search for the start/end of the routine
if re.search(pattern, line, re.I):
indices.append(index)
routine_code = "".join(lines[indices[0]:indices[1] + 1])
for index, line in enumerate(lines):
line = remove_comments(line).strip()
if re.search(pattern, line, re.I):
indices.append(index)
routine_code = "".join(lines[indices[0]:indices[1]+1])
routine_code = merge_declaration_lines(routine_code)
return routine_code, indices
......@@ -116,11 +86,7 @@ def replace_routine_name(old_name, new_name, code_text):
replacement = r"\1 %s" % new_name
return pat.sub(replacement, code_text)
if __name__ == "__main__":
# Get command line arguments
vault_path, path_to_input_sources, routine_name = get_command_line_arguments()
def prepare_vault(vault_path=None):
path_to_preprocessed_sources = "./tmp_processedsources/"
# In case no vault has been provided it will create it from the sources in the input sources folder
if vault_path is None:
......@@ -128,63 +94,69 @@ if __name__ == "__main__":
# Path_to_input_sources
if not isdir(path_to_preprocessed_sources):
mkdir(path_to_preprocessed_sources)
# Format text + replace real declarations with type(rpe_var)
preprocess_sources(path_to_input_sources, path_to_preprocessed_sources)
# Obtain source information
vault = parse_sources(path_to_preprocessed_sources, save_vault=True)
# Remove generated folder
rmtree(path_to_preprocessed_sources)
else:
vault = load_vault(vault_path)
routine = None
for procedure in vault.procedures:
if procedure.name == routine_name:
routine = procedure
break
if routine is None:
print("Routine %s not found in vault." % routine_name)
exit(0)
return vault
routine_code, indices = obtain_routine_source(routine, path_to_input_sources)
routine_lines = routine_code.split("\n")
sensitive_variables = [v for v in routine.variables if v.is_dummy_argument and v.type == "double"]
def get_routine(vault, routine_name=None):
# Find procedure with the routine name
routine = [p for p in vault.procedures if p.name == routine_name]
# Check if we found it
if len(routine) == 0:
print("Routine %s not found in vault." % routine_name)
exit(0)
# All possible combinations of precisions
precision_combinations = list(itertools.product(["dp", "sp"], repeat=len(sensitive_variables)))
return routine[0]
# Get filepath
filepath = join(path_to_input_sources, "%s.f90" % routine.module.name)
def update_line(version_lines, variable, filepath, vault, routines_indices, variable_precision):
# Find the declaration/original lines
d_lines = find_declaration_line(variable, filepath, vault)
# Search in declaration lines
for d_line in d_lines:
# Update line if it's between routine lines routine
if d_line >= routines_indices[0] and d_line <= routines_indices[1]:
original_line = version_lines[d_line - routines_indices[0]]
new_line = replace_real_declaration(variable, original_line, variable_precision)
version_lines[d_line - routines_indices[0]] = new_line
def generate_routine_versions(precision_combinations, sensitive_variables, vault, routine_lines, routine, filepath, indices):
code_versions = {}
interface = "module procedure "
for c_index, combination in enumerate(precision_combinations):
version_lines = routine_lines[:]
new_routine_name = "%s_%04i" % (routine.name, c_index)
# Search and update real declaration
for index, variable in enumerate(sensitive_variables):
variable_precision = combination[index]
# Find the declaration/original lines
d_lines = find_declaration_line(variable, filepath, vault)
for d_line in d_lines:
if d_line >= indices[0] and d_line <= indices[1]:
original_line = version_lines[d_line - indices[0]]
new_line = replace_real_declaration(variable, original_line, variable_precision)
version_lines[d_line - indices[0]] = new_line
version_text = "\n".join(version_lines).strip()
update_line(routine_lines, variable, filepath, vault, indices, combination[index])
# Join the routine lines and uodate nane
version_text = "\n".join(routine_lines).strip()
version_text = replace_routine_name(routine.name, new_routine_name, version_text)
code_versions[c_index] = version_text
if c_index == 0:
interface += "%s" % new_routine_name
else:
interface += ", %s" % new_routine_name
return code_versions
if __name__ == "__main__":
# Get command line arguments
vault_path, path_to_input_sources, routine_name = get_command_line_arguments()
# Get the vault
vault = prepare_vault(vault_path)
# Get routine from vault
routine = get_routine(vault, routine_name)
# Get the routine code
routine_code, indices = obtain_routine_source(routine, path_to_input_sources)
# All possible combinations of precisions x sensitive variables
sensitive_variables = [v for v in routine.variables if v.is_dummy_argument and v.type == "double"]
precision_combinations = list(itertools.product(["dp", "sp"], repeat=len(sensitive_variables)))
# Get filepath
filepath = join(path_to_input_sources, "%s.f90" % routine.module.name)
# Generate the different routine versions
code_versions = generate_routine_versions(precision_combinations, sensitive_variables, vault, routine_code.split("\n"), routine, filepath, indices)
interface = "interface %s\n %s\nend interface %s" % (routine.name, interface, routine.name)
for key, code in code_versions.items():
print(code)
# Deleting temporary folders
if isdir(path_to_preprocessed_sources):
rmtree(path_to_preprocessed_sources)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment