From df6e6dac673f6c7ba2c4931e89c50682b7d4bd93 Mon Sep 17 00:00:00 2001 From: "guorong.zheng" <360996299@qq.com> Date: Thu, 28 Nov 2024 14:29:46 +0800 Subject: [PATCH] refactor:log and get_speed --- .github/workflows/main.yml | 11 ++---- main.py | 22 +++-------- service/app.py | 11 ++---- utils/channel.py | 43 ++++++++++++-------- utils/constants.py | 8 +++- utils/speed.py | 80 ++++++++++++++++++++++++-------------- utils/tools.py | 33 +++++----------- 7 files changed, 104 insertions(+), 104 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 297b717..d4ad49b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -96,14 +96,11 @@ jobs: if [[ -f "$final_m3u_file" ]]; then git add -f "$final_m3u_file" fi - if [[ -f "output/result_cache.pkl" ]]; then - git add -f "output/result_cache.pkl" - fi - if [[ -f "output/user_result.log" ]]; then - git add -f "output/user_result.log" - elif [[ -f "output/result.log" ]]; then - git add -f "output/result.log" + if [[ -f "output/cache.pkl" ]]; then + git add -f "output/cache.pkl" fi + if [[ -f "output/sort.log" ]]; then + git add -f "output/sort.log" if [[ -f "updates/fofa/fofa_hotel_region_result.pkl" ]]; then git add -f "updates/fofa/fofa_hotel_region_result.pkl" fi diff --git a/main.py b/main.py index 8d3ccc5..17edf95 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import asyncio from utils.config import config +import utils.constants as constants from service.app import run_service from utils.channel import ( get_channel_items, @@ -18,8 +19,6 @@ from utils.tools import ( format_interval, check_ipv6_support, resource_path, - setup_logging, - cleanup_logging, ) from updates.subscribe import get_channels_by_subscribe_urls from updates.multicast import get_channels_by_multicast @@ -34,9 +33,6 @@ import pickle import copy -atexit.register(cleanup_logging) - - class UpdateSource: def __init__(self): @@ -109,7 +105,6 @@ class UpdateSource: async def main(self): try: if config.open_update: - setup_logging() main_start_time = time() self.channel_items = get_channel_items() channel_names = [ @@ -143,7 +138,7 @@ class UpdateSource: ) self.start_time = time() self.pbar = tqdm(total=self.total, desc="Sorting") - self.channel_data = process_sort_channel_list( + self.channel_data = await process_sort_channel_list( self.channel_data, ipv6=ipv6_support, callback=sort_callback, @@ -160,24 +155,17 @@ class UpdateSource: ) self.pbar.close() user_final_file = config.final_file - update_file(user_final_file, "output/result_new.txt") + update_file(user_final_file, constants.result_path) if config.open_use_old_result: if open_sort: get_channel_data_cache_with_compare( channel_data_cache, self.channel_data ) with open( - resource_path("output/result_cache.pkl", persistent=True), "wb" + resource_path(constants.cache_path, persistent=True), + "wb", ) as file: pickle.dump(channel_data_cache, file) - if open_sort: - user_log_file = "output/" + ( - "user_result.log" - if os.path.exists("config/user_config.ini") - else "result.log" - ) - update_file(user_log_file, "output/result_new.log", copy=True) - cleanup_logging() convert_to_m3u() total_time = format_interval(time() - main_start_time) print( diff --git a/service/app.py b/service/app.py index 5f00dd6..13674d1 100644 --- a/service/app.py +++ b/service/app.py @@ -3,9 +3,8 @@ import sys sys.path.append(os.path.dirname(sys.path[0])) from flask import Flask, render_template_string -from utils.tools import get_result_file_content, get_ip_address +from utils.tools import get_result_file_content, get_ip_address, resource_path import utils.constants as constants -from utils.config import config app = Flask(__name__) @@ -32,11 +31,9 @@ def show_content(): @app.route("/log") def show_log(): - user_log_file = "output/" + ( - "user_result.log" if os.path.exists("config/user_config.ini") else "result.log" - ) - if os.path.exists(user_log_file): - with open(user_log_file, "r", encoding="utf-8") as file: + log_path = resource_path(constants.sort_log_path) + if os.path.exists(log_path): + with open(log_path, "r", encoding="utf-8") as file: content = file.read() else: content = constants.waiting_tip diff --git a/utils/channel.py b/utils/channel.py index 3d5575d..56502b1 100644 --- a/utils/channel.py +++ b/utils/channel.py @@ -8,6 +8,7 @@ from utils.tools import ( remove_cache_info, resource_path, write_content_into_txt, + get_logger, ) from utils.speed import ( get_speed, @@ -22,7 +23,8 @@ import base64 import pickle import copy import datetime -from concurrent.futures import ThreadPoolExecutor +import asyncio +from logging import INFO def get_name_url(content, pattern, multiline=False, check_url=True): @@ -84,7 +86,7 @@ def get_channel_items(): ) if config.open_use_old_result: - result_cache_path = resource_path("output/result_cache.pkl") + result_cache_path = resource_path(constants.cache_path) if os.path.exists(result_cache_path): with open(result_cache_path, "rb") as file: old_result = pickle.load(file) @@ -543,7 +545,7 @@ def append_total_data( ) -def process_sort_channel_list(data, ipv6=False, callback=None): +async def process_sort_channel_list(data, ipv6=False, callback=None): """ Processs the sort channel list """ @@ -551,22 +553,29 @@ def process_sort_channel_list(data, ipv6=False, callback=None): need_sort_data = copy.deepcopy(data) process_nested_dict(need_sort_data, seen=set(), flag=r"cache:(.*)", force_str="!") result = {} - with ThreadPoolExecutor(max_workers=30) as executor: - try: - for channel_obj in need_sort_data.values(): - for info_list in channel_obj.values(): - for info in info_list: - executor.submit( - get_speed, - info[0], - ipv6_proxy=ipv6_proxy, - callback=callback, - ) - except Exception as e: - print(f"Get speed Error: {e}") + semaphore = asyncio.Semaphore(10) + + async def limited_get_speed(info, ipv6_proxy, callback): + async with semaphore: + return await get_speed(info[0], ipv6_proxy=ipv6_proxy, callback=callback) + + tasks = [ + asyncio.create_task( + limited_get_speed( + info, + ipv6_proxy=ipv6_proxy, + callback=callback, + ) + ) + for channel_obj in need_sort_data.values() + for info_list in channel_obj.values() + for info in info_list + ] + await asyncio.gather(*tasks) + logger = get_logger(constants.sort_log_path, level=INFO, init=True) for cate, obj in data.items(): for name, info_list in obj.items(): - info_list = sort_urls_by_speed_and_resolution(name, info_list) + info_list = sort_urls_by_speed_and_resolution(name, info_list, logger) append_data_to_info_data( result, cate, diff --git a/utils/constants.py b/utils/constants.py index 933a799..5de5691 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -2,9 +2,13 @@ import os output_dir = "output" -log_file = "result_new.log" +result_path = os.path.join(output_dir, "result_new.txt") -log_path = os.path.join(output_dir, log_file) +cache_path = os.path.join(output_dir, "cache.pkl") + +sort_log_path = os.path.join(output_dir, "sort.log") + +log_path = os.path.join(output_dir, "log.log") url_pattern = r"((https?):\/\/)?(\[[0-9a-fA-F:]+\]|([\w-]+\.)+[\w-]+)(:[0-9]{1,5})?(\/[^\s]*)?(\$[^\s]+)?" diff --git a/utils/speed.py b/utils/speed.py index f203a0f..00d2103 100644 --- a/utils/speed.py +++ b/utils/speed.py @@ -3,39 +3,55 @@ from time import time import asyncio import re from utils.config import config -from utils.tools import is_ipv6, remove_cache_info, get_resolution_value +import utils.constants as constants +from utils.tools import is_ipv6, remove_cache_info, get_resolution_value, get_logger import subprocess import yt_dlp -import logging +from concurrent.futures import ProcessPoolExecutor +import functools + +logger = get_logger(constants.log_path) -def get_speed_yt_dlp(url, timeout=config.sort_timeout): +def get_info_yt_dlp(url, timeout=config.sort_timeout): + """ + Get the url info by yt_dlp + """ + ydl_opts = { + "socket_timeout": timeout, + "skip_download": True, + "quiet": True, + "no_warnings": True, + "format": "best", + "logger": logger, + } + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + return ydl.sanitize_info(ydl.extract_info(url, download=False)) + + +async def get_speed_yt_dlp(url, timeout=config.sort_timeout): """ Get the speed of the url by yt_dlp """ try: - ydl_opts = { - "socket_timeout": timeout, - "skip_download": True, - "quiet": True, - "no_warnings": True, - "format": "best", - "logger": logging.getLogger(), - } - with yt_dlp.YoutubeDL(ydl_opts) as ydl: + async with asyncio.timeout(timeout + 2): start_time = time() - info = ydl.extract_info(url, download=False) - fps = info.get("fps", None) or ( - int(round((time() - start_time) * 1000)) - if "id" in info - else float("inf") - ) - resolution = ( - f"{info['width']}x{info['height']}" - if "width" in info and "height" in info - else None - ) - return (fps, resolution) + loop = asyncio.get_running_loop() + with ProcessPoolExecutor() as exc: + info = await loop.run_in_executor( + exc, functools.partial(get_info_yt_dlp, url, timeout) + ) + fps = ( + int(round((time() - start_time) * 1000)) + if len(info) + else float("inf") + ) + resolution = ( + f"{info['width']}x{info['height']}" + if "width" in info and "height" in info + else None + ) + return (fps, resolution) except: return (float("inf"), None) @@ -146,7 +162,7 @@ async def check_stream_speed(url_info): speed_cache = {} -def get_speed(url, ipv6_proxy=None, callback=None): +async def get_speed(url, ipv6_proxy=None, callback=None): """ Get the speed of the url """ @@ -163,7 +179,7 @@ def get_speed(url, ipv6_proxy=None, callback=None): if ipv6_proxy and url_is_ipv6: speed = 0 else: - speed = get_speed_yt_dlp(url) + speed = await get_speed_yt_dlp(url) if cache_key and cache_key not in speed_cache: speed_cache[cache_key] = speed return speed @@ -174,7 +190,7 @@ def get_speed(url, ipv6_proxy=None, callback=None): callback() -def sort_urls_by_speed_and_resolution(name, data): +def sort_urls_by_speed_and_resolution(name, data, logger=None): """ Sort by speed and resolution """ @@ -192,9 +208,13 @@ def sort_urls_by_speed_and_resolution(name, data): resolution = cache_resolution or resolution if response_time != float("inf"): url = remove_cache_info(url) - logging.info( - f"Name: {name}, URL: {url}, Date: {date}, Resolution: {resolution}, Response Time: {response_time} ms" - ) + try: + if logger: + logger.info( + f"Name: {name}, URL: {url}, Date: {date}, Resolution: {resolution}, Response Time: {response_time} ms" + ) + except Exception as e: + print(e) filter_data.append((url, date, resolution, origin)) def combined_key(item): diff --git a/utils/tools.py b/utils/tools.py index ee6c220..ce48f3b 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -15,35 +15,20 @@ import sys import logging from logging.handlers import RotatingFileHandler -handler = None - -def setup_logging(): +def get_logger(path, level=logging.ERROR, init=False): """ - Setup logging + get the logger """ - global handler if not os.path.exists(constants.output_dir): os.makedirs(constants.output_dir) - handler = RotatingFileHandler(constants.log_path, encoding="utf-8") - logging.basicConfig( - handlers=[handler], - format="%(message)s", - level=logging.INFO, - ) - - -def cleanup_logging(): - """ - Cleanup logging - """ - global handler - if handler: - for handler in logging.root.handlers[:]: - handler.close() - logging.root.removeHandler(handler) - if os.path.exists(constants.log_path): - os.remove(constants.log_path) + if init and os.path.exists(path): + os.remove(path) + handler = RotatingFileHandler(path, encoding="utf-8") + logger = logging.getLogger(path) + logger.addHandler(handler) + logger.setLevel(level) + return logger def format_interval(t):