refactor:get_speed

This commit is contained in:
guorong.zheng 2024-11-20 18:09:57 +08:00
parent b032d985c6
commit 483190831d
6 changed files with 110 additions and 191 deletions

View file

@ -6,8 +6,6 @@ from utils.channel import (
append_total_data,
process_sort_channel_list,
write_channel_to_file,
setup_logging,
cleanup_logging,
get_channel_data_cache_with_compare,
format_channel_url_info,
)
@ -21,6 +19,8 @@ from utils.tools import (
format_interval,
check_ipv6_support,
resource_path,
setup_logging,
cleanup_logging,
)
from updates.subscribe import get_channels_by_subscribe_urls
from updates.multicast import get_channels_by_multicast

View file

@ -1,7 +1,5 @@
from asyncio import create_task, gather
from utils.config import config
import utils.constants as constants
from utils.speed import get_speed
from utils.channel import (
format_channel_name,
get_results_from_soup,

View file

@ -2,7 +2,7 @@ from asyncio import Semaphore
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from utils.config import config
from utils.speed import get_speed
from utils.speed import get_speed_requests
from concurrent.futures import ThreadPoolExecutor
from driver.utils import get_soup_driver
from requests_custom.utils import get_soup_requests, close_session
@ -71,7 +71,7 @@ async def get_proxy_list_with_test(base_url, proxy_list):
async def get_speed_task(url, timeout, proxy):
async with semaphore:
return await get_speed(url, timeout=timeout, proxy=proxy)
return await get_speed_requests(url, timeout=timeout, proxy=proxy)
response_times = await tqdm_asyncio.gather(
*(get_speed_task(base_url, timeout=30, proxy=url) for url in proxy_list),

View file

@ -10,15 +10,13 @@ from utils.tools import (
write_content_into_txt,
)
from utils.speed import (
get_speed,
sort_urls_by_speed_and_resolution,
speed_cache,
)
import os
from collections import defaultdict
import re
from bs4 import NavigableString
import logging
from logging.handlers import RotatingFileHandler
from opencc import OpenCC
import base64
import pickle
@ -26,36 +24,6 @@ import copy
import datetime
from concurrent.futures import ThreadPoolExecutor
handler = None
def setup_logging():
"""
Setup logging
"""
global handler
if not os.path.exists(constants.output_dir):
os.makedirs(constants.output_dir)
handler = RotatingFileHandler(constants.log_path, encoding="utf-8")
logging.basicConfig(
handlers=[handler],
format="%(message)s",
level=logging.INFO,
)
def cleanup_logging():
"""
Cleanup logging
"""
global handler
if handler:
for handler in logging.root.handlers[:]:
handler.close()
logging.root.removeHandler(handler)
if os.path.exists(constants.log_path):
os.remove(constants.log_path)
def get_name_url(content, pattern, multiline=False, check_url=True):
"""
@ -462,9 +430,7 @@ def init_info_data(data, cate, name):
data[cate][name] = []
def append_data_to_info_data(
info_data, cate, name, data, origin=None, check=True, insert=False
):
def append_data_to_info_data(info_data, cate, name, data, origin=None, check=True):
"""
Append channel data to total info data
"""
@ -485,14 +451,7 @@ def append_data_to_info_data(
or (not check)
or (check and check_url_by_patterns(pure_url))
):
if insert:
info_data[cate][name].insert(
0, (url, date, resolution, url_origin)
)
else:
info_data[cate][name].append(
(url, date, resolution, url_origin)
)
info_data[cate][name].append((url, date, resolution, url_origin))
urls.append(pure_url)
except:
continue
@ -584,34 +543,6 @@ def append_total_data(
)
def sort_channel_list(
cate,
name,
info_list,
ipv6_proxy=None,
callback=None,
):
"""
Sort the channel list
"""
data = []
try:
if info_list:
sorted_data = sort_urls_by_speed_and_resolution(
info_list, ipv6_proxy=ipv6_proxy, callback=callback
)
if sorted_data:
for (url, date, resolution, origin), response_time in sorted_data:
logging.info(
f"Name: {name}, URL: {url}, Date: {date}, Resolution: {resolution}, Response Time: {response_time} ms"
)
data.append((url, date, resolution, origin))
except Exception as e:
logging.error(f"Error: {e}")
finally:
return {"cate": cate, "name": name, "data": data}
def process_sort_channel_list(data, ipv6=False, callback=None):
"""
Processs the sort channel list
@ -619,72 +550,31 @@ def process_sort_channel_list(data, ipv6=False, callback=None):
ipv6_proxy = None if (not config.open_ipv6 or ipv6) else constants.ipv6_proxy
need_sort_data = copy.deepcopy(data)
process_nested_dict(need_sort_data, seen=set(), flag=r"cache:(.*)", force_str="!")
sort_data = {}
result = {}
with ThreadPoolExecutor(max_workers=30) as executor:
futures = [
executor.submit(
sort_channel_list,
try:
for channel_obj in need_sort_data.values():
for info_list in channel_obj.values():
for info in info_list:
executor.submit(
get_speed,
info[0],
ipv6_proxy=ipv6_proxy,
callback=callback,
)
except Exception as e:
print(f"Get speed Error: {e}")
for cate, obj in data.items():
for name, info_list in obj.items():
info_list = sort_urls_by_speed_and_resolution(name, info_list)
append_data_to_info_data(
result,
cate,
name,
info_list,
ipv6_proxy=ipv6_proxy,
callback=callback,
check=False,
)
for cate, channel_obj in need_sort_data.items()
for name, info_list in channel_obj.items()
]
for future in futures:
result = future.result()
if result:
cate, name, result_data = result["cate"], result["name"], result["data"]
append_data_to_info_data(
sort_data, cate, name, result_data, check=False
)
for cate, obj in data.items():
for name, info_list in obj.items():
sort_info_list = sort_data.get(cate, {}).get(name, [])
sort_urls = {
remove_cache_info(sort_url[0])
for sort_url in sort_info_list
if sort_url and sort_url[0]
}
for url, date, resolution, origin in info_list:
if "$" in url:
info = url.partition("$")[2]
if info and info.startswith("!"):
append_data_to_info_data(
sort_data,
cate,
name,
[(url, date, resolution, origin)],
check=False,
insert=True,
)
continue
matcher = re.search(r"cache:(.*)", info)
if matcher:
cache_key = matcher.group(1)
if not cache_key:
continue
url = remove_cache_info(url)
if url in sort_urls or cache_key not in speed_cache:
continue
cache = speed_cache[cache_key]
if not cache:
continue
response_time, resolution = cache
if response_time and response_time != float("inf"):
append_data_to_info_data(
sort_data,
cate,
name,
[(url, date, resolution, origin)],
check=False,
)
logging.info(
f"Name: {name}, URL: {url}, Date: {date}, Resolution: {resolution}, Response Time: {response_time} ms"
)
return sort_data
return result
def write_channel_to_file(data, ipv6=False, callback=None):

View file

@ -3,9 +3,10 @@ from time import time
import asyncio
import re
from utils.config import config
from utils.tools import is_ipv6, add_url_info, remove_cache_info, get_resolution_value
from utils.tools import is_ipv6, remove_cache_info, get_resolution_value
import subprocess
import yt_dlp
import logging
def get_speed_yt_dlp(url, timeout=config.sort_timeout):
@ -20,16 +21,16 @@ def get_speed_yt_dlp(url, timeout=config.sort_timeout):
"no_warnings": True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
start = time()
start_time = time()
info = ydl.extract_info(url, download=False)
return int(round((time() - start) * 1000)) if info else float("inf")
except Exception as e:
return int(round((time() - start_time) * 1000)) if info else float("inf")
except:
return float("inf")
async def get_speed(url, timeout=config.sort_timeout, proxy=None):
async def get_speed_requests(url, timeout=config.sort_timeout, proxy=None):
"""
Get the speed of the url
Get the speed of the url by requests
"""
async with ClientSession(
connector=TCPConnector(verify_ssl=False), trust_env=True
@ -133,42 +134,26 @@ async def check_stream_speed(url_info):
speed_cache = {}
def get_speed_by_info(url_info, ipv6_proxy=None, callback=None):
def get_speed(url, ipv6_proxy=None, callback=None):
"""
Get the info with speed
Get the speed of the url
"""
url, _, resolution, _ = url_info
url_info = list(url_info)
cache_key = None
url_is_ipv6 = is_ipv6(url)
if "$" in url:
url, _, cache_info = url.partition("$")
matcher = re.search(r"cache:(.*)", cache_info)
if matcher:
cache_key = matcher.group(1)
url_show_info = remove_cache_info(cache_info)
url_info[0] = url
if cache_key in speed_cache:
speed = speed_cache[cache_key][0]
url_info[2] = speed_cache[cache_key][1]
if speed != float("inf"):
if url_show_info:
url_info[0] = add_url_info(url, url_show_info)
return (tuple(url_info), speed)
else:
return float("inf")
try:
cache_key = None
url_is_ipv6 = is_ipv6(url)
if "$" in url:
url, _, cache_info = url.partition("$")
matcher = re.search(r"cache:(.*)", cache_info)
if matcher:
cache_key = matcher.group(1)
if cache_key in speed_cache:
return speed_cache[cache_key][0]
if ipv6_proxy and url_is_ipv6:
url_speed = 0
speed = (url_info, url_speed)
speed = 0
else:
url_speed = get_speed_yt_dlp(url)
speed = (url_info, url_speed) if url_speed != float("inf") else float("inf")
speed = get_speed_yt_dlp(url)
if cache_key and cache_key not in speed_cache:
speed_cache[cache_key] = (url_speed, resolution)
if url_show_info:
speed[0][0] = add_url_info(speed[0][0], url_show_info)
speed = (tuple(speed[0]), speed[1])
speed_cache[cache_key] = (speed, None)
return speed
except:
return float("inf")
@ -177,24 +162,39 @@ def get_speed_by_info(url_info, ipv6_proxy=None, callback=None):
callback()
def sort_urls_by_speed_and_resolution(data, ipv6_proxy=None, callback=None):
def sort_urls_by_speed_and_resolution(name, data):
"""
Sort by speed and resolution
"""
response = []
for url_info in data:
response.append(
get_speed_by_info(url_info, ipv6_proxy=ipv6_proxy, callback=callback)
)
valid_response = [res for res in response if res != float("inf")]
filter_data = []
for url, date, resolution, origin in data:
if origin == "important":
filter_data.append((url, date, resolution, origin))
continue
cache_key_match = re.search(r"cache:(.*)", url.partition("$")[2])
cache_key = cache_key_match.group(1) if cache_key_match else None
if cache_key and cache_key in speed_cache:
cache = speed_cache[cache_key]
if cache:
response_time, cache_resolution = cache
resolution = cache_resolution or resolution
if response_time != float("inf"):
url = remove_cache_info(url)
logging.info(
f"Name: {name}, URL: {url}, Date: {date}, Resolution: {resolution}, Response Time: {response_time} ms"
)
filter_data.append((url, date, resolution, origin))
def combined_key(item):
(_, _, resolution, _), response_time = item
resolution_value = get_resolution_value(resolution) if resolution else 0
return (
-(config.response_time_weight * response_time)
+ config.resolution_weight * resolution_value
)
_, _, resolution, origin = item
if origin == "important":
return -float("inf")
else:
resolution_value = get_resolution_value(resolution) if resolution else 0
return (
config.response_time_weight * response_time
- config.resolution_weight * resolution_value
)
sorted_res = sorted(valid_response, key=combined_key, reverse=True)
return sorted_res
filter_data.sort(key=combined_key)
return filter_data

View file

@ -3,7 +3,6 @@ import datetime
import os
import urllib.parse
import ipaddress
from urllib.parse import urlparse
import socket
from utils.config import config
import utils.constants as constants
@ -13,6 +12,38 @@ from flask import render_template_string, send_file
import shutil
import requests
import sys
import logging
from logging.handlers import RotatingFileHandler
handler = None
def setup_logging():
"""
Setup logging
"""
global handler
if not os.path.exists(constants.output_dir):
os.makedirs(constants.output_dir)
handler = RotatingFileHandler(constants.log_path, encoding="utf-8")
logging.basicConfig(
handlers=[handler],
format="%(message)s",
level=logging.INFO,
)
def cleanup_logging():
"""
Cleanup logging
"""
global handler
if handler:
for handler in logging.root.handlers[:]:
handler.close()
logging.root.removeHandler(handler)
if os.path.exists(constants.log_path):
os.remove(constants.log_path)
def format_interval(t):