# SPDX-FileCopyrightText: 2017 Scott Shawcroft, written for Adafruit Industries # SPDX-FileCopyrightText: Copyright (c) 2020 Melissa LeBlanc-Williams for Adafruit Industries # # SPDX-License-Identifier: MIT """ `adafruit_shell` ================================================================================ Python helper for running Shell scripts in Python * Author(s): Melissa LeBlanc-Williams Implementation Notes -------------------- **Software and Dependencies:** * Linux """ # imports import sys import os import shutil import subprocess import fcntl import platform import fileinput import re from datetime import datetime from clint.textui import colored, prompt import adafruit_platformdetect __version__ = "0.0.0-auto.0" __repo__ = "https://github.com/adafruit/Adafruit_Python_Shell.git" # pylint: disable=too-many-public-methods class Shell: """ Class to help with converting Shell scripts over to Python. Having all the functions in one place makes updates easier and code shorter. """ def __init__(self): self._group = None self._dirstack = [] @staticmethod def select_n(message, selections): """ Display a list of selections for the user to enter """ options = [] for index, selection in enumerate(selections): options.append( { "selector": str(index + 1), "prompt": selection, "return": index + 1, } ) return prompt.options(message, options) def run_command(self, cmd, suppress_message=False, return_output=False): """ Run a shell command and show the output as it runs """ def read_stream(output): file_descriptor = output.fileno() file_flags = fcntl.fcntl(file_descriptor, fcntl.F_GETFL) fcntl.fcntl(file_descriptor, fcntl.F_SETFL, file_flags | os.O_NONBLOCK) try: return output.read() except TypeError: return "" full_output = "" with subprocess.Popen( cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, ) as proc: while proc.poll() is None: err = read_stream(proc.stderr) if err != "" and not suppress_message: self.error(err.strip(), end="\n\r") output = read_stream(proc.stdout) if output != "" and not suppress_message: self.info(output.strip(), end="\n\r") full_output += output return_code = proc.poll() proc.stdout.close() proc.stderr.close() if return_output: return full_output if return_code: return False return True def info(self, message, **kwargs): """ Display a message with the group in green """ if self._group is not None: print(colored.green(self._group) + " " + message, **kwargs) else: print(message, **kwargs) def warn(self, message, **kwargs): """ Display a message with the group in yellow """ if self._group is not None: print(colored.yellow(self._group) + " " + message, **kwargs) else: print(message, **kwargs) def bail(self, message=None, **kwargs): """ Exit and display an error message if given """ if message is None: self.error("Exiting due to error", **kwargs) else: self.error(f"Exiting due to error: {message}", **kwargs) sys.exit(1) def error(self, message, **kwargs): """ Display some information """ if self._group is not None: print(colored.red(self._group) + " " + message, **kwargs) else: print(message, **kwargs) @staticmethod def print_colored(message, color): """Print out a message in a specific color""" colors = ("red", "green", "yellow", "blue", "black", "magenta", "cyan", "white") if color in colors: colorize = getattr(colored, color) print(colorize(message)) def prompt(self, message, *, default=None, force_arg=None, force_arg_value=True): """ A Yes/No prompt that accepts optional defaults Returns True for Yes and False for No """ if force_arg is not None and self.argument_exists(force_arg): return force_arg_value if default is None: choicebox = "[y/n]" else: if default not in ["y", "n"]: default = "y" choicebox = "[Y/n]" if default == "y" else "[y/N]" while True: reply = input(message + " " + choicebox + " ").strip() if reply == "" and default is not None: return default == "y" if re.match("y(?:es)?", reply, re.I): return True if re.match("n(?:o)?", reply, re.I): return False @staticmethod def clear(): """ Clear the screen """ os.system("clear") @staticmethod def reboot(): """ Reboot the system """ os.system("reboot") @staticmethod def getcwd(): """ Get the Current Working Directory """ return os.getcwd() def chdir(self, directory): """ Change directory """ # if directory[0] != "/" and directory[0] != ".": # directory = self.getcwd() + "/" + directory directory = self.path(directory) if not self.exists(directory): raise ValueError("Directory does not exist") if not self.isdir(directory): raise ValueError("Given location is not a directory") os.chdir(directory) def pushd(self, directory): """ Change directory """ # Add current dir to stack self._dirstack.append(self.getcwd()) # change dir self.chdir(directory) def popd(self): """ Change directory """ # Add current dir to stack if len(self._dirstack) > 0: directory = self._dirstack.pop() self.chdir(directory) else: raise RuntimeError("Directory stack empty") @staticmethod def path(file_path): """ Return the relative path. This works for paths starting with ~ """ return os.path.expanduser(file_path) @staticmethod def home_dir(): """ Return the User's home directory """ return os.path.expanduser("~") @staticmethod def is_root(): """ Return whether the current user is logged in as root or has super user access """ return os.geteuid() == 0 @staticmethod def script(): """ Return the name of the script that is running """ return sys.argv[0] def grep(self, search_term, location): """ Run the grep command and return the result """ location = self.path(location) return self.run_command(f"grep {search_term} {location}", suppress_message=True) @staticmethod def date(): """ Return a string containing the current date and time """ return datetime.now().ctime() def reconfig(self, file, pattern, replacement): """ Given a filename, a regex pattern to match and a replacement string, perform replacement if found, else append replacement to end of file. """ if not self.isdir(file): if self.pattern_search(file, pattern): # Pattern found; replace in file self.pattern_replace(file, pattern, replacement) else: # Not found; append (silently) self.write_text_file(file, replacement, append=True) def pattern_search(self, location, pattern, multi_line=False, return_match=False): """ Similar to grep, but uses pure python multi_line will search the entire file as a large text glob, but certain regex patterns such as ^ and $ will not work on a line-by-line basis returns True/False if found """ location = self.path(location) found = False if self.exists(location) and not self.isdir(location): if multi_line: with open(location, "r+", encoding="utf-8") as file: match = re.search(pattern, file.read(), flags=re.DOTALL) if match: found = True else: for line in fileinput.FileInput(location): match = re.search(pattern, line) if match: found = True break if return_match: return match return found def pattern_replace(self, location, pattern, replace="", multi_line=False): """ Similar to sed, but uses pure python multi_line will search the entire file as a large text glob, but certain regex patterns such as ^ and $ will not work on a line-by-line basis """ location = self.path(location) if self.pattern_search(location, pattern, multi_line): if multi_line: regex = re.compile(pattern, flags=re.DOTALL) with open(location, "r+", encoding="utf-8") as file: data = file.read() file.seek(0) file.write(regex.sub(replace, data)) file.truncate() file.close() else: regex = re.compile(pattern) for line in fileinput.FileInput(location, inplace=True): if re.search(pattern, line): print(regex.sub(replace, line), end="") else: print(line, end="") def isdir(self, location): """ Check if a location exists and is a directory """ location = self.path(location) return os.path.exists(location) and os.path.isdir(location) def exists(self, location): """ Check if a path or file exists """ location = self.path(location) return os.path.exists(location) def move(self, source, destination): """ Move a file or directory from source to destination """ source = self.path(source) destination = self.path(destination) if os.path.exists(source): if not os.path.isdir(source) and os.path.isdir(destination): destination += os.sep + os.path.basename(source) shutil.move(source, destination) def copy(self, source, destination): """ Move a file or directory from source to destination """ source = self.path(source) destination = self.path(destination) if os.path.exists(source): if os.path.isdir(source): shutil.copytree(source, destination) else: if os.path.isdir(destination): destination += os.sep + os.path.basename(source) shutil.copy(source, destination) def chmod(self, location, mode): """ Change the permissions of a file or directory """ location = self.path(location) if os.path.exists(location): os.chmod(location, mode) def chown(self, location, user, group=None, recursive=False): """ Change the owner of a file or directory """ if group is None: group = user location = self.path(location) if recursive and os.path.isdir(location): for root, dirs, files in os.walk(location): for dir in dirs: shutil.chown(os.path.join(root, dir), user, group) for file in files: shutil.chown(os.path.join(root, file), user, group) else: shutil.chown(location, user, group) def remove(self, location): """ Remove a file or directory if it exists """ location = self.path(location) if os.path.exists(location): if os.path.isdir(location): shutil.rmtree(location) else: os.remove(location) def require_root(self): """ Check if the current user has root access and exit if not. """ if not self.is_root(): print("Installer must be run as root.") print(f"Try 'sudo python3 {self.script()}'") sys.exit(1) def write_text_file(self, path, content, append=True): """ Write the contents to a file at the specified path """ if append: mode = "a" content = "\n" + content else: mode = "w" with open(self.path(path), mode, encoding="utf-8") as service_file: service_file.write(content) @staticmethod def is_python3(): "Check if we are running Python 3 or later" return int(platform.python_version()[0]) >= 3 @staticmethod def is_linux(): """ Check that we are running linux """ return platform.system() == "Linux" or platform.system() == "Darwin" @staticmethod def is_armhf(): """ Check if Platform.machine() (same as uname -m) returns an ARM platform that supports hardware floating point """ return bool(re.match("armv.l", platform.machine())) @staticmethod def is_armv6(): """ Check if Platform.machine() returns ARM v6 """ return platform.machine() == "armv6l" @staticmethod def is_armv7(): """ Check if Platform.machine() returns ARM v7 """ return platform.machine() == "armv7l" @staticmethod def is_armv8(): """ Check if Platform.machine() returns ARM v8 """ return platform.machine() == "armv8l" @staticmethod def is_arm64(): """ Check if Platform.machine() returns ARM 64 """ return platform.machine() == "aarch64" @staticmethod def get_arch(): """Return a string containing the architecture""" return platform.machine() # pylint: disable=invalid-name def get_os(self): """Return a string containing the release which we can use to compare in the script""" os_releases = ( "Raspbian", "Debian", "Kano", "Mate", "PiTop", "Ubuntu", "Darwin", "Kali", ) release = None if os.path.exists("/etc/os-release"): with open("/etc/os-release", encoding="utf-8") as f: if "Raspbian" in f.read(): release = "Raspbian" if self.run_command("command -v apt-get", suppress_message=True): with open("/etc/os-release", encoding="utf-8") as f: release_file = f.read() for opsys in os_releases: if opsys in release_file: release = opsys if release == "Debian" and os.path.exists("/etc/rpi-issue"): release = "Raspbian" if os.path.isdir(os.path.expanduser("~/.kano-settings")) or os.path.isdir( os.path.expanduser("~/.kanoprofile") ): release = "Kano" if os.path.isdir(os.path.expanduser("~/.config/ubuntu-mate")): release = "Mate" if platform.system() == "Darwin": release = "Darwin" return release def get_raspbian_version(self): """Return a string containing the raspbian version""" if self.get_os() != "Raspbian": return None raspbian_releases = ("bullseye", "buster", "stretch", "jessie", "wheezy") if os.path.exists("/etc/os-release"): with open("/etc/os-release", encoding="utf-8") as f: release_file = f.read() if "/sid" in release_file: return "unstable" for raspbian in raspbian_releases: if raspbian in release_file: return raspbian return None def prompt_reboot(self, default="y", **kwargs): """Prompt the user for a reboot""" if not self.prompt("REBOOT NOW?", default=default, **kwargs): print("Exiting without reboot.") else: print("Reboot started...") os.sync() self.reboot() self.exit() def check_kernel_update_reboot_required(self): """Checks if the pi needs to be rebooted since the last kernel update""" if not self.exists(f"/lib/modules/{self.release()}"): self.error( "OS has not been rebooted since last kernel update. " "Please reboot and re-run the script." ) self.prompt_reboot() # pylint: enable=invalid-name @staticmethod def is_raspberry_pi(): """ Use PlatformDetect to check if this is a Raspberry Pi """ detector = adafruit_platformdetect.Detector() return detector.board.any_raspberry_pi @staticmethod def get_board_model(): """ Use PlatformDetect to get the board model """ detector = adafruit_platformdetect.Detector() return detector.board.id @staticmethod def get_architecture(): """ Get the type of Processor """ return platform.machine() @staticmethod def kernel_minimum(version): """ Check that we are running on at least the specified version """ return platform.release() >= str(version) @staticmethod def release(): """ Return the latest kernel release version """ return platform.release() def argument_exists(self, arg, prefix="-"): """ Check if the given argument was supplied """ return prefix + arg in self.args @staticmethod def exit(status_code=0): """ Exit and return the status code to the OS """ sys.exit(status_code) @property def group(self): """ Get or set the current group that is displayed in color along with messages """ return self._group @group.setter def group(self, value): self._group = str(value) @property def args(self): """ Get a list of supplied arguments """ return sys.argv