diff --git a/README.md b/README.md
index f9d7acf..f1f2c8c 100644
--- a/README.md
+++ b/README.md
@@ -32,11 +32,11 @@ pip install -U traps
```python
import traps
-traps.get() # Download to `traps` directory.
-traps.get("my_homework") # Or to another directory.
+traps.get() # Download one trap to `traps` directory.
+traps.get("my_homework", 15) # Or download 15 traps to another directory.
```
### Command-line interface
-* `$ traps` to download 10 traps to `traps` directory
-* `$ traps -n 20 my_homework` to download 20 traps to `my_homework` directory
+* `$ traps install` to download 10 traps to `traps` directory
+* `$ traps install -n 20 my_homework` to download 20 traps to `my_homework` directory
* `$ traps --help` for more help
diff --git a/pyproject.toml b/pyproject.toml
index 7eeea13..a485adf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,7 +30,7 @@ loguru = "^0.6.0"
[tool.poetry.dev-dependencies]
[tool.poetry.scripts]
-traps = "traps:main"
+traps = "traps.cli:cli"
[build-system]
requires = ["poetry-core>=1.0.0"]
diff --git a/traps/__init__.py b/traps/__init__.py
index 1a24106..ad9ad1e 100644
--- a/traps/__init__.py
+++ b/traps/__init__.py
@@ -1,107 +1,4 @@
-import os
-import secrets
-import sys
-from typing import Union
-import urllib.parse
-from concurrent.futures import ThreadPoolExecutor
-from pathlib import Path
-from threading import Thread
-
-import click
-import requests
-from loguru import logger
+from .downloader import get
__version__ = "2.0.1"
-API_URL = "https://api.waifu.pics/nsfw/trap"
-
-try:
- logger.remove(0)
-except ValueError:
- pass
-
-
-def fetch_url(urls_list: list = None) -> str:
- url = requests.get(API_URL).json()["url"]
- if urls_list is not None:
- urls_list.append(url)
- return url
-
-
-def get(directory: Union[str, os.PathLike] = "traps", url: str = None,
- create_dir: bool = True):
- if url is None:
- url = fetch_url()
- directory = Path(directory)
- if not directory.exists() and create_dir:
- directory.mkdir()
- filename = urllib.parse.urlparse(url).path
- filename = directory.joinpath(secrets.token_hex(8) + Path(filename).suffix)
- with open(filename, "wb") as f:
- logger.debug(f"downloading {url}")
- response = requests.get(url, stream=True)
- for block in response.iter_content(1024):
- if not block:
- break
- f.write(block)
- else:
- logger.success(f"downloaded {url}")
-
-
-@click.command(help="how about you pip install some traps")
-@click.option(
- "-n",
- "-t",
- "--traps",
- type=click.INT,
- default=10,
- show_default=True,
- help="number of traps to get"
-)
-@click.option(
- "-v",
- "--verbose",
- is_flag=True,
- help="verbose output")
-@click.argument(
- "directory",
- default="traps",
- type=click.Path(
- dir_okay=True,
- file_okay=False,
- path_type=Path
- )
-)
-def main(traps: int, directory: Path, verbose: bool):
- if verbose:
- loglevel = "DEBUG"
- else:
- loglevel = "INFO"
- logger.add(
- sys.stderr,
- level=loglevel,
- format="{time:YYYY-MM-DD HH:mm:ss} | "
- "{level: <8} | {message}"
- )
- if not directory.exists():
- logger.debug(f"creating directory {directory}")
- directory.mkdir()
- logger.debug("done")
- urls = []
- threads = [
- Thread(target=fetch_url, args=(urls,))
- for _ in range(traps)
- ]
- logger.debug("fetching URLs")
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- logger.debug("done")
- logger.info("downloading traps")
- with ThreadPoolExecutor(max_workers=8) as p:
- p.map(lambda url: get(directory, url, False), urls)
- logger.info(f"downloaded {traps} traps")
-
-
-if __name__ == '__main__':
- main()
+__all__ = ["get"]
diff --git a/traps/__main__.py b/traps/__main__.py
index f8de9dd..98dcca0 100644
--- a/traps/__main__.py
+++ b/traps/__main__.py
@@ -1,4 +1,4 @@
-from traps import main
+from .cli import cli
-if __name__ == '__main__':
- main()
+if __name__ == "__main__":
+ cli()
diff --git a/traps/cli.py b/traps/cli.py
new file mode 100644
index 0000000..52c09ac
--- /dev/null
+++ b/traps/cli.py
@@ -0,0 +1,58 @@
+import pathlib
+import sys
+
+import click
+from loguru import logger
+
+import traps
+from traps import downloader
+
+PATH_TYPE = click.Path(
+ dir_okay=True,
+ file_okay=False,
+ path_type=pathlib.Path
+)
+CONTEXT_SETTINGS = {
+ "help_option_names": ["-h", "--help"]
+}
+
+
+@click.group(context_settings=CONTEXT_SETTINGS)
+@click.option("-v", "--verbose", is_flag=True, help="Verbose output.")
+def cli(verbose: bool):
+ """How about you pip install some traps?"""
+ if verbose:
+ loglevel = "DEBUG"
+ else:
+ loglevel = "INFO"
+
+ try:
+ logger.remove(0)
+ except ValueError:
+ pass
+
+ logger.add(
+ sys.stderr,
+ level=loglevel,
+ format="{time:YYYY-MM-DD HH:mm:ss} | "
+ "{level: <8} | {message}"
+ )
+
+
+@cli.command("install")
+@click.option("-n", "amount", type=int, default=10,
+ show_default=True, help="Number of traps to install.")
+@click.argument("directory", default="traps", type=PATH_TYPE)
+def install(directory: pathlib.Path, amount: int):
+ """Install (download) traps."""
+ downloader.get(directory, amount)
+
+
+@cli.command("version", help="Print version and exit.")
+def version():
+ print(f"traps {traps.__version__}")
+ sys.exit(0)
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/traps/downloader.py b/traps/downloader.py
new file mode 100644
index 0000000..c2a3c24
--- /dev/null
+++ b/traps/downloader.py
@@ -0,0 +1,72 @@
+import random
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from typing import Union, List
+from xml.etree import ElementTree
+
+import requests
+from click import BadParameter
+
+from traps.utils import filename_from_url
+
+__all__ = ["get"]
+API_URL = "https://safebooru.org/index.php"
+MAX_OFFSET = 130 # Do not change.
+
+
+def _fetch_urls(n: int = 1) -> List[str]:
+ if n > 5000:
+ raise BadParameter("you can't download more than 5000 files at a time")
+ if n < 1:
+ raise BadParameter("you can't download a negative number of files")
+ used_offsets = []
+ urls = []
+
+ def fetch(limit):
+ offset = random.randint(1, MAX_OFFSET)
+ while offset in used_offsets:
+ offset = random.randint(1, MAX_OFFSET)
+ else:
+ used_offsets.append(offset)
+ params = {
+ "page": "dapi",
+ "s": "post",
+ "q": "index",
+ "limit": 100,
+ "pid": offset,
+ "tags": "trap"
+ }
+ resp = requests.get(API_URL, params)
+ posts = ElementTree.fromstring(resp.text).iter("post")
+ return [
+ next(posts).attrib["file_url"]
+ for _ in range(limit)
+ ]
+
+ if n > 100:
+ with ThreadPoolExecutor(max_workers=16) as p:
+ for i in p.map(lambda _: fetch(100), range(n // 100)):
+ urls += i
+ n %= 100
+ if n < 100:
+ urls += fetch(n)
+ return urls
+
+
+def _download(directory: Path, url: str) -> None:
+ resp = requests.get(url, stream=True)
+ filename = filename_from_url(url)
+ with open(directory / filename, "wb") as f:
+ for part in resp.iter_content(1024):
+ if not part:
+ break
+ f.write(part)
+
+
+def get(directory: Union[str, Path] = "traps", amount: int = 1) -> None:
+ if not isinstance(directory, Path):
+ directory = Path(directory)
+ directory.mkdir(exist_ok=True)
+ urls = _fetch_urls(amount)
+ with ThreadPoolExecutor(max_workers=16) as p:
+ p.map(lambda url: _download(directory, url), urls)
diff --git a/traps/utils.py b/traps/utils.py
new file mode 100644
index 0000000..3f60979
--- /dev/null
+++ b/traps/utils.py
@@ -0,0 +1,8 @@
+import pathlib
+import urllib.parse
+
+
+def filename_from_url(url: str) -> str:
+ path = urllib.parse.urlparse(url).path
+ filename = pathlib.Path(path).name
+ return filename