From 6e76179a4e1e76761bfd7be2ad6d12c3f99ec938 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Mon, 30 Dec 2024 16:14:13 +1000 Subject: [PATCH 1/3] Model Builder API Supports creating a model programmatically using the ORT C or C++ API. Supports augmenting an existing model to add nodes. --- cmake/onnxruntime_session.cmake | 1 + cmake/onnxruntime_unittests.cmake | 6 + include/onnxruntime/core/graph/graph.h | 31 +- include/onnxruntime/core/graph/graph_viewer.h | 6 + .../core/session/onnxruntime_c_api.h | 435 +++++++++++++++- .../core/session/onnxruntime_cxx_api.h | 258 +++++++++- .../core/session/onnxruntime_cxx_inline.h | 304 ++++++++++- .../onnxruntime_session_options_config_keys.h | 10 + .../core/framework/onnxruntime_typeinfo.cc | 67 ++- .../core/framework/onnxruntime_typeinfo.h | 2 +- .../core/framework/session_state_utils.cc | 17 +- .../core/framework/tensor_type_and_shape.cc | 35 +- onnxruntime/core/graph/graph.cc | 248 ++++++++- onnxruntime/core/graph/model.cc | 32 +- onnxruntime/core/graph/model.h | 8 +- .../core/graph/model_builder_api_types.h | 48 ++ .../core/session/abi_session_options.cc | 17 +- onnxruntime/core/session/api_utils.cc | 25 - onnxruntime/core/session/api_utils.h | 9 - onnxruntime/core/session/custom_ops.cc | 2 +- onnxruntime/core/session/inference_session.cc | 57 ++- onnxruntime/core/session/inference_session.h | 26 + onnxruntime/core/session/model_builder_api.h | 59 +++ .../core/session/model_builder_c_api.cc | 347 +++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 309 ++++++----- onnxruntime/core/session/ort_apis.h | 21 + onnxruntime/core/session/utils.cc | 125 +++++ onnxruntime/core/session/utils.h | 28 + onnxruntime/test/framework/type_info_test.cc | 26 +- onnxruntime/test/shared_lib/custom_op_utils.h | 6 - onnxruntime/test/shared_lib/test_inference.cc | 162 +++--- .../test/shared_lib/test_model_builder_api.cc | 483 ++++++++++++++++++ .../test/shared_lib/test_ort_format_models.cc | 14 +- onnxruntime/test/shared_lib/utils.h | 52 ++ winml/adapter/winml_adapter_model.cpp | 18 +- 35 files changed, 2897 insertions(+), 397 deletions(-) create mode 100644 onnxruntime/core/graph/model_builder_api_types.h delete mode 100644 onnxruntime/core/session/api_utils.cc delete mode 100644 onnxruntime/core/session/api_utils.h create mode 100644 onnxruntime/core/session/model_builder_api.h create mode 100644 onnxruntime/core/session/model_builder_c_api.cc create mode 100644 onnxruntime/core/session/utils.cc create mode 100644 onnxruntime/core/session/utils.h create mode 100644 onnxruntime/test/shared_lib/test_model_builder_api.cc diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 47cf2dfc5e7aa..c2fe5d23a220d 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -22,6 +22,7 @@ endif() if (onnxruntime_MINIMAL_BUILD) set(onnxruntime_session_src_exclude "${ONNXRUNTIME_ROOT}/core/session/provider_bridge_ort.cc" + "${ONNXRUNTIME_ROOT}/core/session/model_builder_c_api.cc" ) list(REMOVE_ITEM onnxruntime_session_srcs ${onnxruntime_session_src_exclude}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9e3ab4d41f416..40e73c4ec492a 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -511,6 +511,7 @@ set (onnxruntime_shared_lib_test_SRC if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) + list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc) endif() if(onnxruntime_RUN_ONNX_TESTS) @@ -1350,14 +1351,19 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) + + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_ROOT}) + if (onnxruntime_USE_CUDA) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 7798394b045dc..2c40b41774d78 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -27,6 +27,7 @@ #include "core/common/span_utils.h" #include "core/common/status.h" #include "core/common/logging/logging.h" +#include "core/framework/ort_value.h" #include "core/framework/prepacked_weights_container.h" #include "core/graph/onnx_protobuf.h" #include "core/graph/basic_types.h" @@ -39,6 +40,9 @@ #include "core/graph/node_arg.h" #include "core/graph/ort_format_load_options.h" +// Type from Graph API in ORT C API so can't be in a namespace +struct OrtGraph; + namespace onnxruntime { class Graph; struct IndexedSubGraph; @@ -763,6 +767,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const; + /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } @@ -1430,6 +1438,16 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph); + static Status LoadFromModelBuilderApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph); + + Status UpdateUsingModelBuilderApiModel(const OrtModel& api_model); + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const { return runtime_optimizations_; @@ -1630,7 +1648,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // Implementation for initializer replacement Status ReplaceInitializedTensorImpl(ONNX_NAMESPACE::TensorProto new_initializer, bool is_external); - std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, + template // range-initializer returning std::string + std::vector CreateNodeArgs(const StringRange& names, const ArgNameToTypeMap& name_to_type_map); void ToGraphProtoInternal(ONNX_NAMESPACE::GraphProto& graph_proto) const; @@ -1694,6 +1713,8 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return nodes_[node_index].get(); } + Status LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false); + const Model& owning_model_; // GraphProto to store name, version, initializer. @@ -1708,6 +1729,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi InitializedTensorSet name_to_initial_tensor_; + // Initializers that are external to the Graph. e.g. created using Model Builder API from existing memory. + // As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them + // in the Graph instance and retrieve during session state finalization. + std::unordered_map ortvalue_initializers_; + std::unordered_set, std::hash, std::equal_to> sparse_tensor_names_; @@ -1744,6 +1770,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // in some case, a fused sub-graph will happens multiple times in one model, we use a map // to store reusable-schema in lookup. InlinedHashMap> reusable_fused_schema_map_; + #endif // !defined(ORT_MINIMAL_BUILD) // Graph nodes. @@ -1806,7 +1833,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi std::unordered_map> node_arg_to_consumer_nodes_; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - const std::unordered_map domain_to_version_; + std::unordered_map domain_to_version_; // Model IR version. Version ir_version_{ONNX_NAMESPACE::Version::IR_VERSION}; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 9385e2f092e58..6a664d8be9c05 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -193,6 +193,12 @@ class GraphViewer { IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); } #endif + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + return graph_->GetOrtValueInitializer(name, value); + } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a35d975ac8f1b..c883ffa100320 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -665,6 +665,9 @@ typedef struct OrtApi OrtApi; struct OrtTrainingApi; typedef struct OrtTrainingApi OrtTrainingApi; +struct OrtModelBuilderApi; +typedef struct OrtModelBuilderApi OrtModelBuilderApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -847,7 +850,8 @@ struct OrtApi { * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); /** \brief Run the model in an ::OrtSession @@ -1340,6 +1344,8 @@ struct OrtApi { * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. * p_data is owned by caller. ReleaseValue won't release p_data. * + * If you wish to transfer ownership of p_data to ORT use CreateTensorWithDataAndDeleterAsOrtValue. + * * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). * \param[in] p_data Pointer to the data buffer. * \param[in] p_data_len The number of bytes in the data buffer. @@ -2887,7 +2893,8 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /** \brief Create session from memory with prepacked weights container @@ -2910,7 +2917,8 @@ struct OrtApi { */ ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, _Outptr_ OrtSession** out); /// @} @@ -4778,6 +4786,134 @@ struct OrtApi { */ ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + + /** \brief Get the Model Builder API instance + * + * Get the Model Builder API instance to create a new model or augment an existing model. + * + * \return Model Builder API struct + * + * \since Version 1.21. + */ + const OrtModelBuilderApi*(ORT_API_CALL* GetModelBuilderApi)(); + + /** \brief Create an OrtValue for a Tensor that uses pre-existing memory. + * + * Create an OrtValue for a Tensor that uses pre-existing memory. ORT will take ownership of the memory and free it + * using the provided deleter when no longer in use. + * + * \param[in] deleter OrtAllocator instance that will be used to free the memory. + * Only the OrtAllocator:Info and OrtAllocator::Release functions are required. + * The OrtMemoryInfo returned by OrtAllocator::Info must match the location of p_data. + * \param[in] p_data Pointer to the memory that will be used by the Tensor. ORT will take ownership of the memory. + * \param[in] p_data_len Length of the memory in bytes. + * \param[in] shape Dimensions of the Tensor. All values should be > 0. + * \param[in] shape_len Number of dimensions in the shape array. + * \param[in] type Data type of the Tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + + /** \brief Query the session for the opset version of a domain. + * + * When using the Model Builder API to augment a model, any new nodes must conform to the opset version of the + * original model. + * + * \param[in] session OrtSession to query + * \param[in] domain Domain to query. The ONNX domain is an empty string. + * \param[out] opset The opset version of the domain. + * + * \snippet{doc} snippets.dox OrtStatus Return Value. Returns an error if the domain is not used in the model. + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); + + /** \brief Create an OrtTypeInfo instance for a Tensor. + * + * Create an OrtTypeInfo instance for a Tensor to use as graph inputs/outputs with the Model Builder API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a SparseTensor. + * + * Create an OrtTypeInfo instance for a SparseTensor to use as graph inputs/outputs with the Model Builder API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info SparseTensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Map. + * + * Create an OrtTypeInfo instance for a Map to use as graph inputs/outputs with the Model Builder API. + * + * User can release `map_value_type` after creating the OrtTypeInfo. + * + * \param[in] map_key_type Key type for the map. + * \param[in] map_value_type Value type for the map. + * \param[out] TypeInfo instance for the map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Out_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Sequence. + * + * Create an OrtTypeInfo instance for a Sequence to use as graph inputs/outputs with the Model Builder API. + * + * User can release `sequence_type` after creating the OrtTypeInfo. + * + * \param[in] sequence_type Sequence type and shape information. + * \param[out] TypeInfo instance for the sequence. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Out_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for an Optional. + * + * Create an OrtTypeInfo instance for an Optional to use as graph inputs/outputs with the Model Builder API. + * + * User can release `contained_type` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Out_ OrtTypeInfo** type_info); }; /* @@ -4892,6 +5028,299 @@ struct OrtCustomOp { void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); }; +/** + * ORT Model Builder API + */ +ORT_RUNTIME_CLASS(Model); +ORT_RUNTIME_CLASS(Graph); +ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(ValueInfo); + +/** + * \brief The OrtModelBuilderApi struct provides functions to create or augment an ONNX model. + * + * See onnxruntime/test/shared_lib/test_model_builder_api.cc for example usage. + * + * \since Version 1.21. + */ +struct OrtModelBuilderApi { + /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. + * + * \param[in] name The name of the input or output. + * \param[in] type_info The type information for the input or output. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); + + /** \brief Get the name from an OrtValueInfo instance. + * + * \param[in] value_info The OrtValueInfo instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); + + /** \brief Get the type information from an OrtValueInfo instance. + * + * \param[in] value_info The OrtValueInfo instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); + + /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(ValueInfo); + + /** \brief Create an OrtNode to add to an OrtGraph. + * + * Create an OrtNode. + * + * Create attributes with CreateOpAttr. OrtOpAttr instances are copied. + * + * \param[in] operator_name The name of the operator. + * \param[in] domain_name The domain of the operator. Use an empty string for ONNX operators. + * \param[in] node_name The name of the node. + * \param[in] input_names The names of the inputs. + * \param[in] input_names_len The number of input names. + * \param[in] output_names The names of the outputs. + * \param[in] output_names_len The number of output names. + * \param[in] attributes The optional attributes of the node. + * \param[in] attribs_len The number of attributes. May be zero. + * \param[out] node The OrtNode instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, + _Outptr_ OrtNode** node); + + /** \brief Release an OrtNode if it was not added to an OrtGraph. + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Node); + + /** \brief Create an OrtGraph + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateGraph, _Outptr_ OrtGraph** graph); + + /** \brief Set the inputs for the OrtGraph. + * + * Set the graph inputs. + * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] inputs The input OrtValueInfo instances. + * \param[in] inputs_len The number of input OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); + + /** \brief Set the outputs for the OrtGraph. + * + * Set the graph outputs. + * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] outputs The output OrtValueInfo instances. + * \param[in] outputs_len The number of output OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); + + /** \brief Add an initializer to the OrtGraph + * + * Add the initializer to the graph. + * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. + * + * Two options: + * + * Pre-existing memory: + * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue + * with a tensor that contains a pointer to the existing data. + * User must keep pointer valid for lifetime of the inference session. + * Set `data_is_external` to true. + * + * Allocated memory: + * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. + * Set `data_is_external` to false. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] name The value name for the initializer. + * \param[in] tensor The OrtValue instance containing the tensor data. + * \param[in] data_is_external Set to true if the data is external and should not be copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, + bool data_is_external); + + /** \brief Add an OrtNode to an OrtGraph + * + * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] node The OrtNode instance to add to the graph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddNodeToGraph, _In_ OrtGraph* graph, _In_ OrtNode* node); + + /** \brief Release an OrtGraph if it was not added to an OrtModel. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Graph); + + /** \brief Create an OrtModel. + * + * Create an OrtModel. + * + * This can be used to build a new model, or to augment an existing model. + * + * \param[in] domain_names The domain names for the model. + * If augmenting an existing model add additional domains if needed. + * \param[in] opset_versions The opset versions for the model. + * If augmenting an existing model add additional opset versions if needed. + * \param[in] opset_entries_len The number of domain_names and opset_versions entries. + * Domain and opset entries should be 1:1 + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); + + /** \brief Add an OrtGraph to an OrtModel. + * + * Add the graph to a model. This should be called once when creating a new model. + * + * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. + * + * \param[in] model The OrtModel instance to update. + * \param[in] graph The OrtGraph instance to add to the model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph); + + /** \brief Release an OrtModel. + * + * Release the OrtModel. + * This should be called after the model is added to a session using CreateSessionFromModel or ApplyModelToSession. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_CLASS_RELEASE(Model); + + /** \brief Create an OrtSession using the OrtModel. + * + * Create an inference session using the OrtModel. + * This will validate the model, run optimizers, and prepare the session for inferencing. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that can be augmented with additional nodes. + * Nodes can be added to the model using AddNodeToGraph. + * Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes. + * Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling + * FinalizeModelBuilderSession. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelBuilderSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that can be augmented with additional nodes. + * Nodes can be added to the model using AddNodeToGraph. + * Graph inputs/outputs should be updated wtih SetGraphInputs and SetGraphOutputs to reflect the new nodes. + * Apply the changes with ApplyModelToSession and prepare the session for inferencing by calling + * FinalizeModelBuilderSession. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Apply the changes from the model to the session. + * + * Apply the changes from the model to the session that was created using CreateModelBuilderSession[FromArray]. + * All changes will be validated. + * Call FinalizeModelBuilderSession to prepare the session for inferencing. + * + * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. + * i.e. you don't need to call SetGraphInputs/Outputs if they are unchanged. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(ApplyModelToModelBuilderSession, _In_ OrtSession* session, _In_ OrtModel* model); + + /** \brief Finalize the Model Builder session. + * + * Finalize the Model Builder session. + * This will run optimizers and prepare the session for inferencing. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.21. + */ + ORT_API2_STATUS(FinalizeModelBuilderSession, _In_ OrtSession* session, _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container); +}; + /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index f3e9758766d00..715f61b17144c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -26,16 +26,17 @@ #include "onnxruntime_c_api.h" #include "onnxruntime_float16.h" +#include #include #include -#include #include #include #include -#include +#include #include #include -#include +#include +#include #ifdef ORT_NO_EXCEPTIONS #include @@ -120,7 +121,7 @@ const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); #endif #endif -/// This returns a reference to the OrtApi interface in use +/// This returns a reference to the ORT C API. inline const OrtApi& GetApi() noexcept { return *Global::api_; } /// @@ -143,6 +144,20 @@ std::string GetBuildInfoString(); /// vector of strings std::vector GetAvailableProviders(); +/// +/// This returns a reference to the ORT C Model Builder API. Used if building or augmenting a model at runtime. +/// +/// ORT C Model Builder API reference +inline const OrtModelBuilderApi& GetModelBuilderApi() { + auto* api = GetApi().GetModelBuilderApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Model Builder API is not available in this build", ORT_FAIL); + } + + return *api; +} + /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -526,6 +541,15 @@ ORT_DEFINE_RELEASE(KernelInfo); #undef ORT_DEFINE_RELEASE +#define ORT_DEFINE_MODELBUILDER_API_RELEASE(NAME) \ + inline void OrtRelease(Ort##NAME* ptr) { GetModelBuilderApi().Release##NAME(ptr); } + +ORT_DEFINE_MODELBUILDER_API_RELEASE(ValueInfo); +ORT_DEFINE_MODELBUILDER_API_RELEASE(Node); +ORT_DEFINE_MODELBUILDER_API_RELEASE(Graph); +ORT_DEFINE_MODELBUILDER_API_RELEASE(Model); +#undef ORT_DEFINE_MODELBUILDER_API_RELEASE + /** \brief This is a tagging template type. Use it with Base to indicate that the C++ interface object * has no ownership of the underlying C object. */ @@ -559,7 +583,9 @@ struct Base { constexpr Base() = default; constexpr explicit Base(contained_type* p) noexcept : p_{p} {} - ~Base() { OrtRelease(p_); } + ~Base() { + OrtRelease(p_); + } Base(const Base&) = delete; Base& operator=(const Base&) = delete; @@ -639,6 +665,10 @@ struct TypeInfo; struct Value; struct ModelMetadata; +namespace ModelBuilderAPI { +struct Model; +} + /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators * and release them at the end of the scope. The lifespan of the given allocator * must eclipse the lifespan of AllocatedStringPtr instance @@ -1051,6 +1081,10 @@ struct ConstSessionImpl : Base { size_t GetOutputCount() const; ///< Returns the number of model outputs size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden + std::vector GetInputNames() const; + std::vector GetOutputNames() const; + std::vector GetOverridableInitializerNames() const; + /** \brief Returns a copy of input name at the specified index. * * \param index must less than the value returned by GetInputCount() @@ -1084,6 +1118,8 @@ struct ConstSessionImpl : Base { TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo + + int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain }; template @@ -1161,6 +1197,9 @@ struct SessionImpl : ConstSessionImpl { * \param[in] kv_len Number of elements in the keys and values arrays */ void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); + + void FinalizeModelBuilderSession(const ModelBuilderAPI::Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); }; } // namespace detail @@ -1172,13 +1211,32 @@ using UnownedSession = detail::SessionImpl>; * */ struct Session : detail::SessionImpl { - explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used - Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession + /// Create an empty Session object, must be assigned a valid one to be used. Wraps OrtApi::CreateSession + explicit Session(std::nullptr_t) {} + explicit Session(OrtSession* p) : SessionImpl{p} {} ///< C API Interop + + Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer - Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray + OrtPrepackedWeightsContainer* prepacked_weights_container); + + /// Wraps OrtApi::CreateSessionFromArray + Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); + + /// Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, - OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer + OrtPrepackedWeightsContainer* prepacked_weights_container); + + /// Wraps OrtModelBuilderApi::CreateSessionFromModel + Session(const Env& env, const ModelBuilderAPI::Model& model, const SessionOptions& options); + + /// Wraps OrtModelBuilderApi::CreateModelBuilderSession + static Session CreateModelBuilderSession(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); + + /// Wraps OrtModelBuilderApi::CreateModelBuilderSession + static Session CreateModelBuilderSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options); ConstSession GetConst() const { return ConstSession{this->p_}; } UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } @@ -1210,7 +1268,7 @@ using ConstMemoryInfo = detail::MemoryInfoImpl { static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created - explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C Api + explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl{p} {} ///< Take ownership of a pointer created by C API MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } }; @@ -1233,6 +1291,7 @@ struct TensorTypeAndShapeInfoImpl : Base { [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions + std::vector GetSymbolicDimensions() const; std::vector GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape }; @@ -1248,8 +1307,18 @@ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl; using Base::Base; - explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used - explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API + /// Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used + explicit TensorTypeAndShapeInfo(std::nullptr_t) {} + /// Used for interop with the C API + explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} + + // Create a TensorTypeAndShapeInfo object with the specified element type and dimensions + // symbolic_dims are optional, but should be 1:1 with dims. + // The value in symbolic_dims will be used for all entries in dims that are -1. + explicit TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims = nullptr); + ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } }; @@ -1344,9 +1413,16 @@ struct TypeInfo : detail::TypeInfoImpl { using Base = detail::TypeInfoImpl; using Base::Base; - explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used + /// Create an empty TypeInfo object, must be assigned a valid one to be used + explicit TypeInfo(std::nullptr_t) {} explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl{p} {} ///< C API Interop + static TypeInfo CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_info); + static TypeInfo CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_info); + static TypeInfo CreateSequenceTypeInfo(ConstTypeInfo sequence_type); + static TypeInfo CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type); + static TypeInfo CreateOptionalTypeInfo(ConstTypeInfo contained_type); + ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } }; @@ -1701,7 +1777,8 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. */ template - static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); + static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len); /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. * @@ -1712,11 +1789,25 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, + static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue. + * + * \param deleter OrtAllocator that will be used to free the buffer when no longer required. + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. - * This overload will allocate the buffer for the tensor according to the supplied shape and data type. + * This overload will allocate the buffer for the tensor according to the supplied shape and data type. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. * The input data would need to be copied into the allocated buffer. * This API is not suitable for strings. @@ -1740,7 +1831,8 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a Map Onnx type representation. * The API would ref-count the supplied OrtValues and they will be released @@ -2459,6 +2551,136 @@ struct CustomOpBase : OrtCustomOp { int end_ver_ = MAX_CUSTOM_OP_END_VER; }; +// +// Model Builder API C++ wrappers +// +namespace ModelBuilderAPI { + +namespace detail { +template +struct ValueInfoImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + std::string Name() const; + ConstTypeInfo TypeInfo() const; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstValueInfo = detail::ValueInfoImpl>; + +/** \brief Wrapper around ::OrtValueInfo + * + */ +struct ValueInfo : detail::ValueInfoImpl { + explicit ValueInfo(std::nullptr_t) {} ///< No instance is created + /// Take ownership of a pointer created by C API + explicit ValueInfo(OrtValueInfo* p) : ValueInfoImpl{p} {} + + // Create ValueInfo for a tensor + explicit ValueInfo(const std::string& name, const ConstTypeInfo& type_info); + + ConstValueInfo GetConst() const { return ConstValueInfo{this->p_}; } +}; + +namespace detail { +template +struct NodeImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstNode = detail::NodeImpl>; + +/** \brief Wrapper around ::OrtNode + * + */ +struct Node : detail::NodeImpl { + explicit Node(std::nullptr_t) {} ///< No instance is created + explicit Node(OrtNode* p) : NodeImpl{p} {} ///< Take ownership of a pointer created by C API + + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names); + + /// + /// Wraps CreateNode. Node takes ownership of attributes on success and updates the OpAttr in `attributes` to do so. + /// + Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes); + + ConstNode GetConst() const { return ConstNode{this->p_}; } + + private: + static void Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node); +}; + +namespace detail { +template +struct GraphImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + void SetInputs(std::vector& inputs); + void SetOutputs(std::vector& outputs); + void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value + void AddNode(Node& node); // Graph takes ownership of Node +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstGraph = detail::GraphImpl>; + +/** \brief Wrapper around ::OrtGraph + * + */ +struct Graph : detail::GraphImpl { + explicit Graph(std::nullptr_t) {} ///< No instance is created + explicit Graph(OrtGraph* p) : GraphImpl{p} {} ///< Take ownership of a pointer created by C API + Graph(); + + ConstGraph GetConst() const { return ConstGraph{this->p_}; } +}; + +namespace detail { +template +struct ModelImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + void AddGraph(Graph& graph); +}; +} // namespace detail + +// Const object holder that does not own the underlying object +using ConstModel = detail::ModelImpl>; + +/** \brief Wrapper around ::OrtModel + * + */ +struct Model : detail::ModelImpl { + using DomainOpsetPair = std::pair; + + explicit Model(std::nullptr_t) {} ///< No instance is created + explicit Model(OrtModel* p) : ModelImpl{p} {} ///< Take ownership of a pointer created by C API + Model(const std::vector& opsets); + + ConstModel GetConst() const { return ConstModel{this->p_}; } +}; +} // namespace ModelBuilderAPI + } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 3aeb9412f350e..1de5db266961d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -995,6 +995,57 @@ inline size_t ConstSessionImpl::GetOverridableInitializerCount() const { return out; } +template +inline std::vector ConstSessionImpl::GetInputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetInputCount(); + std::vector input_names; + input_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetInputName(this->p_, i, allocator, &name)); + input_names.push_back(name); + } + + return input_names; +} + +template +inline std::vector ConstSessionImpl::GetOutputNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_inputs = GetOutputCount(); + std::vector output_names; + output_names.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOutputName(this->p_, i, allocator, &name)); + output_names.push_back(name); + } + + return output_names; +} + +template +inline std::vector ConstSessionImpl::GetOverridableInitializerNames() const { + AllocatorWithDefaultOptions allocator; + + auto num_initializers = GetOverridableInitializerCount(); + std::vector initializer_names; + initializer_names.reserve(num_initializers); + + for (size_t i = 0; i < num_initializers; ++i) { + char* name = nullptr; + ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, i, allocator, &name)); + initializer_names.push_back(name); + } + + return initializer_names; +} + template inline AllocatedStringPtr ConstSessionImpl::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const { char* out; @@ -1051,6 +1102,13 @@ inline TypeInfo ConstSessionImpl::GetOverridableInitializerTypeInfo(size_t in return TypeInfo{out}; } +template +inline int ConstSessionImpl::GetOpset(const std::string& domain) const { + int opset; + ThrowOnError(GetApi().SessionGetOpsetForDomain(this->p_, domain.c_str(), &opset)); + return opset; +} + template inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count) { @@ -1098,6 +1156,13 @@ inline void SessionImpl::SetEpDynamicOptions(const char* const* keys, const c ThrowOnError(GetApi().SetEpDynamicOptions(this->p_, keys, values, kv_len)); } +template +inline void SessionImpl::FinalizeModelBuilderSession(const ModelBuilderAPI::Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container) { + ThrowOnError(GetModelBuilderApi().ApplyModelToModelBuilderSession(this->p_, model)); + ThrowOnError(GetModelBuilderApi().FinalizeModelBuilderSession(this->p_, options, prepacked_weights_container)); +} + } // namespace detail inline SessionOptions::SessionOptions() { @@ -1144,6 +1209,30 @@ inline Session::Session(const Env& env, const void* model_data, size_t model_dat prepacked_weights_container, &this->p_)); } +inline Session::Session(const Env& env, const ModelBuilderAPI::Model& model, const SessionOptions& options) { + ThrowOnError(GetModelBuilderApi().CreateSessionFromModel(env, model.GetConst(), options, &this->p_)); +} + +// static +inline Session Session::CreateModelBuilderSession(const Env& env, const ORTCHAR_T* model_path, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelBuilderApi().CreateModelBuilderSession(env, model_path, options, &session)); + return Session(session); +} + +// static +inline Session Session::CreateModelBuilderSession(const Env& env, const void* model_data, size_t model_data_length, + const SessionOptions& options) { + OrtSession* session = nullptr; + ThrowOnError(GetModelBuilderApi().CreateModelBuilderSessionFromArray(env, model_data, model_data_length, options, + &session)); + return Session(session); +} + +void FinalizeModelBuilderSession(const ModelBuilderAPI::Model& model, const SessionOptions& options, + OrtPrepackedWeightsContainer* prepacked_weights_container); + inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const { char* out; ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out)); @@ -1211,6 +1300,57 @@ inline int64_t ModelMetadata::GetVersion() const { return out; } +inline TensorTypeAndShapeInfo::TensorTypeAndShapeInfo(ONNXTensorElementDataType element_type, + const std::vector& dims, + const std::vector* symbolic_dims) { + ThrowOnError(GetApi().CreateTensorTypeAndShapeInfo(&p_)); + ThrowOnError(GetApi().SetTensorElementType(p_, element_type)); + ThrowOnError(GetApi().SetDimensions(p_, dims.data(), dims.size())); + + if (symbolic_dims) { + std::vector symbolic_dims_cstr; + symbolic_dims_cstr.reserve(symbolic_dims->size()); + std::transform(symbolic_dims->begin(), symbolic_dims->end(), std::back_inserter(symbolic_dims_cstr), + [](const std::string& s) { return s.c_str(); }); + ThrowOnError(GetApi().SetSymbolicDimensions(p_, symbolic_dims_cstr.data(), symbolic_dims_cstr.size())); + } +} + +// static +inline TypeInfo TypeInfo::CreateTensorInfo(ConstTensorTypeAndShapeInfo tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetApi().CreateTensorTypeInfo(tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSparseTensorInfo(ConstTensorTypeAndShapeInfo sparse_tensor_type_and_shape_info) { + OrtTypeInfo* output = nullptr; + ThrowOnError(GetApi().CreateSparseTensorTypeInfo(sparse_tensor_type_and_shape_info, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateSequenceTypeInfo(ConstTypeInfo sequence_type) { + OrtTypeInfo* output; + ThrowOnError(GetApi().CreateSequenceTypeInfo(sequence_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateMapTypeInfo(ONNXTensorElementDataType key_type, ConstTypeInfo value_type) { + OrtTypeInfo* output; + ThrowOnError(GetApi().CreateMapTypeInfo(key_type, value_type, &output)); + return TypeInfo{output}; +} + +// static +inline TypeInfo TypeInfo::CreateOptionalTypeInfo(ConstTypeInfo contained_type) { + OrtTypeInfo* output; + ThrowOnError(GetApi().CreateOptionalTypeInfo(contained_type, &output)); + return TypeInfo{output}; +} + namespace detail { template @@ -1244,9 +1384,16 @@ inline void TensorTypeAndShapeInfoImpl::GetSymbolicDimensions(const char** va ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count)); } +template +inline std::vector TensorTypeAndShapeInfoImpl::GetSymbolicDimensions() const { + std::vector out(GetDimensionsCount(), nullptr); + ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, out.data(), out.size())); + return out; +} + template inline std::vector TensorTypeAndShapeInfoImpl::GetShape() const { - std::vector out(GetDimensionsCount(), 0); + std::vector out(GetDimensionsCount(), -1); ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size())); return out; } @@ -1560,23 +1707,35 @@ void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_inf } // namespace detail template -inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len) { return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); return Value{out}; } +inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count, + shape, shape_len, type, &out)); + return Value{out}; +} + template inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); return Value{out}; @@ -1594,7 +1753,8 @@ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& values_shape, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, - values_shape.shape, values_shape.shape_len, type, &out)); + values_shape.shape, values_shape.shape_len, type, + &out)); return Value{out}; } @@ -2167,4 +2327,138 @@ inline const OrtOpAttr* ShapeInferContext::GetAttrHdl(const char* attr_name) con return attr_hdl; } +namespace ModelBuilderAPI { +inline std::vector StringsToCharPtrs(const std::vector& strings) { + std::vector ptrs; + ptrs.reserve(strings.size()); + std::transform(strings.begin(), strings.end(), std::back_inserter(ptrs), + [](const std::string& s) { return s.c_str(); }); + + return ptrs; +} + +// static +inline void Node::Init(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes, + OrtNode*& node) { + auto inputs = StringsToCharPtrs(input_names); + auto outputs = StringsToCharPtrs(output_names); + + std::vector attributes_ptrs; + attributes_ptrs.reserve(attributes.size()); + std::transform(attributes.begin(), attributes.end(), std::back_inserter(attributes_ptrs), + [](OpAttr& attr) -> OrtOpAttr* { return attr; }); + + ThrowOnError(GetModelBuilderApi().CreateNode(operator_name.c_str(), operator_domain.c_str(), node_name.c_str(), + inputs.data(), inputs.size(), + outputs.data(), outputs.size(), + attributes_ptrs.data(), attributes_ptrs.size(), + &node)); + + // Node now owns the attributes + std::for_each(attributes.begin(), attributes.end(), [](OpAttr& attr) { attr.release(); }); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names, + std::vector& attributes) { + Init(operator_name, operator_domain, node_name, input_names, output_names, attributes, p_); +} + +inline Node::Node(const std::string& operator_name, const std::string& operator_domain, + const std::string& node_name, + const std::vector& input_names, + const std::vector& output_names) { + std::vector empty_attributes; + Init(operator_name, operator_domain, node_name, input_names, output_names, empty_attributes, p_); +} + +inline Graph::Graph() { + ThrowOnError(GetModelBuilderApi().CreateGraph(&p_)); +} + +inline Model::Model(const std::vector& opsets) { + std::vector domains; + std::vector versions; + domains.reserve(opsets.size()); + versions.reserve(opsets.size()); + + for (const auto& pair : opsets) { + domains.push_back(pair.first.c_str()); + versions.push_back(pair.second); + } + + ThrowOnError(GetModelBuilderApi().CreateModel(domains.data(), versions.data(), opsets.size(), &p_)); +} + +inline ValueInfo::ValueInfo(const std::string& name, const ConstTypeInfo& type_info) { + ThrowOnError(GetModelBuilderApi().CreateValueInfo(name.c_str(), type_info, &p_)); +} +namespace detail { +template <> +inline std::string ValueInfoImpl::Name() const { + const char* name = nullptr; + ThrowOnError(GetModelBuilderApi().GetValueInfoName(this->p_, &name)); + return name; +} + +template <> +inline ConstTypeInfo ValueInfoImpl::TypeInfo() const { + const OrtTypeInfo* type_info = nullptr; + ThrowOnError(GetModelBuilderApi().GetValueInfoTypeInfo(this->p_, &type_info)); + return ConstTypeInfo{type_info}; +} + +template <> +inline void GraphImpl::SetInputs(std::vector& inputs) { + std::vector inputs_ptrs; + inputs_ptrs.reserve(inputs.size()); + + // Graph takes ownership. + std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi.release(); }); + + ThrowOnError(GetModelBuilderApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); + + // Graph now owns the inputs + std::for_each(inputs.begin(), inputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::SetOutputs(std::vector& outputs) { + std::vector outputs_ptrs; + outputs_ptrs.reserve(outputs.size()); + std::transform(outputs.begin(), outputs.end(), std::back_inserter(outputs_ptrs), + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); + + ThrowOnError(GetModelBuilderApi().SetGraphOutputs(p_, outputs_ptrs.data(), outputs_ptrs.size())); + + // Graph now owns the outputs + std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); +} + +template <> +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { + // Graph takes ownership of `initializer` + ThrowOnError(GetModelBuilderApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); +} + +template <> +inline void GraphImpl::AddNode(Node& node) { + // Graph takes ownership of `node` + ThrowOnError(GetModelBuilderApi().AddNodeToGraph(p_, node.release())); +} + +template <> +inline void ModelImpl::AddGraph(Graph& graph) { + // Model takes ownership of `graph` + ThrowOnError(GetModelBuilderApi().AddGraphToModel(p_, graph.release())); +} +} // namespace detail +} // namespace ModelBuilderAPI } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 64a4dd19c12b0..0c3ebebc521d1 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -300,3 +300,13 @@ static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "sessio // “Default”: OS determines the scheduling priority and processor performance to service this workload. [Default] // “Efficient”: OS treats this workload is efficiency oriented with low scheduling priority and efficient processor performance. static const char* const kOrtEpDynamicOptionsWorkloadType = "ep.dynamic.workload_type"; + +// Create an Inference Session that will use the Model Builder API to create/update the model. +// This flag will create the session but not fully initialize it. A model, if provided, will be loaded. +// A session logger will be created, and execution providers will be registered. +// Any device specific allocators and IDataTransfer objects will be registered. +// This allows CreateAllocator to return device specific allocators registered by EPs. +// FUTURE: This will also allow CopyTensors to utilize the IDataTransfer objects +// "0": Disabled. [DEFAULT] +// "1": Enable Model Builder Session +static const char* const kOrtSessionOptionsEnableModelBuilder = "session.model_builder_session"; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.cc b/onnxruntime/core/framework/onnxruntime_typeinfo.cc index a884927abddb7..91383425f16d9 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.cc +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.cc @@ -40,7 +40,7 @@ OrtTypeInfo::OrtTypeInfo(std::unique_ptr optional_type_info : type(ONNX_TYPE_OPTIONAL), optional_type_info(std::move(optional_type_info)) {} OrtTypeInfo::OrtTypeInfo(ONNXType type, std::unique_ptr data) noexcept - : type(type), data(std::move(data)) { + : type(type), tensor_type_info(std::move(data)) { } OrtTypeInfo::~OrtTypeInfo() = default; @@ -55,7 +55,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetOnnxTypeFromTypeInfo, _In_ const struct OrtTypeI ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToTensorInfo, _In_ const struct OrtTypeInfo* input, _Outptr_result_maybenull_ const struct OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN - *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) ? input->data.get() : nullptr; + *out = (input->type == ONNX_TYPE_TENSOR || input->type == ONNX_TYPE_SPARSETENSOR) + ? input->tensor_type_info.get() + : nullptr; return nullptr; API_IMPL_END } @@ -84,8 +86,8 @@ ORT_API_STATUS_IMPL(OrtApis::CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeI API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const out, - _Out_ size_t* len) { +ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, + _Out_ const char** const out, _Out_ size_t* len) { API_IMPL_BEGIN *out = type_info->denotation.c_str(); *len = type_info->denotation.size(); @@ -93,6 +95,59 @@ ORT_API_STATUS_IMPL(OrtApis::GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_TENSOR); + ti->tensor_type_info = tensor_info->Clone(); + *type_info = ti.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SPARSETENSOR); + ti->tensor_type_info = tensor_info->Clone(); + *type_info = ti.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, + _In_ const OrtTypeInfo* map_value_type, _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_MAP); + ti->map_type_info = std::make_unique(map_key_type, map_value_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_SEQUENCE); + ti->sequence_type_info = std::make_unique(sequence_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, + _Out_ OrtTypeInfo** type_info) { + API_IMPL_BEGIN + auto ti = std::make_unique(ONNXType::ONNX_TYPE_OPTIONAL); + ti->optional_type_info = std::make_unique(contained_type->Clone()); + *type_info = ti.release(); + + return nullptr; + API_IMPL_END +} + ORT_API(void, OrtApis::ReleaseTypeInfo, _Frees_ptr_opt_ OrtTypeInfo* ptr) { std::unique_ptr p(ptr); } @@ -298,8 +353,8 @@ std::unique_ptr OrtTypeInfo::Clone() const { #endif case ONNX_TYPE_TENSOR: { std::unique_ptr info; - if (data) { - info = data->Clone(); + if (tensor_type_info) { + info = tensor_type_info->Clone(); } result = MakePtr(type, std::move(info)); result->denotation = denotation; diff --git a/onnxruntime/core/framework/onnxruntime_typeinfo.h b/onnxruntime/core/framework/onnxruntime_typeinfo.h index 72d263d5fa442..54bb946e0d36b 100644 --- a/onnxruntime/core/framework/onnxruntime_typeinfo.h +++ b/onnxruntime/core/framework/onnxruntime_typeinfo.h @@ -31,7 +31,7 @@ struct OrtTypeInfo { ONNXType type; std::string denotation; - std::unique_ptr data; + std::unique_ptr tensor_type_info; std::unique_ptr map_type_info; std::unique_ptr sequence_type_info; std::unique_ptr optional_type_info; diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 83a353615bc35..2d4034991cc3a 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -203,13 +203,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } -common::Status AllocateTensor( - const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, - const onnxruntime::DataTypeImpl* const& type, - onnxruntime::TensorShape& tensor_shape, - bool use_device_allocator_for_initializers, - const onnxruntime::AllocatorPtr& alloc) { +common::Status AllocateTensor(const onnxruntime::MemBuffer* m, + std::unique_ptr& p_tensor, + const onnxruntime::DataTypeImpl* const& type, + onnxruntime::TensorShape& tensor_shape, + bool use_device_allocator_for_initializers, + const onnxruntime::AllocatorPtr& alloc) { if (m != nullptr) { p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); if (m->GetLen() < p_tensor->SizeInBytes()) { @@ -354,6 +353,7 @@ common::Status SaveInitializedTensors( } ORT_RETURN_IF_ERROR(planner.Trace(entry.first, entry.second)); } + // 2. allocate weight buffer on different locations // planned_initializers_memory_size_in_byte is not actual physical size. // It's the virtual size computed by planner. @@ -386,6 +386,9 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { ort_value = *(session_options.initializers_to_share_map.at(name)); LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; + + } else if (graph.GetOrtValueInitializer(name, ort_value)) { + // populated OrtValue from the Graph instance } else { const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 418e46924fb9f..b9225f95ce7cc 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -49,10 +49,27 @@ ORT_API_STATUS_IMPL(OrtApis::SetTensorElementType, _Inout_ OrtTensorTypeAndShape API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* this_ptr, +ORT_API_STATUS_IMPL(OrtApis::SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) { API_IMPL_BEGIN - this_ptr->shape = onnxruntime::TensorShape(dim_values, dim_count); + if (std::any_of(dim_values, dim_values + dim_count, [](int64_t v) { return v < -1; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "dim_values must be -1 (symbolic dimension) or larger."); + } + + auto num_dims = std::max(dim_count, info->dim_params.size()); + + // make shape and dim_values consistent + info->dim_params.resize(num_dims, ""); + + std::vector dims; + dims.resize(num_dims, -1); + + for (size_t idx = 0; idx < dim_count; ++idx) { + dims[idx] = dim_values[idx]; + } + + info->shape = onnxruntime::TensorShape(dims); + return nullptr; API_IMPL_END } @@ -88,10 +105,22 @@ ORT_API_STATUS_IMPL(OrtApis::GetSymbolicDimensions, ORT_API_STATUS_IMPL(OrtApis::SetSymbolicDimensions, _In_ struct OrtTensorTypeAndShapeInfo* info, _In_ const char** names, _In_ size_t dim_params_length) { + auto num_dims = std::max(info->shape.NumDimensions(), dim_params_length); + + // make shape and dim_values consistent + if (num_dims > info->shape.NumDimensions()) { + auto dim_values = info->shape.AsShapeVector(); + dim_values.resize(num_dims, -1); + info->shape = onnxruntime::TensorShape(dim_values); + } + info->dim_params.clear(); + info->dim_params.resize(num_dims, ""); + for (size_t idx = 0; idx < dim_params_length; ++idx) { - info->dim_params.push_back(names[idx]); + info->dim_params[idx] = names[idx]; } + return nullptr; } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 0b6610db5e007..8d1becdb24a9f 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -16,6 +16,7 @@ #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensor_external_data_info.h" @@ -25,6 +26,7 @@ #include "core/graph/graph_viewer.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/model.h" +#include "core/graph/model_builder_api_types.h" #include "core/graph/model_load_utils.h" #include "core/graph/model_saving_options.h" #include "core/graph/node_attr_utils.h" @@ -50,7 +52,9 @@ namespace onnxruntime { #define NO_CHANGE_ON_SYNC_FLAG(...) \ do { \ const bool sync_needed = GraphProtoSyncNeeded(); \ - { __VA_ARGS__; } \ + { \ + __VA_ARGS__; \ + } \ GraphProtoSyncNeeded(sync_needed); \ } while (0) @@ -3496,6 +3500,11 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { #if !defined(DISABLE_SPARSE_TENSORS) sparse_tensor_names_.erase(tensor_name); #endif + + if (auto it = ortvalue_initializers_.find(tensor_name); it != ortvalue_initializers_.end()) { + ortvalue_initializers_.erase(it); + } + SetGraphResolveNeeded(); } else { #if !defined(DISABLE_SPARSE_TENSORS) @@ -3627,8 +3636,8 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( return Status::OK(); } -#endif // DISABLE_EXTERNAL_INITIALIZERS +#endif // DISABLE_EXTERNAL_INITIALIZERS #endif // !defined(ORT_MINIMAL_BUILD) bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto*& value) const { @@ -3641,6 +3650,16 @@ bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorPro return true; } +bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + auto it = ortvalue_initializers_.find(name); + if (it == ortvalue_initializers_.end()) { + return false; + } + + value = it->second; + return true; +} + void Graph::CleanAllInitializedTensors() noexcept { name_to_initial_tensor_.clear(); #if !defined(DISABLE_SPARSE_TENSORS) @@ -3655,6 +3674,8 @@ void Graph::CleanAllInitializedTensors() noexcept { for (int i = 0; i < num_cleared; i++) { delete graph_proto_->mutable_initializer()->ReleaseCleared(); } + + ortvalue_initializers_.clear(); } const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::string& initializer_name, @@ -3704,13 +3725,14 @@ void Graph::AddValueInfo(const NodeArg* new_value_info) { value_info_.insert(new_value_info); } -std::vector Graph::CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, +template +std::vector Graph::CreateNodeArgs(const StringRange& names, const ArgNameToTypeMap& name_to_type_map) { const auto name_to_type_map_end = name_to_type_map.end(); std::vector results; results.reserve(names.size()); - for (auto& name : names) { + for (const std::string& name : names) { const TypeProto* type = nullptr; auto name_to_type_iter = name_to_type_map.find(name); @@ -5325,6 +5347,9 @@ Status Graph::InlineFunction(Node& callnode) { } void Graph::SetInputs(gsl::span inputs) { + graph_inputs_including_initializers_.clear(); + graph_inputs_excluding_initializers_.clear(); + // creating graph from scratch // rely on SetGraphInputsOutputs() to fix up graph_inputs_excluding_initializers_ // if is_loaded_from_model_file_ == false @@ -5333,7 +5358,6 @@ void Graph::SetInputs(gsl::span inputs) { if (is_loaded_from_model_file_) { // graph loaded from model file - graph_inputs_excluding_initializers_.clear(); for (const auto* input : inputs) { ORT_ENFORCE(input->Exists(), "Input to set must exist."); if (name_to_initial_tensor_.find(input->Name()) == name_to_initial_tensor_.end()) { @@ -5350,6 +5374,7 @@ void Graph::SetInputs(gsl::span inputs) { } void Graph::SetOutputs(gsl::span outputs) { + graph_outputs_.clear(); graph_outputs_.reserve(outputs.size()); graph_outputs_.assign(outputs.begin(), outputs.end()); @@ -5668,4 +5693,217 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph return Status::OK(); } +#if !defined(ORT_MINIMAL_BUILD) +namespace { +ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { + // the model builder API checks that the OrtValueInfo has a complete and valid OrtTypeInfo instance and that the + // name is not null/empty. + ORT_ENFORCE(vi.type_info->type == ONNX_TYPE_TENSOR, + "Internal error. Model Builder API should only allow OrtValueInfo for tensor to be created."); + + ValueInfoProto value_info_proto; + value_info_proto.set_name(vi.name); + + auto* tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + const OrtTensorTypeAndShapeInfo& tensor_info = *vi.type_info->tensor_type_info.get(); + tensor->set_elem_type(tensor_info.type); + + auto& shape = *tensor->mutable_shape(); + + size_t idx = 0; + for (auto dim : tensor_info.shape.GetDims()) { + auto& dim_proto = *shape.add_dim(); + if (dim >= 0) { + dim_proto.set_dim_value(dim); + } else { + const std::string& dim_param = tensor_info.dim_params[idx]; + // if empty leave the new dim_proto with neither dim_value nor dim_param set. this represents an 'unknown' dim + if (!dim_param.empty()) { + dim_proto.set_dim_param(dim_param); + } + } + } + + return value_info_proto; +} +} // namespace + +Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updating_existing_graph) { + ArgNameToTypeMap name_to_type_map; + + // NOTE: need to create NodeArgs as we go along + + // add inputs first. the shape from an input for a non-const initializer is preferred, so we want to create the + // NodeArg for the value using that + + auto add_graph_inputs_outputs = [&, this](const std::vector>& graph_inputs_or_outputs, + bool is_input) { + // when updating a model we don't require the inputs or outputs to be set if they're unchanged. + if (updating_existing_graph && graph_inputs_or_outputs.empty()) { + return; + } + + std::vector node_args; + node_args.reserve(graph_inputs_or_outputs.size()); + for (auto& ort_value_info : graph_inputs_or_outputs) { + ValueInfoProto value_info = OrtValueInfoToOnnx(*ort_value_info); + + name_to_type_map[value_info.name()] = value_info.type(); + node_args.push_back(&GetOrCreateNodeArg(value_info.name(), &value_info.type())); + } + + if (is_input) { + SetInputs(node_args); + } else { + SetOutputs(node_args); + } + }; + + auto add_initializers = [this](const std::unordered_map>& initializers, + bool is_external) { + for (auto& name_and_ortvalue : initializers) { + // convert from OrtValue to TensorProto + const std::string& name = name_and_ortvalue.first; + OrtValue& v = *name_and_ortvalue.second; + + ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); + const Tensor& t = v.Get(); + TensorProto& tensor_proto = *graph_proto_->add_initializer(); + + tensor_proto.set_name(name); + tensor_proto.set_data_type(t.GetElementType()); + for (auto dim : t.Shape().GetDims()) { + tensor_proto.add_dims(dim); + } + + if (is_external) { + // pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + const void* data_offset = t.DataRaw(); // address of memory not offset into file + auto offset = narrow(reinterpret_cast(data_offset)); + + ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("location"); + // magic tag for existing memory that causes 'offset' to be treated as a pointer to the memory + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(t.SizeInBytes())); + + // copy OrtValue to keep it alive and to store the deleter if provided. + ortvalue_initializers_.emplace(name, v); + v = OrtValue{}; // reset as we have taken a copy + } else { + tensor_proto.set_raw_data(t.DataRaw(), t.SizeInBytes()); + } + + TypeProto type_proto{TypeProtoFromTensorProto(tensor_proto)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(name, &type_proto)); + + name_to_initial_tensor_.emplace(name, &tensor_proto); + } + }; + + // process graph inputs first as we want the type/shape from them to be preferred if a graph input + // has a matching initializer + add_graph_inputs_outputs(api_graph.inputs, /*input*/ true); + + // add initializers + ortvalue_initializers_.reserve(api_graph.external_initializers.size()); + add_initializers(api_graph.external_initializers, /*is_external*/ true); + add_initializers(api_graph.initializers, /*is_external*/ false); + + // add graph outputs + add_graph_inputs_outputs(api_graph.outputs, /*input*/ false); + + // add nodes + for (const auto& ort_node : api_graph.nodes) { + const OrtNode& node = *ort_node; + + // convert Constant nodes to initializers + if (node.operator_name == "Constant" && node.domain_name == kOnnxDomain) { + // graph_proto_ provides storage + TensorProto& tensor = *graph_proto_->add_initializer(); + + // create NodeProto from OrtNode so we can use the existing conversion functions + NodeProto node_proto; + + // 'Constant' node has no inputs or attributes + ORT_RETURN_IF_NOT(node.input_names.empty() && node.attributes.size() == 1 && node.output_names.size() == 1, + node.node_name, + " is an invalid 'Constant' node. " + "Must have no inputs, one attribute and one output. "); + + node_proto.add_attribute()->CopyFrom(node.attributes[0]); + node_proto.add_output(node.output_names[0]); + + node_proto.set_op_type(node.operator_name); + node_proto.set_name(node.node_name); + node_proto.set_domain(node.domain_name); + + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, /*model_path*/ "", tensor)); + name_to_initial_tensor_.emplace(node.output_names[0], &tensor); + + continue; + } + + auto input_defs = CreateNodeArgs(node.input_names, name_to_type_map); + auto output_defs = CreateNodeArgs(node.output_names, name_to_type_map); + + const auto num_attributes = node.attributes.size(); + + NodeAttributes attributes; + attributes.reserve(num_attributes); + + for (const auto& attr : node.attributes) { + attributes[attr.name()] = attr; + } + + ORT_IGNORE_RETURN_VALUE(AddNode(node.node_name, node.operator_name, /*doc_string*/ "", + input_defs, output_defs, &attributes, node.domain_name)); + } + + return Resolve(); +} + +// static +Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph) { + graph = std::make_unique(owning_model, + domain_to_version, + schema_registry, + /*parent_graph*/ nullptr, /*parent_node*/ nullptr, + logger, + strict_shape_type_inference); + + return graph->LoadFromModelBuilderApiModel(api_graph); +} + +Status Graph::UpdateUsingModelBuilderApiModel(const OrtModel& api_model) { + for (auto& entry : api_model.domain_to_version) { + if (auto it = domain_to_version_.find(entry.first); it != domain_to_version_.end()) { + if (it->second != entry.second) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Domain version can not be changed for '", entry.first, + "'. Current version: ", it->second); + } + } else { + domain_to_version_.insert(entry); + } + } + + // this will replace inputs/outputs and add nodes. + return LoadFromModelBuilderApiModel(*api_model.graph, /*updating_existing_graph*/ true); +} + +#endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index be0531e6473fb..01ef75af4076d 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -7,6 +7,7 @@ #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/model.h" +#include "core/graph/model_builder_api_types.h" #include "core/graph/model_load_utils.h" #ifdef _MSC_VER @@ -738,6 +739,36 @@ Status Model::Load(int fd, const PathString& model_path, std::shared_ptr& return Status::OK(); } +// static +common::Status Model::LoadFromModelBuilderApiModel(const OrtModel& model_builder_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model) { + model = std::make_unique(); + model->model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + // The optimizer Initializer class requires a path if external data is used, however in the Graph API usage the + // external data is pointing to pre-allocated memory and does not require a path. Set a dummy value to make it happy. + model->model_path_ = std::filesystem::path("_GRAPH_API_MODEL_"); + + auto schema_registry = std::make_shared(); + if (local_registries != nullptr) { + for (const auto& schema_collection : *local_registries) { + schema_registry->RegisterRegistry(schema_collection); + } + } + + ORT_RETURN_IF_ERROR(Graph::LoadFromModelBuilderApiModel(*model_builder_api_model.graph, + *model, + model_builder_api_model.domain_to_version, + schema_registry, + options.strict_shape_type_inference, + logger, + model->graph_)); + + return Status::OK(); +} + Status Model::Save(Model& model, int p_fd) { if (p_fd < 0) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, " is less than 0."); @@ -917,5 +948,4 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, #endif return Status::OK(); } - } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 2d2086aef41fd..dcf3197f61170 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -280,6 +280,12 @@ class Model { const logging::Logger& logger, const ModelOptions& options = {}); + static common::Status LoadFromModelBuilderApiModel(const OrtModel& graph_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model); + common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; @@ -333,7 +339,7 @@ class Model { ModelMetaData model_metadata_; // Path to model file. May be empty. - const std::filesystem::path model_path_; + std::filesystem::path model_path_; // Main graph of the model. std::unique_ptr graph_; diff --git a/onnxruntime/core/graph/model_builder_api_types.h b/onnxruntime/core/graph/model_builder_api_types.h new file mode 100644 index 0000000000000..acc29beca0d8d --- /dev/null +++ b/onnxruntime/core/graph/model_builder_api_types.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/graph/onnx_protobuf.h" + +// ORT C interface types for OrtGraphApi can't be in a namespace. +// We need to define them here so onnxruntime::Model can be created from OrtModel. + +struct OrtValueInfo { + std::string name; + std::unique_ptr type_info; +}; + +struct OrtOpAttr { + ONNX_NAMESPACE::AttributeProto attr_proto; +}; + +struct OrtNode { + std::string operator_name; + std::string domain_name; + std::string node_name; + + // OrtOpAttr is 1:1 with ONNX_NAMESPACE::AttributeProto currently. + // https://github.com/microsoft/onnxruntime/blob/bd5a759d0cdbed6e7f611c990d4eb5457a9ecf60/onnxruntime/core/session/standalone_op_invoker.cc#L318 + // Might be better if it had a wrapper struct so we have more flexibility. + // AFAIK (TBC) that's an implementation detail so we should be able to change it. + std::vector attributes; + std::vector input_names; + std::vector output_names; + + // FUTURE if we need control flow nodes + // std::unordered_map subgraphs; +}; + +struct OrtGraph { + std::vector> inputs; + std::vector> outputs; + std::unordered_map> initializers; + std::unordered_map> external_initializers; + std::vector> nodes; +}; + +struct OrtModel { + std::unique_ptr graph; + std::unordered_map domain_to_version; +}; diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 7ef23d6c9e895..2e733f67a888c 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/graph/onnx_protobuf.h" -#include "core/common/inlined_containers.h" -#include "core/session/onnxruntime_c_api.h" -#include "core/session/ort_apis.h" -#include "core/framework/error_code_helper.h" -#include #include +#include #include + +#include "core/common/inlined_containers.h" +#include "core/framework/error_code_helper.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" -#include "abi_session_options_impl.h" -#include "api_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" +#include "core/session/utils.h" OrtSessionOptions::~OrtSessionOptions() = default; diff --git a/onnxruntime/core/session/api_utils.cc b/onnxruntime/core/session/api_utils.cc deleted file mode 100644 index f7cb8520b1e5d..0000000000000 --- a/onnxruntime/core/session/api_utils.cc +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "api_utils.h" - -onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { - const size_t str_len = str.size(); - const size_t req_size = str_len + 1; - - if (out == nullptr) { // User is querying the total output buffer size - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - if (*size >= req_size) { // User provided a buffer of sufficient size - std::memcpy(out, str.data(), str_len); - out[str_len] = '\0'; - *size = req_size; - return onnxruntime::common::Status::OK(); - } - - // User has provided a buffer that is not large enough - *size = req_size; - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); -} diff --git a/onnxruntime/core/session/api_utils.h b/onnxruntime/core/session/api_utils.h deleted file mode 100644 index 27c2bbd66f8d5..0000000000000 --- a/onnxruntime/core/session/api_utils.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include - -onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 33d2a0244b453..68142fed11df6 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -20,7 +20,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/graph/onnx_protobuf.h" #include "core/session/allocator_adapters.h" -#include "core/session/api_utils.h" +#include "core/session/utils.h" #include "core/session/custom_ops.h" #include "core/session/inference_session.h" #include "core/session/ort_apis.h" diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 223eed248800e..b0f8db4538807 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -38,6 +38,7 @@ #include "core/framework/utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/graph/model_builder_api_types.h" #include "core/graph/model_saving_options.h" #include "core/optimizer/graph_transformer_utils.h" #include "core/optimizer/graph_transformer.h" @@ -67,11 +68,11 @@ #include "core/optimizer/stft_decomposition.h" #endif #include "core/session/environment.h" -#include "core/session/user_logging_sink.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/user_logging_sink.h" #include "core/util/protobuf_parsing_utils.h" #include "core/util/thread_utils.h" @@ -1194,6 +1195,56 @@ common::Status InferenceSession::Load() { return LoadWithLoader(loader, "model_loading_from_saved_proto"); } +common::Status InferenceSession::Load(const OrtModel& model_builder_api_model) { + std::lock_guard l(session_mutex_); + + if (is_model_loaded_) { // already loaded + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + if (is_inited_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; + + // need to go from unique_ptr to shared_ptr when moving into model_ + std::unique_ptr tmp_model; + ORT_RETURN_IF_ERROR(Model::LoadFromModelBuilderApiModel(model_builder_api_model, + HasLocalSchema() ? &custom_schema_registries_ : nullptr, + ModelOptions(true, strict_shape_type_inference), + *session_logger_, tmp_model)); + + model_ = std::move(tmp_model); + + is_model_loaded_ = true; + + return Status::OK(); +} + +common::Status InferenceSession::ApplyUpdates(const OrtModel& model_builder_api_model) { + std::lock_guard l(session_mutex_); + + if (!is_model_loaded_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session does not contain a loaded model."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + if (is_inited_) { + Status status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session has already been initialized."); + LOGS(*session_logger_, ERROR) << status.ErrorMessage(); + return status; + } + + return model_->MainGraph().UpdateUsingModelBuilderApiModel(model_builder_api_model); +} + common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. @@ -3285,6 +3336,10 @@ common::Status InferenceSession::WaitForNotification(Notification* p_executor_do return Status::OK(); } +const Model& InferenceSession::GetModel() const { + return *model_; +} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index e28ff75345785..f89eacb633e42 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -46,6 +46,9 @@ namespace ONNX_NAMESPACE { class ModelProto; } // namespace ONNX_NAMESPACE +// OrtModelBuilderApi Model. Used to dynamically construct a model via C API at runtime. +struct OrtModel; + namespace onnxruntime { // forward declarations class CustomRegistry; class Environment; @@ -319,6 +322,27 @@ class InferenceSession { * @return OK if success. */ [[nodiscard]] common::Status Load(); + + /** + * Load an OrtModel that was dynamically constructed via OrtModelBuilderApi. + * + * @param graph_api_model OrtModel from OrtModelBuilderApi + * @return OK if success. + */ + [[nodiscard]] common::Status Load(const OrtModel& graph_api_model); + + /** + * Apply updates from an OrtModel that was created via OrtModelBuilderApi. + * This can: + * - add nodes at the start and end of the model + * - add initializers + * - update the graph inputs/outputs + * + * @param graph_api_model OrtModel from OrtModelBuilderApi + * @return OK if success. + */ + [[nodiscard]] common::Status ApplyUpdates(const OrtModel& graph_api_model); + #endif // !defined(ORT_MINIMAL_BUILD) /** @@ -545,6 +569,8 @@ class InferenceSession { */ Status AddPrePackedWeightsContainer(PrepackedWeightsContainer* prepacked_weights_container); + const Model& GetModel() const; + protected: #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/model_builder_api.h b/onnxruntime/core/session/model_builder_api.h new file mode 100644 index 0000000000000..b27888fe418ba --- /dev/null +++ b/onnxruntime/core/session/model_builder_api.h @@ -0,0 +1,59 @@ +namespace OrtModelBuilderAPI { + +// implementation that returns the API struct +ORT_API(const OrtModelBuilderApi*, GetModelBuilderApi); + +ORT_API_STATUS_IMPL(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); +ORT_API_STATUS_IMPL(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); +ORT_API_STATUS_IMPL(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); +ORT_API(void, ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info); + +ORT_API_STATUS_IMPL(CreateNode, const char* operator_name, const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, + _Outptr_ OrtNode** node); +ORT_API(void, ReleaseNode, _Frees_ptr_opt_ OrtNode* node); + +ORT_API_STATUS_IMPL(CreateGraph, _Outptr_ OrtGraph** graph); +ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); +ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, + bool data_is_external); +ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); +ORT_API(void, ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph); + +ORT_API_STATUS_IMPL(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); +ORT_API_STATUS_IMPL(AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph); +ORT_API(void, ReleaseModel, _Frees_ptr_opt_ OrtModel* model); + +// TODO Do we need this, or could we use CreateModelBuilder with nullptr for model_path? +ORT_API_STATUS_IMPL(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + +// +// Model editing APIs for updating existing model. +// +ORT_API_STATUS_IMPL(CreateModelBuilderSession, _In_ const OrtEnv* env, + _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + +ORT_API_STATUS_IMPL(CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + +ORT_API_STATUS_IMPL(ApplyModelToModelBuilderSession, _In_ OrtSession* session, _In_ OrtModel* model); + +ORT_API_STATUS_IMPL(FinalizeModelBuilderSession, _In_ OrtSession* session, _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container); + +} // namespace OrtModelBuilderAPI diff --git a/onnxruntime/core/session/model_builder_c_api.cc b/onnxruntime/core/session/model_builder_c_api.cc new file mode 100644 index 0000000000000..25e2409805c74 --- /dev/null +++ b/onnxruntime/core/session/model_builder_c_api.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/framework/error_code_helper.h" +#include "core/framework/ort_value.h" +#include "core/framework/onnxruntime_typeinfo.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/graph/constants.h" +#include "core/graph/model_builder_api_types.h" +#include "core/graph/onnx_protobuf.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/inference_session.h" +#include "core/session/model_builder_api.h" +#include "core/session/ort_apis.h" +#include "core/session/ort_env.h" +#include "core/session/utils.h" + +using namespace onnxruntime; + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info) { + API_IMPL_BEGIN + if (name == nullptr || *name == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "name cannot be null or empty string"); + } + + if (type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "type_info cannot be null"); + } + + if (type_info->type != ONNX_TYPE_TENSOR) { + return OrtApis::CreateStatus(ORT_FAIL, "Only tensor types are supported currently"); + } + + if (type_info->tensor_type_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tensor_type_info cannot be null"); + } + + auto vi = std::make_unique(); + vi->name = name; + vi->type_info = type_info->Clone(); + + *value_info = vi.release(); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name) { + API_IMPL_BEGIN + *name = value_info->name.c_str(); + return nullptr; + API_IMPL_END +} +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info) { + API_IMPL_BEGIN + + *type_info = value_info->type_info.get(); + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtModelBuilderAPI::ReleaseValueInfo, _Frees_ptr_opt_ OrtValueInfo* value_info) { + delete value_info; +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateNode, const char* operator_name, const char* domain_name, + _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _Inout_opt_ OrtOpAttr** attributes, _In_opt_ size_t attribs_len, + _Outptr_ OrtNode** node) { + API_IMPL_BEGIN + auto n = std::make_unique(); + n->operator_name = operator_name; + n->domain_name = domain_name == kOnnxDomainAlias ? kOnnxDomain : domain_name; + n->node_name = node_name; + + n->input_names.reserve(input_names_len); + for (size_t i = 0; i < input_names_len; ++i) { + n->input_names.push_back(input_names[i]); + } + + n->output_names.reserve(output_names_len); + for (size_t i = 0; i < output_names_len; ++i) { + n->output_names.push_back(output_names[i]); + } + + if (attributes != nullptr) { + n->attributes.reserve(attribs_len); + for (size_t i = 0; i < attribs_len; ++i) { + n->attributes.push_back(*reinterpret_cast(attributes[i])); + } + } + + *node = n.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtModelBuilderAPI::ReleaseNode, _Frees_ptr_opt_ OrtNode* node) { + delete node; +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateGraph, _Outptr_ OrtGraph** graph) { + API_IMPL_BEGIN + auto g = std::make_unique(); + + // do some reserves to reduce reallocation. if we had a hint about sizes upfront that would be optimal + g->inputs.reserve(8); + g->outputs.reserve(8); + g->initializers.reserve(32); + g->external_initializers.reserve(32); + g->nodes.reserve(64); + + *graph = g.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::SetGraphInputs, _In_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len) { + API_IMPL_BEGIN + for (size_t i = 0; i < inputs_len; ++i) { + if (inputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "inputs cannot contain null entries"); + } + + graph->inputs.push_back(std::unique_ptr(inputs[i])); // take ownership + inputs[i] = nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::SetGraphOutputs, _In_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len) { + API_IMPL_BEGIN + for (size_t i = 0; i < outputs_len; ++i) { + if (outputs[i] == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "outputs cannot contain null entries"); + } + + graph->outputs.push_back(std::unique_ptr(outputs[i])); // take ownership + outputs[i] = nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, + _Inout_ OrtValue* tensor, bool data_is_external) { + API_IMPL_BEGIN + if (data_is_external) { +#if !defined(DISABLE_EXTERNAL_INITIALIZERS) + graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership +#else + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External initializers are not supported in this build"); +#endif + } else { + graph->initializers[name] = std::unique_ptr(tensor); // take ownership + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node) { + API_IMPL_BEGIN + graph->nodes.push_back(std::unique_ptr(node)); // take ownership + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtModelBuilderAPI::ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph) { + delete graph; +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model) { + API_IMPL_BEGIN + auto m = std::make_unique(); + for (size_t i = 0; i < opset_entries_len; ++i) { + m->domain_to_version[domain_names[i]] = opset_versions[i]; + } + + *model = m.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::AddGraphToModel, _In_ OrtModel* model, _Inout_ OrtGraph* graph) { + API_IMPL_BEGIN + + if (graph == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "graph cannot be null"); + } + + model->graph = std::unique_ptr(graph); // take ownership + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtModelBuilderAPI::ReleaseModel, _Frees_ptr_opt_ OrtModel* model) { + delete model; +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + + std::unique_ptr sess; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(*model)); + + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + + *out = reinterpret_cast(sess.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateModelBuilderSession, + _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, session)); + *out = reinterpret_cast(session.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateModelBuilderSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out) { + API_IMPL_BEGIN + std::unique_ptr session; + OrtStatus* status = nullptr; + *out = nullptr; + + ORT_TRY { + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, session)); + *out = reinterpret_cast(session.release()); + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = OrtApis::CreateStatus(ORT_FAIL, e.what()); + }); + } + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::ApplyModelToModelBuilderSession, + _In_ OrtSession* session, _In_ OrtModel* model) { + API_IMPL_BEGIN + auto sess = reinterpret_cast(session); + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->ApplyUpdates(*model)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::FinalizeModelBuilderSession, _In_ OrtSession* session, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container) { + API_IMPL_BEGIN + auto sess = reinterpret_cast(session); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); + return nullptr; + API_IMPL_END +} + +static constexpr OrtModelBuilderApi ort_graph_api = { + // NOTE: The C# bindings depend on the API order within this struct so all additions must be at the end, + // and no functions can be removed (the implementation needs to change to return an error). + &OrtModelBuilderAPI::CreateValueInfo, + &OrtModelBuilderAPI::GetValueInfoName, + &OrtModelBuilderAPI::GetValueInfoTypeInfo, + &OrtModelBuilderAPI::ReleaseValueInfo, + + &OrtModelBuilderAPI::CreateNode, + &OrtModelBuilderAPI::ReleaseNode, + + &OrtModelBuilderAPI::CreateGraph, + &OrtModelBuilderAPI::SetGraphInputs, + &OrtModelBuilderAPI::SetGraphOutputs, + &OrtModelBuilderAPI::AddInitializerToGraph, + &OrtModelBuilderAPI::AddNodeToGraph, + &OrtModelBuilderAPI::ReleaseGraph, + + &OrtModelBuilderAPI::CreateModel, + &OrtModelBuilderAPI::AddGraphToModel, + &OrtModelBuilderAPI::ReleaseModel, + + &OrtModelBuilderAPI::CreateSessionFromModel, + + &OrtModelBuilderAPI::CreateModelBuilderSession, + &OrtModelBuilderAPI::CreateModelBuilderSessionFromArray, + &OrtModelBuilderAPI::ApplyModelToModelBuilderSession, + &OrtModelBuilderAPI::FinalizeModelBuilderSession, +}; + +// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtModelBuilderApi, FinalizeModelBuilderSession) / sizeof(void*) == 19, + "Size of version 21 API cannot change"); // initial version in ORT 1.21 + +ORT_API(const OrtModelBuilderApi*, OrtModelBuilderAPI::GetModelBuilderApi) { + // No constraints on the API version yet. + return &ort_graph_api; +} diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ca6950af0227a..706ddadb4418b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1,45 +1,47 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/onnxruntime_c_api.h" -#include "core/session/allocator_adapters.h" -#include "core/session/inference_session_utils.h" -#include "core/session/IOBinding.h" -#include "core/framework/allocator.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/execution_provider.h" -#include "core/framework/tensor_type_and_shape.h" -#include "core/framework/utils.h" #include #include #include +#include #include #include "core/common/common.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" -#include "core/common/status.h" #include "core/common/safeint.h" -#include "core/graph/constants.h" -#include "core/graph/graph.h" +#include "core/common/status.h" +#include "core/common/string_helper.h" #include "core/framework/allocator.h" -#include "core/framework/tensor.h" +#include "core/framework/allocator.h" +#include "core/framework/callback.h" +#include "core/framework/data_types.h" +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" +#include "core/framework/onnxruntime_typeinfo.h" #include "core/framework/ort_value.h" +#include "core/framework/tensor.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/TensorSeq.h" +#include "core/framework/utils.h" +#include "core/graph/constants.h" +#include "core/graph/graph.h" +#include "core/graph/model.h" #include "core/providers/get_execution_providers.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/allocator_adapters.h" #include "core/session/environment.h" -#include "core/framework/callback.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/onnxruntime_typeinfo.h" #include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/IOBinding.h" +#include "core/session/lora_adapters.h" +#include "core/session/model_builder_api.h" +#include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" -#include "core/framework/data_types.h" -#include "abi_session_options_impl.h" -#include "core/framework/TensorSeq.h" -#include -#include "core/common/string_helper.h" - -#include "core/session/lora_adapters.h" +#include "core/session/utils.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -114,6 +116,72 @@ using namespace onnxruntime; auto v = (value); \ auto tensor = v->GetMutable(); +namespace { +// Create tensor. Allocates memory. Tensor owns memory. Allocator is wrapped and stored in a shared_ptr in Tensor. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, + OrtAllocator* allocator, OrtValue& value) { + TensorShape tensor_shape(shape, shape_len); + AllocatorPtr alloc_ptr = std::make_shared(allocator); + Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); + return nullptr; +} + +// Create Tensor with existing data. Tensor does not own memory. +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + const OrtMemoryInfo* info, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); + return nullptr; +} + +ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, + const int64_t* shape, size_t shape_len, + OrtAllocator* deleter, + void* p_data, size_t p_data_len, + OrtValue& ort_value) { + TensorShape tensor_shape(shape, shape_len); + if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); + } + + size_t size_to_allocate = 0; + Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); + + if (!status.IsOK()) { + return ToOrtStatus(status); + } + + if (size_to_allocate > p_data_len) { + std::ostringstream oss; + oss << "p_data_len was smaller than expected. Expected:" << size_to_allocate << " Got:" << p_data_len; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); + } + + AllocatorPtr alloc_ptr = std::make_shared(deleter); + Tensor::InitOrtValue(ml_type, tensor_shape, p_data, std::move(alloc_ptr), ort_value); + return nullptr; +} + +} // namespace + ORT_API_STATUS_IMPL(OrtApis::CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out) { @@ -187,50 +255,6 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, API_IMPL_END } -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, - _Inout_ OrtAllocator* allocator, OrtValue& value) { - TensorShape tensor_shape(shape, shape_len); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - Tensor::InitOrtValue(ml_type, tensor_shape, std::move(alloc_ptr), value); - return nullptr; -} - -ORT_STATUS_PTR CreateTensorImplForSeq(MLDataType elem_type, const int64_t* shape, size_t shape_len, Tensor& out) { - OrtAllocator* allocator; - // TODO(pranav): what allocator should be used to create the tensor here? - // for the sake of simplicity of the API using the default one here - ORT_API_RETURN_IF_ERROR(OrtApis::GetAllocatorWithDefaultOptions(&allocator)); - AllocatorPtr alloc_ptr = std::make_shared(allocator); - TensorShape tensor_shape(shape, shape_len); - out = Tensor(elem_type, tensor_shape, std::move(alloc_ptr)); - return nullptr; -} - -/** - * - * this function will create a copy of the allocator info - */ -ORT_STATUS_PTR CreateTensorImpl(MLDataType ml_type, const int64_t* shape, size_t shape_len, const OrtMemoryInfo* info, - void* p_data, size_t p_data_len, OrtValue& ort_value) { - TensorShape tensor_shape(shape, shape_len); - if (std::any_of(tensor_shape.GetDims().begin(), tensor_shape.GetDims().end(), [](int64_t v) { return v < 0; })) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "tried creating tensor with negative value in shape"); - } - - size_t size_to_allocate = 0; - Status status = Tensor::CalculateTensorStorageSize(ml_type, tensor_shape, 0 /*alignment*/, size_to_allocate); - if (!status.IsOK()) { - return ToOrtStatus(status); - } - if (size_to_allocate > p_data_len) { - std::ostringstream oss; - oss << "not enough space: expected " << size_to_allocate << ", got " << p_data_len; - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, oss.str().c_str()); - } - Tensor::InitOrtValue(ml_type, tensor_shape, p_data, *info, ort_value); - return nullptr; -} - ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -243,6 +267,20 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAsOrtValue, _In_ const OrtMemor API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out) { + API_IMPL_BEGIN + auto ml_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); + auto value = std::make_unique(); + ORT_API_RETURN_IF_ERROR(CreateTensorImpl(ml_type, shape, shape_len, deleter, p_data, p_data_len, *value)); + *out = value.release(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { @@ -678,97 +716,6 @@ ORT_API_STATUS_IMPL(OrtApis::EnableOrtCustomOps, _Inout_ OrtSessionOptions* opti API_IMPL_END } -namespace { -// provider either model_path, or modal_data + model_data_length. -static ORT_STATUS_PTR CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, - _In_ const OrtEnv* env, - _In_opt_z_ const ORTCHAR_T* model_path, - _In_opt_ const void* model_data, - size_t model_data_length, - std::unique_ptr& sess) { - // quick check here to decide load path. InferenceSession will provide error message for invalid values. - // TODO: Could move to a helper - const Env& os_env = Env::Default(); // OS environment (!= ORT environment) - bool load_config_from_model = - os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; - - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - if (model_path != nullptr) { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_path); - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment(), - model_data, static_cast(model_data_length)); - } -#else - return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); -#endif - } else { - sess = std::make_unique( - options == nullptr ? onnxruntime::SessionOptions() : options->value, - env->GetEnvironment()); - } - -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Add custom domains - if (options && !options->custom_op_domains_.empty()) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); - } -#endif - - // Finish load - if (load_config_from_model) { -#if !defined(ORT_MINIMAL_BUILD) - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); -#endif - } else { - if (model_path != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); - } else { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); - } - } - - return nullptr; -} - -static ORT_STATUS_PTR InitializeSession(_In_ const OrtSessionOptions* options, - _In_ std::unique_ptr<::onnxruntime::InferenceSession>& sess, - _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr) { - // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of - // byte addressable memory - std::vector> provider_list; - if (options) { - for (auto& factory : options->provider_factories) { - auto provider = factory->CreateProvider(); - provider_list.push_back(std::move(provider)); - } - } - - // register the providers - for (auto& provider : provider_list) { - if (provider) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->RegisterExecutionProvider(std::move(provider))); - } - } - - if (prepacked_weights_container != nullptr) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddPrePackedWeightsContainer( - reinterpret_cast(prepacked_weights_container))); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Initialize()); - - return nullptr; -} - -} // namespace - ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN @@ -778,7 +725,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); } @@ -801,7 +748,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); *out = reinterpret_cast(sess.release()); } @@ -1208,7 +1155,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetResizedStringTensorElementBuffer, _Inout_ OrtVal } namespace { - OrtStatusPtr GetTensorStringSpan(const ::OrtValue& v, gsl::span& span) { if (!v.IsAllocated()) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue should contain a Tensor or a Sparse Tensor"); @@ -1376,6 +1322,20 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerCount, _In_ const O return GetNodeDefListCountHelper(sess, get_overridable_initializers_fn, out); } +ORT_API_STATUS_IMPL(OrtApis::SessionGetOpsetForDomain, _In_ const OrtSession* ort_session, _In_ const char* domain, + _Out_ int* opset) { + const auto& session = *reinterpret_cast(ort_session); + const auto& domain_opset_map = session.GetModel().MainGraph().DomainToVersionMap(); + + auto it = domain_opset_map.find(domain); + if (it == domain_opset_map.cend()) { + return OrtApis::CreateStatus(ORT_FAIL, "Domain not used by model."); + } + + *opset = it->second; + return nullptr; +} + static ORT_STATUS_PTR GetNodeDefTypeInfoHelper(const OrtSession* sess, GetDefListFn get_fn, size_t index, _Outptr_ struct OrtTypeInfo** out) { API_IMPL_BEGIN @@ -2112,7 +2072,6 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_ } namespace { - struct ProviderBuffer { char** buffer_; char* next_write_; @@ -2342,7 +2301,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionWithPrepackedWeightsContainer, _In_ co ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2368,7 +2327,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArrayWithPrepackedWeightsContainer ORT_TRY { ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, sess, prepacked_weights_container)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess, prepacked_weights_container)); *out = reinterpret_cast(sess.release()); } @@ -2419,13 +2378,21 @@ ORT_API(const OrtTrainingApi*, OrtApis::GetTrainingApi, uint32_t version) { version, ORT_API_VERSION); return nullptr; #else - ORT_UNUSED_PARAMETER(version); return nullptr; #endif } +ORT_API(const OrtModelBuilderApi*, OrtApis::GetModelBuilderApi) { +#if !defined(ORT_MINIMAL_BUILD) + return OrtModelBuilderAPI::GetModelBuilderApi(); +#else + fprintf(stderr, "The Model Builder API is not supported in a minimal build.\n"); + return nullptr; +#endif +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -2812,6 +2779,18 @@ static constexpr OrtApi ort_api_1_to_21 = { &OrtApis::SetEpDynamicOptions, // End of Version 20 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::GetModelBuilderApi, + + &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, + &OrtApis::SessionGetOpsetForDomain, + + // APIs to create/edit type info when building/modifying a model using the Model Builder API + &OrtApis::CreateTensorTypeInfo, + &OrtApis::CreateSparseTensorTypeInfo, + &OrtApis::CreateMapTypeInfo, + &OrtApis::CreateSequenceTypeInfo, + &OrtApis::CreateOptionalTypeInfo, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 52d3c98d526dc..0f2fbf8b31f12 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -533,4 +533,25 @@ ORT_API_STATUS_IMPL(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* optio ORT_API_STATUS_IMPL(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + +ORT_API(const OrtModelBuilderApi*, GetModelBuilderApi); + +ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + +ORT_API_STATUS_IMPL(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, + _Out_ int* opset); + +// APIs to create/edit type info +ORT_API_STATUS_IMPL(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Out_ OrtTypeInfo** type_info); +ORT_API_STATUS_IMPL(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Out_ OrtTypeInfo** type_info); } // namespace OrtApis diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc new file mode 100644 index 0000000000000..afb1ed2696c9f --- /dev/null +++ b/onnxruntime/core/session/utils.cc @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/utils.h" + +#include "core/framework/error_code_helper.h" +#include "core/framework/execution_provider.h" +#include "core/session/abi_session_options_impl.h" +// #include "core/session/environment.h" +#include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" +#include "core/session/ort_env.h" + +using namespace onnxruntime; + +common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { + const size_t str_len = str.size(); + const size_t req_size = str_len + 1; + + if (out == nullptr) { // User is querying the total output buffer size + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + if (*size >= req_size) { // User provided a buffer of sufficient size + std::memcpy(out, str.data(), str_len); + out[str_len] = '\0'; + *size = req_size; + return onnxruntime::common::Status::OK(); + } + + // User has provided a buffer that is not large enough + *size = req_size; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg); +} + +// provider either model_path, or modal_data + model_data_length. +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess) { + // quick check here to decide load path. InferenceSession will provide error message for invalid values. + // TODO: Could move to a helper + const Env& os_env = Env::Default(); // OS environment (!= ORT environment) + bool load_config_from_model = + os_env.GetEnvironmentVar(inference_session_utils::kOrtLoadConfigFromModelEnvVar) == "1"; + + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + if (model_path != nullptr) { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_path); + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment(), + model_data, static_cast(model_data_length)); + } +#else + return OrtApis::CreateStatus(ORT_FAIL, "Loading config from ONNX models is not supported in this build."); +#endif + } else { + sess = std::make_unique( + options == nullptr ? onnxruntime::SessionOptions() : options->value, + env->GetEnvironment()); + } + +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + // Add custom domains + if (options && !options->custom_op_domains_.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(options->custom_op_domains_)); + } +#endif + + // Finish load + if (load_config_from_model) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load()); +#endif + } else { + if (model_path != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_path)); + } else { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->Load(model_data, static_cast(model_data_length))); + } + } + + return nullptr; +} + +OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, + _In_ onnxruntime::InferenceSession& sess, + _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container) { + // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of + // byte addressable memory + std::vector> provider_list; + if (options) { + for (auto& factory : options->provider_factories) { + auto provider = factory->CreateProvider(); + provider_list.push_back(std::move(provider)); + } + } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + } + } + + if (prepacked_weights_container != nullptr) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer( + reinterpret_cast(prepacked_weights_container))); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); + + return nullptr; +} diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h new file mode 100644 index 0000000000000..ac8ad60758b5b --- /dev/null +++ b/onnxruntime/core/session/utils.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/session/onnxruntime_c_api.h" + +onnxruntime::common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size); + +struct OrtSessionOptions; +struct OrtStatus; +struct OrtPrepackedWeightsContainer; +namespace onnxruntime { +class InferenceSession; +} + +OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, + _In_ const OrtEnv* env, + _In_opt_z_ const ORTCHAR_T* model_path, + _In_opt_ const void* model_data, + size_t model_data_length, + std::unique_ptr& sess); + +OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, + _In_ onnxruntime::InferenceSession& sess, + _Inout_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); diff --git a/onnxruntime/test/framework/type_info_test.cc b/onnxruntime/test/framework/type_info_test.cc index ee787fb071d97..d8ef668bf1c7e 100644 --- a/onnxruntime/test/framework/type_info_test.cc +++ b/onnxruntime/test/framework/type_info_test.cc @@ -22,9 +22,9 @@ TEST(TypeInfoTests, TensorProto) { auto tensor_type_info = OrtTypeInfo::FromTypeProto(tensor_type.value); ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info->type); - ASSERT_NE(nullptr, tensor_type_info->data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info->tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info->tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info->tensor_type_info->shape.GetDims())); } TEST(TypeInfoTests, SequenceWithTensorElement) { @@ -37,9 +37,9 @@ TEST(TypeInfoTests, SequenceWithTensorElement) { const auto& tensor_type_info = *seq_type_info->sequence_type_info->sequence_key_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); } TEST(TypeInfoTests, OptionalWithTensorProto) { @@ -54,9 +54,9 @@ TEST(TypeInfoTests, OptionalWithTensorProto) { const auto& contained_type = *optional_type_info->optional_type_info->contained_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, contained_type.type); - ASSERT_NE(nullptr, contained_type.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.data->shape.GetDims())); + ASSERT_NE(nullptr, contained_type.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, contained_type.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), contained_type.tensor_type_info->shape.GetDims())); } #if !defined(DISABLE_ML_OPS) @@ -74,11 +74,11 @@ TEST(TypeInfoTests, MapWithTensorValue) { const auto& tensor_type_info = *map_info.map_value_type_; ASSERT_EQ(ONNX_TYPE_TENSOR, tensor_type_info.type); - ASSERT_NE(nullptr, tensor_type_info.data); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.data->type); - ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.data->shape.GetDims())); + ASSERT_NE(nullptr, tensor_type_info.tensor_type_info); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, tensor_type_info.tensor_type_info->type); + ASSERT_TRUE(SpanEq(AsSpan({1, 2, 3, 4}), tensor_type_info.tensor_type_info->shape.GetDims())); } #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/custom_op_utils.h b/onnxruntime/test/shared_lib/custom_op_utils.h index e11540aaa5691..ea2a5f2771342 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.h +++ b/onnxruntime/test/shared_lib/custom_op_utils.h @@ -8,12 +8,6 @@ #include #endif -struct Input { - const char* name = nullptr; - std::vector dims; - std::vector values; -}; - struct MyCustomKernel { MyCustomKernel(const OrtApi& ort_api, const OrtKernelInfo* /*info*/) : ort_(ort_api) { diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index e8c8c8db8d08f..438b0cf6f2f24 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include -#include -#include -#include +#include #include +#include +#include +#include #include -#include +#include #include +#include #include +#include + #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -25,13 +27,13 @@ #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/util/thread_utils.h" -#include "onnxruntime_config.h" -#include "providers.h" -#include "test_allocator.h" -#include "test_fixture.h" -#include "utils.h" -#include "custom_op_utils.h" -#include +#include "test/shared_lib/custom_op_utils.h" +#include "test/shared_lib/test_fixture.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/providers.h" +#include "test/util/include/test_allocator.h" + +#include "onnxruntime_config.h" // generated file in build output dir #ifdef _WIN32 #include @@ -62,48 +64,6 @@ constexpr size_t countof(T (&)[N]) { return N; } extern std::unique_ptr ort_env; -template -void RunSession(OrtAllocator* allocator, Ort::Session& session_object, - const std::vector& inputs, - const char* output_name, - const std::vector& dims_y, - const std::vector& values_y, - Ort::Value* output_tensor) { - std::vector ort_inputs; - std::vector input_names; - for (size_t i = 0; i < inputs.size(); i++) { - input_names.emplace_back(inputs[i].name); - ort_inputs.emplace_back( - Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), - inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); - } - - std::vector ort_outputs; - if (output_tensor) - session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, output_tensor, 1); - else { - ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), - &output_name, 1); - ASSERT_EQ(ort_outputs.size(), 1u); - output_tensor = &ort_outputs[0]; - } - - auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); - ASSERT_EQ(type_info.GetShape(), dims_y); - size_t total_len = type_info.GetElementCount(); - ASSERT_EQ(values_y.size(), total_len); - - OutT* f = output_tensor->GetTensorMutableData(); - for (size_t i = 0; i != total_len; ++i) { - if constexpr (std::is_same::value || std::is_same::value) { - ASSERT_NEAR(values_y[i], f[i], 1e-3); - } else { - ASSERT_EQ(values_y[i], f[i]); - } - } -} - #ifdef USE_DML struct DmlObjects { ComPtr d3d12_device; @@ -299,12 +259,12 @@ Ort::Value CreateTensorValueFromExistingD3DResource( #endif -template +template > static void TestInference(Ort::Env& env, const std::basic_string& model_uri, const std::vector& inputs, const char* output_name, const std::vector& expected_dims_y, - const std::vector& expected_values_y, + const std::vector& expected_values_y, int provider_type, OrtCustomOpDomain* custom_op_domain_ptr, const ORTCHAR_T* custom_op_library_filename, @@ -361,26 +321,26 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod auto default_allocator = std::make_unique(); // without preallocated output tensor - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - nullptr); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + nullptr); // with preallocated output tensor - Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), - expected_dims_y.data(), expected_dims_y.size()); + Ort::Value value_y = Ort::Value::CreateTensor(default_allocator.get(), + expected_dims_y.data(), expected_dims_y.size()); // test it twice for (int i = 0; i != 2; ++i) - RunSession(default_allocator.get(), - session, - inputs, - output_name, - expected_dims_y, - expected_values_y, - &value_y); + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims_y, + expected_values_y, + &value_y); } } @@ -449,8 +409,8 @@ class CApiTestWithProvider : public testing::Test, public ::testing::WithParamIn TEST_P(CApiTestWithProvider, simple) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -620,8 +580,8 @@ TEST(CApiTest, SparseInputModel) { TEST(CApiTest, custom_op_handler) { std::cout << "Running custom op inference" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -656,8 +616,8 @@ TEST(CApiTest, custom_op_handler) { TEST(CApiTest, custom_op_set_input_memory_type) { std::cout << "Running custom op inference" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -686,8 +646,8 @@ TEST(CApiTest, custom_op_set_input_memory_type) { #if !defined(ORT_MINIMAL_BUILD) TEST(CApiTest, StandaloneOpHandler) { - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -810,7 +770,7 @@ TEST(CApiTest, test_enable_ort_customops_stringlower) { // test custom op which accepts float and double as inputs TEST(CApiTest, varied_input_custom_op_handler) { - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "X"; inputs[0].dims = {3}; inputs[0].values = {2.0f, 3.0f, 4.0f}; @@ -1421,8 +1381,8 @@ TEST(CApiTest, custom_op_with_attributes_handler) { TEST(CApiTest, RegisterCustomOpForCPUAndCUDA) { std::cout << "Tests registration of a custom op of the same name for both CPU and CUDA EPs" << std::endl; - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -1530,7 +1490,7 @@ TEST(CApiTest, test_custom_op_openvino_wrapper_library) { // The custom op extracts the serialized .xml/.bin bytes and creates an in-memory OpenVINO model // during kernel creation. The custom op is passed an image of a hand-drawn "1" as an input during computation, which // is then inferenced using OpenVINO C++ APIs. - std::vector inputs(1); + std::vector> inputs(1); inputs[0].name = "Input3"; inputs[0].dims = {1, 1, 28, 28}; @@ -1629,7 +1589,7 @@ TEST(CApiTest, test_custom_op_library) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "input_1"; inputs[0].dims = {3, 5}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1681,7 +1641,7 @@ TEST(CApiTest, DISABLED_test_custom_op_shape_infer_attr) { #else TEST(CApiTest, test_custom_op_shape_infer_attr) { #endif - std::vector inputs(1); + std::vector> inputs(1); inputs[0].name = "input_0"; inputs[0].dims = {5}; inputs[0].values = {1.f, 2.f, 3.f, 4.f, 5.f}; @@ -1714,7 +1674,7 @@ TEST(CApiTest, test_custom_op_library_copy_variadic) { #endif std::cout << "Running inference using custom op shared library" << std::endl; - std::vector inputs(2); + std::vector> inputs(2); inputs[0].name = "input_0"; inputs[0].dims = {15}; inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, @@ -1868,8 +1828,8 @@ void PrepareModule() { TEST(CApiTest, test_pyop) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1881,8 +1841,8 @@ TEST(CApiTest, test_pyop) { TEST(CApiTest, test_pyop_multi) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1894,8 +1854,8 @@ TEST(CApiTest, test_pyop_multi) { TEST(CApiTest, test_pyop_kwarg) { std::call_once(my_module_flag, PrepareModule); - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {2, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f}; @@ -1919,7 +1879,7 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) { // In reduced ops build, test a model containing ops not included in required_ops.config cannot be loaded. // See onnxruntime/test/testdata/reduced_build_test.readme.txt for more details of the setup constexpr PATH_TYPE model_uri = TSTR("testdata/reduced_build_test.onnx_model_with_excluded_ops"); - std::vector inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; + std::vector> inputs = {{"X", {3}, {-1.0f, 2.0f, -3.0f}}}; std::vector expected_dims_y = {3}; std::vector expected_values_y = {0.1f, 0.1f, 0.1f}; bool failed = false; @@ -3155,8 +3115,8 @@ TEST(CApiTest, TestSharedAllocators) { OrtEnv* env_ptr = (OrtEnv*)(*ort_env); // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3342,8 +3302,8 @@ TEST(CApiTest, TestSharedAllocators) { TEST(CApiTest, TestSharingOfInitializerAndItsPrepackedVersion) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -3738,8 +3698,8 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { // simple inference test // prepare inputs - std::vector inputs(1); - Input& input = inputs.back(); + std::vector> inputs(1); + auto& input = inputs.back(); input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc new file mode 100644 index 0000000000000..cd7b774ad64d5 --- /dev/null +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -0,0 +1,483 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/common/narrow.h" +#include "core/graph/constants.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_lite_custom_op.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/shared_lib/test_fixture.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/test_allocator.h" + +#include "onnxruntime_config.h" // generated file in build output dir + +extern std::unique_ptr ort_env; + +using namespace Ort; + +namespace { + +Ort::Session CreateSession(Ort::Env& env, + ModelBuilderAPI::Model& graph_api_model, + Ort::SessionOptions* session_options_for_test = nullptr) { + Ort::SessionOptions default_session_options; + Ort::SessionOptions& session_options = session_options_for_test ? *session_options_for_test + : default_session_options; + + // Set this to save the model if you want to debug. + // session_options.SetOptimizedModelFilePath(ORT_TSTR("model_builder_output.onnx")); + + Ort::Session session(env, graph_api_model, session_options); + + // Session should not require the model to stay alive so free it now to validate. + graph_api_model = ModelBuilderAPI::Model(nullptr); + + return session; +} + +template +void TestInference(Ort::Session& session, + const std::vector>& inputs, + const char* output_name, + const std::vector& expected_dims, + const std::vector& expected_values) { + auto default_allocator = std::make_unique(); + + // without preallocated output tensor + RunSession(default_allocator.get(), + session, + inputs, + output_name, + expected_dims, + expected_values, + nullptr); +} + +// Create OrtNode using the C API +OrtNode* CreateNode(const OrtModelBuilderApi& api, + const char* operator_name, const char* node_name, + const gsl::span input_names, + const gsl::span output_names, + const gsl::span attributes = {}, + const char* domain_name = onnxruntime::kOnnxDomain) { + OrtNode* node = nullptr; + Ort::ThrowOnError(api.CreateNode(operator_name, domain_name, node_name, + input_names.data(), input_names.size(), + output_names.data(), output_names.size(), + attributes.data(), attributes.size(), + &node)); + return node; +} + +// convenience func to convert initalizer lists to gsl::span +OrtNode* CreateNode(const OrtModelBuilderApi& api, + const char* operator_name, const char* node_name, + const std::initializer_list input_names, + const std::initializer_list output_names, + const std::initializer_list attributes = {}, + const char* domain_name = onnxruntime::kOnnxDomain) { + std::vector inputs(input_names); + std::vector outputs(output_names); + std::vector attrs(attributes); + return CreateNode(api, operator_name, node_name, inputs, outputs, attrs, domain_name); +} +} // namespace + +struct TestAllocator : public OrtAllocator { + TestAllocator() { + version = ORT_API_VERSION; + Info = [](const struct OrtAllocator* this_ptr) -> const struct OrtMemoryInfo* { + auto* test_allocator = static_cast(this_ptr); + return test_allocator->memory_info; + }; + + Free = [](struct OrtAllocator* allocator, void* p) -> void { + auto* test_allocator = static_cast(allocator); + // find the matching pointer and remove it + auto it = std::find_if(test_allocator->weights.begin(), test_allocator->weights.end(), + [p](const std::unique_ptr>& v) { return v->data() == p; }); + if (it == test_allocator->weights.end()) { + throw std::runtime_error("Free called with unknown pointer"); + } + + test_allocator->weights.erase(it); + }; + + Alloc = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::runtime_error("This should not be used"); + }; + + Reserve = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::runtime_error("This should not be used"); + }; + } + + // initializers that are used directly by the model. as there's no copy they must remain valid. + // we store them in the test allocator so we can validate that Free is called + std::vector>> weights; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, + OrtMemType::OrtMemTypeDefault); +}; + +// Test the ModelBuilderAPI C api +// Uses the ORT C++ api for the rest for simplicity +TEST(ModelBuilderAPITest, Basic_CApi) { + const auto& api = Ort::GetApi(); + const auto& graph_api = Ort::GetModelBuilderApi(); + + TestAllocator deleter; + + // return void so we can use ASSERT_* in the lambda + const auto build_model = [&](bool use_constant_node, OrtModel*& model) -> void { + OrtGraph* graph = nullptr; + Ort::ThrowOnError(graph_api.CreateGraph(&graph)); + + // + // Create OrtModel with a Gemm. X input is 3x2, Y input is 2x3, Z output is 3x3. + // X is model input. Y is initializer. + // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. + // + + // model input + OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; + std::vector input_dims = {3, 2}; + // can use api.SetSymbolicDimensions to set symbolic dimensions. + // the input array should have the same rank as the call to SetDimensions. + // e.g. call SetDimensions with {-1, 3, 2} and SetSymbolicDimensions with {"N", nullptr, nullptr} to create + // a shape of {"N", 3, 2} + + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, input_dims.data(), input_dims.size())); + + OrtTypeInfo* input_type_info = nullptr; + Ort::ThrowOnError(api.CreateTensorTypeInfo(tensor_type_info, &input_type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy + + // create ValueInfo and release the type info as CreateValueInfo takes a copy. + OrtValueInfo* input_value_info = nullptr; + Ort::ThrowOnError(graph_api.CreateValueInfo("X", input_type_info, &input_value_info)); + api.ReleaseTypeInfo(input_type_info); // input_value_info took a copy + tensor_type_info = nullptr; + + // model outputs + OrtTypeInfo* output_type_info = nullptr; + std::vector output_dims = {3, 3}; + + Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); + Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + Ort::ThrowOnError(api.SetDimensions(tensor_type_info, output_dims.data(), output_dims.size())); + + Ort::ThrowOnError(api.CreateTensorTypeInfo(tensor_type_info, &output_type_info)); + api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy + + OrtValueInfo* output_value_info = nullptr; + Ort::ThrowOnError(graph_api.CreateValueInfo("Z", output_type_info, &output_value_info)); + api.ReleaseTypeInfo(output_type_info); + + std::vector graph_inputs = {input_value_info}; + std::vector graph_outputs = {output_value_info}; + Ort::ThrowOnError(graph_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size())); + Ort::ThrowOnError(graph_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size())); + + // + // Gemm node + // + + OrtOpAttr* alpha_attr = nullptr; + float alpha_value = 2.0; + Ort::ThrowOnError(api.CreateOpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT, &alpha_attr)); + + std::vector node_input_names = {"X", "Y"}; + const std::string gemm_output_name = use_constant_node ? "Z_temp" : "Z"; + std::vector node_output_names = {gemm_output_name.c_str()}; + std::vector node_attributes{alpha_attr}; + OrtNode* node = CreateNode(graph_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); + + api.ReleaseOpAttr(alpha_attr); // CreateNode copies all OrtOpAttr instances + + Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + // Y input + std::vector y_dims = {2, 3}; + deleter.weights.emplace_back( + std::make_unique>(std::initializer_list{1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f})); + auto& y_values = *deleter.weights.back(); + + // create an initializer for the Y input. add to `weights` so the memory remains valid + OrtValue* y_tensor = nullptr; + auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession + Ort::ThrowOnError( + api.CreateTensorWithDataAndDeleterAsOrtValue(&deleter, + y_values.data(), y_values.size() * sizeof(y_values[0]), + y_dims.data(), y_dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + &y_tensor)); + + Ort::ThrowOnError(graph_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); + y_tensor = nullptr; // graph now owns + + if (use_constant_node) { + // Test that a Constant node is converted to an intializer + + // create Constant node that is used as the Max in a Clip to limit the output + OrtOpAttr* value_attr = nullptr; + float max = 60.0f; + Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &value_attr)); + + node = CreateNode(graph_api, "Constant", "clip_max", {}, {"max"}, {value_attr}); + Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + + node = CreateNode(graph_api, "Clip", "Clip1", {gemm_output_name.c_str(), "", "max"}, {"Z"}); + Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node + } + + std::vector domain_names = {onnxruntime::kOnnxDomain}; + std::vector opset_versions = {18}; + Ort::ThrowOnError(graph_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(), + &model)); + Ort::ThrowOnError(graph_api.AddGraphToModel(model, graph)); + graph = nullptr; // model now owns + }; + + auto run_test = [&](bool use_constant_node) -> void { + OrtModel* model = nullptr; + build_model(use_constant_node, model); + + ASSERT_NE(model, nullptr) << "build_model should have created a model"; + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + std::vector expected_dims = {3, 3}; + ModelBuilderAPI::Model cxx_model(model); + auto session = CreateSession(*ort_env, cxx_model); + + std::vector expected_output; + if (use_constant_node) { + expected_output = {18.0f, 24.0f, 30.0f, + 38.0f, 52.0f, 60.0f, // clipped + 58.0f, 60.0f, 60.0f}; // clipped + } else { + expected_output = {18.0f, 24.0f, 30.0f, + 38.0f, 52.0f, 66.0f, + 58.0f, 80.0f, 102.0f}; + } + + TestInference(session, inputs, "Z", expected_dims, expected_output); + + api.ReleaseSession(session.release()); + + ASSERT_EQ(deleter.weights.size(), 0) << "All weights should have been freed"; + }; + + run_test(false); + run_test(true); // use Constant node for initializer +} + +TEST(ModelBuilderAPITest, Basic_CxxApi) { + // initializers that are used directly by the model. as there's no copy they must remain valid + std::vector>> weights; + + Ort::ModelBuilderAPI::Graph graph; + + // + // Create OrtModel with a Gemm. X input is 3x2, Y input is 2x3, Z output is 3x3. + // X is model input. Y is initializer. + // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. + // + + std::vector graph_inputs; + std::vector graph_outputs; + + // model input. it's {3, 2} but use a symbolic dim to test that works. + std::vector input_dims({-1, 2}); + std::vector input_symbolic_dims({"multiple_of_3", ""}); + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims, + &input_symbolic_dims); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs.emplace_back("X", input_type_info.GetConst()); + + // model outputs + std::vector output_dims = {-1, 3}; + std::vector output_symbolic_dims({"multiple_of_3", ""}); + TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + output_dims, + &output_symbolic_dims); + auto output_type_info = TypeInfo::CreateTensorInfo(output_tensor_info.GetConst()); + graph_outputs.emplace_back("Z", output_type_info.GetConst()); + + graph.SetInputs(graph_inputs); + graph.SetOutputs(graph_outputs); + + // + // Gemm node + // + + std::vector attributes; + float alpha_value = 2.0; + attributes.push_back(OpAttr("alpha", &alpha_value, 1, OrtOpAttrType::ORT_OP_ATTR_FLOAT)); + + ModelBuilderAPI::Node node("Gemm", onnxruntime::kOnnxDomain, "Gemm1", {"X", "Y"}, {"Z"}, attributes); + + graph.AddNode(node); + + // create an initializer for the Y input. + // add to `weights` so it remains valid for the lifetime of the session and we can avoid copying the data. + std::vector y_dims = {2, 3}; + weights.emplace_back(std::make_unique>(std::initializer_list{1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f})); + auto& y_values = *weights.back(); + auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession + auto y_tensor = Value::CreateTensor(info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); + graph.AddInitializer("Y", y_tensor, /*data is external*/ true); + + std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; + ModelBuilderAPI::Model model(opsets); + model.AddGraph(graph); + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + std::vector expected_dims = {3, 3}; + + auto session = CreateSession(*ort_env, model); + TestInference(session, inputs, "Z", expected_dims, + {18.0f, 24.0f, 30.0f, + 38.0f, 52.0f, 66.0f, + 58.0f, 80.0f, 102.0f}); +} + +TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { + // + // Load existing model + // Add Cast to change the model input from float to int64 + // Update model inputs to match + // Run + // + + SessionOptions so; + + // Set this to save the model if you want to debug. + // so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_edited.onnx")); + + Session session = Session::CreateModelBuilderSession(*ort_env, TSTR("testdata/mnist.onnx"), so); + + ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string + + // we augment the original model with nodes, initializers and the updated model inputs/outputs from this model. + // the original graph is unchanged. nodes can be added before/after it. initializers can be added. + // new nodes must conform to the original domain:opset of the model. + // additional operator domain:opset pairs can be added. + std::vector opsets; // no additional opsets required + ModelBuilderAPI::Model model(opsets); + + std::vector input_names = session.GetInputNames(); + ASSERT_EQ(input_names.size(), 1); + + TypeInfo orig_input = session.GetInputTypeInfo(0); + ASSERT_EQ(orig_input.GetTensorTypeAndShapeInfo().GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + const std::string new_input_name = "Int64Input"; + + // Add Cast node to convert input from float to int64 + std::vector attributes; + int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); + + ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, {input_names[0]}, + attributes); + + // we're replacing the only input, so we don't need to call session.GetInputTypeInfo(x) to copy other inputs + // in order to preserve them + std::vector graph_inputs; + TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, + orig_input.GetTensorTypeAndShapeInfo().GetShape()); + auto input_type_info = TypeInfo::CreateTensorInfo(input_tensor_info.GetConst()); + graph_inputs.emplace_back(new_input_name, input_type_info.GetConst()); + + ModelBuilderAPI::Graph graph; // new info to augment the model with + + graph.AddNode(node); + graph.SetInputs(graph_inputs); + + // the node we added does not require any new opsets. + model.AddGraph(graph); + + session.FinalizeModelBuilderSession(model, so); + + std::vector> inputs(1); + auto& input = inputs[0]; + input.name = new_input_name.c_str(); + input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape(); + + auto num_values = std::accumulate(input.dims.begin(), input.dims.end(), int64_t(1), std::multiplies()); + input.values.resize(size_t(num_values)); + std::iota(input.values.begin(), input.values.end(), 1); + + std::vector expected_dims = {1, 10}; + std::vector expected_output = {-48.5088f, -1040.2948f, -347.0959f, 101.7392f, 421.3352f, + 750.92145f, 231.5060f, -1694.4152f, 681.5623f, 378.1689f}; + + TestInference(session, inputs, session.GetOutputNames()[0].c_str(), expected_dims, expected_output); + + // double check with original model + { + SessionOptions expected_so; + Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so); + std::vector> expected_inputs(1); + auto& expected_input = expected_inputs[0]; + expected_input.name = input_names[0].c_str(); + expected_input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape(); + expected_input.values.reserve(size_t(num_values)); + std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values), + [&](int64_t value) { return float(value); }); + + TestInference(expected_session, expected_inputs, session.GetOutputNames()[0].c_str(), + expected_dims, expected_output); + } +} + +TEST(ModelBuilderAPITest, InvalidDimension) { + try { + std::vector input_dims = {-2, 2}; + TensorTypeAndShapeInfo tensor_type_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + input_dims); + // invalid dim of -2 should cause exception + TypeInfo::CreateTensorInfo(tensor_type_info.GetConst()); + FAIL() << "Expected exception for invalid dimension"; + } catch (const Ort::Exception& e) { + ASSERT_STREQ(e.what(), "dim_values must be -1 (symbolic dimension) or larger."); + } +} + +/* +Tests required + +- Create invalid model. Graph::Resolve should fail. +- Invalid edit. Graph::Resolve should fail. +- All the non-tensor Create*TypeInfo functions need to be validated +*/ diff --git a/onnxruntime/test/shared_lib/test_ort_format_models.cc b/onnxruntime/test/shared_lib/test_ort_format_models.cc index 99a9ebc3362ae..b3491e3476f23 100644 --- a/onnxruntime/test/shared_lib/test_ort_format_models.cc +++ b/onnxruntime/test/shared_lib/test_ort_format_models.cc @@ -17,7 +17,7 @@ extern std::unique_ptr ort_env; [[maybe_unused]] static void TestInference(Ort::Env& env, const std::basic_string& model_uri, - const std::vector& inputs, const char* output_name, + const std::vector>& inputs, const char* output_name, const std::vector& expected_dims_y, const std::vector& expected_values_y, Ort::CustomOpDomain& custom_op_domain, void* cuda_compute_stream = nullptr) { Ort::SessionOptions session_options; @@ -100,8 +100,8 @@ TEST(OrtFormatCustomOpTests, ConvertOnnxModelToOrt) { } // now load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -130,8 +130,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModel) { custom_op_domain.Add(&custom_op); // load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; @@ -151,8 +151,8 @@ TEST(OrtFormatCustomOpTests, LoadOrtModelStandaloneCustomOpImplementation) { custom_op_domain.Add(&standalone_op); // load the ORT format model and execute it - std::vector inputs(1); - Input& input = inputs[0]; + std::vector> inputs(1); + auto& input = inputs[0]; input.name = "X"; input.dims = {3, 2}; input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; diff --git a/onnxruntime/test/shared_lib/utils.h b/onnxruntime/test/shared_lib/utils.h index 483753f2ae6b2..5d15582b86cb9 100644 --- a/onnxruntime/test/shared_lib/utils.h +++ b/onnxruntime/test/shared_lib/utils.h @@ -5,4 +5,56 @@ #include "core/session/onnxruntime_cxx_api.h" +#include "gtest/gtest.h" + OrtCUDAProviderOptions CreateDefaultOrtCudaProviderOptionsWithCustomStream(void* cuda_compute_stream = nullptr); + +template +struct Input { + const char* name = nullptr; + std::vector dims; + std::vector values; +}; + +template > +void RunSession(OrtAllocator* allocator, + Ort::Session& session_object, + const std::vector& inputs, + const char* output_name, + const std::vector& output_dims, + const std::vector& expected_output, + Ort::Value* output_tensor) { + std::vector ort_inputs; + std::vector input_names; + for (size_t i = 0; i < inputs.size(); i++) { + input_names.emplace_back(inputs[i].name); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(allocator->Info(allocator), const_cast(inputs[i].values.data()), + inputs[i].values.size(), inputs[i].dims.data(), inputs[i].dims.size())); + } + + std::vector ort_outputs; + if (output_tensor) + session_object.Run(Ort::RunOptions{nullptr}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, output_tensor, 1); + else { + ort_outputs = session_object.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), + &output_name, 1); + ASSERT_EQ(ort_outputs.size(), 1u); + output_tensor = &ort_outputs[0]; + } + + auto type_info = output_tensor->GetTensorTypeAndShapeInfo(); + ASSERT_EQ(type_info.GetShape(), output_dims); + size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(expected_output.size(), total_len); + + auto* actual = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i != total_len; ++i) { + if constexpr (std::is_same::value || std::is_same::value) { + EXPECT_NEAR(expected_output[i], actual[i], 1e-3) << "i=" << i; + } else { + EXPECT_EQ(expected_output[i], actual[i]) << "i=" << i; + } + } +} diff --git a/winml/adapter/winml_adapter_model.cpp b/winml/adapter/winml_adapter_model.cpp index 195bf6e5f0ffd..cf02c6fa2328b 100644 --- a/winml/adapter/winml_adapter_model.cpp +++ b/winml/adapter/winml_adapter_model.cpp @@ -593,13 +593,13 @@ ORT_API_STATUS_IMPL( input.set_name(input_name); if (info->type == ONNXType::ONNX_TYPE_TENSOR) { - auto num_dims = info->data->shape.NumDimensions(); + auto num_dims = info->tensor_type_info->shape.NumDimensions(); CreateTypeProto_Tensor( input.mutable_type()->mutable_tensor_type(), input_name, - (num_dims == 0) ? nullptr : &info->data->shape[0], + (num_dims == 0) ? nullptr : &info->tensor_type_info->shape[0], num_dims, - ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) + ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) ); } return nullptr; @@ -619,12 +619,12 @@ ORT_API_STATUS_IMPL( ONNX_NAMESPACE::TensorProto& input = *graph.add_initializer(); input.set_name(input_name); - auto num_dims = info->data->shape.NumDimensions(); + auto num_dims = info->tensor_type_info->shape.NumDimensions(); for (size_t i = 0; i < num_dims; i++) { - input.add_dims(info->data->shape[i]); + input.add_dims(info->tensor_type_info->shape[i]); } - input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type)); + input.set_data_type(ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type)); auto tensor = value->GetMutable(); input.set_raw_data(tensor->DataRaw(), tensor->SizeInBytes()); @@ -645,9 +645,9 @@ ORT_API_STATUS_IMPL( CreateTypeProto_Tensor( output.mutable_type()->mutable_tensor_type(), output_name, - &info->data->shape[0], - info->data->shape.NumDimensions(), - ONNXTensorElementDataTypeToTensorProto_DataType(info->data->type) + &info->tensor_type_info->shape[0], + info->tensor_type_info->shape.NumDimensions(), + ONNXTensorElementDataTypeToTensorProto_DataType(info->tensor_type_info->type) ); } return nullptr; From 147c574c6143d551ebd580c68f10639b49f1bb25 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 3 Jan 2025 11:01:16 +1000 Subject: [PATCH 2/3] Minor updates --- .../core/session/onnxruntime_c_api.h | 2 +- .../test/shared_lib/test_model_builder_api.cc | 34 ++++++++++--------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c883ffa100320..c69d1d8471579 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5161,7 +5161,7 @@ struct OrtModelBuilderApi { * Pre-existing memory: * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue * with a tensor that contains a pointer to the existing data. - * User must keep pointer valid for lifetime of the inference session. + * If using CreateTensorWithDataAsOrtValue you must keep the pointer valid for lifetime of the inference session. * Set `data_is_external` to true. * * Allocated memory: diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index cd7b774ad64d5..e4870809de141 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -131,14 +131,14 @@ struct TestAllocator : public OrtAllocator { // Uses the ORT C++ api for the rest for simplicity TEST(ModelBuilderAPITest, Basic_CApi) { const auto& api = Ort::GetApi(); - const auto& graph_api = Ort::GetModelBuilderApi(); + const auto& model_builder_api = Ort::GetModelBuilderApi(); TestAllocator deleter; // return void so we can use ASSERT_* in the lambda const auto build_model = [&](bool use_constant_node, OrtModel*& model) -> void { OrtGraph* graph = nullptr; - Ort::ThrowOnError(graph_api.CreateGraph(&graph)); + Ort::ThrowOnError(model_builder_api.CreateGraph(&graph)); // // Create OrtModel with a Gemm. X input is 3x2, Y input is 2x3, Z output is 3x3. @@ -164,7 +164,7 @@ TEST(ModelBuilderAPITest, Basic_CApi) { // create ValueInfo and release the type info as CreateValueInfo takes a copy. OrtValueInfo* input_value_info = nullptr; - Ort::ThrowOnError(graph_api.CreateValueInfo("X", input_type_info, &input_value_info)); + Ort::ThrowOnError(model_builder_api.CreateValueInfo("X", input_type_info, &input_value_info)); api.ReleaseTypeInfo(input_type_info); // input_value_info took a copy tensor_type_info = nullptr; @@ -180,13 +180,15 @@ TEST(ModelBuilderAPITest, Basic_CApi) { api.ReleaseTensorTypeAndShapeInfo(tensor_type_info); // input_type_info took a copy OrtValueInfo* output_value_info = nullptr; - Ort::ThrowOnError(graph_api.CreateValueInfo("Z", output_type_info, &output_value_info)); + Ort::ThrowOnError(model_builder_api.CreateValueInfo("Z", output_type_info, &output_value_info)); api.ReleaseTypeInfo(output_type_info); std::vector graph_inputs = {input_value_info}; std::vector graph_outputs = {output_value_info}; - Ort::ThrowOnError(graph_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size())); - Ort::ThrowOnError(graph_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size())); + Ort::ThrowOnError(model_builder_api.SetGraphInputs(graph, graph_inputs.data(), graph_inputs.size())); + Ort::ThrowOnError(model_builder_api.SetGraphOutputs(graph, graph_outputs.data(), graph_outputs.size())); + input_value_info = nullptr; // graph now owns the input/output values + output_value_info = nullptr; // // Gemm node @@ -200,11 +202,11 @@ TEST(ModelBuilderAPITest, Basic_CApi) { const std::string gemm_output_name = use_constant_node ? "Z_temp" : "Z"; std::vector node_output_names = {gemm_output_name.c_str()}; std::vector node_attributes{alpha_attr}; - OrtNode* node = CreateNode(graph_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); + OrtNode* node = CreateNode(model_builder_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); api.ReleaseOpAttr(alpha_attr); // CreateNode copies all OrtOpAttr instances - Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node)); + Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node)); node = nullptr; // graph now owns node // Y input @@ -226,7 +228,7 @@ TEST(ModelBuilderAPITest, Basic_CApi) { ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &y_tensor)); - Ort::ThrowOnError(graph_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); + Ort::ThrowOnError(model_builder_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); y_tensor = nullptr; // graph now owns if (use_constant_node) { @@ -237,20 +239,20 @@ TEST(ModelBuilderAPITest, Basic_CApi) { float max = 60.0f; Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &value_attr)); - node = CreateNode(graph_api, "Constant", "clip_max", {}, {"max"}, {value_attr}); - Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node)); + node = CreateNode(model_builder_api, "Constant", "clip_max", {}, {"max"}, {value_attr}); + Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node)); node = nullptr; // graph now owns node - node = CreateNode(graph_api, "Clip", "Clip1", {gemm_output_name.c_str(), "", "max"}, {"Z"}); - Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node)); + node = CreateNode(model_builder_api, "Clip", "Clip1", {gemm_output_name.c_str(), "", "max"}, {"Z"}); + Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node)); node = nullptr; // graph now owns node } std::vector domain_names = {onnxruntime::kOnnxDomain}; std::vector opset_versions = {18}; - Ort::ThrowOnError(graph_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(), - &model)); - Ort::ThrowOnError(graph_api.AddGraphToModel(model, graph)); + Ort::ThrowOnError(model_builder_api.CreateModel(domain_names.data(), opset_versions.data(), domain_names.size(), + &model)); + Ort::ThrowOnError(model_builder_api.AddGraphToModel(model, graph)); graph = nullptr; // model now owns }; From 4e2d061977d66b75517b8855a42ad5dd159ad895 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 8 Jan 2025 10:23:49 +1000 Subject: [PATCH 3/3] - take ownership of OrtOpAttr in CreateNode - enforce 128 byte minimum for tensors with external data to avoid shape inferencing issues - update unit tests to use 128 byte initializer so external data can be tested - support saving initializer with in-memory external data to ONNX model by copying into TensorProto's raw_data property. --- .../core/session/onnxruntime_c_api.h | 21 +++- .../core/session/onnxruntime_cxx_inline.h | 4 +- onnxruntime/core/graph/graph.cc | 62 ++++++++---- onnxruntime/core/session/inference_session.h | 9 +- .../core/session/model_builder_c_api.cc | 30 +++++- .../test/shared_lib/test_model_builder_api.cc | 97 +++++++++++-------- 6 files changed, 149 insertions(+), 74 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 06b83fdf7319b..fcd42323efe92 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5158,15 +5158,28 @@ struct OrtModelBuilderApi { * * Two options: * + * Allocated memory: + * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. + * Set `data_is_external` to false. + * * Pre-existing memory: * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue * with a tensor that contains a pointer to the existing data. - * If using CreateTensorWithDataAsOrtValue you must keep the pointer valid for lifetime of the inference session. * Set `data_is_external` to true. * - * Allocated memory: - * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. - * Set `data_is_external` to false. + * The pointer must remain valid for the duration of the inference session. + * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session + * is released. + * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as + * soon as the OrtValue is no longer in use. + * + * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. + * For smaller tensors use CreateTensorAsOrtValue. + * + * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is + * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of + * memory, so this limit acts as a simple catch-all rule to avoid issues. + * e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`. * * \param[in] graph The OrtGraph instance to update. * \param[in] name The value name for the initializer. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 1de5db266961d..7365d39938fdb 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2418,10 +2418,8 @@ template <> inline void GraphImpl::SetInputs(std::vector& inputs) { std::vector inputs_ptrs; inputs_ptrs.reserve(inputs.size()); - - // Graph takes ownership. std::transform(inputs.begin(), inputs.end(), std::back_inserter(inputs_ptrs), - [](ValueInfo& vi) -> OrtValueInfo* { return vi.release(); }); + [](ValueInfo& vi) -> OrtValueInfo* { return vi; }); ThrowOnError(GetModelBuilderApi().SetGraphInputs(p_, inputs_ptrs.data(), inputs_ptrs.size())); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 41a3c28e01408..660b0cf288c67 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -7,21 +7,24 @@ #include #include #include -#include #include +#include -#include "core/common/common.h" #include + +#include "core/common/common.h" #include "core/common/inlined_containers.h" #include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/flatbuffers/flatbuffers_utils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/flatbuffers/schema/ort.fbs.h" -#include "core/framework/tensor_shape.h" #include "core/framework/tensor_external_data_info.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor_type_and_shape.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" +#include "core/graph/function_utils.h" #include "core/graph/graph_flatbuffers_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/indexed_sub_graph.h" @@ -32,7 +35,6 @@ #include "core/graph/node_attr_utils.h" #include "core/graph/op.h" #include "core/graph/runtime_optimization_record_container.h" -#include "core/graph/function_utils.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/graph/function.h" @@ -4096,27 +4098,51 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProto() const { // This is used for constructing full path for external data // if it exists + auto add_initializer = [](TensorList& output_initializers, const TensorProto& initializer) -> void { + TensorProto& output = *output_initializers.Add(); + output = initializer; + + // copy any in-memory external data into raw data + if (utils::HasExternalData(initializer)) { + const std::filesystem::path ignored; + std::basic_string location; + onnxruntime::FileOffsetType file_offset; + SafeInt tensor_byte_size; + + ORT_THROW_IF_ERROR(utils::GetExternalDataInfo(initializer, ignored, location, file_offset, tensor_byte_size)); + + if (location == onnxruntime::utils::kTensorProtoMemoryAddressTag) { + // file_offset is address + void* data = reinterpret_cast(file_offset); + + // set in raw data + output.clear_data_location(); + output.set_raw_data(data, tensor_byte_size); + } + } + }; + + auto* mutable_initializers = result.mutable_initializer(); + #if !defined(DISABLE_SPARSE_TENSORS) const auto& model_path = ModelPath(); // We want to make sure that sparse initializers do not appear // as dense duplicates within the initializers list. - if (!sparse_tensor_names_.empty()) { - const auto sparse_end = sparse_tensor_names_.end(); - auto* mutable_initializer = result.mutable_initializer(); - for (const auto& initializer : graph_proto_->initializer()) { - if (sparse_end == sparse_tensor_names_.find(initializer.name())) { - *mutable_initializer->Add() = initializer; - } else { - auto& sparse_initializer = *result.add_sparse_initializer(); - auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); - ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); - } + const bool has_sparse_initializers = !sparse_tensor_names_.empty(); + const auto sparse_end = sparse_tensor_names_.end(); + for (const auto& initializer : graph_proto_->initializer()) { + if (!has_sparse_initializers || sparse_end == sparse_tensor_names_.find(initializer.name())) { + add_initializer(*mutable_initializers, initializer); + } else { + auto& sparse_initializer = *result.add_sparse_initializer(); + auto status = utils::DenseTensorToSparseTensorProto(initializer, model_path, sparse_initializer); + ORT_ENFORCE(status.IsOK(), "Failed to convert dense initializer to sparse"); } - } else { - *result.mutable_initializer() = graph_proto_->initializer(); } #else - *result.mutable_initializer() = graph_proto_->initializer(); + for (const auto& initializer : graph_proto_->initializer()) { + add_initializer(*mutable_initializers, initializer); + } #endif return result; diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index f89eacb633e42..89a2693d1956b 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -627,6 +627,12 @@ class InferenceSession { /// convenience pointer to logger. should always be the same as session_state_.Logger(); const logging::Logger* session_logger_; + // The list of execution providers. + // This MUST be prior to model_ in case there are values in the model that were allocated using an allocator + // provided by the EP. If that is the case the allocator's `free` implementation may depend on other parts of the + // EP instance. + ExecutionProviders execution_providers_; + // The model served by this inference session instance. // Currently this has to be a shared ptr because the Model::Load method // returns a shared_ptr only. Ideally factory functions should always return @@ -637,9 +643,6 @@ class InferenceSession { // The file path of where the model was loaded. e.g. /tmp/test_squeezenet/model.onnx PathString model_location_; - // The list of execution providers. - ExecutionProviders execution_providers_; - private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); void SetLoggingManager(const SessionOptions& session_options, diff --git a/onnxruntime/core/session/model_builder_c_api.cc b/onnxruntime/core/session/model_builder_c_api.cc index 25e2409805c74..8eac1ebce36ab 100644 --- a/onnxruntime/core/session/model_builder_c_api.cc +++ b/onnxruntime/core/session/model_builder_c_api.cc @@ -93,6 +93,9 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateNode, const char* operator_name, c n->attributes.reserve(attribs_len); for (size_t i = 0; i < attribs_len; ++i) { n->attributes.push_back(*reinterpret_cast(attributes[i])); + // take ownership. as we took a copy that means releasing the original value + OrtApis::ReleaseOpAttr(attributes[i]); + attributes[i] = nullptr; } } @@ -156,12 +159,31 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::SetGraphOutputs, _In_ OrtGraph* graph, ORT_API_STATUS_IMPL(OrtModelBuilderAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, bool data_is_external) { API_IMPL_BEGIN + if (!tensor->IsTensor()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only Tensor is currently supported."); + } + + if (!tensor->IsAllocated()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Tensor must be allocated."); + } + + const auto& t = tensor->Get(); + if (t.Location().device.Type() != OrtDevice::CPU) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Only CPU based tensors are currently supported."); + } + if (data_is_external) { -#if !defined(DISABLE_EXTERNAL_INITIALIZERS) + // enforce that an external initializer is not used if the data size is < 128 bytes. + // the reason for this is to avoid potential shape inferencing errors if this initializer is providing an + // input involved in that. the ONNX shape inferencing does not support external data for those values. + // e.g. Reshape's `shape` input, Reduce's `axes', Slice's `starts`, `ends`, `steps`, Clip's `min`, `max`, etc. + if (t.SizeInBytes() < 128) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "External initializer should only be used for data >= 128 bytes. " + "Please use CreateTensorAsOrtValue instead."); + } + graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership -#else - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External initializers are not supported in this build"); -#endif } else { graph->initializers[name] = std::unique_ptr(tensor); // take ownership } diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index e4870809de141..b6314f48a0e09 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -141,14 +141,14 @@ TEST(ModelBuilderAPITest, Basic_CApi) { Ort::ThrowOnError(model_builder_api.CreateGraph(&graph)); // - // Create OrtModel with a Gemm. X input is 3x2, Y input is 2x3, Z output is 3x3. + // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. // X is model input. Y is initializer. // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. // // model input OrtTensorTypeAndShapeInfo* tensor_type_info = nullptr; - std::vector input_dims = {3, 2}; + std::vector input_dims = {3, 4}; // can use api.SetSymbolicDimensions to set symbolic dimensions. // the input array should have the same rank as the call to SetDimensions. // e.g. call SetDimensions with {-1, 3, 2} and SetSymbolicDimensions with {"N", nullptr, nullptr} to create @@ -170,7 +170,7 @@ TEST(ModelBuilderAPITest, Basic_CApi) { // model outputs OrtTypeInfo* output_type_info = nullptr; - std::vector output_dims = {3, 3}; + std::vector output_dims = {3, 8}; Ort::ThrowOnError(api.CreateTensorTypeAndShapeInfo(&tensor_type_info)); Ort::ThrowOnError(api.SetTensorElementType(tensor_type_info, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); @@ -203,24 +203,22 @@ TEST(ModelBuilderAPITest, Basic_CApi) { std::vector node_output_names = {gemm_output_name.c_str()}; std::vector node_attributes{alpha_attr}; OrtNode* node = CreateNode(model_builder_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); - - api.ReleaseOpAttr(alpha_attr); // CreateNode copies all OrtOpAttr instances + alpha_attr = nullptr; // Node now owns Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node)); node = nullptr; // graph now owns node // Y input - std::vector y_dims = {2, 3}; - deleter.weights.emplace_back( - std::make_unique>(std::initializer_list{1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f})); + // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. + // Under 128 bytes must use CreateTensorAsOrtValue. + std::vector y_dims = {4, 8}; + + deleter.weights.emplace_back(std::make_unique>(32)); auto& y_values = *deleter.weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); - // create an initializer for the Y input. add to `weights` so the memory remains valid + // create an initializer for the Y input. add to `weights` so the memory remains valid. OrtValue* y_tensor = nullptr; - auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession Ort::ThrowOnError( api.CreateTensorWithDataAndDeleterAsOrtValue(&deleter, y_values.data(), y_values.size() * sizeof(y_values[0]), @@ -232,18 +230,24 @@ TEST(ModelBuilderAPITest, Basic_CApi) { y_tensor = nullptr; // graph now owns if (use_constant_node) { - // Test that a Constant node is converted to an intializer + // Test that a Constant node is converted to an initializer - // create Constant node that is used as the Max in a Clip to limit the output - OrtOpAttr* value_attr = nullptr; - float max = 60.0f; - Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &value_attr)); + // create Constant nodes for min/max to limit output range + OrtOpAttr* min_attr = nullptr; + float min = 400.0f; + Ort::ThrowOnError(api.CreateOpAttr("value", &min, sizeof(min), ORT_OP_ATTR_FLOAT, &min_attr)); + node = CreateNode(model_builder_api, "Constant", "clip_min", {}, {"min"}, {min_attr}); + Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node)); + node = nullptr; // graph now owns node - node = CreateNode(model_builder_api, "Constant", "clip_max", {}, {"max"}, {value_attr}); + OrtOpAttr* max_attr = nullptr; + float max = 900.0f; + Ort::ThrowOnError(api.CreateOpAttr("value", &max, sizeof(max), ORT_OP_ATTR_FLOAT, &max_attr)); + node = CreateNode(model_builder_api, "Constant", "clip_max", {}, {"max"}, {max_attr}); Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node)); node = nullptr; // graph now owns node - node = CreateNode(model_builder_api, "Clip", "Clip1", {gemm_output_name.c_str(), "", "max"}, {"Z"}); + node = CreateNode(model_builder_api, "Clip", "Clip1", {gemm_output_name.c_str(), "min", "max"}, {"Z"}); Ort::ThrowOnError(model_builder_api.AddNodeToGraph(graph, node)); node = nullptr; // graph now owns node } @@ -265,22 +269,25 @@ TEST(ModelBuilderAPITest, Basic_CApi) { std::vector> inputs(1); auto& input = inputs[0]; input.name = "X"; - input.dims = {3, 2}; - input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + 8.0f, 7.0f, 6.0f, 5.0f, + 9.0f, 3.0f, 5.0f, 7.0f}; - std::vector expected_dims = {3, 3}; + std::vector expected_dims = {3, 8}; ModelBuilderAPI::Model cxx_model(model); auto session = CreateSession(*ort_env, cxx_model); std::vector expected_output; if (use_constant_node) { - expected_output = {18.0f, 24.0f, 30.0f, - 38.0f, 52.0f, 60.0f, // clipped - 58.0f, 60.0f, 60.0f}; // clipped + // clipped with min 400 and max 900 + expected_output = {400.0f, 400.0f, 400.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 900.0f, 900.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 900.0f}; } else { - expected_output = {18.0f, 24.0f, 30.0f, - 38.0f, 52.0f, 66.0f, - 58.0f, 80.0f, 102.0f}; + expected_output = {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}; } TestInference(session, inputs, "Z", expected_dims, expected_output); @@ -301,7 +308,7 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { Ort::ModelBuilderAPI::Graph graph; // - // Create OrtModel with a Gemm. X input is 3x2, Y input is 2x3, Z output is 3x3. + // Create OrtModel with a Gemm. X input is 3x4, Y input is 4x8, Z output is 3x8. // X is model input. Y is initializer. // Set the alpha attribute of the Gemm node to 2.0 to test attribute handling. // @@ -309,8 +316,8 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { std::vector graph_inputs; std::vector graph_outputs; - // model input. it's {3, 2} but use a symbolic dim to test that works. - std::vector input_dims({-1, 2}); + // model input. it's {3, 4} but use a symbolic dim to test that works. + std::vector input_dims({-1, 4}); std::vector input_symbolic_dims({"multiple_of_3", ""}); TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, input_dims, @@ -319,7 +326,7 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { graph_inputs.emplace_back("X", input_type_info.GetConst()); // model outputs - std::vector output_dims = {-1, 3}; + std::vector output_dims = {-1, 8}; std::vector output_symbolic_dims({"multiple_of_3", ""}); TensorTypeAndShapeInfo output_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, output_dims, @@ -344,10 +351,14 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { // create an initializer for the Y input. // add to `weights` so it remains valid for the lifetime of the session and we can avoid copying the data. - std::vector y_dims = {2, 3}; - weights.emplace_back(std::make_unique>(std::initializer_list{1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f})); + // As it's 128 bytes it could either be allocated using CreateTensorAsOrtValue or use existing memory. + // Under 128 bytes must use CreateTensorAsOrtValue. + std::vector y_dims = {4, 8}; + + weights.emplace_back(std::make_unique>(32)); auto& y_values = *weights.back(); + std::iota(y_values.begin(), y_values.end(), 1.0f); + auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession @@ -361,16 +372,18 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { std::vector> inputs(1); auto& input = inputs[0]; input.name = "X"; - input.dims = {3, 2}; - input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + input.dims = {3, 4}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, + 8.0f, 7.0f, 6.0f, 5.0f, + 9.0f, 3.0f, 5.0f, 7.0f}; - std::vector expected_dims = {3, 3}; + std::vector expected_dims = {3, 8}; auto session = CreateSession(*ort_env, model); TestInference(session, inputs, "Z", expected_dims, - {18.0f, 24.0f, 30.0f, - 38.0f, 52.0f, 66.0f, - 58.0f, 80.0f, 102.0f}); + {340.0f, 360.0f, 380.0f, 400.0f, 420.0f, 440.0f, 460.0f, 480.0f, + 596.0f, 648.0f, 700.0f, 752.0f, 804.0f, 856.0f, 908.0f, 960.0f, + 592.0f, 640.0f, 688.0f, 736.0f, 784.0f, 832.0f, 880.0f, 928.0f}); } TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) {