123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- #pragma once
- #include <cerrno>
- #include <cstdio>
- #include <cstring>
- #include <fstream>
- #include <istream>
- #include <mutex>
- #include <ostream>
- #include <unordered_set>
- #include <c10/core/Allocator.h>
- #include <c10/core/Backend.h>
- #include "caffe2/serialize/istream_adapter.h"
- #include "caffe2/serialize/read_adapter_interface.h"
- #include "caffe2/serialize/versions.h"
- extern "C" {
- typedef struct mz_zip_archive mz_zip_archive;
- }
- // PyTorch containers are a special zip archive with the following layout
- // archive_name.zip contains:
- // archive_name/
- // version # a file with a single decimal number written in ascii,
- // # used to establish the version of the archive format
- // model.json # overall model description, this is a json output of
- // # ModelDef from torch.proto
- // # the following names are by convention only, model.json will
- // # refer to these files by full names
- // tensors/
- // 0 # flat storage for tensor data, meta-data about shapes, etc. is
- // # in model.json
- // 1
- // ...
- // # code entries will only exist for modules that have methods attached
- // code/
- // archive_name.py # serialized torch script code (python syntax, using
- // PythonPrint) archive_name_my_submodule.py # submodules have separate
- // files
- //
- // The PyTorchStreamWriter also ensures additional useful properties for these
- // files
- // 1. All files are stored uncompressed.
- // 2. All files in the archive are aligned to 64 byte boundaries such that
- // it is possible to mmap the entire file and get an aligned pointer to
- // tensor data.
- // 3. We universally write in ZIP64 format for consistency.
- // The PyTorchStreamReader also provides additional properties:
- // 1. It can read zip files that are created with common
- // zip tools. This means that even though our writer doesn't compress files,
- // the reader can still read files that were compressed.
- // 2. It provides a getRecordOffset function which returns the offset into the
- // raw file where file data lives. If the file was written with
- // PyTorchStreamWriter it is guaranteed to be 64 byte aligned.
- // PyTorchReader/Writer handle checking the version number on the archive format
- // and ensure that all files are written to a archive_name directory so they
- // unzip cleanly.
- // When developing this format we want to pay particular attention to the
- // following use cases:
- //
- // -- Reading --
- // 1) Reading with full random access
- // a) Reading with file api's such as fread()
- // b) mmaping the file and jumping around the mapped region
- // 2) Reading with 1-pass sequential access
- // -> A reader will need to build up a data structure of parsed structures
- // as it reads
- //
- // -- Writing --
- // 1) Writing with full random access
- // 2) Writing with 1-pass sequential access
- // -> We must take care not to require updating values that have already
- // been written. We place the variable-length index at the end and do
- // not put any indicies into the header to fulfill this constraint.
- // The model.json, which contains all the metadata information,
- // should be written as the last file. One reason is that the size of tensor
- // data is usually stable. As long as the shape and type of the tensor do not
- // change, the size of the data won't change. On the other sied, the size of the
- // serialized model is likely to change, so we store it as the last record, and
- // we don't need to move previous records when updating the model data.
- // The zip format is sufficiently flexible to handle the above use-case.
- // it puts its central directory at the end of the archive and we write
- // model.json as the last file when writing after we have accumulated all
- // other information.
- namespace caffe2 {
- namespace serialize {
- class TORCH_API PyTorchStreamReader final {
- public:
- explicit PyTorchStreamReader(const std::string& file_name);
- explicit PyTorchStreamReader(std::istream* in);
- explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
- // return dataptr, size
- std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
- size_t getRecordOffset(const std::string& name);
- bool hasRecord(const std::string& name);
- std::vector<std::string> getAllRecords();
- ~PyTorchStreamReader();
- uint64_t version() const {
- return version_;
- }
- void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
- load_debug_symbol_ = should_load_debug_symbol;
- }
- private:
- void init();
- size_t read(uint64_t pos, char* buf, size_t n);
- void valid(const char* what, const char* info = "");
- size_t getRecordID(const std::string& name);
- friend size_t
- istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
- std::unique_ptr<mz_zip_archive> ar_;
- std::string archive_name_;
- std::string archive_name_plus_slash_;
- std::shared_ptr<ReadAdapterInterface> in_;
- int64_t version_;
- std::mutex reader_lock_;
- bool load_debug_symbol_ = true;
- };
- class TORCH_API PyTorchStreamWriter final {
- public:
- explicit PyTorchStreamWriter(std::string archive_name);
- explicit PyTorchStreamWriter(
- const std::function<size_t(const void*, size_t)> writer_func);
- void setMinVersion(const uint64_t version);
- void writeRecord(
- const std::string& name,
- const void* data,
- size_t size,
- bool compress = false);
- void writeEndOfFile();
- const std::unordered_set<std::string>& getAllWrittenRecords();
- bool finalized() const {
- return finalized_;
- }
- const std::string& archiveName() {
- return archive_name_;
- }
- ~PyTorchStreamWriter();
- private:
- void setup(const std::string& file_name);
- void valid(const char* what, const char* info = "");
- size_t current_pos_ = 0;
- std::unordered_set<std::string> files_written_;
- std::unique_ptr<mz_zip_archive> ar_;
- std::string archive_name_;
- std::string archive_name_plus_slash_;
- std::string padding_;
- std::ofstream file_stream_;
- std::function<size_t(const void*, size_t)> writer_func_;
- // This number will be updated when the model has operators
- // that have valid upgraders.
- uint64_t version_ = kMinProducedFileFormatVersion;
- bool finalized_ = false;
- bool err_seen_ = false;
- friend size_t ostream_write_func(
- void* pOpaque,
- uint64_t file_ofs,
- const void* pBuf,
- size_t n);
- };
- namespace detail {
- // Writer-specific constants
- constexpr uint64_t kFieldAlignment = 64;
- // Returns a record to be appended to the local user extra data entry in order
- // to make data beginning aligned at kFieldAlignment bytes boundary.
- size_t getPadding(
- size_t cursor,
- size_t filename_size,
- size_t size,
- std::string& padding_buf);
- } // namespace detail
- } // namespace serialize
- } // namespace caffe2
|