Coverage for notion_client/helpers.py: 98%

128 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-04 09:02 +0000

1"""Utility functions for notion-sdk-py.""" 

2 

3import re 

4from typing import ( 

5 Any, 

6 AsyncGenerator, 

7 Awaitable, 

8 Callable, 

9 Dict, 

10 Generator, 

11 List, 

12 Optional, 

13) 

14from urllib.parse import urlparse 

15from uuid import UUID 

16 

17 

18def pick(base: Dict[Any, Any], *keys: str) -> Dict[Any, Any]: 

19 """Return a dict composed of key value pairs for keys passed as args.""" 

20 result = {} 

21 for key in keys: 

22 if key not in base: 

23 continue 

24 value = base.get(key) 

25 if value is None and key == "start_cursor": 

26 continue 

27 result[key] = value 

28 return result 

29 

30 

31def get_url(object_id: str) -> str: 

32 """Return the URL for the object with the given id.""" 

33 return f"https://notion.so/{UUID(object_id).hex}" 

34 

35 

36def get_id(url: str) -> str: 

37 """Return the id of the object behind the given URL.""" 

38 parsed = urlparse(url) 

39 if parsed.netloc not in ("notion.so", "www.notion.so"): 

40 raise ValueError("Not a valid Notion URL.") 

41 path = parsed.path 

42 if len(path) < 32: 

43 raise ValueError("The path in the URL seems to be incorrect.") 

44 raw_id = path[-32:] 

45 return str(UUID(raw_id)) 

46 

47 

48def iterate_paginated_api( 

49 function: Callable[..., Any], **kwargs: Any 

50) -> Generator[Any, None, None]: 

51 """Return an iterator over the results of any paginated Notion API.""" 

52 next_cursor = kwargs.pop("start_cursor", None) 

53 

54 while True: 

55 response = function(**kwargs, start_cursor=next_cursor) 

56 for result in response.get("results"): 

57 yield result 

58 

59 next_cursor = response.get("next_cursor") 

60 if not response.get("has_more") or not next_cursor: 

61 return 

62 

63 

64def collect_paginated_api(function: Callable[..., Any], **kwargs: Any) -> List[Any]: 

65 """Collect all the results of paginating an API into a list.""" 

66 return [result for result in iterate_paginated_api(function, **kwargs)] 

67 

68 

69async def async_iterate_paginated_api( 

70 function: Callable[..., Awaitable[Any]], **kwargs: Any 

71) -> AsyncGenerator[Any, None]: 

72 """Return an async iterator over the results of any paginated Notion API.""" 

73 next_cursor = kwargs.pop("start_cursor", None) 

74 

75 while True: 

76 response = await function(**kwargs, start_cursor=next_cursor) 

77 for result in response.get("results"): 

78 yield result 

79 

80 next_cursor = response.get("next_cursor") 

81 if (not response["has_more"]) | (next_cursor is None): 

82 return 

83 

84 

85async def async_collect_paginated_api( 

86 function: Callable[..., Awaitable[Any]], **kwargs: Any 

87) -> List[Any]: 

88 """Collect asynchronously all the results of paginating an API into a list.""" 

89 return [result async for result in async_iterate_paginated_api(function, **kwargs)] 

90 

91 

92def is_full_block(response: Dict[Any, Any]) -> bool: 

93 """Return `True` if response is a full block.""" 

94 return response.get("object") == "block" and "type" in response 

95 

96 

97def is_full_page(response: Dict[Any, Any]) -> bool: 

98 """Return `True` if response is a full page.""" 

99 return response.get("object") == "page" and "url" in response 

100 

101 

102def is_full_data_source(response: Dict[Any, Any]) -> bool: 

103 """* Return `true` if `response` is a full data source.""" 

104 return response.get("object") == "data_source" 

105 

106 

107def is_full_database(response: Dict[Any, Any]) -> bool: 

108 """Return `True` if response is a full database.""" 

109 return response.get("object") == "database" and "title" in response 

110 

111 

112def is_full_page_or_data_source(response: Dict[Any, Any]) -> bool: 

113 """Return `True` if `response` is a full data_source or a full page.""" 

114 if response.get("object") == "data_source": 

115 return is_full_data_source(response) 

116 return is_full_page(response) 

117 

118 

119def is_full_user(response: Dict[Any, Any]) -> bool: 

120 """Return `True` if response is a full user.""" 

121 return "type" in response 

122 

123 

124def is_full_comment(response: Dict[Any, Any]) -> bool: 

125 """Return `True` if response is a full comment.""" 

126 return "type" in response 

127 

128 

129def is_text_rich_text_item_response(rich_text: Dict[Any, Any]) -> bool: 

130 """Return `True` if `rich_text` is a text.""" 

131 return rich_text.get("type") == "text" 

132 

133 

134def is_equation_rich_text_item_response(rich_text: Dict[Any, Any]) -> bool: 

135 """Return `True` if `rich_text` is an equation.""" 

136 return rich_text.get("type") == "equation" 

137 

138 

139def is_mention_rich_text_item_response(rich_text: Dict[Any, Any]) -> bool: 

140 """Return `True` if `rich_text` is a mention.""" 

141 return rich_text.get("type") == "mention" 

142 

143 

144def _format_uuid(compact_uuid: str) -> str: 

145 """Format a compact UUID (32 chars) into standard format with hyphens.""" 

146 if len(compact_uuid) != 32: 

147 raise ValueError("UUID must be exactly 32 characters") 

148 

149 return ( 

150 f"{compact_uuid[:8]}-{compact_uuid[8:12]}-{compact_uuid[12:16]}-" 

151 f"{compact_uuid[16:20]}-{compact_uuid[20:]}" 

152 ) 

153 

154 

155def extract_notion_id(url_or_id: str) -> Optional[str]: 

156 """Extract a Notion ID from a Notion URL or return the input if it's already a valid ID. 

157 

158 Prioritizes path IDs over query parameters to avoid extracting view IDs instead of database IDs. 

159 

160 Returns the extracted UUID in standard format (with hyphens) or None if invalid. 

161 

162 ```python 

163 # Database URL with view ID - extracts database ID, not view ID 

164 extract_notion_id('https://notion.so/workspace/DB-abc123def456789012345678901234ab?v=viewid123') 

165 # Returns: 'abc123de-f456-7890-1234-5678901234ab' # database ID 

166 

167 # Already formatted UUID 

168 extract_notion_id('12345678-1234-1234-1234-123456789abc') 

169 # Returns: '12345678-1234-1234-1234-123456789abc' 

170 ``` 

171 """ 

172 if not url_or_id or not isinstance(url_or_id, str): 

173 return None 

174 

175 trimmed = url_or_id.strip() 

176 

177 # Check if it's already a properly formatted UUID 

178 uuid_pattern = re.compile( 

179 r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.IGNORECASE 

180 ) 

181 if uuid_pattern.match(trimmed): 

182 return trimmed.lower() 

183 

184 # Check if it's a compact UUID (32 chars, no hyphens) 

185 compact_uuid_pattern = re.compile(r"^[0-9a-f]{32}$", re.IGNORECASE) 

186 if compact_uuid_pattern.match(trimmed): 

187 return _format_uuid(trimmed.lower()) 

188 

189 # For URLs, check if it's a valid Notion domain 

190 if "://" in trimmed: 

191 if not re.search(r"://(?:www\.)?notion\.(?:so|site)/", trimmed, re.IGNORECASE): 

192 return None 

193 

194 # Fallback to query parameters if no direct ID found 

195 query_match = re.search( 

196 r"[?&](?:p|page_id|database_id)=([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}|[0-9a-f]{32})", 

197 trimmed, 

198 re.IGNORECASE, 

199 ) 

200 if query_match: 

201 match_str = query_match.group(1).lower() 

202 return match_str if "-" in match_str else _format_uuid(match_str) 

203 

204 # Last resort: any 32-char hex string in the URL 

205 any_match = re.search(r"([0-9a-f]{32})", trimmed, re.IGNORECASE) 

206 if any_match: 

207 return _format_uuid(any_match.group(1).lower()) 

208 

209 return None 

210 

211 

212def extract_database_id(database_url: str) -> Optional[str]: 

213 """Extract a database ID from a Notion URL or validate if it's already a valid ID. 

214 

215 This is an alias for `extract_notion_id` for clarity when working with databases. 

216 

217 Returns the extracted UUID in standard format (with hyphens) or None if invalid. 

218 """ 

219 return extract_notion_id(database_url) 

220 

221 

222def extract_page_id(page_url: str) -> Optional[str]: 

223 """Extract a page ID from a Notion URL or validate if it's already a valid ID. 

224 

225 This is an alias for `extract_notion_id` for clarity when working with pages. 

226 

227 Returns the extracted UUID in standard format (with hyphens) or None if invalid. 

228 """ 

229 return extract_notion_id(page_url) 

230 

231 

232def extract_block_id(url_or_id: str) -> Optional[str]: 

233 """Extract a block ID from a Notion URL fragment or validate if it's already a valid ID. 

234 

235 Specifically looks for block IDs in URL fragments (after #). 

236 If no fragment is present, falls back to `extract_notion_id` behavior. 

237 

238 Returns the extracted UUID in standard format (with hyphens) or None if invalid. 

239 """ 

240 if not url_or_id or not isinstance(url_or_id, str): 

241 return None 

242 

243 # Look for block fragment in URL (#block-32chars or just #32chars or #formatted-uuid) 

244 block_match = re.search( 

245 r"#(?:block-)?([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}|[0-9a-f]{32})", 

246 url_or_id, 

247 re.IGNORECASE, 

248 ) 

249 if block_match: 

250 match_str = block_match.group(1).lower() 

251 # If it's already formatted, return as is; otherwise format it 

252 return match_str if "-" in match_str else _format_uuid(match_str) 

253 

254 # Fall back to general ID extraction for non-URL inputs 

255 return extract_notion_id(url_or_id) 

256 

257 

258def iterate_data_source_templates( 

259 function: Callable[..., Any], **kwargs: Any 

260) -> Generator[Any, None, None]: 

261 """Return an iterator over templates from a data source. 

262 

263 Example: 

264 

265 ```python 

266 for template in iterate_data_source_templates( 

267 client.data_sources.list_templates, 

268 data_source_id=data_source_id, 

269 ): 

270 print(template["name"], template["is_default"]) 

271 ``` 

272 """ 

273 next_cursor = kwargs.pop("start_cursor", None) 

274 

275 while True: 

276 response = function(**kwargs, start_cursor=next_cursor) 

277 for template in response.get("templates", []): 

278 yield template 

279 

280 next_cursor = response.get("next_cursor") 

281 if not response.get("has_more") or not next_cursor: 

282 return 

283 

284 

285def collect_data_source_templates( 

286 function: Callable[..., Any], **kwargs: Any 

287) -> List[Any]: 

288 """Collect all templates from a data source into a list. 

289 

290 Example: 

291 

292 ```python 

293 templates = collect_data_source_templates( 

294 client.data_sources.list_templates, 

295 data_source_id=data_source_id, 

296 ) 

297 # Do something with templates. 

298 ``` 

299 """ 

300 return [template for template in iterate_data_source_templates(function, **kwargs)] 

301 

302 

303async def async_iterate_data_source_templates( 

304 function: Callable[..., Awaitable[Any]], **kwargs: Any 

305) -> AsyncGenerator[Any, None]: 

306 """Return an async iterator over templates from a data source. 

307 

308 Example: 

309 

310 ```python 

311 async for template in async_iterate_data_source_templates( 

312 async_client.data_sources.list_templates, 

313 data_source_id=data_source_id, 

314 ): 

315 print(template["name"], template["is_default"]) 

316 ``` 

317 """ 

318 next_cursor = kwargs.pop("start_cursor", None) 

319 

320 while True: 

321 response = await function(**kwargs, start_cursor=next_cursor) 

322 for template in response.get("templates", []): 

323 yield template 

324 

325 next_cursor = response.get("next_cursor") 

326 if not response.get("has_more") or not next_cursor: 

327 return 

328 

329 

330async def async_collect_data_source_templates( 

331 function: Callable[..., Awaitable[Any]], **kwargs: Any 

332) -> List[Any]: 

333 """Collect asynchronously all templates from a data source into a list. 

334 

335 Example: 

336 

337 ```python 

338 templates = await async_collect_data_source_templates( 

339 async_client.data_sources.list_templates, 

340 data_source_id=data_source_id, 

341 ) 

342 # Do something with templates. 

343 ``` 

344 """ 

345 return [ 

346 template 

347 async for template in async_iterate_data_source_templates(function, **kwargs) 

348 ]