33import shutil
44
55import requests
6+ import math
7+
8+ import requests .exceptions
9+ import tqdm
10+ import sys
11+ from concurrent .futures import ThreadPoolExecutor
12+ import concurrent .futures
613from fedml .api .modules .utils import authenticate
714from fedml .core .distributed .communication .s3 .remote_storage import S3Storage
815from fedml .core .mlops .mlops_configs import Configs , MLOpsConfigs
@@ -19,6 +26,7 @@ def __init__(self, data: dict):
1926 self .updatedAt = data .get ("updateTime" , None )
2027 self .size = _get_size (data .get ("fileSize" ,None ))
2128 self .tag_list = data .get ("tags" , None )
29+ self .download_url = data .get ("fileUrl" , None )
2230
2331
2432# Todo (alaydshah): Store service name in metadata
@@ -40,16 +48,16 @@ def upload(data_path, api_key, name, description, tag_list, service, show_progre
4048 if not archive_path :
4149 return FedMLResponse (code = ResponseCode .FAILURE , message = message )
4250
43- store = _get_storage_service (service )
4451 name = os .path .splitext (os .path .basename (archive_path ))[0 ] if name is None else name
4552 file_name = name + ".zip"
4653 dest_path = os .path .join (user_id , file_name )
4754 file_size = os .path .getsize (archive_path )
4855
49- file_uploaded_url = store .upload_file_with_progress (src_local_path = archive_path , dest_s3_path = dest_path ,
50- show_progress = show_progress ,
51- out_progress_to_err = out_progress_to_err ,
52- progress_desc = progress_desc , metadata = metadata )
56+ file_uploaded_url , message = _upload_multipart (api_key , file_name , archive_path , show_progress ,
57+ out_progress_to_err ,
58+ progress_desc , metadata )
59+
60+
5361 os .remove (archive_path )
5462 if not file_uploaded_url :
5563 return FedMLResponse (code = ResponseCode .FAILURE , message = f"Failed to upload file: { archive_path } " )
@@ -81,24 +89,36 @@ def download(data_name, api_key, service, dest_path, show_progress=True) -> FedM
8189 if user_id is None :
8290 return FedMLResponse (code = ResponseCode .FAILURE , message = message )
8391
84- store = _get_storage_service (service )
85- zip_file_name = data_name + ".zip"
86- key = os .path .join (user_id , zip_file_name )
87- path_local = os .path .abspath (zip_file_name )
88- dest_path = os .path .abspath (dest_path ) if dest_path else data_name
89- if store .download_file_with_progress (path_s3 = key , path_local = path_local , show_progress = show_progress ):
90- try :
91- shutil .unpack_archive (path_local , dest_path )
92- os .remove (path_local )
93- abs_dest_path = os .path .abspath (dest_path )
94- return FedMLResponse (code = ResponseCode .SUCCESS , message = f"Successfully downloaded and unzipped data at "
95- f"{ abs_dest_path } " , data = abs_dest_path )
96- except Exception as e :
97- error_message = f"Failed to unpack archive: { e } "
92+ metadata_response = get_metadata (data_name , api_key )
93+ if metadata_response .code == ResponseCode .SUCCESS :
94+ metadata = metadata_response .data
95+ if not metadata or not isinstance (metadata , StorageMetadata ):
96+ error_message = f"Unable to get the download URL"
97+ logging .error (error_message )
98+ return FedMLResponse (code = ResponseCode .FAILURE , message = error_message )
99+ download_url = metadata .download_url
100+ zip_file_name = data_name + ".zip"
101+ path_local = os .path .abspath (zip_file_name )
102+ dest_path = os .path .abspath (dest_path ) if dest_path else data_name
103+ if _download_using_presigned_url (download_url , zip_file_name , show_progress = show_progress ):
104+ try :
105+ shutil .unpack_archive (path_local , dest_path )
106+ os .remove (path_local )
107+ abs_dest_path = os .path .abspath (dest_path )
108+ return FedMLResponse (code = ResponseCode .SUCCESS , message = f"Successfully downloaded and unzipped data at "
109+ f"{ abs_dest_path } " , data = abs_dest_path )
110+ except Exception as e :
111+ error_message = f"Failed to unpack archive: { e } "
112+ logging .error (error_message )
113+ return FedMLResponse (code = ResponseCode .FAILURE , message = error_message )
114+
115+ else :
116+ error_message = "Failed to download data from source"
98117 logging .error (error_message )
99118 return FedMLResponse (code = ResponseCode .FAILURE , message = error_message )
119+
100120 else :
101- error_message = f"Failed to download data: { data_name } "
121+ error_message = "Unable to get the download URL "
102122 logging .error (error_message )
103123 return FedMLResponse (code = ResponseCode .FAILURE , message = error_message )
104124
@@ -196,6 +216,194 @@ def delete(data_name, service, api_key=None) -> FedMLResponse:
196216 logging .error (message , data_name , service )
197217 return FedMLResponse (code = ResponseCode .FAILURE , message = message , data = False )
198218
219+ def _get_num_chunks (file_size , max_chunk_size ):
220+ num_chunks = math .ceil (file_size / max_chunk_size )
221+ return num_chunks
222+
223+
224+ def get_chunks (file_path , chunk_size ):
225+ with open (file_path , 'rb' ) as file :
226+ while True :
227+ chunk = file .read (chunk_size )
228+ if not chunk :
229+ break
230+ yield chunk
231+
232+
233+ def _get_presigned_url (api_key , request_url , file_name , part_number = None ):
234+ cert_path = MLOpsConfigs .get_cert_path_with_version ()
235+ headers = ServerConstants .API_HEADERS
236+ headers ["Authorization" ] = f"Bearer { api_key } "
237+ params_dict = {'fileKey' : file_name }
238+ if part_number is not None :
239+ params_dict ['partNumber' ] = part_number
240+ if cert_path is None :
241+ try :
242+ requests .session ().verify = cert_path
243+ response = requests .get (request_url , verify = True , headers = headers , params = params_dict )
244+ except requests .exceptions .SSLError as err :
245+ MLOpsConfigs .install_root_ca_file ()
246+ response = requests .get (request_url , verify = True , headers = headers , params = params_dict )
247+ else :
248+ response = requests .get (request_url , verify = True , headers = headers , params = params_dict )
249+ return response
250+
251+
252+ def _upload_part (url ,part_data ,session ):
253+ response = session .put (url ,data = part_data ,verify = True )
254+ return response
255+
256+
257+ def _upload_chunk (presigned_url , chunk , part , pbar = None , max_retries = 20 ,session = None ):
258+ for retry_attempt in range (max_retries ):
259+ try :
260+ response = _upload_part (presigned_url ,chunk ,session )
261+ except requests .exceptions .RequestException as e :
262+ if retry_attempt < max_retries :
263+ continue
264+ else :
265+ raise requests .exceptions .RequestException
266+
267+ if (pbar is not None ):
268+ pbar .update (chunk .__sizeof__ ())
269+ return {'etag' : response .headers ['ETag' ], 'partNumber' : part }
270+ raise requests .exceptions .RequestException
271+
272+ def _process_post_response (response ):
273+ if response .status_code != 200 :
274+ message = (f"Failed to complete multipart upload with status code = { response .status_code } , "
275+ f"response.content: { response .content } " )
276+ logging .error (message )
277+ return None , message
278+ else :
279+ resp_data = response .json ()
280+ code = resp_data .get ("code" , None )
281+ data_url = resp_data .get ("data" , None )
282+
283+ if code is None or data_url is None or code == "FAILURE" :
284+ message = resp_data .get ("message" , None )
285+ message = (f"Failed to complete multipart upload with following message: { message } , "
286+ f"response.content: { response .content } " )
287+ return None , message
288+
289+ return data_url , "Successfully uploaded the data! "
290+
291+ def _complete_multipart_upload (api_key , file_key , part_info , upload_id ):
292+ complete_multipart_url = ServerConstants .get_complete_multipart_upload_url ()
293+ body_dict = {"fileKey" : file_key , 'partETags' : part_info , 'uploadId' : upload_id }
294+
295+ cert_path = MLOpsConfigs .get_cert_path_with_version ()
296+ headers = ServerConstants .API_HEADERS
297+ headers ["Authorization" ] = f"Bearer { api_key } "
298+ if cert_path is None :
299+ try :
300+ requests .session ().verify = cert_path
301+ complete_multipart_response = requests .post (complete_multipart_url , json = body_dict , verify = True ,
302+ headers = headers )
303+ except requests .exceptions .SSLError as err :
304+ MLOpsConfigs .install_root_ca_file ()
305+ complete_multipart_response = requests .post (complete_multipart_url , json = body_dict , verify = True ,
306+ headers = headers )
307+ else :
308+ complete_multipart_response = requests .post (complete_multipart_url , json = body_dict , verify = True ,
309+ headers = headers )
310+
311+ return _process_post_response (complete_multipart_response )
312+
313+ def _upload_multipart (api_key : str , file_key , archive_path , show_progress , out_progress_to_err ,
314+ progress_desc_text , metadata ):
315+ request_url = ServerConstants .get_presigned_multi_part_url ()
316+
317+ file_size = os .path .getsize (archive_path )
318+
319+ max_chunk_size = 20 * 1024 * 1024
320+
321+ num_chunks = _get_num_chunks (file_size , max_chunk_size )
322+
323+ upload_id = ""
324+ presigned_urls = []
325+
326+ presigned_url_response = _get_presigned_url (api_key , request_url , file_key , num_chunks )
327+
328+ if presigned_url_response .status_code != 200 :
329+ message = (f"Failed to get presigned URL with status code = { presigned_url_response .status_code } , "
330+ f"response.content: { presigned_url_response .content } " )
331+ logging .error (message )
332+ return None , message
333+ else :
334+ resp_data = presigned_url_response .json ()
335+ code = resp_data .get ("code" , None )
336+ data = resp_data .get ("data" , None )
337+
338+ if code is None or data is None or code == "FAILURE" :
339+ message = resp_data .get ("message" , None )
340+ message = (f"Failed getting presigned URL with following message: { message } , "
341+ f"response.content: { presigned_url_response .content } " )
342+ return None , message
343+
344+ upload_id = data ['uploadId' ]
345+ presigned_urls = data ['urls' ]
346+
347+ parts = []
348+ chunks = get_chunks (archive_path , max_chunk_size )
349+ part_info = []
350+ chunk_count = 0
351+ successful_chunks = 0
352+
353+ atomic_session = requests .session ()
354+ atomic_session .verify = MLOpsConfigs .get_cert_path_with_version ()
355+ with tqdm .tqdm (total = file_size , unit = "B" , unit_scale = True ,
356+ file = sys .stderr if out_progress_to_err else sys .stdout ,
357+ desc = progress_desc_text , leave = False ) as pbar :
358+ for part , chunk in enumerate (chunks , start = 1 ):
359+ presigned_url = presigned_urls [part - 1 ]
360+ chunk_count += 1
361+ # Upload chunk to presigned_url in a separate thread from the thread pool of 10 workers.
362+ if show_progress :
363+ try :
364+ part_data = _upload_chunk (presigned_url = presigned_url , chunk = chunk , part = part ,
365+ pbar = pbar ,session = atomic_session )
366+ part_info .append (part_data )
367+ successful_chunks += 1
368+ except Exception as e :
369+ return None , "unsuccessful"
370+
371+ else :
372+ try :
373+ part_data = _upload_chunk (presigned_url = presigned_url , chunk = chunk , part = part ,
374+ pbar = pbar ,session = atomic_session )
375+ part_info .append (part_data )
376+ successful_chunks += 1
377+ except Exception as e :
378+ return None , "unsuccessful"
379+
380+ if successful_chunks == chunk_count :
381+ return _complete_multipart_upload (api_key , file_key , part_info , upload_id )
382+ else :
383+ return None , "Unsuccessful!"
384+
385+
386+ def _download_using_presigned_url (url , fname , chunk_size = 1024 * 1024 , show_progress = True ):
387+ download_response = requests .get (url , verify = True , stream = True )
388+ if download_response .status_code == 200 :
389+ total = int (download_response .headers .get ('content-length' , 0 ))
390+ if show_progress :
391+ with open (fname , 'wb' ) as file , tqdm .tqdm (
392+ desc = fname ,
393+ total = total ,
394+ unit = 'B' ,
395+ unit_scale = True ,
396+ unit_divisor = 1024 ,
397+ ) as bar :
398+ for data in download_response .iter_content (chunk_size = chunk_size ):
399+ size = file .write (data )
400+ bar .update (size )
401+ else :
402+ with open (fname , "wb" ) as file :
403+ for data in download_response .iter_content (chunk_size = chunk_size ):
404+ size = file .write (data )
405+ return True
406+ return False
199407
200408def _get_user_id_from_api_key (api_key : str ) -> (str , str ):
201409 user_url = ServerConstants .get_user_url ()
0 commit comments