oneAPI Deep Neural Network Library (oneDNN)
Performance library for Deep Learning
2.3.0
dnnl.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
19 
20 #ifndef ONEAPI_DNNL_DNNL_HPP
21 #define ONEAPI_DNNL_DNNL_HPP
22 
23 #include "oneapi/dnnl/dnnl_config.h"
24 
26 #include <algorithm>
27 #include <cstdlib>
28 #include <iterator>
29 #include <memory>
30 #include <string>
31 #include <vector>
32 #include <unordered_map>
33 
34 #include "oneapi/dnnl/dnnl.h"
35 
37 
38 // __cpp_exceptions is referred from
39 // https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html
40 // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
41 // Microsoft C++ Compiler does not provide an option to disable exceptions
42 #ifndef DNNL_ENABLE_EXCEPTIONS
43 #if __cpp_exceptions || __EXCEPTIONS \
44  || (defined(_MSC_VER) && !defined(__clang__))
45 #define DNNL_ENABLE_EXCEPTIONS 1
46 #else
47 #define DNNL_ENABLE_EXCEPTIONS 0
48 #endif
49 #endif
50 
51 #if defined(__GNUC__) || defined(__clang__)
52 #define DNNL_TRAP() __builtin_trap()
53 #elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
54 #define DNNL_TRAP() __debugbreak()
55 #else
56 #error "unknown compiler"
57 #endif
58 
59 #if DNNL_ENABLE_EXCEPTIONS
60 #define DNNL_THROW_ERROR(status, msg) throw error(status, msg)
61 #else
62 #include <cstdio>
63 #define DNNL_THROW_ERROR(status, msg) \
64  do { \
65  fputs(msg, stderr); \
66  DNNL_TRAP(); \
67  } while (0)
68 #endif
69 
72 
74 namespace dnnl {
75 
79 
84 struct error : public std::exception {
86  const char *message;
87 
92  error(dnnl_status_t status, const char *message)
93  : status(status), message(message) {}
94 
96  const char *what() const noexcept override { return message; }
97 
103  static void wrap_c_api(dnnl_status_t status, const char *message) {
104  if (status != dnnl_success) DNNL_THROW_ERROR(status, message);
105  }
106 };
107 
109 template <typename T>
110 void validate_container_size(const T &v, const char *error_message,
111  int min_size = 1, int max_size = -1) {
112  const int size = (int)v.size();
113  if (size < min_size || (max_size >= 0 && size > max_size))
114  DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
115 }
117 
119 template <typename T>
120 struct handle_traits {};
121 
135 template <typename T, typename traits = handle_traits<T>>
136 struct handle {
137 private:
138  static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
139  std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
140 
141 protected:
142  bool operator==(const T other) const { return other == data_.get(); }
143  bool operator!=(const T other) const { return !(*this == other); }
144 
145 public:
153  handle() = default;
154 
156  handle(const handle<T, traits> &) = default;
160  handle(handle<T, traits> &&) = default;
163 
169  explicit handle(T t, bool weak = false) { reset(t, weak); }
170 
176  void reset(T t, bool weak = false) {
177  data_.reset(t, weak ? &dummy_destructor : traits::destructor);
178  }
179 
185  T get(bool allow_empty = false) const {
186  T result = data_.get();
187  if (allow_empty == false && result == nullptr)
188  DNNL_THROW_ERROR(
189  dnnl_invalid_arguments, "object is not initialized");
190  return result;
191  }
192 
197  explicit operator T() const { return get(true); }
198 
202  explicit operator bool() const { return get(true) != nullptr; }
203 
210  bool operator==(const handle<T, traits> &other) const {
211  return other.data_.get() == data_.get();
212  }
213 
220  bool operator!=(const handle &other) const { return !(*this == other); }
221 };
222 
224 template <>
225 struct handle_traits<dnnl_memory_t> {
226  static dnnl_status_t destructor(dnnl_memory_t p) {
227  return dnnl_memory_destroy(p);
228  }
229 };
230 
231 template <>
232 struct handle_traits<dnnl_primitive_desc_t> {
233  static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
234  return dnnl_primitive_desc_destroy(p);
235  }
236 };
237 
238 template <>
239 struct handle_traits<dnnl_primitive_t> {
240  static dnnl_status_t destructor(dnnl_primitive_t p) {
241  return dnnl_primitive_destroy(p);
242  }
243 };
244 
245 template <>
246 struct handle_traits<dnnl_primitive_desc_iterator_t> {
247  static dnnl_status_t destructor(dnnl_primitive_desc_iterator_t p) {
249  }
250 };
252 
254 
255 struct stream;
256 struct memory;
257 struct primitive_desc;
258 
263 
267 
269 struct primitive : public handle<dnnl_primitive_t> {
271  enum class kind {
281  sum = dnnl_sum,
293  lrn = dnnl_lrn,
301  rnn = dnnl_rnn,
315  prelu = dnnl_prelu,
316  };
317 
318  using handle::handle;
319 
321  primitive() = default;
322 
327 
332 
338 
342  inline kind get_kind() const;
343 
356  void execute(const stream &astream,
357  const std::unordered_map<int, memory> &args) const;
358 };
359 
365  return static_cast<dnnl_primitive_kind_t>(akind);
366 }
367 
371  "could not get a primitive descriptor from a primitive");
372  return pd;
373 }
374 
377  // TODO (Roma): the code below is only needed because get_primitive_desc
378  // returns a C type.
381  pd, dnnl_query_primitive_kind, 0, (void *)&kind),
382  "could not get a primitive kind from a primitive descriptor");
383  return static_cast<dnnl::primitive::kind>(kind);
384 }
385 
387 
399 
401 enum class scratchpad_mode {
424 };
425 
431  return static_cast<dnnl_scratchpad_mode_t>(mode);
432 }
433 
435 enum class prop_kind {
459 };
460 
466  return static_cast<dnnl_prop_kind_t>(akind);
467 }
468 
470 enum class algorithm {
472  undef = dnnl_alg_kind_undef,
616 };
617 
622  return static_cast<dnnl_alg_kind_t>(aalgorithm);
623 }
624 
626 
629 
631 enum class normalization_flags : unsigned {
637 
646 
653 
659 
664 
669 };
670 
675  return static_cast<dnnl_normalization_flags_t>(flags);
676 }
677 
679 
682 
684 enum class rnn_flags : unsigned {
687 };
688 
693  return static_cast<dnnl_rnn_flags_t>(flags);
694 }
695 
696 #define DNNL_DEFINE_BITMASK_OPS(enum_name) \
697  inline enum_name operator|(enum_name lhs, enum_name rhs) { \
698  return static_cast<enum_name>( \
699  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
700  } \
701 \
702  inline enum_name operator&(enum_name lhs, enum_name rhs) { \
703  return static_cast<enum_name>( \
704  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
705  } \
706 \
707  inline enum_name operator^(enum_name lhs, enum_name rhs) { \
708  return static_cast<enum_name>( \
709  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
710  } \
711 \
712  inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
713  lhs = static_cast<enum_name>( \
714  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
715  return lhs; \
716  } \
717 \
718  inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
719  lhs = static_cast<enum_name>( \
720  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
721  return lhs; \
722  } \
723 \
724  inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
725  lhs = static_cast<enum_name>( \
726  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
727  return lhs; \
728  } \
729 \
730  inline enum_name operator~(enum_name rhs) { \
731  return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
732  }
733 
734 DNNL_DEFINE_BITMASK_OPS(normalization_flags)
735 DNNL_DEFINE_BITMASK_OPS(rnn_flags)
736 
737 enum class rnn_direction {
751 };
752 
757  return static_cast<dnnl_rnn_direction_t>(dir);
758 }
759 
761 
764 
771 enum class query {
774 
779 
784 
791 
796 
801 
804 
807 
842 
861 };
862 
867  return static_cast<dnnl_query_t>(aquery);
868 }
869 
871 
873 
884 
886 template <>
887 struct handle_traits<dnnl_engine_t> {
888  static dnnl_status_t destructor(dnnl_engine_t p) {
889  return dnnl_engine_destroy(p);
890  }
891 };
893 
895 struct engine : public handle<dnnl_engine_t> {
896  friend struct primitive;
897  friend struct reorder;
898 
900  enum class kind {
904  cpu = dnnl_cpu,
906  gpu = dnnl_gpu,
907  };
908 
909  using handle::handle;
910 
913  engine() = default;
914 
919  static size_t get_count(kind akind) {
920  return dnnl_engine_get_count(convert_to_c(akind));
921  }
922 
928  engine(kind akind, size_t index) {
931  dnnl_engine_create(&engine, convert_to_c(akind), index),
932  "could not create an engine");
933  reset(engine);
934  }
935 
941  dnnl_engine_t c_engine;
944  dnnl::convert_to_c(dnnl::query::engine), 0, &c_engine),
945  "could not get an engine from a primitive_desc");
946  reset(c_engine, true);
947  }
948 
951  kind get_kind() const {
954  "could not get kind of an engine");
955  return static_cast<engine::kind>(kind);
956  }
957 
963  template <typename primitive_desc>
964  static engine query(const primitive_desc &pd) {
965  return query(pd, dnnl::query::engine);
966  }
967 
968 private:
969  static dnnl_engine_kind_t convert_to_c(kind akind) {
970  return static_cast<dnnl_engine_kind_t>(akind);
971  }
972 
973  template <typename primitive_desc>
974  static engine query(const primitive_desc &pd, dnnl::query what) {
975  dnnl_engine_t c_engine;
977  dnnl::convert_to_c(what), 0, &c_engine),
978  "could not get an engine from a primitive_desc");
979  return engine(c_engine, true);
980  }
981 };
982 
988  return static_cast<dnnl_engine_kind_t>(akind);
989 }
990 
992 
1000 
1002 template <>
1003 struct handle_traits<dnnl_stream_t> {
1004  static dnnl_status_t destructor(dnnl_stream_t p) {
1005  return dnnl_stream_destroy(p);
1006  }
1007 };
1009 
1011 struct stream : public handle<dnnl_stream_t> {
1012  using handle::handle;
1013 
1015  enum class flags : unsigned {
1017  in_order = dnnl_stream_in_order,
1022  };
1023 
1026  stream() = default;
1027 
1033  stream(const engine &aengine, flags aflags = flags::default_flags) {
1036  static_cast<dnnl_stream_flags_t>(aflags)),
1037  "could not create a stream");
1038  reset(stream);
1039  }
1040 
1042  engine get_engine() const {
1043  dnnl_engine_t c_engine;
1045  "could not get an engine from a stream object");
1046  return engine(c_engine, true);
1047  }
1048 
1053  dnnl_stream_wait(get()), "could not wait on a stream");
1054  return *this;
1055  }
1056 };
1057 
1058 DNNL_DEFINE_BITMASK_OPS(stream::flags)
1059 
1060 
1127 
1134 struct memory : public handle<dnnl_memory_t> {
1135  using handle::handle;
1136 
1138  typedef dnnl_dim_t dim;
1141  typedef std::vector<dim> dims;
1142 
1149  template <typename T>
1150  static void validate_dims(const std::vector<T> &v, int min_size = 0) {
1151  validate_container_size(
1152  v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
1153  }
1154 
1156  enum class data_type {
1160  f16 = dnnl_f16,
1163  bf16 = dnnl_bf16,
1165  f32 = dnnl_f32,
1167  s32 = dnnl_s32,
1169  s8 = dnnl_s8,
1171  u8 = dnnl_u8,
1172  };
1173 
1176  static size_t data_type_size(data_type adata_type) {
1177  return dnnl_data_type_size(convert_to_c(adata_type));
1178  }
1179 
1181  enum class format_kind {
1186  any = dnnl_format_kind_any,
1190  blocked = dnnl_blocked,
1192  wino = dnnl_format_kind_wino,
1194  packed = dnnl_format_kind_rnn_packed,
1195  };
1196 
1237  enum class format_tag {
1242  any = dnnl_format_tag_any,
1243 
1245  a = dnnl_a,
1246 
1248  ab = dnnl_ab,
1250  ba = dnnl_ba,
1251 
1253  abc = dnnl_abc,
1255  acb = dnnl_acb,
1257  bac = dnnl_bac,
1259  bca = dnnl_bca,
1261  cba = dnnl_cba,
1262 
1264  abcd = dnnl_abcd,
1266  abdc = dnnl_abdc,
1268  acdb = dnnl_acdb,
1270  bacd = dnnl_bacd,
1272  bcda = dnnl_bcda,
1274  cdba = dnnl_cdba,
1276  dcab = dnnl_dcab,
1277 
1279  abcde = dnnl_abcde,
1281  abdec = dnnl_abdec,
1283  acbde = dnnl_acbde,
1285  acdeb = dnnl_acdeb,
1287  bacde = dnnl_bacde,
1289  bcdea = dnnl_bcdea,
1291  cdeba = dnnl_cdeba,
1293  decab = dnnl_decab,
1295  abced = dnnl_abced,
1296 
1298  abcdef = dnnl_abcdef,
1300  abdfce = dnnl_abdfce,
1302  acbdef = dnnl_acbdef,
1304  abdefc = dnnl_abdefc,
1306  defcab = dnnl_defcab,
1308  abcdfe = dnnl_abcdfe,
1309 
1311  abcdefg = dnnl_abcdefg,
1313  abcdegf = dnnl_abcdegf,
1314 
1316  abcdefgh = dnnl_abcdefgh,
1318  abcdefhg = dnnl_abcdefhg,
1319 
1321  abcdefghi = dnnl_abcdefghi,
1323  abcdefgih = dnnl_abcdefgih,
1324 
1326  abcdefghij = dnnl_abcdefghij,
1328  abcdefghji = dnnl_abcdefghji,
1329 
1331  abcdefghijk = dnnl_abcdefghijk,
1333  abcdefghikj = dnnl_abcdefghikj,
1334 
1336  abcdefghijkl = dnnl_abcdefghijkl,
1338  abcdefghijlk = dnnl_abcdefghijlk,
1339 
1341  x = a,
1343  nc = ab,
1345  cn = ba,
1347  tn = ab,
1349  nt = ba,
1351  ncw = abc,
1353  nwc = acb,
1355  nchw = abcd,
1357  nhwc = acdb,
1359  chwn = bcda,
1361  ncdhw = abcde,
1363  ndhwc = acdeb,
1364 
1366  oi = ab,
1368  io = ba,
1370  oiw = abc,
1372  owi = acb,
1374  wio = cba,
1376  iwo = bca,
1378  oihw = abcd,
1380  hwio = cdba,
1382  ohwi = acdb,
1384  ihwo = bcda,
1386  iohw = bacd,
1388  oidhw = abcde,
1390  dhwio = cdeba,
1392  odhwi = acdeb,
1394  iodhw = bacde,
1396  idhwo = bcdea,
1397 
1399  goiw = abcd,
1401  gowi = abdc,
1403  wigo = dcab,
1405  gohwi = abdec,
1407  goihw = abcde,
1409  hwigo = decab,
1411  giohw = acbde,
1413  goidhw = abcdef,
1415  giodhw = acbdef,
1417  godhwi = abdefc,
1419  dhwigo = defcab,
1420 
1423  tnc = abc,
1426  ntc = bac,
1429  ldnc = abcd,
1437  ldigo = abcde,
1445  ldgoi = abdec,
1449  ldio = abcd,
1453  ldoi = abdc,
1461  ldgo = abcd,
1462 
1463  // Opaque blocked formats
1464 
1465  AB16b16a = dnnl_AB16b16a,
1466  AB16b32a = dnnl_AB16b32a,
1467  AB16b64a = dnnl_AB16b64a,
1468  AB8b16a2b = dnnl_AB8b16a2b,
1469  AB8b32a2b = dnnl_AB8b32a2b,
1470  AB8b64a2b = dnnl_AB8b64a2b,
1471  AB4b16a4b = dnnl_AB4b16a4b,
1472  AB4b32a4b = dnnl_AB4b32a4b,
1473  AB4b64a4b = dnnl_AB4b64a4b,
1474  AB16b16a4b = dnnl_AB16b16a4b,
1475  AB16b32a4b = dnnl_AB16b32a4b,
1476  AB16b48a4b = dnnl_AB16b48a4b,
1477  AB16b64a4b = dnnl_AB16b64a4b,
1478  AB16b16a2b = dnnl_AB16b16a2b,
1479  AB16b32a2b = dnnl_AB16b32a2b,
1480  AB16b48a2b = dnnl_AB16b48a2b,
1481  AB16b64a2b = dnnl_AB16b64a2b,
1482  Abc16a = dnnl_Abc16a,
1483  ABc16a16b = dnnl_ABc16a16b,
1484  ABc4a4b = dnnl_ABc4a4b,
1485  aBc16b = dnnl_aBc16b,
1486  aBc32b = dnnl_aBc32b,
1487  ABc16b16a = dnnl_ABc16b16a,
1488  ABc16b32a = dnnl_ABc16b32a,
1489  ABc16b64a = dnnl_ABc16b64a,
1490  Abc4a = dnnl_Abc4a,
1491  aBc4b = dnnl_aBc4b,
1492  ABc4b16a4b = dnnl_ABc4b16a4b,
1493  ABc4b32a4b = dnnl_ABc4b32a4b,
1494  ABc4b64a4b = dnnl_ABc4b64a4b,
1495  ABc2b8a4b = dnnl_ABc2b8a4b,
1496  ABc16a16b2a = dnnl_ABc16a16b2a,
1497  ABc16b16a4b = dnnl_ABc16b16a4b,
1498  ABc16b32a4b = dnnl_ABc16b32a4b,
1499  ABc16b48a4b = dnnl_ABc16b48a4b,
1500  ABc16b64a4b = dnnl_ABc16b64a4b,
1501  ABc16b16a2b = dnnl_ABc16b16a2b,
1502  ABc16b32a2b = dnnl_ABc16b32a2b,
1503  ABc16b48a2b = dnnl_ABc16b48a2b,
1504  ABc16b64a2b = dnnl_ABc16b64a2b,
1505  ABc4b4a = dnnl_ABc4b4a,
1506  ABc8a16b2a = dnnl_ABc8a16b2a,
1507  ABc8a8b = dnnl_ABc8a8b,
1508  ABc8a4b = dnnl_ABc8a4b,
1509  aBc8b = dnnl_aBc8b,
1510  ABc8b16a2b = dnnl_ABc8b16a2b,
1511  ABc8b32a2b = dnnl_ABc8b32a2b,
1512  ABc8b64a2b = dnnl_ABc8b64a2b,
1513  ABc8b8a = dnnl_ABc8b8a,
1514  Abcd8a = dnnl_Abcd8a,
1515  Abcd16a = dnnl_Abcd16a,
1516  Abcd32a = dnnl_Abcd32a,
1517  ABcd16a16b = dnnl_ABcd16a16b,
1518  aBcd16b = dnnl_aBcd16b,
1519  aBcd32b = dnnl_aBcd32b,
1520  ABcd16b16a = dnnl_ABcd16b16a,
1521  ABcd16b32a = dnnl_ABcd16b32a,
1522  ABcd16b64a = dnnl_ABcd16b64a,
1523  aBCd16b16c = dnnl_aBCd16b16c,
1524  aBCd16c16b = dnnl_aBCd16c16b,
1525  Abcd4a = dnnl_Abcd4a,
1526  aBcd4b = dnnl_aBcd4b,
1527  ABcd4b16a4b = dnnl_ABcd4b16a4b,
1528  ABcd4b32a4b = dnnl_ABcd4b32a4b,
1529  ABcd4b64a4b = dnnl_ABcd4b64a4b,
1530  ABcd2b8a4b = dnnl_ABcd2b8a4b,
1531  ABcd4b4a = dnnl_ABcd4b4a,
1532  ABcd4a4b = dnnl_ABcd4a4b,
1533  aBCd4c16b4c = dnnl_aBCd4c16b4c,
1534  aBCd2c8b4c = dnnl_aBCd2c8b4c,
1535  ABcd16a16b2a = dnnl_ABcd16a16b2a,
1536  ABcd16b16a4b = dnnl_ABcd16b16a4b,
1537  ABcd16b32a4b = dnnl_ABcd16b32a4b,
1538  ABcd16b48a4b = dnnl_ABcd16b48a4b,
1539  ABcd16b64a4b = dnnl_ABcd16b64a4b,
1540  ABcd16b16a2b = dnnl_ABcd16b16a2b,
1541  ABcd16b32a2b = dnnl_ABcd16b32a2b,
1542  ABcd16b48a2b = dnnl_ABcd16b48a2b,
1543  ABcd16b64a2b = dnnl_ABcd16b64a2b,
1544  aBCd16b16c2b = dnnl_aBCd16b16c2b,
1545  aBCd16c16b4c = dnnl_aBCd16c16b4c,
1546  aBCd16c16b2c = dnnl_aBCd16c16b2c,
1547  aBCd4c4b = dnnl_aBCd4c4b,
1548  aBCd4b4c = dnnl_aBCd4b4c,
1549  ABcd8a16b2a = dnnl_ABcd8a16b2a,
1550  ABcd8a8b = dnnl_ABcd8a8b,
1551  ABcd8a4b = dnnl_ABcd8a4b,
1553  aBcd8b = dnnl_aBcd8b,
1554  ABcd8b16a2b = dnnl_ABcd8b16a2b,
1555  ABcd8b32a2b = dnnl_ABcd8b32a2b,
1556  ABcd8b64a2b = dnnl_ABcd8b64a2b,
1557  aBCd8b16c2b = dnnl_aBCd8b16c2b,
1559  ABcd8b8a = dnnl_ABcd8b8a,
1560  aBCd8b8c = dnnl_aBCd8b8c,
1561  aBCd8b4c = dnnl_aBCd8b4c,
1562  aBCd8c16b2c = dnnl_aBCd8c16b2c,
1563  aBCd8c8b = dnnl_aBCd8c8b,
1564  Abcde16a = dnnl_Abcde16a,
1565  Abcde32a = dnnl_Abcde32a,
1566  ABcde16a16b = dnnl_ABcde16a16b,
1567  aBcde16b = dnnl_aBcde16b,
1568  aBcde32b = dnnl_aBcde32b,
1569  ABcde16b16a = dnnl_ABcde16b16a,
1570  ABcde16b32a = dnnl_ABcde16b32a,
1571  ABcde16b64a = dnnl_ABcde16b64a,
1572  aBCde16b16c = dnnl_aBCde16b16c,
1573  aBCde16c16b = dnnl_aBCde16c16b,
1574  aBCde2c8b4c = dnnl_aBCde2c8b4c,
1575  Abcde4a = dnnl_Abcde4a,
1576  aBcde4b = dnnl_aBcde4b,
1577  ABcde4b4a = dnnl_ABcde4b4a,
1578  ABcde4a4b = dnnl_ABcde4a4b,
1579  aBCde4b4c = dnnl_aBCde4b4c,
1580  aBCde4c16b4c = dnnl_aBCde4c16b4c,
1581  aBCde16b16c2b = dnnl_aBCde16b16c2b,
1582  aBCde16c16b4c = dnnl_aBCde16c16b4c,
1583  aBCde16c16b2c = dnnl_aBCde16c16b2c,
1584  aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
1585  aBCde4c4b = dnnl_aBCde4c4b,
1586  Abcde8a = dnnl_Abcde8a,
1587  ABcde8a8b = dnnl_ABcde8a8b,
1588  ABcde8a4b = dnnl_ABcde8a4b,
1589  aBcde8b = dnnl_aBcde8b,
1590  ABcde8b16a2b = dnnl_ABcde8b16a2b,
1591  ABcde8b32a2b = dnnl_ABcde8b32a2b,
1592  ABcde8b64a2b = dnnl_ABcde8b64a2b,
1593  ABcde4b16a4b = dnnl_ABcde4b16a4b,
1594  ABcde4b32a4b = dnnl_ABcde4b32a4b,
1595  ABcde4b64a4b = dnnl_ABcde4b64a4b,
1596  ABcde16b16a4b = dnnl_ABcde16b16a4b,
1597  ABcde16b32a4b = dnnl_ABcde16b32a4b,
1598  ABcde16b48a4b = dnnl_ABcde16b48a4b,
1599  ABcde16b64a4b = dnnl_ABcde16b64a4b,
1600  ABcde16b16a2b = dnnl_ABcde16b16a2b,
1601  ABcde16b32a2b = dnnl_ABcde16b32a2b,
1602  ABcde16b48a2b = dnnl_ABcde16b48a2b,
1603  ABcde16b64a2b = dnnl_ABcde16b64a2b,
1604  ABcde2b8a4b = dnnl_ABcde2b8a4b,
1605  aBCde8b16c2b = dnnl_aBCde8b16c2b,
1606  ABcde8b8a = dnnl_ABcde8b8a,
1607  aBCde8b8c = dnnl_aBCde8b8c,
1608  aBCde8b4c = dnnl_aBCde8b4c,
1609  ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1610  ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1611  aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1612  aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1613  aBCde8c16b2c = dnnl_aBCde8c16b2c,
1614  aBCde8c8b = dnnl_aBCde8c8b,
1615  aBcdef16b = dnnl_aBcdef16b,
1616  aBCdef16b16c = dnnl_aBCdef16b16c,
1617  aBCdef16c16b = dnnl_aBCdef16c16b,
1618  aBcdef4b = dnnl_aBcdef4b,
1619  aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
1620  aBCdef4c4b = dnnl_aBCdef4c4b,
1621  aBCdef4b4c = dnnl_aBCdef4b4c,
1622  aBCdef8b8c = dnnl_aBCdef8b8c,
1623  aBCdef8b4c = dnnl_aBCdef8b4c,
1624  aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1625  aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
1626  aBCdef8c8b = dnnl_aBCdef8c8b,
1627  aBdc16b = dnnl_aBdc16b,
1628  aBdc4b = dnnl_aBdc4b,
1629  aBdc8b = dnnl_aBdc8b,
1630  aBdec16b = dnnl_aBdec16b,
1631  aBdec4b = dnnl_aBdec4b,
1632  aBdec8b = dnnl_aBdec8b,
1633  aBdefc16b = dnnl_aBdefc16b,
1634  aCBdef16c16b = dnnl_aCBdef16c16b,
1635  aCBdef16b16c = dnnl_aCBdef16b16c,
1636  aBdefc4b = dnnl_aBdefc4b,
1637  aBdefc8b = dnnl_aBdefc8b,
1638  Acb16a = dnnl_Acb16a,
1639  Acb4a = dnnl_Acb4a,
1640  Acb8a = dnnl_Acb8a,
1641  aCBd16b16c = dnnl_aCBd16b16c,
1642  aCBd16c16b = dnnl_aCBd16c16b,
1643  aCBde16b16c = dnnl_aCBde16b16c,
1644  aCBde16c16b = dnnl_aCBde16c16b,
1645  Acdb16a = dnnl_Acdb16a,
1646  Acdb4a = dnnl_Acdb4a,
1647  Acdb8a = dnnl_Acdb8a,
1648  Acdeb16a = dnnl_Acdeb16a,
1649  Acdeb4a = dnnl_Acdeb4a,
1650  Acdeb8a = dnnl_Acdeb8a,
1651  BAc16a16b = dnnl_BAc16a16b,
1652  BAc16b16a = dnnl_BAc16b16a,
1653  BAcd16a16b = dnnl_BAcd16a16b,
1654  BAcd16b16a = dnnl_BAcd16b16a,
1655  ABcd32a32b = dnnl_ABcd32a32b,
1656  BAcde16b16a = dnnl_BAcde16b16a,
1657  BAcde16a16b = dnnl_BAcde16a16b,
1658  aBdec32b = dnnl_aBdec32b,
1659  Abcdef16a = dnnl_Abcdef16a,
1660  Abcdef32a = dnnl_Abcdef32a,
1661  Acdb32a = dnnl_Acdb32a,
1662  aBCd2b4c2b = dnnl_aBCd2b4c2b,
1663  aBCde2b4c2b = dnnl_aBCde2b4c2b,
1664  aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
1665  aBCd2c4b2c = dnnl_aBCd2c4b2c,
1666  aBCde2c4b2c = dnnl_aBCde2c4b2c,
1667  aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
1668  aBCd4b8c2b = dnnl_aBCd4b8c2b,
1669  aBCde4b8c2b = dnnl_aBCde4b8c2b,
1670  aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
1671  aBCd4c8b2c = dnnl_aBCd4c8b2c,
1672  aBCde4c8b2c = dnnl_aBCde4c8b2c,
1673  aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
1674  AB32a32b8a4b = dnnl_AB32a32b8a4b,
1675  AB32a32b8a2b = dnnl_AB32a32b8a2b,
1676  AB8a4b = dnnl_AB8a4b,
1677  AB8a2b = dnnl_AB8a2b,
1678  abDc32d = dnnl_abDc32d,
1679  abDC32d4c = dnnl_abDC32d4c,
1680  abCd32c = dnnl_abCd32c,
1681  abdEc32e = dnnl_abdEc32e,
1682  abdEC32e2c = dnnl_abdEC32e2c,
1683  abdEC32e4c = dnnl_abdEC32e4c,
1684  abdCe32c = dnnl_abdCe32c,
1685  abdCE32c2e = dnnl_abdCE32c2e,
1686  aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
1687  aBdC16b4c = dnnl_aBdC16b4c,
1688  aBdeC16b4c = dnnl_aBdeC16b4c,
1689  AcB16a4b = dnnl_AcB16a4b,
1690  AcdB16a2b = dnnl_AcdB16a2b,
1691  aBdefC16b4c = dnnl_aBdefC16b4c,
1692  AcdeB16a4b = dnnl_AcdeB16a4b,
1693 
1694  Acb32a = dnnl_Acb32a,
1695  AcB32a2b = dnnl_AcB32a2b,
1696  AcB32a4b = dnnl_AcB32a4b,
1697  Acb48a = dnnl_Acb48a,
1698  AcB48a2b = dnnl_AcB48a2b,
1699  AcB48a4b = dnnl_AcB48a4b,
1700  Acb64a = dnnl_Acb64a,
1701  AcB64a2b = dnnl_AcB64a2b,
1702  AcB64a4b = dnnl_AcB64a4b,
1703  cBa2b = dnnl_cBa2b,
1704  cBa4b = dnnl_cBa4b,
1705  aBdc32b = dnnl_aBdc32b,
1706  aBdC32b2c = dnnl_aBdC32b2c,
1707  aBdC32b4c = dnnl_aBdC32b4c,
1708  aBdc48b = dnnl_aBdc48b,
1709  aBdC48b2c = dnnl_aBdC48b2c,
1710  aBdC48b4c = dnnl_aBdC48b4c,
1711  aBdc64b = dnnl_aBdc64b,
1712  aBdC64b2c = dnnl_aBdC64b2c,
1713  aBdC64b4c = dnnl_aBdC64b4c,
1714  adcb = dnnl_adcb,
1715  adCb2c = dnnl_adCb2c,
1716  adCb4c = dnnl_adCb4c,
1717  AcdB32a2b = dnnl_AcdB32a2b,
1718  AcdB32a4b = dnnl_AcdB32a4b,
1719  Acdb48a = dnnl_Acdb48a,
1720  AcdB48a2b = dnnl_AcdB48a2b,
1721  AcdB48a4b = dnnl_AcdB48a4b,
1722  Acdb64a = dnnl_Acdb64a,
1723  AcdB64a2b = dnnl_AcdB64a2b,
1724  AcdB64a4b = dnnl_AcdB64a4b,
1725  cdBa2b = dnnl_cdBa2b,
1726  cdBa4b = dnnl_cdBa4b,
1727  aBdeC32b2c = dnnl_aBdeC32b2c,
1728  aBdeC32b4c = dnnl_aBdeC32b4c,
1729  aBdec48b = dnnl_aBdec48b,
1730  aBdeC48b2c = dnnl_aBdeC48b2c,
1731  aBdeC48b4c = dnnl_aBdeC48b4c,
1732  aBdec64b = dnnl_aBdec64b,
1733  aBdeC64b2c = dnnl_aBdeC64b2c,
1734  aBdeC64b4c = dnnl_aBdeC64b4c,
1735  adecb = dnnl_adecb,
1736  adeCb2c = dnnl_adeCb2c,
1737  adeCb4c = dnnl_adeCb4c,
1738  Acdeb32a = dnnl_Acdeb32a,
1739  AcdeB32a2b = dnnl_AcdeB32a2b,
1740  AcdeB32a4b = dnnl_AcdeB32a4b,
1741  Acdeb48a = dnnl_Acdeb48a,
1742  AcdeB48a2b = dnnl_AcdeB48a2b,
1743  AcdeB48a4b = dnnl_AcdeB48a4b,
1744  Acdeb64a = dnnl_Acdeb64a,
1745  AcdeB64a2b = dnnl_AcdeB64a2b,
1746  AcdeB64a4b = dnnl_AcdeB64a4b,
1747  cdeBa2b = dnnl_cdeBa2b,
1748  cdeBa4b = dnnl_cdeBa4b,
1749  aBdefc32b = dnnl_aBdefc32b,
1750  aBdefC32b2c = dnnl_aBdefC32b2c,
1751  aBdefC32b4c = dnnl_aBdefC32b4c,
1752  aBdefc48b = dnnl_aBdefc48b,
1753  aBdefC48b2c = dnnl_aBdefC48b2c,
1754  aBdefC48b4c = dnnl_aBdefC48b4c,
1755  aBdefc64b = dnnl_aBdefc64b,
1756  aBdefC64b2c = dnnl_aBdefC64b2c,
1757  aBdefC64b4c = dnnl_aBdefC64b4c,
1758  adefcb = dnnl_adefcb,
1759  adefCb2c = dnnl_adefCb2c,
1760  adefCb4c = dnnl_adefCb4c,
1761 
1762  format_tag_last = dnnl_format_tag_last,
1763 
1764  nCdhw16c = dnnl_nCdhw16c,
1765  nCdhw4c = dnnl_nCdhw4c,
1766  nCdhw8c = dnnl_nCdhw8c,
1767  nChw16c = dnnl_nChw16c,
1768  nChw4c = dnnl_nChw4c,
1769  nChw8c = dnnl_nChw8c,
1770  nCw16c = dnnl_nCw16c,
1771  nCw4c = dnnl_nCw4c,
1772  nCw8c = dnnl_nCw8c,
1773  NCw16n16c = dnnl_NCw16n16c,
1774  NChw16n16c = dnnl_NChw16n16c,
1775  NCdhw16n16c = dnnl_NCdhw16n16c,
1776  NCdhw32n32c = dnnl_NCdhw32n32c,
1777  NChw32n32c = dnnl_NChw32n32c,
1778  IOhw16i16o = dnnl_IOhw16i16o,
1779  OI16i16o = dnnl_OI16i16o,
1780  OI16i32o = dnnl_OI16i32o,
1781  OI16i64o = dnnl_OI16i64o,
1782  OI8i16o2i = dnnl_OI8i16o2i,
1783  OI8i32o2i = dnnl_OI8i32o2i,
1784  OI8i64o2i = dnnl_OI8i64o2i,
1785  OI4i16o4i = dnnl_OI4i16o4i,
1786  OI4i32o4i = dnnl_OI4i32o4i,
1787  OI4i64o4i = dnnl_OI4i64o4i,
1788  Ohwi32o = dnnl_Ohwi32o,
1789  IOdhw16i16o = dnnl_IOdhw16i16o,
1790  gIOhw16i16o = dnnl_gIOhw16i16o,
1791  gOhwi32o = dnnl_gOhwi32o,
1792  Goidhw16g = dnnl_Goidhw16g,
1793  IOw16o16i = dnnl_IOw16o16i,
1794  OIw16i16o = dnnl_OIw16i16o,
1795  OIw16i32o = dnnl_OIw16i32o,
1796  OIw16i64o = dnnl_OIw16i64o,
1797  IOw16i16o = dnnl_IOw16i16o,
1798  gIOw16i16o = dnnl_gIOw16i16o,
1799  OIw16o16i = dnnl_OIw16o16i,
1800  Oiw16o = dnnl_Oiw16o,
1801  OIw4i16o4i = dnnl_OIw4i16o4i,
1802  OIw4i32o4i = dnnl_OIw4i32o4i,
1803  OIw4i64o4i = dnnl_OIw4i64o4i,
1804  OIw2i8o4i = dnnl_OIw2i8o4i,
1805  OIw4i4o = dnnl_OIw4i4o,
1806  OIw4o4i = dnnl_OIw4o4i,
1807  Oiw4o = dnnl_Oiw4o,
1808  OIw8i16o2i = dnnl_OIw8i16o2i,
1809  OIw8i32o2i = dnnl_OIw8i32o2i,
1810  OIw8i64o2i = dnnl_OIw8i64o2i,
1811  OIw8i8o = dnnl_OIw8i8o,
1812  OIw8o16i2o = dnnl_OIw8o16i2o,
1813  OIw8o8i = dnnl_OIw8o8i,
1814  OIw8o4i = dnnl_OIw8o4i,
1815  OIw16i16o4i = dnnl_OIw16i16o4i,
1816  OIw16i32o4i = dnnl_OIw16i32o4i,
1817  OIw16i48o4i = dnnl_OIw16i48o4i,
1818  OIw16i64o4i = dnnl_OIw16i64o4i,
1819  OIw16i16o2i = dnnl_OIw16i16o2i,
1820  OIw16i32o2i = dnnl_OIw16i32o2i,
1821  OIw16i48o2i = dnnl_OIw16i48o2i,
1822  OIw16i64o2i = dnnl_OIw16i64o2i,
1823  OIw16o16i2o = dnnl_OIw16o16i2o,
1824  Owi16o = dnnl_Owi16o,
1825  OwI16o2i = dnnl_OwI16o2i,
1826  Owi4o = dnnl_Owi4o,
1827  Owi8o = dnnl_Owi8o,
1828  IOhw16o16i = dnnl_IOhw16o16i,
1829  Ohwi16o = dnnl_Ohwi16o,
1830  OhwI16o2i = dnnl_OhwI16o2i,
1831  Ohwi4o = dnnl_Ohwi4o,
1832  Ohwi8o = dnnl_Ohwi8o,
1833  OIhw16i16o = dnnl_OIhw16i16o,
1834  OIhw16i32o = dnnl_OIhw16i32o,
1835  OIhw16i64o = dnnl_OIhw16i64o,
1836  OIhw16o16i = dnnl_OIhw16o16i,
1837  Oihw16o = dnnl_Oihw16o,
1838  OIhw4i16o4i = dnnl_OIhw4i16o4i,
1839  OIhw4i32o4i = dnnl_OIhw4i32o4i,
1840  OIhw4i64o4i = dnnl_OIhw4i64o4i,
1841  OIhw4i4o = dnnl_OIhw4i4o,
1842  OIhw4o4i = dnnl_OIhw4o4i,
1843  Oihw4o = dnnl_Oihw4o,
1844  OIhw8i16o2i = dnnl_OIhw8i16o2i,
1845  OIhw8i32o2i = dnnl_OIhw8i32o2i,
1846  OIhw8i64o2i = dnnl_OIhw8i64o2i,
1847  OIhw8i8o = dnnl_OIhw8i8o,
1848  OIhw8o16i2o = dnnl_OIhw8o16i2o,
1849  OIhw8o8i = dnnl_OIhw8o8i,
1850  OIhw8o4i = dnnl_OIhw8o4i,
1851  OIhw2i8o4i = dnnl_OIhw2i8o4i,
1852  IOdhw16o16i = dnnl_IOdhw16o16i,
1853  Odhwi16o = dnnl_Odhwi16o,
1854  OdhwI16o2i = dnnl_OdhwI16o2i,
1855  Odhwi4o = dnnl_Odhwi4o,
1856  Odhwi8o = dnnl_Odhwi8o,
1857  OIdhw16i16o = dnnl_OIdhw16i16o,
1858  OIdhw16i32o = dnnl_OIdhw16i32o,
1859  OIdhw16i64o = dnnl_OIdhw16i64o,
1860  OIdhw16o16i = dnnl_OIdhw16o16i,
1861  Oidhw16o = dnnl_Oidhw16o,
1862  OIdhw4i4o = dnnl_OIdhw4i4o,
1863  OIdhw4o4i = dnnl_OIdhw4o4i,
1864  Oidhw4o = dnnl_Oidhw4o,
1865  OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1866  OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
1867  OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
1868  OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
1869  OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
1870  OIdhw16i32o4i = dnnl_OIdhw16i32o4i,
1871  OIdhw16i48o4i = dnnl_OIdhw16i48o4i,
1872  OIdhw16i64o4i = dnnl_OIdhw16i64o4i,
1873  OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
1874  OIdhw16i32o2i = dnnl_OIdhw16i32o2i,
1875  OIdhw16i48o2i = dnnl_OIdhw16i48o2i,
1876  OIdhw16i64o2i = dnnl_OIdhw16i64o2i,
1877  OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
1878  OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
1879  OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
1880  OIdhw8i8o = dnnl_OIdhw8i8o,
1881  OIdhw8o8i = dnnl_OIdhw8o8i,
1882  OIdhw8o4i = dnnl_OIdhw8o4i,
1883  gIOw16o16i = dnnl_gIOw16o16i,
1884  gOIw16i16o = dnnl_gOIw16i16o,
1885  gOIw16o16i = dnnl_gOIw16o16i,
1886  gOiw16o = dnnl_gOiw16o,
1887  gOIw4i16o4i = dnnl_gOIw4i16o4i,
1888  gOIw2i8o4i = dnnl_gOIw2i8o4i,
1889  gOIw4i4o = dnnl_gOIw4i4o,
1890  gOIw4o4i = dnnl_gOIw4o4i,
1891  gOiw4o = dnnl_gOiw4o,
1892  gOIw8i16o2i = dnnl_gOIw8i16o2i,
1893  gOIw8i8o = dnnl_gOIw8i8o,
1894  gOIw8o16i2o = dnnl_gOIw8o16i2o,
1895  gOIw8o8i = dnnl_gOIw8o8i,
1896  gOIw8o4i = dnnl_gOIw8o4i,
1897  gOIw16i16o4i = dnnl_gOIw16i16o4i,
1898  gOIw16i16o2i = dnnl_gOIw16i16o2i,
1899  gOIw16o16i2o = dnnl_gOIw16o16i2o,
1900  gOwi16o = dnnl_gOwi16o,
1901  gOwI16o2i = dnnl_gOwI16o2i,
1902  gOwi4o = dnnl_gOwi4o,
1903  gOwi8o = dnnl_gOwi8o,
1904  Goiw8g = dnnl_Goiw8g,
1905  Goiw16g = dnnl_Goiw16g,
1906  gIOhw16o16i = dnnl_gIOhw16o16i,
1907  gOhwi16o = dnnl_gOhwi16o,
1908  gOhwI16o2i = dnnl_gOhwI16o2i,
1909  gOhwi4o = dnnl_gOhwi4o,
1910  gOhwi8o = dnnl_gOhwi8o,
1911  Goihw16g = dnnl_Goihw16g,
1912  gOIhw16i16o = dnnl_gOIhw16i16o,
1913  gOIhw16o16i = dnnl_gOIhw16o16i,
1914  gOihw16o = dnnl_gOihw16o,
1915  gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1916  gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1917  gOIhw4i4o = dnnl_gOIhw4i4o,
1918  gOIhw4o4i = dnnl_gOIhw4o4i,
1919  gOihw4o = dnnl_gOihw4o,
1920  Goihw8g = dnnl_Goihw8g,
1921  gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1922  gOIhw8i8o = dnnl_gOIhw8i8o,
1923  gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1924  OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
1925  OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
1926  OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1927  OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1928  gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
1929  gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
1930  gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1931  gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1932  OIhw16i16o4i = dnnl_OIhw16i16o4i,
1933  OIhw16i32o4i = dnnl_OIhw16i32o4i,
1934  OIhw16i48o4i = dnnl_OIhw16i48o4i,
1935  OIhw16i64o4i = dnnl_OIhw16i64o4i,
1936  OIhw16i16o2i = dnnl_OIhw16i16o2i,
1937  OIhw16i32o2i = dnnl_OIhw16i32o2i,
1938  OIhw16i48o2i = dnnl_OIhw16i48o2i,
1939  OIhw16i64o2i = dnnl_OIhw16i64o2i,
1940  OIhw16o16i2o = dnnl_OIhw16o16i2o,
1941  gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
1942  gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
1943  gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
1944  gOIhw8o8i = dnnl_gOIhw8o8i,
1945  gOIhw8o4i = dnnl_gOIhw8o4i,
1946  gIOdhw16i16o = dnnl_gIOdhw16i16o,
1947  gIOdhw16o16i = dnnl_gIOdhw16o16i,
1948  gOdhwi16o = dnnl_gOdhwi16o,
1949  gOdhwI16o2i = dnnl_gOdhwI16o2i,
1950  gOdhwi4o = dnnl_gOdhwi4o,
1951  gOdhwi8o = dnnl_gOdhwi8o,
1952  gOIdhw16i16o = dnnl_gOIdhw16i16o,
1953  gOIdhw16o16i = dnnl_gOIdhw16o16i,
1954  gOidhw16o = dnnl_gOidhw16o,
1955  gOIdhw4i4o = dnnl_gOIdhw4i4o,
1956  gOIdhw4o4i = dnnl_gOIdhw4o4i,
1957  gOidhw4o = dnnl_gOidhw4o,
1958  gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1959  gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
1960  gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
1961  gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
1962  gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
1963  gOIdhw8i8o = dnnl_gOIdhw8i8o,
1964  gOIdhw8o8i = dnnl_gOIdhw8o8i,
1965  gOIdhw8o4i = dnnl_gOIdhw8o4i,
1966  gOIw2i4o2i = dnnl_gOIw2i4o2i,
1967  gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
1968  gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
1969  gOIw2o4i2o = dnnl_gOIw2o4i2o,
1970  gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
1971  gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
1972  gOIw4i8o2i = dnnl_gOIw4i8o2i,
1973  gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
1974  gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
1975  gOIw4o8i2o = dnnl_gOIw4o8i2o,
1976  gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
1977  gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
1978  ldOi32o = abDc32d,
1979  ldOI32o4i = abDC32d4c,
1980  ldgOi32o = abdEc32e,
1981  ldgOI32o2i = abdEC32e2c,
1982  ldgOI32o4i = abdEC32e4c,
1983  OwI16o4i = dnnl_OwI16o4i,
1984  OhwI16o4i = dnnl_OhwI16o4i,
1985  gOwI16o4i = dnnl_gOwI16o4i,
1986  gOhwI16o4i = dnnl_gOhwI16o4i,
1987  OdhwI16o4i = dnnl_OdhwI16o4i,
1988  gOdhwI16o4i = dnnl_gOdhwI16o4i,
1989 
1990  Owi32o = dnnl_Owi32o,
1991  OwI32o2i = dnnl_OwI32o2i,
1992  OwI32o4i = dnnl_OwI32o4i,
1993  Owi48o = dnnl_Owi48o,
1994  OwI48o2i = dnnl_OwI48o2i,
1995  OwI48o4i = dnnl_OwI48o4i,
1996  Owi64o = dnnl_Owi64o,
1997  OwI64o2i = dnnl_OwI64o2i,
1998  OwI64o4i = dnnl_OwI64o4i,
1999  wIo2i = dnnl_wIo2i,
2000  wIo4i = dnnl_wIo4i,
2001  gOwi32o = dnnl_gOwi32o,
2002  gOwI32o2i = dnnl_gOwI32o2i,
2003  gOwI32o4i = dnnl_gOwI32o4i,
2004  gOwi48o = dnnl_gOwi48o,
2005  gOwI48o2i = dnnl_gOwI48o2i,
2006  gOwI48o4i = dnnl_gOwI48o4i,
2007  gOwi64o = dnnl_gOwi64o,
2008  gOwI64o2i = dnnl_gOwI64o2i,
2009  gOwI64o4i = dnnl_gOwI64o4i,
2010  gwio = dnnl_gwio,
2011  gwIo2i = dnnl_gwIo2i,
2012  gwIo4i = dnnl_gwIo4i,
2013  OhwI32o = dnnl_OhwI32o,
2014  OhwI32o2i = dnnl_OhwI32o2i,
2015  OhwI32o4i = dnnl_OhwI32o4i,
2016  Ohwi48o = dnnl_Ohwi48o,
2017  OhwI48o2i = dnnl_OhwI48o2i,
2018  OhwI48o4i = dnnl_OhwI48o4i,
2019  Ohwi64o = dnnl_Ohwi64o,
2020  OhwI64o2i = dnnl_OhwI64o2i,
2021  OhwI64o4i = dnnl_OhwI64o4i,
2022  hwIo2i = dnnl_hwIo2i,
2023  hwIo4i = dnnl_hwIo4i,
2024  gOhwI32o = dnnl_gOhwI32o,
2025  gOhwI32o2i = dnnl_gOhwI32o2i,
2026  gOhwI32o4i = dnnl_gOhwI32o4i,
2027  gOhwi48o = dnnl_gOhwi48o,
2028  gOhwI48o2i = dnnl_gOhwI48o2i,
2029  gOhwI48o4i = dnnl_gOhwI48o4i,
2030  gOhwi64o = dnnl_gOhwi64o,
2031  gOhwI64o2i = dnnl_gOhwI64o2i,
2032  gOhwI64o4i = dnnl_gOhwI64o4i,
2033  ghwio = dnnl_ghwio,
2034  ghwIo2i = dnnl_ghwIo2i,
2035  ghwIo4i = dnnl_ghwIo4i,
2036  Odhwi32o = dnnl_Odhwi32o,
2037  OdhwI32o2i = dnnl_OdhwI32o2i,
2038  OdhwI32o4i = dnnl_OdhwI32o4i,
2039  Odhwi48o = dnnl_Odhwi48o,
2040  OdhwI48o2i = dnnl_OdhwI48o2i,
2041  OdhwI48o4i = dnnl_OdhwI48o4i,
2042  Odhwi64o = dnnl_Odhwi64o,
2043  OdhwI64o2i = dnnl_OdhwI64o2i,
2044  OdhwI64o4i = dnnl_OdhwI64o4i,
2045  dhwIo2i = dnnl_dhwIo2i,
2046  dhwIo4i = dnnl_dhwIo4i,
2047  gOdhwi32o = dnnl_gOdhwi32o,
2048  gOdhwI32o2i = dnnl_gOdhwI32o2i,
2049  gOdhwI32o4i = dnnl_gOdhwI32o4i,
2050  gOdhwi48o = dnnl_gOdhwi48o,
2051  gOdhwI48o2i = dnnl_gOdhwI48o2i,
2052  gOdhwI48o4i = dnnl_gOdhwI48o4i,
2053  gOdhwi64o = dnnl_gOdhwi64o,
2054  gOdhwI64o2i = dnnl_gOdhwI64o2i,
2055  gOdhwI64o4i = dnnl_gOdhwI64o4i,
2056  gdhwio = dnnl_gdhwio,
2057  gdhwIo2i = dnnl_gdhwIo2i,
2058  gdhwIo4i = dnnl_gdhwIo4i,
2059  };
2060 
2062  struct desc {
2063  friend struct memory;
2066 
2069  desc() : data() {}
2070 
2086  desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
2087  bool allow_empty = false)
2088  : data() {
2089  validate_dims(adims);
2091  (int)adims.size(), adims.data(), convert_to_c(adata_type),
2092  convert_to_c(aformat_tag));
2093  if (!allow_empty)
2095  "could not construct a memory descriptor using a "
2096  "format tag");
2097  }
2098 
2114  desc(const dims &adims, data_type adata_type, const dims &strides,
2115  bool allow_empty = false)
2116  : data() {
2117  validate_dims(adims);
2118  if (!strides.empty()) validate_dims(strides, (int)adims.size());
2120  (int)adims.size(), adims.data(), convert_to_c(adata_type),
2121  strides.empty() ? nullptr : &strides[0]);
2122  if (!allow_empty)
2124  "could not construct a memory descriptor using "
2125  "strides");
2126  }
2127 
2131  desc(const dnnl_memory_desc_t &data) : data(data) {}
2132 
2135  //
2144  desc submemory_desc(const dims &adims, const dims &offsets,
2145  bool allow_empty = false) const {
2146  validate_dims(adims, data.ndims);
2147  validate_dims(offsets, data.ndims);
2150  &sub_md, &data, adims.data(), offsets.data());
2151  if (!allow_empty)
2152  error::wrap_c_api(status, "could not construct a sub-memory");
2153  return desc(sub_md);
2154  }
2155 
2200  desc reshape(const dims &adims, bool allow_empty = false) const {
2201  if (data.ndims) validate_dims(adims, 1);
2204  &out_md, &data, (int)adims.size(), adims.data());
2205  if (!allow_empty)
2207  status, "could not reshape a memory descriptor");
2208  return desc(out_md);
2209  }
2210 
2248  desc permute_axes(const std::vector<int> &permutation,
2249  bool allow_empty = false) const {
2250  validate_dims(permutation, data.ndims);
2253  &out_md, &data, permutation.data());
2254  if (!allow_empty)
2256  "could not permute axes of a memory descriptor");
2257  return desc(out_md);
2258  }
2259 
2263  return static_cast<memory::data_type>(data.data_type);
2264  }
2265 
2270  memory::dims dims() const {
2271  return memory::dims(data.dims, data.dims + data.ndims);
2272  }
2273 
2278  size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
2279 
2283  bool is_zero() const { return data.ndims == 0; }
2284 
2289  bool operator==(const desc &other) const {
2290  return dnnl_memory_desc_equal(&data, &other.data) != 0;
2291  }
2292 
2297  bool operator!=(const desc &other) const { return !operator==(other); }
2298 
2302  explicit operator bool() const { return data.ndims != 0; }
2303  };
2304 
2309  memory() = default;
2310 
2330  memory(const desc &md, const engine &aengine, void *handle) {
2331  dnnl_memory_t result;
2333  dnnl_memory_create(&result, &md.data, aengine.get(), handle),
2334  "could not create a memory object");
2335  reset(result);
2336  }
2337 
2344  memory(const desc &md, const engine &aengine)
2345  : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
2346 
2348  desc get_desc() const {
2349  const dnnl_memory_desc_t *cdesc;
2351  "could not get a memory descriptor from a memory object");
2352  return desc(*cdesc);
2353  }
2354 
2356  engine get_engine() const {
2357  dnnl_engine_t c_engine;
2358  error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
2359  "could not get an engine from a memory object");
2360  return engine(c_engine, true);
2361  }
2362 
2367  void *get_data_handle() const {
2368  void *handle;
2370  "could not get a native handle from a memory object");
2371  return handle;
2372  }
2373 
2402  void set_data_handle(void *handle, const stream &astream) const {
2404  get(), handle, astream.get(true)),
2405  "could not set native handle of a memory object");
2406  }
2407 
2418  void set_data_handle(void *handle) const {
2420  dnnl_memory_set_data_handle_v2(get(), handle, nullptr),
2421  "could not set native handle of a memory object");
2422  }
2423 
2445  template <typename T = void>
2446  T *map_data() const {
2447  void *mapped_ptr;
2448  error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
2449  "could not map memory object data");
2450  return static_cast<T *>(mapped_ptr);
2451  }
2452 
2463  void unmap_data(void *mapped_ptr) const {
2464  error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
2465  "could not unmap memory object data");
2466  }
2467 
2468  static dnnl_data_type_t convert_to_c(data_type adata_type) {
2469  return static_cast<dnnl_data_type_t>(adata_type);
2470  }
2471  static dnnl_format_tag_t convert_to_c(format_tag format) {
2472  return static_cast<dnnl_format_tag_t>(format);
2473  }
2474 };
2475 
2476 inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
2477  return a == memory::convert_to_c(b);
2478 }
2479 inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
2480  return !(a == b);
2481 }
2482 inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
2483  return b == a;
2484 }
2485 inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
2486  return !(a == b);
2487 }
2488 
2489 inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
2490  return a == memory::convert_to_c(b);
2491 }
2492 inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
2493  return !(a == b);
2494 }
2495 inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
2496  return b == a;
2497 }
2498 inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
2499  return !(a == b);
2500 }
2501 
2503 
2511 
2513 template <>
2514 struct handle_traits<dnnl_post_ops_t> {
2515  static dnnl_status_t destructor(dnnl_post_ops_t p) {
2516  return dnnl_post_ops_destroy(p);
2517  }
2518 };
2520 
2528 struct post_ops : public handle<dnnl_post_ops_t> {
2530 
2533  dnnl_post_ops_t result;
2535  dnnl_post_ops_create(&result), "could not create post-ops");
2536  reset(result);
2537  }
2538 
2540  int len() const { return dnnl_post_ops_len(get()); }
2541 
2545  primitive::kind kind(int index) const {
2547  "post-ops index is out of range");
2548  return static_cast<primitive::kind>(
2549  dnnl_post_ops_get_kind(get(), index));
2550  }
2551 
2580  void append_sum(float scale = 1.f,
2582  if (data_type == memory::data_type::undef)
2584  "could not append a sum post-op");
2585  else
2587  memory::convert_to_c(data_type)),
2588  "could not append a sum post-op");
2589  }
2590 
2595  void get_params_sum(int index, float &scale) const {
2597  "could not get parameters of a sum post-op");
2598  }
2599 
2606  int index, float &scale, memory::data_type &data_type) const {
2607  dnnl_data_type_t c_data_type;
2609  get(), index, &scale, &c_data_type),
2610  "could not get parameters of a sum post-op");
2611  data_type = static_cast<memory::data_type>(c_data_type);
2612  }
2613 
2628  float scale, algorithm aalgorithm, float alpha, float beta) {
2630  convert_to_c(aalgorithm), alpha, beta),
2631  "could not append an elementwise post-op");
2632  }
2633 
2641  void get_params_eltwise(int index, float &scale, algorithm &aalgorithm,
2642  float &alpha, float &beta) const {
2643  dnnl_alg_kind_t c_alg;
2645  get(), index, &scale, &c_alg, &alpha, &beta),
2646  "could not get parameters of an elementwise post-op");
2647  aalgorithm = static_cast<dnnl::algorithm>(c_alg);
2648  }
2649 
2678  void append_dw_k3s1p1(memory::data_type weights_data_type,
2679  memory::data_type bias_data_type, memory::data_type dst_data_type,
2680  int mask, const std::vector<float> &scales) {
2681 
2683  memory::convert_to_c(weights_data_type),
2684  memory::convert_to_c(bias_data_type),
2685  memory::convert_to_c(dst_data_type),
2686  scales.size(), mask, &scales[0]),
2687  "could not append depthwise post-op");
2688  }
2689 
2704  void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type,
2705  memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2706  int &mask, std::vector<float> &scales) const {
2707 
2708  dnnl_data_type_t c_weights_data_type;
2709  dnnl_data_type_t c_bias_data_type;
2710  dnnl_data_type_t c_dst_data_type;
2711  dnnl_dim_t count;
2712  int c_mask;
2713  const float *c_scales;
2715  &c_weights_data_type, &c_bias_data_type,
2716  &c_dst_data_type, &count, &c_mask, &c_scales),
2717  "could not get parameters of depthwise post-op");
2718 
2719  weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2720  bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2721  dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2722  scales.resize(count);
2723 
2724  mask = c_mask;
2725  for (dnnl_dim_t c = 0; c < count; ++c)
2726  scales[c] = c_scales[c];
2727  return;
2728  }
2729 
2763  void append_dw_k3s2p1(memory::data_type weights_data_type,
2764  memory::data_type bias_data_type, memory::data_type dst_data_type,
2765  int mask, const std::vector<float> &scales) {
2766 
2768  memory::convert_to_c(weights_data_type),
2769  memory::convert_to_c(bias_data_type),
2770  memory::convert_to_c(dst_data_type),
2771  scales.size(), mask, &scales[0]),
2772  "could not append depthwise post-op");
2773  }
2774 
2789  void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type,
2790  memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2791  int &mask, std::vector<float> &scales) const {
2792 
2793  dnnl_data_type_t c_weights_data_type;
2794  dnnl_data_type_t c_bias_data_type;
2795  dnnl_data_type_t c_dst_data_type;
2796  dnnl_dim_t count;
2797  int c_mask;
2798  const float *c_scales;
2800  &c_weights_data_type, &c_bias_data_type,
2801  &c_dst_data_type, &count, &c_mask, &c_scales),
2802  "could not get parameters of depthwise post-op");
2803 
2804  weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2805  bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2806  dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2807  scales.resize(count);
2808 
2809  mask = c_mask;
2810  for (dnnl_dim_t c = 0; c < count; ++c)
2811  scales[c] = c_scales[c];
2812  return;
2813  }
2814 
2829  void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
2831  convert_to_c(aalgorithm), &src1_desc.data),
2832  "could not append a binary post-op");
2833  }
2834 
2841  int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
2842  dnnl_alg_kind_t c_alg;
2843  const dnnl_memory_desc_t *data;
2845  dnnl_post_ops_get_params_binary(get(), index, &c_alg, &data),
2846  "could not get parameters of a binary post-op");
2847  aalgorithm = static_cast<dnnl::algorithm>(c_alg);
2848  src1_desc.data = *data;
2849  }
2850 };
2851 
2853 template <>
2854 struct handle_traits<dnnl_primitive_attr_t> {
2855  static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
2856  return dnnl_primitive_attr_destroy(p);
2857  }
2858 };
2860 
2864 struct primitive_attr : public handle<dnnl_primitive_attr_t> {
2866 
2869  dnnl_primitive_attr_t result;
2871  "could not create primitive attribute");
2872  reset(result);
2873  }
2874 
2881  : handle<dnnl_primitive_attr_t>(attr) {}
2882 
2885  dnnl_scratchpad_mode_t result;
2888  "could not get scratchpad mode primitive attribute");
2889  return scratchpad_mode(result);
2890  }
2891 
2897  get(), dnnl::convert_to_c(mode)),
2898  "could not set scratchpad mode primitive attribute");
2899  }
2900 
2910  void get_output_scales(int &mask, std::vector<float> &scales) const {
2911  dnnl_dim_t count;
2912  int c_mask;
2913  const float *c_scales;
2915  get(), &count, &c_mask, &c_scales),
2916  "could not get output scales primitive attribute");
2917  scales.resize(count);
2918 
2919  mask = c_mask;
2920  for (dnnl_dim_t c = 0; c < count; ++c)
2921  scales[c] = c_scales[c];
2922  }
2923 
2966  void set_output_scales(int mask, const std::vector<float> &scales) {
2969  get(), (dnnl_dim_t)scales.size(), mask, scales.data()),
2970  "could not set output scales primitive attribute");
2971  }
2972 
2984  void get_scales(int arg, int &mask, std::vector<float> &scales) const {
2985  dnnl_dim_t count;
2986  int c_mask;
2987  const float *c_scales;
2989  get(), arg, &count, &c_mask, &c_scales),
2990  "could not get scales primitive attributes");
2991  scales.resize(count);
2992 
2993  mask = c_mask;
2994  for (dnnl_dim_t c = 0; c < count; ++c)
2995  scales[c] = c_scales[c];
2996  }
2997 
3014  void set_scales(int arg, int mask, const std::vector<float> &scales) {
3017  (dnnl_dim_t)scales.size(), mask, scales.data()),
3018  "could not set scales primitive attribute");
3019  }
3020 
3032  int arg, int &mask, std::vector<int32_t> &zero_points) const {
3033  dnnl_dim_t count;
3034  int c_mask;
3035  const int32_t *c_zero_points;
3037  get(), arg, &count, &c_mask, &c_zero_points),
3038  "could not get zero points primitive attribute");
3039  zero_points.resize(count);
3040 
3041  mask = c_mask;
3042  for (dnnl_dim_t c = 0; c < count; ++c)
3043  zero_points[c] = c_zero_points[c];
3044  }
3045 
3067  int arg, int mask, const std::vector<int32_t> &zero_points) {
3069  (dnnl_dim_t)zero_points.size(), mask,
3070  zero_points.data()),
3071  "could not set zero points primitive attribute");
3072  }
3073 
3077  const post_ops get_post_ops() const {
3078  post_ops result;
3079  const_dnnl_post_ops_t c_result;
3081  "could not get post-ops primitive attribute");
3082  result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
3083  return result;
3084  }
3085 
3094  void set_post_ops(const post_ops ops) {
3096  "could not set post-ops primitive attribute");
3097  }
3098 
3132  void set_rnn_data_qparams(float scale, float shift) {
3135  "could not set RNN data quantization parameters primitive "
3136  "attribute");
3137  }
3138 
3148  void get_rnn_data_qparams(float &scale, float &shift) {
3149  float c_scale, c_shift;
3151  get(), &c_scale, &c_shift),
3152  "could not set RNN data quantization parameters primitive "
3153  "attribute");
3154  scale = c_scale;
3155  shift = c_shift;
3156  }
3157 
3184  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
3186  (int)scales.size(), mask, scales.data()),
3187  "could not set RNN weights quantization parameters primitive "
3188  "attribute");
3189  }
3190 
3210  void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
3211  dnnl_dim_t count;
3212  int c_mask;
3213  const float *c_scales;
3215  get(), &count, &c_mask, &c_scales),
3216  "could not get primitive RNN weights quantization "
3217  "parameters attributes");
3218  scales.resize(count);
3219 
3220  mask = c_mask;
3221  for (dnnl_dim_t c = 0; c < count; c++)
3222  scales[c] = c_scales[c];
3223  }
3224 
3226  // The low-precision configuration of the RNN primitives expect input
3227  // weights to use the signed 8-bit integer data type. The scaling factors
3228  // are used to quantize floating-point data to signed integer and must be
3252  int mask, const std::vector<float> &scales) {
3255  get(), (int)scales.size(), mask, scales.data()),
3256  "could not set primitive RNN weights projection quantization "
3257  "parameters attributes");
3258  }
3259 
3280  int &mask, std::vector<float> &scales) {
3281  dnnl_dim_t count;
3282  int c_mask;
3283  const float *c_scales;
3286  get(), &count, &c_mask, &c_scales),
3287  "could not get primitive RNN weights projection quantization "
3288  "parameters attributes");
3289  scales.resize(count);
3290 
3291  mask = c_mask;
3292  for (dnnl_dim_t c = 0; c < count; c++)
3293  scales[c] = c_scales[c];
3294  }
3295 };
3296 
3298 
3301 
3303 struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
3305 
3307  primitive_desc_base() = default;
3308 
3311  engine get_engine() const { return engine::query(*this); }
3312 
3315  const char *impl_info_str() const {
3316  const char *res;
3318  get(), dnnl_query_impl_info_str, 0, &res),
3319  "could not retrieve implementation info string from a "
3320  "primitive descriptor");
3321  return res;
3322  }
3323 
3328  memory::dim res;
3330  get(), dnnl::convert_to_c(what), 0, &res);
3331  return status == dnnl_success ? res : 0;
3332  }
3333 
3348  memory::desc query_md(query what, int idx = 0) const {
3349  std::vector<query> valid_q {query::src_md, query::diff_src_md,
3353  if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
3354  [=](query q) { return what == q; }))
3355  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3356  "memory descriptor query is invalid");
3357 
3359  get(), dnnl::convert_to_c(what), idx);
3360  return cdesc ? memory::desc(*cdesc) : memory::desc();
3361  }
3362 
3368  memory::desc src_desc(int idx) const {
3369  return query_md(query::src_md, idx);
3370  }
3371 
3377  memory::desc dst_desc(int idx) const {
3378  return query_md(query::dst_md, idx);
3379  }
3380 
3386  memory::desc weights_desc(int idx) const {
3387  return query_md(query::weights_md, idx);
3388  }
3389 
3395  memory::desc diff_src_desc(int idx) const {
3396  return query_md(query::diff_src_md, idx);
3397  }
3398 
3404  memory::desc diff_dst_desc(int idx) const {
3405  return query_md(query::diff_dst_md, idx);
3406  }
3407 
3414  return query_md(query::diff_weights_md, idx);
3415  }
3416 
3417  // Separate versions without the index argument for documentation
3418  // purposes.
3419 
3424  memory::desc src_desc() const { return src_desc(0); }
3425 
3430  memory::desc dst_desc() const { return dst_desc(0); }
3431 
3436  memory::desc weights_desc() const { return weights_desc(0); }
3437 
3443 
3449 
3455 
3461  return query_md(query::workspace_md, 0);
3462  }
3463 
3470  return query_md(query::scratchpad_md, 0);
3471  }
3472 
3476  dnnl_engine_t c_engine;
3479  0, &c_engine),
3480  "could not retrieve scratchpad engine from a primitive "
3481  "descriptor");
3482  return engine(c_engine, true);
3483  }
3484 
3488  const_dnnl_primitive_attr_t const_c_attr;
3490  "could not get attributes from a primitive descriptor");
3491  dnnl_primitive_attr_t c_attr;
3492  error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
3493  "could not clone primitive attributes");
3494  return primitive_attr(c_attr);
3495  }
3496 
3500  dnnl_primitive_kind_t kind;
3502  dnnl_query_primitive_kind, 0, (void *)&kind),
3503  "could not get primitive kind from a primitive descriptor");
3504  return static_cast<dnnl::primitive::kind>(kind);
3505  }
3506 
3507 protected:
3512  dnnl_primitive_desc_t new_pd;
3514  "could not clone a primitive descriptor");
3515  reset(new_pd);
3516  }
3517 
3533  : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
3534 
3547  dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
3548  : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
3549 
3564  dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
3565  dnnl::prop_kind prop_kind2) {
3566  // It is OK to pass an empty primitive descriptor
3567  if (pd == nullptr) return;
3568 
3569  dnnl_status_t rc;
3570 
3571  dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
3572  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
3573  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
3574 
3575  // Check that primitive kind matches
3576  dnnl_primitive_kind_t pd_kind;
3578  pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
3580  rc, "could not get primitive kind from a primitive descriptor");
3581  if (pd_kind != c_prim_kind)
3582  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3583  "primitive descriptor operation kind mismatch");
3584 
3585  // Check that propagation kind matches
3586  dnnl_prop_kind_t pd_prop_kind;
3588  pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
3589 
3590  // Something went wrong
3591  if (rc != dnnl_success && rc != dnnl_unimplemented)
3592  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3593  "could not get propagation kind from the primitive "
3594  "descriptor");
3595 
3596  // Everything is fine
3597  if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
3598  || (rc == dnnl_success
3599  && (pd_prop_kind == c_prop_kind1
3600  || pd_prop_kind == c_prop_kind2))) {
3601  reset_with_clone(pd);
3602  return;
3603  }
3604 
3605  // We could get the propagation kind but there is a mismatch
3606  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3607  "primitive descriptor propagation kind mismatch");
3608  }
3609 
3610  using base = primitive_desc_base;
3611 };
3612 
3614 
3623 
3625 struct reorder : public primitive {
3629 
3631  primitive_desc() = default;
3632 
3650  primitive_desc(const engine &src_engine, const memory::desc &src_md,
3651  const engine &dst_engine, const memory::desc &dst_md,
3652  const primitive_attr &attr = primitive_attr(),
3653  bool allow_empty = false) {
3654  dnnl_primitive_desc_t result;
3656  &src_md.data, src_engine.get(), &dst_md.data,
3657  dst_engine.get(), attr.get());
3658  if (!allow_empty)
3660  "could not create a primitive descriptor for a reorder "
3661  "primitive");
3663  }
3664 
3676  primitive_desc(const memory &src, const memory &dst,
3677  const primitive_attr &attr = primitive_attr(),
3678  bool allow_empty = false) {
3679  dnnl_primitive_desc_t result;
3680  auto src_md = src.get_desc();
3681  auto dst_md = dst.get_desc();
3683  &src_md.data, src.get_engine().get(), &dst_md.data,
3684  dst.get_engine().get(), attr.get());
3685  if (!allow_empty)
3687  "could not create a primitive descriptor for a reorder "
3688  "primitive");
3690  }
3691 
3698 
3703  }
3704 
3709  }
3710 
3712  memory::desc src_desc() const { return base::src_desc(0); }
3713 
3715  memory::desc dst_desc() const { return base::dst_desc(0); }
3716  };
3717 
3719  reorder() = default;
3720 
3723  reorder(const primitive_desc &pd) : primitive(pd.get()) {}
3724 
3732  reorder(const memory &src, const memory &dst,
3733  const primitive_attr &attr = primitive_attr())
3734  : primitive(primitive_desc(src, dst, attr).get()) {}
3735 
3736  using primitive::execute;
3737 
3744  void execute(const stream &astream, memory &src, memory &dst) const {
3745  primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
3746  }
3747 };
3748 
3750 
3758 
3760 inline std::vector<dnnl_memory_desc_t> convert_to_c(
3761  const std::vector<memory::desc> &mems) {
3762  std::vector<dnnl_memory_desc_t> c_mems;
3763  c_mems.reserve(mems.size());
3764  for (const auto &s : mems)
3765  c_mems.push_back(s.data);
3766  return c_mems;
3767 }
3769 
3771 struct concat : public primitive {
3775 
3777  primitive_desc() = default;
3778 
3789  primitive_desc(const memory::desc &dst, int concat_dimension,
3790  const std::vector<memory::desc> &srcs, const engine &aengine,
3791  const primitive_attr &attr = primitive_attr()) {
3792  auto c_srcs = convert_to_c(srcs);
3793 
3794  dnnl_primitive_desc_t result;
3797  (int)c_srcs.size(), concat_dimension, c_srcs.data(),
3798  attr.get(), aengine.get()),
3799  "could not create a primitive descriptor for a concat "
3800  "primitive");
3801  reset(result);
3802  }
3803 
3816  primitive_desc(int concat_dimension,
3817  const std::vector<memory::desc> &srcs, const engine &aengine,
3818  const primitive_attr &attr = primitive_attr()) {
3819  auto c_api_srcs = convert_to_c(srcs);
3820 
3821  dnnl_primitive_desc_t result;
3823  dnnl_concat_primitive_desc_create(&result, nullptr,
3824  (int)c_api_srcs.size(), concat_dimension,
3825  c_api_srcs.data(), attr.get(), aengine.get()),
3826  "could not create a primitive descriptor for a concat "
3827  "primitive");
3828  reset(result);
3829  }
3830 
3837 
3839  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3840 
3842  memory::desc dst_desc() const { return base::dst_desc(0); }
3843  };
3844 
3846  concat() = default;
3847 
3850  concat(const primitive_desc &pd) : primitive(pd.get()) {}
3851 };
3852 
3854 
3862 
3864 struct sum : public primitive {
3868 
3870  primitive_desc() = default;
3871 
3881  const std::vector<float> &scales,
3882  const std::vector<memory::desc> &srcs, const engine &aengine,
3883  const primitive_attr &attr = primitive_attr()) {
3884  validate_container_size(scales,
3885  "counts of scales and sources are not equal",
3886  (int)srcs.size(), (int)srcs.size());
3887 
3888  auto c_api_srcs = convert_to_c(srcs);
3889 
3890  dnnl_primitive_desc_t result;
3892  dnnl_sum_primitive_desc_create(&result, &dst.data,
3893  (int)c_api_srcs.size(), scales.data(),
3894  c_api_srcs.data(), attr.get(), aengine.get()),
3895  "could not create a primitive descriptor for a sum "
3896  "primitive");
3897  reset(result);
3898  }
3899 
3910  primitive_desc(const std::vector<float> &scales,
3911  const std::vector<memory::desc> &srcs, const engine &aengine,
3912  const primitive_attr &attr = primitive_attr()) {
3913  validate_container_size(scales,
3914  "counts of scales and sources are not equal",
3915  (int)srcs.size(), (int)srcs.size());
3916 
3917  auto c_api_srcs = convert_to_c(srcs);
3918  dnnl_primitive_desc_t result;
3920  dnnl_sum_primitive_desc_create(&result, nullptr,
3921  (int)c_api_srcs.size(), scales.data(),
3922  c_api_srcs.data(), attr.get(), aengine.get()),
3923  "could not create a primitive descriptor for a sum "
3924  "primitive");
3925  reset(result);
3926  }
3927 
3934 
3936  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3937 
3939  memory::desc dst_desc() const { return base::dst_desc(0); }
3940  };
3941 
3943  sum() = default;
3944 
3947  sum(const primitive_desc &pd) : primitive(pd.get()) {}
3948 };
3949 
3951 
3954 
3959 
3960  primitive_desc() = default;
3961 
3985  const engine &aengine, const_dnnl_primitive_desc_t hint_fwd_pd,
3986  bool allow_empty = false)
3987  : allow_empty_(allow_empty) {
3988  dnnl_primitive_desc_iterator_t iterator = nullptr;
3990  desc, attr ? attr->get() : nullptr, aengine.get(), hint_fwd_pd);
3991  if (!allow_empty)
3993  status, "could not create a primitive descriptor iterator");
3994  pd_iterator.reset(iterator);
3995  fetch_impl();
3996  }
3997 
4002  bool next_impl() {
4004  = dnnl_primitive_desc_iterator_next(pd_iterator.get());
4005  if (status == dnnl_iterator_ends) return false;
4007  status, "could not advance a primitive descriptor iterator");
4008  fetch_impl();
4009  return true;
4010  }
4011 
4012 private:
4013  bool allow_empty_ = false;
4015  void fetch_impl() {
4017  pd_iterator.get(allow_empty_));
4018  error::wrap_c_api(pd != nullptr || allow_empty_ ? dnnl_success
4020  "could not fetch a primitive descriptor from a primitive "
4021  "descriptor iterator");
4022  reset(pd);
4023  }
4024 };
4025 
4027 
4037 
4041  struct desc {
4043 
4074  desc(prop_kind aprop_kind, algorithm aalgorithm,
4075  const memory::desc &src_desc, const memory::desc &weights_desc,
4076  const memory::desc &bias_desc, const memory::desc &dst_desc,
4077  const memory::dims &strides, const memory::dims &padding_l,
4078  const memory::dims &padding_r) {
4079  memory::validate_dims(strides, src_desc.data.ndims - 2);
4080  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4081  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4084  dnnl::convert_to_c(aprop_kind),
4085  convert_to_c(aalgorithm), &src_desc.data,
4086  &weights_desc.data, &bias_desc.data, &dst_desc.data,
4087  &strides[0], &padding_l[0], &padding_r[0]),
4088  "could not create a descriptor for a convolution forward "
4089  "propagation primitive");
4090  }
4091 
4120  desc(prop_kind aprop_kind, algorithm aalgorithm,
4121  const memory::desc &src_desc, const memory::desc &weights_desc,
4122  const memory::desc &dst_desc, const memory::dims &strides,
4123  const memory::dims &padding_l, const memory::dims &padding_r) {
4124  memory::validate_dims(strides, src_desc.data.ndims - 2);
4125  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4126  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4129  dnnl::convert_to_c(aprop_kind),
4130  convert_to_c(aalgorithm), &src_desc.data,
4131  &weights_desc.data, nullptr, &dst_desc.data,
4132  &strides[0], &padding_l[0], &padding_r[0]),
4133  "could not create a descriptor for a convolution forward "
4134  "propagation primitive");
4135  }
4136 
4169  desc(prop_kind aprop_kind, algorithm aalgorithm,
4170  const memory::desc &src_desc, const memory::desc &weights_desc,
4171  const memory::desc &bias_desc, const memory::desc &dst_desc,
4172  const memory::dims &strides, const memory::dims &dilates,
4173  const memory::dims &padding_l, const memory::dims &padding_r) {
4174  memory::validate_dims(strides, src_desc.data.ndims - 2);
4175  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4176  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4177  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4179  dnnl::convert_to_c(aprop_kind),
4180  convert_to_c(aalgorithm), &src_desc.data,
4181  &weights_desc.data, &bias_desc.data,
4182  &dst_desc.data, &strides[0], &dilates[0],
4183  &padding_l[0], &padding_r[0]),
4184  "could not create a descriptor for a dilated convolution "
4185  "forward propagation primitive");
4186  }
4187 
4218  desc(prop_kind aprop_kind, algorithm aalgorithm,
4219  const memory::desc &src_desc, const memory::desc &weights_desc,
4220  const memory::desc &dst_desc, const memory::dims &strides,
4221  const memory::dims &dilates, const memory::dims &padding_l,
4222  const memory::dims &padding_r) {
4223  memory::validate_dims(strides, src_desc.data.ndims - 2);
4224  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4225  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4226  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4228  dnnl::convert_to_c(aprop_kind),
4229  convert_to_c(aalgorithm), &src_desc.data,
4230  &weights_desc.data, nullptr,
4231  &dst_desc.data, &strides[0], &dilates[0],
4232  &padding_l[0], &padding_r[0]),
4233  "could not create a descriptor for a dilated convolution "
4234  "forward propagation primitive");
4235  }
4236  };
4237 
4241  primitive_desc() = default;
4242 
4253  primitive_desc(const desc &adesc, const engine &aengine,
4254  bool allow_empty = false)
4255  : dnnl::primitive_desc(
4256  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4257 
4269  primitive_desc(const desc &adesc, const primitive_attr &attr,
4270  const engine &aengine, bool allow_empty = false)
4271  : dnnl::primitive_desc(
4272  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4273 
4281  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4284 
4286  memory::desc src_desc() const { return base::src_desc(0); }
4287 
4290 
4292  memory::desc dst_desc() const { return base::dst_desc(0); }
4293 
4299  };
4300 
4302  convolution_forward() = default;
4303 
4308 };
4309 
4312 
4314  struct desc {
4316 
4342  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
4343  const memory::desc &weights_desc,
4344  const memory::desc &diff_dst_desc, const memory::dims &strides,
4345  const memory::dims &padding_l, const memory::dims &padding_r) {
4346  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4347  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4348  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4351  convert_to_c(aalgorithm), &diff_src_desc.data,
4352  &weights_desc.data, &diff_dst_desc.data,
4353  &strides[0], &padding_l[0], &padding_r[0]),
4354  "could not create a descriptor for a convolution backward "
4355  "propagation primitive");
4356  }
4357 
4385  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
4386  const memory::desc &weights_desc,
4387  const memory::desc &diff_dst_desc, const memory::dims &strides,
4388  const memory::dims &dilates, const memory::dims &padding_l,
4389  const memory::dims &padding_r) {
4390  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4391  memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
4392  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4393  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4396  convert_to_c(aalgorithm), &diff_src_desc.data,
4397  &weights_desc.data, &diff_dst_desc.data,
4398  &strides[0], &dilates[0], &padding_l[0],
4399  &padding_r[0]),
4400  "could not create a descriptor for a dilated convolution "
4401  "backward propagation primitive");
4402  }
4403  };
4404 
4408  primitive_desc() = default;
4409 
4423  primitive_desc(const desc &adesc, const engine &aengine,
4424  const convolution_forward::primitive_desc &hint_fwd_pd,
4425  bool allow_empty = false)
4426  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
4427  hint_fwd_pd.get(), allow_empty) {}
4428 
4443  primitive_desc(const desc &adesc, const primitive_attr &attr,
4444  const engine &aengine,
4445  const convolution_forward::primitive_desc &hint_fwd_pd,
4446  bool allow_empty = false)
4447  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
4448  hint_fwd_pd.get(), allow_empty) {}
4449 
4457  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4459 
4462 
4465 
4468  };
4469 
4472 
4477 };
4478 
4482  struct desc {
4484 
4512  desc(algorithm aalgorithm, const memory::desc &src_desc,
4513  const memory::desc &diff_weights_desc,
4514  const memory::desc &diff_bias_desc,
4515  const memory::desc &diff_dst_desc, const memory::dims &strides,
4516  const memory::dims &padding_l, const memory::dims &padding_r) {
4517  memory::validate_dims(strides, src_desc.data.ndims - 2);
4518  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4519  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4522  convert_to_c(aalgorithm), &src_desc.data,
4523  &diff_weights_desc.data, &diff_bias_desc.data,
4524  &diff_dst_desc.data, &strides[0], &padding_l[0],
4525  &padding_r[0]),
4526  "could not create a descriptor for a convolution weights "
4527  "update primitive");
4528  }
4529 
4555  desc(algorithm aalgorithm, const memory::desc &src_desc,
4556  const memory::desc &diff_weights_desc,
4557  const memory::desc &diff_dst_desc, const memory::dims &strides,
4558  const memory::dims &padding_l, const memory::dims &padding_r) {
4559  memory::validate_dims(strides, src_desc.data.ndims - 2);
4560  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4561  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4563  convert_to_c(aalgorithm), &src_desc.data,
4564  &diff_weights_desc.data, nullptr,
4565  &diff_dst_desc.data, &strides[0],
4566  &padding_l[0], &padding_r[0]),
4567  "could not create a descriptor for a convolution weights "
4568  "update primitive");
4569  }
4570 
4600  desc(algorithm aalgorithm, const memory::desc &src_desc,
4601  const memory::desc &diff_weights_desc,
4602  const memory::desc &diff_bias_desc,
4603  const memory::desc &diff_dst_desc, const memory::dims &strides,
4604  const memory::dims &dilates, const memory::dims &padding_l,
4605  const memory::dims &padding_r) {
4606  memory::validate_dims(strides, src_desc.data.ndims - 2);
4607  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4608  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4609  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4612  convert_to_c(aalgorithm), &src_desc.data,
4613  &diff_weights_desc.data, &diff_bias_desc.data,
4614  &diff_dst_desc.data, &strides[0], &dilates[0],
4615  &padding_l[0], &padding_r[0]),
4616  "could not create a descriptor for a dilated convolution "
4617  "weights gradient primitive");
4618  }
4619 
4647  desc(algorithm aalgorithm, const memory::desc &src_desc,
4648  const memory::desc &diff_weights_desc,
4649  const memory::desc &diff_dst_desc, const memory::dims &strides,
4650  const memory::dims &dilates, const memory::dims &padding_l,
4651  const memory::dims &padding_r) {
4652  memory::validate_dims(strides, src_desc.data.ndims - 2);
4653  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4654  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4655  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4658  convert_to_c(aalgorithm), &src_desc.data,
4659  &diff_weights_desc.data, nullptr,
4660  &diff_dst_desc.data, &strides[0], &dilates[0],
4661  &padding_l[0], &padding_r[0]),
4662  "could not create a descriptor for a dilated convolution "
4663  "weights gradient primitive");
4664  }
4665  };
4666 
4670  primitive_desc() = default;
4671 
4684  primitive_desc(const desc &adesc, const engine &aengine,
4685  const convolution_forward::primitive_desc &hint_fwd_pd,
4686  bool allow_empty = false)
4687  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
4688  hint_fwd_pd.get(), allow_empty) {}
4689 
4703  primitive_desc(const desc &adesc, const primitive_attr &attr,
4704  const engine &aengine,
4705  const convolution_forward::primitive_desc &hint_fwd_pd,
4706  bool allow_empty = false)
4707  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
4708  hint_fwd_pd.get(), allow_empty) {}
4709 
4717  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4719 
4721  memory::desc src_desc() const { return base::src_desc(0); }
4722 
4725  return base::diff_weights_desc(0);
4726  }
4727 
4730 
4736  return base::diff_weights_desc(1);
4737  }
4738  };
4739 
4742 
4747 };
4748 
4750 //
4758 
4762  struct desc {
4764 
4794  desc(prop_kind aprop_kind, algorithm aalgorithm,
4795  const memory::desc &src_desc, const memory::desc &weights_desc,
4796  const memory::desc &bias_desc, const memory::desc &dst_desc,
4797  const memory::dims &strides, const memory::dims &padding_l,
4798  const memory::dims &padding_r) {
4799  memory::validate_dims(strides, src_desc.data.ndims - 2);
4800  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4801  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4804  dnnl::convert_to_c(aprop_kind),
4805  convert_to_c(aalgorithm), &src_desc.data,
4806  &weights_desc.data, &bias_desc.data, &dst_desc.data,
4807  &strides[0], &padding_l[0], &padding_r[0]),
4808  "could not create a descriptor for a deconvolution forward "
4809  "propagation primitive");
4810  }
4811 
4839  desc(prop_kind aprop_kind, algorithm aalgorithm,
4840  const memory::desc &src_desc, const memory::desc &weights_desc,
4841  const memory::desc &dst_desc, const memory::dims &strides,
4842  const memory::dims &padding_l, const memory::dims &padding_r) {
4843  memory::validate_dims(strides, src_desc.data.ndims - 2);
4844  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4845  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4848  dnnl::convert_to_c(aprop_kind),
4849  convert_to_c(aalgorithm), &src_desc.data,
4850  &weights_desc.data, nullptr, &dst_desc.data,
4851  &strides[0], &padding_l[0], &padding_r[0]),
4852  "could not create a descriptor for a deconvolution forward "
4853  "propagation primitive");
4854  }
4855 
4887  desc(prop_kind aprop_kind, algorithm aalgorithm,
4888  const memory::desc &src_desc, const memory::desc &weights_desc,
4889  const memory::desc &bias_desc, const memory::desc &dst_desc,
4890  const memory::dims &strides, const memory::dims &dilates,
4891  const memory::dims &padding_l, const memory::dims &padding_r) {
4892  memory::validate_dims(strides, src_desc.data.ndims - 2);
4893  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4894  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4895  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4897  &data, dnnl::convert_to_c(aprop_kind),
4898  convert_to_c(aalgorithm), &src_desc.data,
4899  &weights_desc.data, &bias_desc.data,
4900  &dst_desc.data, &strides[0], &dilates[0],
4901  &padding_l[0], &padding_r[0]),
4902  "could not create a descriptor for a dilated deconvolution "
4903  "forward propagation primitive");
4904  }
4905 
4935  desc(prop_kind aprop_kind, algorithm aalgorithm,
4936  const memory::desc &src_desc, const memory::desc &weights_desc,
4937  const memory::desc &dst_desc, const memory::dims &strides,
4938  const memory::dims &dilates, const memory::dims &padding_l,
4939  const memory::dims &padding_r) {
4940  memory::validate_dims(strides, src_desc.data.ndims - 2);
4941  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4942  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4943  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4945  &data, dnnl::convert_to_c(aprop_kind),
4946  convert_to_c(aalgorithm), &src_desc.data,
4947  &weights_desc.data, nullptr,
4948  &dst_desc.data, &strides[0], &dilates[0],
4949  &padding_l[0], &padding_r[0]),
4950  "could not create a descriptor for a dilated deconvolution "
4951  "forward propagation primitive");
4952  }
4953  };
4954 
4958  primitive_desc() = default;
4959 
4970  primitive_desc(const desc &adesc, const engine &aengine,
4971  bool allow_empty = false)
4972  : dnnl::primitive_desc(
4973  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4974 
4986  primitive_desc(const desc &adesc, const primitive_attr &attr,
4987  const engine &aengine, bool allow_empty = false)
4988  : dnnl::primitive_desc(
4989  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4990 
4998  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5001 
5003  memory::desc src_desc() const { return base::src_desc(0); }
5004 
5007 
5009  memory::desc dst_desc() const { return base::dst_desc(0); }
5010 
5013  };
5014 
5017 
5022 };
5023 
5027  struct desc {
5029 
5054  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5055  const memory::desc &weights_desc,
5056  const memory::desc &diff_dst_desc, const memory::dims &strides,
5057  const memory::dims &padding_l, const memory::dims &padding_r) {
5058  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5059  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5060  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5063  convert_to_c(aalgorithm), &diff_src_desc.data,
5064  &weights_desc.data, &diff_dst_desc.data,
5065  &strides[0], &padding_l[0], &padding_r[0]),
5066  "could not create a descriptor for a deconvolution "
5067  "backward propagation primitive");
5068  }
5069 
5096  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5097  const memory::desc &weights_desc,
5098  const memory::desc &diff_dst_desc, const memory::dims &strides,
5099  const memory::dims &dilates, const memory::dims &padding_l,
5100  const memory::dims &padding_r) {
5101  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5102  memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
5103  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5104  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5107  convert_to_c(aalgorithm), &diff_src_desc.data,
5108  &weights_desc.data, &diff_dst_desc.data,
5109  &strides[0], &dilates[0], &padding_l[0],
5110  &padding_r[0]),
5111  "could not create a descriptor for a dilated deconvolution "
5112  "backward propagation primitive");
5113  }
5114  };
5115 
5119  primitive_desc() = default;
5120 
5134  primitive_desc(const desc &adesc, const engine &aengine,
5135  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5136  bool allow_empty = false)
5137  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5138  hint_fwd_pd.get(), allow_empty) {}
5139 
5154  primitive_desc(const desc &adesc, const primitive_attr &attr,
5155  const engine &aengine,
5156  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5157  bool allow_empty = false)
5158  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5159  hint_fwd_pd.get(), allow_empty) {}
5160 
5168  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5170 
5173 
5176 
5179  };
5180 
5183 
5188 };
5189 
5193  struct desc {
5195 
5222  desc(algorithm aalgorithm, const memory::desc &src_desc,
5223  const memory::desc &diff_weights_desc,
5224  const memory::desc &diff_bias_desc,
5225  const memory::desc &diff_dst_desc, const memory::dims &strides,
5226  const memory::dims &padding_l, const memory::dims &padding_r) {
5227  memory::validate_dims(strides, src_desc.data.ndims - 2);
5228  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5229  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5232  convert_to_c(aalgorithm), &src_desc.data,
5233  &diff_weights_desc.data, &diff_bias_desc.data,
5234  &diff_dst_desc.data, &strides[0], &padding_l[0],
5235  &padding_r[0]),
5236  "could not create a descriptor for a deconvolution weights "
5237  "update primitive");
5238  }
5239 
5264  desc(algorithm aalgorithm, const memory::desc &src_desc,
5265  const memory::desc &diff_weights_desc,
5266  const memory::desc &diff_dst_desc, const memory::dims &strides,
5267  const memory::dims &padding_l, const memory::dims &padding_r) {
5268  memory::validate_dims(strides, src_desc.data.ndims - 2);
5269  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5270  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5272  &data, convert_to_c(aalgorithm),
5273  &src_desc.data, &diff_weights_desc.data,
5274  nullptr, &diff_dst_desc.data, &strides[0],
5275  &padding_l[0], &padding_r[0]),
5276  "could not create a descriptor for a deconvolution weights "
5277  "update primitive");
5278  }
5279 
5308  desc(algorithm aalgorithm, const memory::desc &src_desc,
5309  const memory::desc &diff_weights_desc,
5310  const memory::desc &diff_bias_desc,
5311  const memory::desc &diff_dst_desc, const memory::dims &strides,
5312  const memory::dims &dilates, const memory::dims &padding_l,
5313  const memory::dims &padding_r) {
5314  memory::validate_dims(strides, src_desc.data.ndims - 2);
5315  memory::validate_dims(dilates, src_desc.data.ndims - 2);
5316  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5317  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5320  convert_to_c(aalgorithm), &src_desc.data,
5321  &diff_weights_desc.data, &diff_bias_desc.data,
5322  &diff_dst_desc.data, &strides[0], &dilates[0],
5323  &padding_l[0], &padding_r[0]),
5324  "could not create a descriptor for a dilated deconvolution "
5325  "weights gradient primitive");
5326  }
5327 
5354  desc(algorithm aalgorithm, const memory::desc &src_desc,
5355  const memory::desc &diff_weights_desc,
5356  const memory::desc &diff_dst_desc, const memory::dims &strides,
5357  const memory::dims &dilates, const memory::dims &padding_l,
5358  const memory::dims &padding_r) {
5359  memory::validate_dims(strides, src_desc.data.ndims - 2);
5360  memory::validate_dims(dilates, src_desc.data.ndims - 2);
5361  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5362  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5365  convert_to_c(aalgorithm), &src_desc.data,
5366  &diff_weights_desc.data, nullptr,
5367  &diff_dst_desc.data, &strides[0], &dilates[0],
5368  &padding_l[0], &padding_r[0]),
5369  "could not create a descriptor for a dilated deconvolution "
5370  "weights gradient primitive");
5371  }
5372  };
5373 
5377  primitive_desc() = default;
5378 
5392  primitive_desc(const desc &adesc, const engine &aengine,
5393  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5394  bool allow_empty = false)
5395  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5396  hint_fwd_pd.get(), allow_empty) {}
5397 
5412  primitive_desc(const desc &adesc, const primitive_attr &attr,
5413  const engine &aengine,
5414  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5415  bool allow_empty = false)
5416  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5417  hint_fwd_pd.get(), allow_empty) {}
5418 
5426  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5428 
5430  memory::desc src_desc() const { return base::src_desc(0); }
5431 
5434  return base::diff_weights_desc(0);
5435  }
5436 
5439 
5442  return base::diff_weights_desc(1);
5443  }
5444  };
5445 
5448 
5453 };
5454 
5456 
5465 
5467 struct lrn_forward : public primitive {
5469  struct desc {
5470  dnnl_lrn_desc_t data;
5471 
5485  desc(prop_kind aprop_kind, algorithm aalgorithm,
5486  const memory::desc &data_desc, memory::dim local_size,
5487  float alpha, float beta, float k = 1.f) {
5489  dnnl::convert_to_c(aprop_kind),
5490  convert_to_c(aalgorithm), &data_desc.data,
5491  local_size, alpha, beta, k),
5492  "could not create a descriptor for a lrn forward "
5493  "propagation primitive");
5494  }
5495  };
5496 
5500  primitive_desc() = default;
5501 
5511  primitive_desc(const desc &adesc, const engine &aengine,
5512  bool allow_empty = false)
5513  : dnnl::primitive_desc(
5514  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5515 
5526  primitive_desc(const desc &adesc, const primitive_attr &attr,
5527  const engine &aengine, bool allow_empty = false)
5528  : dnnl::primitive_desc(
5529  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5530 
5538  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
5541 
5543  memory::desc src_desc() const { return base::src_desc(0); }
5544 
5546  memory::desc dst_desc() const { return base::dst_desc(0); }
5547 
5550  };
5551 
5553  lrn_forward() = default;
5554 
5559 };
5560 
5562 struct lrn_backward : public primitive {
5564  struct desc {
5565  dnnl_lrn_desc_t data;
5566 
5579  desc(algorithm aalgorithm, const memory::desc &data_desc,
5580  const memory::desc &diff_data_desc, memory::dim local_size,
5581  float alpha, float beta, float k = 1.f) {
5583  dnnl_lrn_backward_desc_init(&data, convert_to_c(aalgorithm),
5584  &diff_data_desc.data, &data_desc.data, local_size,
5585  alpha, beta, k),
5586  "could not create a descriptor for a lrn backward "
5587  "propagation primitive");
5588  }
5589  };
5590 
5594  primitive_desc() = default;
5595 
5608  primitive_desc(const desc &adesc, const engine &aengine,
5609  const lrn_forward::primitive_desc &hint_fwd_pd,
5610  bool allow_empty = false)
5611  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5612  hint_fwd_pd.get(), allow_empty) {}
5613 
5627  primitive_desc(const desc &adesc, const primitive_attr &attr,
5628  const engine &aengine,
5629  const lrn_forward::primitive_desc &hint_fwd_pd,
5630  bool allow_empty = false)
5631  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5632  hint_fwd_pd.get(), allow_empty) {}
5633 
5641  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
5643 
5646 
5649 
5652  };
5653 
5655  lrn_backward() = default;
5656 
5661 };
5662 
5664 
5672 
5674 struct pooling_forward : public primitive {
5676  struct desc {
5677  dnnl_pooling_desc_t data;
5678 
5703  desc(prop_kind aprop_kind, algorithm aalgorithm,
5704  const memory::desc &src_desc, const memory::desc &dst_desc,
5705  const memory::dims &strides, const memory::dims &kernel,
5706  const memory::dims &padding_l, const memory::dims &padding_r) {
5707  memory::validate_dims(strides, src_desc.data.ndims - 2);
5708  memory::validate_dims(kernel, src_desc.data.ndims - 2);
5709  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5710  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5712  dnnl::convert_to_c(aprop_kind),
5713  convert_to_c(aalgorithm), &src_desc.data,
5714  &dst_desc.data, &strides[0], &kernel[0],
5715  &padding_l[0], &padding_r[0]),
5716  "could not create a descriptor for a pooling forward "
5717  "propagation primitive");
5718  }
5719  };
5720 
5724  primitive_desc() = default;
5725 
5735  primitive_desc(const desc &adesc, const engine &aengine,
5736  bool allow_empty = false)
5737  : dnnl::primitive_desc(
5738  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5739 
5750  primitive_desc(const desc &adesc, const primitive_attr &attr,
5751  const engine &aengine, bool allow_empty = false)
5752  : dnnl::primitive_desc(
5753  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5754 
5762  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
5765 
5767  memory::desc src_desc() const { return base::src_desc(0); }
5768 
5770  memory::desc dst_desc() const { return base::dst_desc(0); }
5771 
5774  };
5775 
5777  pooling_forward() = default;
5778 
5783 };
5784 
5786 struct pooling_backward : public primitive {
5788  struct desc {
5789  dnnl_pooling_desc_t data;
5790 
5812  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5813  const memory::desc &diff_dst_desc, const memory::dims &strides,
5814  const memory::dims &kernel, const memory::dims &padding_l,
5815  const memory::dims &padding_r) {
5816  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5817  memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
5818  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5819  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5822  convert_to_c(aalgorithm), &diff_src_desc.data,
5823  &diff_dst_desc.data, &strides[0], &kernel[0],
5824  &padding_l[0], &padding_r[0]),
5825  "could not create a descriptor for a pooling backward "
5826  "propagation primitive");
5827  }
5828  };
5829 
5833  primitive_desc() = default;
5834 
5847  primitive_desc(const desc &adesc, const engine &aengine,
5848  const pooling_forward::primitive_desc &hint_fwd_pd,
5849  bool allow_empty = false)
5850  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5851  hint_fwd_pd.get(), allow_empty) {}
5852 
5866  primitive_desc(const desc &adesc, const primitive_attr &attr,
5867  const engine &aengine,
5868  const pooling_forward::primitive_desc &hint_fwd_pd,
5869  bool allow_empty = false)
5870  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5871  hint_fwd_pd.get(), allow_empty) {}
5872 
5880  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
5882 
5885 
5888 
5891  };
5892 
5894  pooling_backward() = default;
5895 
5900 };
5901 
5903 
5924 
5926 struct eltwise_forward : public primitive {
5928  struct desc {
5929  dnnl_eltwise_desc_t data;
5930 
5943  desc(prop_kind aprop_kind, algorithm aalgorithm,
5944  const memory::desc &data_desc, float alpha = 0,
5945  float beta = 0) {
5947  dnnl::convert_to_c(aprop_kind),
5948  dnnl::convert_to_c(aalgorithm),
5949  &data_desc.data, alpha, beta),
5950  "could not create a descriptor for an eltwise forward "
5951  "propagation primitive");
5952  }
5953  };
5954 
5958  primitive_desc() = default;
5959 
5970  primitive_desc(const desc &adesc, const engine &aengine,
5971  bool allow_empty = false)
5972  : dnnl::primitive_desc(
5973  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5974 
5986  primitive_desc(const desc &adesc, const primitive_attr &attr,
5987  const engine &aengine, bool allow_empty = false)
5988  : dnnl::primitive_desc(
5989  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5990 
5998  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
6001 
6003  memory::desc src_desc() const { return base::src_desc(0); }
6004 
6006  memory::desc dst_desc() const { return base::dst_desc(0); }
6007  };
6008 
6010  eltwise_forward() = default;
6011 
6016 };
6017 
6019 struct eltwise_backward : public primitive {
6021  struct desc {
6022  dnnl_eltwise_desc_t data;
6023 
6035  desc(algorithm aalgorithm, const memory::desc &diff_data_desc,
6036  const memory::desc &data_desc, float alpha = 0,
6037  float beta = 0) {
6040  dnnl::convert_to_c(aalgorithm),
6041  &diff_data_desc.data, &data_desc.data, alpha, beta),
6042  "could not create a descriptor for an eltwise backward "
6043  "propagation primitive");
6044  }
6045  };
6046 
6050  primitive_desc() = default;
6051 
6065  primitive_desc(const desc &adesc, const engine &aengine,
6066  const eltwise_forward::primitive_desc &hint_fwd_pd,
6067  bool allow_empty = false)
6068  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6069  hint_fwd_pd.get(), allow_empty) {}
6070 
6085  primitive_desc(const desc &adesc, const primitive_attr &attr,
6086  const engine &aengine,
6087  const eltwise_forward::primitive_desc &hint_fwd_pd,
6088  bool allow_empty = false)
6089  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6090  hint_fwd_pd.get(), allow_empty) {}
6091 
6099  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
6101 
6103  memory::desc src_desc() const { return base::src_desc(0); }
6104 
6107 
6110  };
6111 
6113  eltwise_backward() = default;
6114 
6119 };
6120 
6122 
6130 
6132 struct softmax_forward : public primitive {
6134  struct desc {
6135  dnnl_softmax_desc_t data;
6136 
6138  desc() = default;
6139 
6148  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6149  int softmax_axis) {
6151  dnnl::convert_to_c(aprop_kind),
6152  &data_desc.data, softmax_axis),
6153  "could not create a descriptor for a softmax forward "
6154  "propagation primitive");
6155  }
6156  };
6157 
6161  primitive_desc() = default;
6162 
6173  primitive_desc(const desc &adesc, const engine &aengine,
6174  bool allow_empty = false)
6175  : dnnl::primitive_desc(
6176  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6177 
6189  primitive_desc(const desc &adesc, const primitive_attr &attr,
6190  const engine &aengine, bool allow_empty = false)
6191  : dnnl::primitive_desc(
6192  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6193 
6201  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6204 
6206  memory::desc src_desc() const { return base::src_desc(0); }
6207 
6209  memory::desc dst_desc() const { return base::dst_desc(0); }
6210  };
6211 
6213  softmax_forward() = default;
6214 
6219 };
6220 
6222 struct softmax_backward : public primitive {
6224  struct desc {
6225  dnnl_softmax_desc_t data;
6226 
6228  desc() = default;
6229 
6237  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
6238  int softmax_axis) {
6240  dnnl_softmax_backward_desc_init(&data, &diff_data_desc.data,
6241  &data_desc.data, softmax_axis),
6242  "could not create a descriptor for a softmax backward "
6243  "propagation primitive");
6244  }
6245  };
6246 
6250  primitive_desc() = default;
6251 
6265  primitive_desc(const desc &adesc, const engine &aengine,
6266  const softmax_forward::primitive_desc &hint_fwd_pd,
6267  bool allow_empty = false)
6268  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6269  hint_fwd_pd.get(), allow_empty) {}
6270 
6285  primitive_desc(const desc &adesc, const primitive_attr &attr,
6286  const engine &aengine,
6287  const softmax_forward::primitive_desc &hint_fwd_pd,
6288  bool allow_empty = false)
6289  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6290  hint_fwd_pd.get(), allow_empty) {}
6291 
6299  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6301 
6303  memory::desc dst_desc() const { return base::dst_desc(0); }
6304 
6307 
6310  };
6311 
6313  softmax_backward() = default;
6314 
6319 };
6320 
6322 
6330 
6334  struct desc {
6336 
6338  desc() = default;
6339 
6348  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6349  int logsoftmax_axis) {
6351  dnnl::convert_to_c(aprop_kind),
6352  &data_desc.data, logsoftmax_axis),
6353  "could not create a descriptor for a logsoftmax forward "
6354  "propagation primitive");
6355  }
6356  };
6357 
6361  primitive_desc() = default;
6362 
6373  primitive_desc(const desc &adesc, const engine &aengine,
6374  bool allow_empty = false)
6375  : dnnl::primitive_desc(
6376  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6377 
6389  primitive_desc(const desc &adesc, const primitive_attr &attr,
6390  const engine &aengine, bool allow_empty = false)
6391  : dnnl::primitive_desc(
6392  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6393 
6401  : dnnl::primitive_desc(pd,
6402  // Logsoftmax and softmax share the implementation and
6403  // currently report the same primitive kind. Hence this
6404  // must be softmax and not logsoftmax.
6405  dnnl::primitive::kind::softmax,
6408 
6410  memory::desc src_desc() const { return base::src_desc(0); }
6411 
6413  memory::desc dst_desc() const { return base::dst_desc(0); }
6414  };
6415 
6417  logsoftmax_forward() = default;
6418 
6423 };
6424 
6428  struct desc {
6430 
6432  desc() = default;
6433 
6441  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
6442  int logsoftmax_axis) {
6444  &diff_data_desc.data, &data_desc.data,
6445  logsoftmax_axis),
6446  "could not create a descriptor for a logsoftmax backward "
6447  "propagation primitive");
6448  }
6449  };
6450 
6454  primitive_desc() = default;
6455 
6469  primitive_desc(const desc &adesc, const engine &aengine,
6470  const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6471  bool allow_empty = false)
6472  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6473  hint_fwd_pd.get(), allow_empty) {}
6474 
6489  primitive_desc(const desc &adesc, const primitive_attr &attr,
6490  const engine &aengine,
6491  const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6492  bool allow_empty = false)
6493  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6494  hint_fwd_pd.get(), allow_empty) {}
6495 
6503  : dnnl::primitive_desc(pd,
6504  // Logsoftmax and softmax share the implementation and
6505  // currently report the same primitive kind. Hence this
6506  // must be softmax and not logsoftmax.
6507  dnnl::primitive::kind::softmax,
6509 
6511  memory::desc dst_desc() const { return base::dst_desc(0); }
6512 
6515 
6518  };
6519 
6521  logsoftmax_backward() = default;
6522 
6527 };
6528 
6530 
6550 
6554  struct desc {
6556 
6571  desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon,
6572  normalization_flags flags) {
6575  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6576  epsilon, convert_to_c(flags)),
6577  "could not create a descriptor for a batch normalization "
6578  "forward propagation primitive");
6579  }
6580  };
6581 
6586  primitive_desc() = default;
6587 
6598  primitive_desc(const desc &adesc, const engine &aengine,
6599  bool allow_empty = false)
6600  : dnnl::primitive_desc(
6601  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6602 
6614  primitive_desc(const desc &adesc, const primitive_attr &attr,
6615  const engine &aengine, bool allow_empty = false)
6616  : dnnl::primitive_desc(
6617  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6618 
6626  : dnnl::primitive_desc(pd,
6627  dnnl::primitive::kind::batch_normalization,
6630 
6632  memory::desc src_desc() const { return base::src_desc(0); }
6633 
6635  memory::desc dst_desc() const { return base::dst_desc(0); }
6636 
6639 
6642 
6645  memory::desc mean_desc() const { return stat_desc(mean); }
6646 
6649  memory::desc variance_desc() const { return stat_desc(var); }
6650 
6651  private:
6652  enum {
6653  mean = 1,
6654  var = 2,
6655  };
6656  memory::desc stat_desc(int kind) const {
6661  &p),
6662  "could not retrieve a descriptor from a primitive "
6663  "descriptor for batch normalization forward propagation "
6664  "primitive");
6665  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6666  : query::dst_md,
6667  kind);
6668  }
6669  };
6670 
6673 
6678 };
6679 
6683  struct desc {
6685 
6698  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
6699  const memory::desc &data_desc, float epsilon,
6700  normalization_flags flags) {
6702  dnnl::convert_to_c(aprop_kind),
6703  &diff_data_desc.data, &data_desc.data,
6704  epsilon, convert_to_c(flags)),
6705  "could not create a descriptor for a batch normalization "
6706  "backward propagation primitive");
6707  }
6708  };
6709 
6714  primitive_desc() = default;
6715 
6729  primitive_desc(const desc &adesc, const engine &aengine,
6731  bool allow_empty = false)
6732  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6733  hint_fwd_pd.get(), allow_empty) {}
6734 
6749  primitive_desc(const desc &adesc, const primitive_attr &attr,
6750  const engine &aengine,
6752  bool allow_empty = false)
6753  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6754  hint_fwd_pd.get(), allow_empty) {}
6755 
6763  : dnnl::primitive_desc(pd,
6764  dnnl::primitive::kind::batch_normalization,
6766  }
6767 
6769  memory::desc src_desc() const { return base::src_desc(0); }
6770 
6773 
6775  memory::desc dst_desc() const { return base::dst_desc(0); }
6776 
6779 
6782 
6785  return base::diff_weights_desc(0);
6786  }
6787 
6790 
6793  return query_md(query::src_md, 2);
6794  }
6795 
6798  };
6799 
6802 
6807 };
6808 
6810 
6832 
6836  struct desc {
6838 
6850  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6851  const memory::desc &stat_desc, float epsilon,
6852  normalization_flags flags) {
6855  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6856  &stat_desc.data, epsilon, convert_to_c(flags)),
6857  "could not create a descriptor for a layer normalization "
6858  "forward propagation primitive");
6859  }
6860 
6871  desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon,
6872  normalization_flags flags) {
6875  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6876  nullptr, epsilon, convert_to_c(flags)),
6877  "could not create a descriptor for a layer normalization "
6878  "forward propagation primitive");
6879  }
6880  };
6881 
6886  primitive_desc() = default;
6887 
6898  primitive_desc(const desc &adesc, const engine &aengine,
6899  bool allow_empty = false)
6900  : dnnl::primitive_desc(
6901  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6902 
6914  primitive_desc(const desc &adesc, const primitive_attr &attr,
6915  const engine &aengine, bool allow_empty = false)
6916  : dnnl::primitive_desc(
6917  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6918 
6926  : dnnl::primitive_desc(pd,
6927  dnnl::primitive::kind::layer_normalization,
6930 
6932  memory::desc src_desc() const { return base::src_desc(0); }
6933 
6935  memory::desc dst_desc() const { return base::dst_desc(0); }
6936 
6939 
6942 
6944  memory::desc mean_desc() const { return stat_desc(mean); }
6945 
6947  memory::desc variance_desc() const { return stat_desc(var); }
6948 
6949  private:
6950  enum {
6951  mean = 1,
6952  var = 2,
6953  };
6954  memory::desc stat_desc(int kind) const {
6959  &p),
6960  "could not retrieve a descriptor from a primitive "
6961  "descriptor for layer normalization forward propagation "
6962  "primitive");
6963  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6964  : query::dst_md,
6965  kind);
6966  }
6967  };
6968 
6971 
6976 };
6977 
6981  struct desc {
6983 
6997  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
6998  const memory::desc &data_desc, const memory::desc &stat_desc,
6999  float epsilon, normalization_flags flags) {
7002  dnnl::convert_to_c(aprop_kind),
7003  &diff_data_desc.data, &data_desc.data,
7004  &stat_desc.data, epsilon, convert_to_c(flags)),
7005  "could not create a descriptor for a batch normalization "
7006  "backward propagation primitive");
7007  }
7008 
7021  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
7022  const memory::desc &data_desc, float epsilon,
7023  normalization_flags flags) {
7025  dnnl::convert_to_c(aprop_kind),
7026  &diff_data_desc.data, &data_desc.data,
7027  nullptr, epsilon, convert_to_c(flags)),
7028  "could not create a descriptor for a batch normalization "
7029  "backward propagation primitive");
7030  }
7031  };
7032 
7037  primitive_desc() = default;
7038 
7052  primitive_desc(const desc &adesc, const engine &aengine,
7054  bool allow_empty = false)
7055  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7056  hint_fwd_pd.get(), allow_empty) {}
7057 
7072  primitive_desc(const desc &adesc, const primitive_attr &attr,
7073  const engine &aengine,
7075  bool allow_empty = false)
7076  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7077  hint_fwd_pd.get(), allow_empty) {}
7078 
7086  : dnnl::primitive_desc(pd,
7087  dnnl::primitive::kind::layer_normalization,
7089  }
7090 
7092  memory::desc src_desc() const { return base::src_desc(0); }
7093 
7096 
7098  memory::desc dst_desc() const { return base::dst_desc(0); }
7099 
7102 
7105 
7108  return base::diff_weights_desc(0);
7109  }
7110 
7113 
7116  return query_md(query::src_md, 2);
7117  }
7118 
7121  };
7122 
7125 
7130 };
7131 
7133 
7141 
7145  struct desc {
7147 
7162  desc(prop_kind aprop_kind, const memory::desc &src_desc,
7163  const memory::desc &weights_desc, const memory::desc &bias_desc,
7164  const memory::desc &dst_desc) {
7166  dnnl::convert_to_c(aprop_kind),
7167  &src_desc.data, &weights_desc.data,
7168  &bias_desc.data, &dst_desc.data),
7169  "could not create a descriptor for an inner product "
7170  "forward propagation primitive");
7171  }
7172 
7186  desc(prop_kind aprop_kind, const memory::desc &src_desc,
7187  const memory::desc &weights_desc,
7188  const memory::desc &dst_desc) {
7191  dnnl::convert_to_c(aprop_kind), &src_desc.data,
7192  &weights_desc.data, nullptr, &dst_desc.data),
7193  "could not create a descriptor for an inner product "
7194  "forward propagation primitive");
7195  }
7196  };
7197 
7201  primitive_desc() = default;
7202 
7213  primitive_desc(const desc &adesc, const engine &aengine,
7214  bool allow_empty = false)
7215  : dnnl::primitive_desc(
7216  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
7217 
7229  primitive_desc(const desc &adesc, const primitive_attr &attr,
7230  const engine &aengine, bool allow_empty = false)
7231  : dnnl::primitive_desc(
7232  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7233 
7241  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7244 
7246  memory::desc src_desc() const { return base::src_desc(0); }
7247 
7250 
7252  memory::desc dst_desc() const { return base::dst_desc(0); }
7253 
7256  };
7257 
7260 
7265 };
7266 
7270  struct desc {
7272 
7283  desc(const memory::desc &diff_src_desc,
7284  const memory::desc &weights_desc,
7285  const memory::desc &diff_dst_desc) {
7287  &diff_src_desc.data, &weights_desc.data,
7288  &diff_dst_desc.data),
7289  "could not create a descriptor for an inner product "
7290  "backward propagation primitive");
7291  }
7292  };
7293 
7298  primitive_desc() = default;
7299 
7313  primitive_desc(const desc &adesc, const engine &aengine,
7314  const inner_product_forward::primitive_desc &hint_fwd_pd,
7315  bool allow_empty = false)
7316  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7317  hint_fwd_pd.get(), allow_empty) {}
7318 
7333  primitive_desc(const desc &adesc, const primitive_attr &attr,
7334  const engine &aengine,
7335  const inner_product_forward::primitive_desc &hint_fwd_pd,
7336  bool allow_empty = false)
7337  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7338  hint_fwd_pd.get(), allow_empty) {}
7339 
7347  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7349 
7352 
7355 
7358  };
7359 
7362 
7367 };
7368 
7372  struct desc {
7374 
7386  desc(const memory::desc &src_desc,
7387  const memory::desc &diff_weights_desc,
7388  const memory::desc &diff_bias_desc,
7389  const memory::desc &diff_dst_desc) {
7392  &src_desc.data, &diff_weights_desc.data,
7393  &diff_bias_desc.data, &diff_dst_desc.data),
7394  "could not create a descriptor for an inner product "
7395  "weights gradient primitive");
7396  }
7397 
7408  desc(const memory::desc &src_desc,
7409  const memory::desc &diff_weights_desc,
7410  const memory::desc &diff_dst_desc) {
7413  &src_desc.data, &diff_weights_desc.data, nullptr,
7414  &diff_dst_desc.data),
7415  "could not create a descriptor for an inner product "
7416  "weights gradient primitive");
7417  }
7418  };
7419 
7423  primitive_desc() = default;
7424 
7438  primitive_desc(const desc &adesc, const engine &aengine,
7439  const inner_product_forward::primitive_desc &hint_fwd_pd,
7440  bool allow_empty = false)
7441  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7442  hint_fwd_pd.get(), allow_empty) {}
7443 
7458  primitive_desc(const desc &adesc, const primitive_attr &attr,
7459  const engine &aengine,
7460  const inner_product_forward::primitive_desc &hint_fwd_pd,
7461  bool allow_empty = false)
7462  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7463  hint_fwd_pd.get(), allow_empty) {}
7464 
7472  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7474 
7476  memory::desc src_desc() const { return base::src_desc(0); }
7477 
7480  return base::diff_weights_desc(0);
7481  }
7482 
7485 
7488  return base::diff_weights_desc(1);
7489  }
7490  };
7491 
7494 
7499 };
7500 
7502 
7510 
7513  using primitive_desc::primitive_desc;
7514 
7517 
7526  dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
7527  : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
7528 
7533  }
7534 
7541  }
7542 
7547  }
7548 
7553  }
7554 
7559  }
7560 
7565  }
7566 
7571  }
7572 
7579  }
7580 
7585  }
7586 
7593  }
7594 
7599  }
7600 
7605  }
7606 
7613  }
7614 
7619  }
7620 
7625  }
7626 
7631  }
7632 
7636  return base::query_md(
7638  }
7639 
7643  return base::query_md(
7645  }
7646 
7653  }
7654 
7659  }
7660 
7667  }
7668 
7673  }
7674 
7675 protected:
7676  using rnn_base = rnn_primitive_desc_base;
7677 
7678  // (Deliberately not using doxygen comments)
7679  //
7680  // Constructs an RNN primitive descriptor base from a C API primitive
7681  // descriptor while checking that it actually describes the expected
7682  // primitive by comparing propagation and primitive kinds. Caller can
7683  // pass two options propagation kinds. This is typically used to check
7684  // that propagation kind is inference or training forward propagation.
7685  //
7686  // @param pd C API primitive descriptor.
7687  // @param prop_kind1 Expected propagation kind.
7688  // @param prop_kind2 Expected propagation kind.
7689  // @param cell_kind Expected cell kind.
7691  dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
7692  dnnl::algorithm cell_kind) {
7694  dnnl_status_t rc;
7695  rc = dnnl_primitive_desc_query(pd, dnnl_query_rnn_d, 0, &rnn_d);
7696  error::wrap_c_api(rc,
7697  "could not retrieve a descriptor from a primitive descriptor "
7698  "for an RNN primitive");
7699 
7700  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
7701  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
7702  dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
7703 
7704  bool ok = rnn_d->primitive_kind == dnnl_rnn
7705  && (rnn_d->prop_kind == c_prop_kind1
7706  || rnn_d->prop_kind == c_prop_kind2)
7707  && rnn_d->cell_kind == c_cell_kind;
7708 
7709  if (!ok)
7710  DNNL_THROW_ERROR(dnnl_invalid_arguments,
7711  "mismatch between expected and provided descriptors for an "
7712  "RNN primitive");
7713 
7714  reset_with_clone(pd);
7715  }
7716 };
7717 
7721  struct desc {
7722  dnnl_rnn_desc_t data;
7723 
7764  desc(prop_kind aprop_kind, algorithm activation,
7765  rnn_direction direction, const memory::desc &src_layer_desc,
7766  const memory::desc &src_iter_desc,
7767  const memory::desc &weights_layer_desc,
7768  const memory::desc &weights_iter_desc,
7769  const memory::desc &bias_desc,
7770  const memory::desc &dst_layer_desc,
7771  const memory::desc &dst_iter_desc,
7772  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7773  float beta = 0.0f) {
7776  dnnl::convert_to_c(aprop_kind),
7777  dnnl::convert_to_c(activation),
7778  dnnl::convert_to_c(direction), &src_layer_desc.data,
7779  &src_iter_desc.data, &weights_layer_desc.data,
7780  &weights_iter_desc.data, &bias_desc.data,
7781  &dst_layer_desc.data, &dst_iter_desc.data,
7782  dnnl::convert_to_c(flags), alpha, beta),
7783  "could not create a descriptor for a vanilla RNN forward "
7784  "propagation primitive");
7785  }
7786  };
7787 
7791  primitive_desc() = default;
7792 
7803  primitive_desc(const desc &adesc, const engine &aengine,
7804  bool allow_empty = false)
7806  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
7807 
7819  primitive_desc(const desc &adesc, const primitive_attr &attr,
7820  const engine &aengine, bool allow_empty = false)
7822  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7823 
7833  dnnl::algorithm::vanilla_rnn) {}
7834 
7837  return rnn_base::src_layer_desc();
7838  }
7839 
7842 
7846  }
7847 
7850  return rnn_base::weights_iter_desc();
7851  }
7852 
7855 
7858  return rnn_base::dst_layer_desc();
7859  }
7860 
7863 
7866  return rnn_base::workspace_desc();
7867  }
7868  };
7869 
7871  vanilla_rnn_forward() = default;
7872 
7877 };
7878 
7882  struct desc {
7883  dnnl_rnn_desc_t data;
7884 
7937  desc(prop_kind aprop_kind, algorithm activation,
7938  rnn_direction direction, const memory::desc &src_layer_desc,
7939  const memory::desc &src_iter_desc,
7940  const memory::desc &weights_layer_desc,
7941  const memory::desc &weights_iter_desc,
7942  const memory::desc &bias_desc,
7943  const memory::desc &dst_layer_desc,
7944  const memory::desc &dst_iter_desc,
7945  const memory::desc &diff_src_layer_desc,
7946  const memory::desc &diff_src_iter_desc,
7947  const memory::desc &diff_weights_layer_desc,
7948  const memory::desc &diff_weights_iter_desc,
7949  const memory::desc &diff_bias_desc,
7950  const memory::desc &diff_dst_layer_desc,
7951  const memory::desc &diff_dst_iter_desc,
7952  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7953  float beta = 0.0f) {
7956  dnnl::convert_to_c(aprop_kind),
7957  dnnl::convert_to_c(activation),
7958  dnnl::convert_to_c(direction), &src_layer_desc.data,
7959  &src_iter_desc.data, &weights_layer_desc.data,
7960  &weights_iter_desc.data, &bias_desc.data,
7961  &dst_layer_desc.data, &dst_iter_desc.data,
7962  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
7963  &diff_weights_layer_desc.data,
7964  &diff_weights_iter_desc.data, &diff_bias_desc.data,
7965  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
7966  dnnl::convert_to_c(flags), alpha, beta),
7967  "could not create a descriptor for a vanilla RNN backward "
7968  "propagation primitive");
7969  }
7970  };
7971 
7975  primitive_desc() = default;
7976 
7990  primitive_desc(const desc &adesc, const engine &aengine,
7991  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
7992  bool allow_empty = false)
7993  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
7994  hint_fwd_pd.get(), allow_empty) {}
7995 
8010  primitive_desc(const desc &adesc, const primitive_attr &attr,
8011  const engine &aengine,
8012  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
8013  bool allow_empty = false)
8014  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
8015  hint_fwd_pd.get(), allow_empty) {}
8016 
8025  dnnl::algorithm::vanilla_rnn) {}
8026 
8029  return rnn_base::src_layer_desc();
8030  }
8031 
8034 
8038  }
8039 
8042  return rnn_base::weights_iter_desc();
8043  }
8044 
8047 
8050  return rnn_base::dst_layer_desc();
8051  }
8052 
8055 
8058  return rnn_base::workspace_desc();
8059  }
8060 
8064  }
8065 
8069  }
8070 
8074  }
8075 
8079  }
8080 
8083  return rnn_base::diff_bias_desc();
8084  }
8085 
8089  }
8090 
8094  }
8095  };
8096 
8099 
8104 };
8105 
8107 struct lstm_forward : public primitive {
8109  struct desc {
8110  dnnl_rnn_desc_t data;
8111 
8160  desc(prop_kind aprop_kind, rnn_direction direction,
8161  const memory::desc &src_layer_desc,
8162  const memory::desc &src_iter_desc,
8163  const memory::desc &src_iter_c_desc,
8164  const memory::desc &weights_layer_desc,
8165  const memory::desc &weights_iter_desc,
8166  const memory::desc &weights_peephole_desc,
8167  const memory::desc &weights_projection_desc,
8168  const memory::desc &bias_desc,
8169  const memory::desc &dst_layer_desc,
8170  const memory::desc &dst_iter_desc,
8171  const memory::desc &dst_iter_c_desc,
8172  rnn_flags flags = rnn_flags::undef) {
8175  dnnl::convert_to_c(aprop_kind),
8176  dnnl::convert_to_c(direction), &src_layer_desc.data,
8177  &src_iter_desc.data, &src_iter_c_desc.data,
8178  &weights_layer_desc.data, &weights_iter_desc.data,
8179  &weights_peephole_desc.data,
8180  &weights_projection_desc.data, &bias_desc.data,
8181  &dst_layer_desc.data, &dst_iter_desc.data,
8182  &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8183  "could not create a descriptor for an LSTM forward "
8184  "propagation primitive");
8185  }
8186 
8228  desc(prop_kind aprop_kind, rnn_direction direction,
8229  const memory::desc &src_layer_desc,
8230  const memory::desc &src_iter_desc,
8231  const memory::desc &src_iter_c_desc,
8232  const memory::desc &weights_layer_desc,
8233  const memory::desc &weights_iter_desc,
8234  const memory::desc &weights_peephole_desc,
8235  const memory::desc &bias_desc,
8236  const memory::desc &dst_layer_desc,
8237  const memory::desc &dst_iter_desc,
8238  const memory::desc &dst_iter_c_desc,
8239  rnn_flags flags = rnn_flags::undef) {
8242  dnnl::convert_to_c(aprop_kind),
8243  dnnl::convert_to_c(direction), &src_layer_desc.data,
8244  &src_iter_desc.data, &src_iter_c_desc.data,
8245  &weights_layer_desc.data, &weights_iter_desc.data,
8246  &weights_peephole_desc.data, &bias_desc.data,
8247  &dst_layer_desc.data, &dst_iter_desc.data,
8248  &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8249  "could not create a descriptor for an LSTM forward "
8250  "propagation primitive");
8251  }
8252 
8289  desc(prop_kind aprop_kind, rnn_direction direction,
8290  const memory::desc &src_layer_desc,
8291  const memory::desc &src_iter_desc,
8292  const memory::desc &src_iter_c_desc,
8293  const memory::desc &weights_layer_desc,
8294  const memory::desc &weights_iter_desc,
8295  const memory::desc &bias_desc,
8296  const memory::desc &dst_layer_desc,
8297  const memory::desc &dst_iter_desc,
8298  const memory::desc &dst_iter_c_desc,
8299  rnn_flags flags = rnn_flags::undef) {
8302  dnnl::convert_to_c(aprop_kind),
8303  dnnl::convert_to_c(direction), &src_layer_desc.data,
8304  &src_iter_desc.data, &src_iter_c_desc.data,
8305  &weights_layer_desc.data, &weights_iter_desc.data,
8306  &bias_desc.data, &dst_layer_desc.data,
8307  &dst_iter_desc.data, &dst_iter_c_desc.data,
8308  dnnl::convert_to_c(flags)),
8309  "could not create a descriptor for an LSTM forward "
8310  "propagation primitive");
8311  }
8312  };
8313 
8317  primitive_desc() = default;
8318 
8328  primitive_desc(const desc &adesc, const engine &aengine,
8329  bool allow_empty = false)
8331  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
8332 
8343  primitive_desc(const desc &adesc, const primitive_attr &attr,
8344  const engine &aengine, bool allow_empty = false)
8346  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
8347 
8358 
8361  return rnn_base::src_layer_desc();
8362  }
8363 
8366 
8369  return rnn_base::src_iter_c_desc();
8370  }
8371 
8375  }
8376 
8379  return rnn_base::weights_iter_desc();
8380  }
8381 
8385  }
8386 
8390  }
8391 
8394 
8397  return rnn_base::dst_layer_desc();
8398  }
8399 
8402 
8405  return rnn_base::dst_iter_c_desc();
8406  }
8407 
8410  return rnn_base::workspace_desc();
8411  }
8412  };
8413 
8415  lstm_forward() = default;
8416 
8421 };
8422 
8424 struct lstm_backward : public primitive {
8426  struct desc {
8427  dnnl_rnn_desc_t data;
8428 
8504  desc(prop_kind aprop_kind, rnn_direction direction,
8505  const memory::desc &src_layer_desc,
8506  const memory::desc &src_iter_desc,
8507  const memory::desc &src_iter_c_desc,
8508  const memory::desc &weights_layer_desc,
8509  const memory::desc &weights_iter_desc,
8510  const memory::desc &weights_peephole_desc,
8511  const memory::desc &weights_projection_desc,
8512  const memory::desc &bias_desc,
8513  const memory::desc &dst_layer_desc,
8514  const memory::desc &dst_iter_desc,
8515  const memory::desc &dst_iter_c_desc,
8516  const memory::desc &diff_src_layer_desc,
8517  const memory::desc &diff_src_iter_desc,
8518  const memory::desc &diff_src_iter_c_desc,
8519  const memory::desc &diff_weights_layer_desc,
8520  const memory::desc &diff_weights_iter_desc,
8521  const memory::desc &diff_weights_peephole_desc,
8522  const memory::desc &diff_weights_projection_desc,
8523  const memory::desc &diff_bias_desc,
8524  const memory::desc &diff_dst_layer_desc,
8525  const memory::desc &diff_dst_iter_desc,
8526  const memory::desc &diff_dst_iter_c_desc,
8527  rnn_flags flags = rnn_flags::undef) {
8530  dnnl::convert_to_c(aprop_kind),
8531  dnnl::convert_to_c(direction), &src_layer_desc.data,
8532  &src_iter_desc.data, &src_iter_c_desc.data,
8533  &weights_layer_desc.data, &weights_iter_desc.data,
8534  &weights_peephole_desc.data,
8535  &weights_projection_desc.data, &bias_desc.data,
8536  &dst_layer_desc.data, &dst_iter_desc.data,
8537  &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8538  &diff_src_iter_desc.data,
8539  &diff_src_iter_c_desc.data,
8540  &diff_weights_layer_desc.data,
8541  &diff_weights_iter_desc.data,
8542  &diff_weights_peephole_desc.data,
8543  &diff_weights_projection_desc.data,
8544  &diff_bias_desc.data, &diff_dst_layer_desc.data,
8545  &diff_dst_iter_desc.data,
8546  &diff_dst_iter_c_desc.data,
8547  dnnl::convert_to_c(flags)),
8548  "could not create a descriptor for an LSTM backward "
8549  "propagation primitive");
8550  }
8551 
8616  desc(prop_kind aprop_kind, rnn_direction direction,
8617  const memory::desc &src_layer_desc,
8618  const memory::desc &src_iter_desc,
8619  const memory::desc &src_iter_c_desc,
8620  const memory::desc &weights_layer_desc,
8621  const memory::desc &weights_iter_desc,
8622  const memory::desc &weights_peephole_desc,
8623  const memory::desc &bias_desc,
8624  const memory::desc &dst_layer_desc,
8625  const memory::desc &dst_iter_desc,
8626  const memory::desc &dst_iter_c_desc,
8627  const memory::desc &diff_src_layer_desc,
8628  const memory::desc &diff_src_iter_desc,
8629  const memory::desc &diff_src_iter_c_desc,
8630  const memory::desc &diff_weights_layer_desc,
8631  const memory::desc &diff_weights_iter_desc,
8632  const memory::desc &diff_weights_peephole_desc,
8633  const memory::desc &diff_bias_desc,
8634  const memory::desc &diff_dst_layer_desc,
8635  const memory::desc &diff_dst_iter_desc,
8636  const memory::desc &diff_dst_iter_c_desc,
8637  rnn_flags flags = rnn_flags::undef) {
8640  dnnl::convert_to_c(aprop_kind),
8641  dnnl::convert_to_c(direction), &src_layer_desc.data,
8642  &src_iter_desc.data, &src_iter_c_desc.data,
8643  &weights_layer_desc.data, &weights_iter_desc.data,
8644  &weights_peephole_desc.data, &bias_desc.data,
8645  &dst_layer_desc.data, &dst_iter_desc.data,
8646  &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8647  &diff_src_iter_desc.data,
8648  &diff_src_iter_c_desc.data,
8649  &diff_weights_layer_desc.data,
8650  &diff_weights_iter_desc.data,
8651  &diff_weights_peephole_desc.data,
8652  &diff_bias_desc.data, &diff_dst_layer_desc.data,
8653  &diff_dst_iter_desc.data,
8654  &diff_dst_iter_c_desc.data,
8655  dnnl::convert_to_c(flags)),
8656  "could not create a descriptor for an LSTM backward "
8657  "propagation primitive");
8658  }
8659 
8715  desc(prop_kind aprop_kind, rnn_direction direction,
8716  const memory::desc &src_layer_desc,
8717  const memory::desc &src_iter_desc,
8718  const memory::desc &src_iter_c_desc,
8719  const memory::desc &weights_layer_desc,
8720  const memory::desc &weights_iter_desc,
8721  const memory::desc &bias_desc,
8722  const memory::desc &dst_layer_desc,
8723  const memory::desc &dst_iter_desc,
8724  const memory::desc &dst_iter_c_desc,
8725  const memory::desc &diff_src_layer_desc,
8726  const memory::desc &diff_src_iter_desc,
8727  const memory::desc &diff_src_iter_c_desc,
8728  const memory::desc &diff_weights_layer_desc,
8729  const memory::desc &diff_weights_iter_desc,
8730  const memory::desc &diff_bias_desc,
8731  const memory::desc &diff_dst_layer_desc,
8732  const memory::desc &diff_dst_iter_desc,
8733  const memory::desc &diff_dst_iter_c_desc,
8734  rnn_flags flags = rnn_flags::undef) {
8737  dnnl::convert_to_c(aprop_kind),
8738  dnnl::convert_to_c(direction), &src_layer_desc.data,
8739  &src_iter_desc.data, &src_iter_c_desc.data,
8740  &weights_layer_desc.data, &weights_iter_desc.data,
8741  &bias_desc.data, &dst_layer_desc.data,
8742  &dst_iter_desc.data, &dst_iter_c_desc.data,
8743  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
8744  &diff_src_iter_c_desc.data,
8745  &diff_weights_layer_desc.data,
8746  &diff_weights_iter_desc.data, &diff_bias_desc.data,
8747  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
8748  &diff_dst_iter_c_desc.data,
8749  dnnl::convert_to_c(flags)),
8750  "could not create a descriptor for an LSTM backward "
8751  "propagation primitive");
8752  }
8753  };
8754 
8758  primitive_desc() = default;
8759 
8772  primitive_desc(const desc &adesc, const engine &aengine,
8773  const lstm_forward::primitive_desc &hint_fwd_pd,
8774  bool allow_empty = false)
8775  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
8776  hint_fwd_pd.get(), allow_empty) {}
8777 
8791  primitive_desc(const desc &adesc, const primitive_attr &attr,
8792  const engine &aengine,
8793  const lstm_forward::primitive_desc &hint_fwd_pd,
8794  bool allow_empty = false)
8795  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
8796  hint_fwd_pd.get(), allow_empty) {}
8797 
8807 
8810  return rnn_base::src_layer_desc();
8811  }
8812 
8815 
8818  return rnn_base::src_iter_c_desc();
8819  }
8820 
8824  }
8825 
8828  return rnn_base::weights_iter_desc();
8829  }
8830 
8834  }
8835 
8839  }
8840 
8843 
8846  return rnn_base::dst_layer_desc();
8847  }
8848 
8851 
8854  return rnn_base::dst_iter_c_desc();
8855  }
8856 
8859  return rnn_base::workspace_desc();
8860  }
8861 
8865  }
8866 
8870  }
8871 
8875  }
8876 
8880  }
8881 
8885  }
8886 
8890  }
8891 
8895  }
8896 
8899  return rnn_base::diff_bias_desc();
8900  }
8901 
8905  }
8906 
8910  }
8911 
8915  }
8916  };
8917 
8919  lstm_backward() = default;
8920 
8925 };
8926 
8928 struct gru_forward : public primitive {
8930  struct desc {
8931  dnnl_rnn_desc_t data;
8932 
8965  desc(prop_kind aprop_kind, rnn_direction direction,
8966  const memory::desc &src_layer_desc,
8967  const memory::desc &src_iter_desc,
8968  const memory::desc &weights_layer_desc,
8969  const memory::desc &weights_iter_desc,
8970  const memory::desc &bias_desc,
8971  const memory::desc &dst_layer_desc,
8972  const memory::desc &dst_iter_desc,
8973  rnn_flags flags = rnn_flags::undef) {
8976  dnnl::convert_to_c(aprop_kind),
8977  dnnl::convert_to_c(direction), &src_layer_desc.data,
8978  &src_iter_desc.data, &weights_layer_desc.data,
8979  &weights_iter_desc.data, &bias_desc.data,
8980  &dst_layer_desc.data, &dst_iter_desc.data,
8981  dnnl::convert_to_c(flags)),
8982  "could not create a descriptor for a GRU forward "
8983  "propagation primitive");
8984  }
8985  };
8986 
8990  primitive_desc() = default;
8991 
9001  primitive_desc(const desc &adesc, const engine &aengine,
9002  bool allow_empty = false)
9004  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9005 
9016  primitive_desc(const desc &adesc, const primitive_attr &attr,
9017  const engine &aengine, bool allow_empty = false)
9019  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9020 
9030  dnnl::algorithm::vanilla_gru) {}
9031 
9034  return rnn_base::src_layer_desc();
9035  }
9036 
9039 
9043  }
9044 
9047  return rnn_base::weights_iter_desc();
9048  }
9049 
9052 
9055  return rnn_base::dst_layer_desc();
9056  }
9057 
9060 
9063  return rnn_base::workspace_desc();
9064  }
9065  };
9066 
9068  gru_forward() = default;
9069 
9074 };
9075 
9077 struct gru_backward : public primitive {
9079  struct desc {
9080  dnnl_rnn_desc_t data;
9081 
9126  desc(prop_kind aprop_kind, rnn_direction direction,
9127  const memory::desc &src_layer_desc,
9128  const memory::desc &src_iter_desc,
9129  const memory::desc &weights_layer_desc,
9130  const memory::desc &weights_iter_desc,
9131  const memory::desc &bias_desc,
9132  const memory::desc &dst_layer_desc,
9133  const memory::desc &dst_iter_desc,
9134  const memory::desc &diff_src_layer_desc,
9135  const memory::desc &diff_src_iter_desc,
9136  const memory::desc &diff_weights_layer_desc,
9137  const memory::desc &diff_weights_iter_desc,
9138  const memory::desc &diff_bias_desc,
9139  const memory::desc &diff_dst_layer_desc,
9140  const memory::desc &diff_dst_iter_desc,
9141  rnn_flags flags = rnn_flags::undef) {
9144  dnnl::convert_to_c(aprop_kind),
9145  dnnl::convert_to_c(direction), &src_layer_desc.data,
9146  &src_iter_desc.data, &weights_layer_desc.data,
9147  &weights_iter_desc.data, &bias_desc.data,
9148  &dst_layer_desc.data, &dst_iter_desc.data,
9149  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9150  &diff_weights_layer_desc.data,
9151  &diff_weights_iter_desc.data, &diff_bias_desc.data,
9152  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9153  dnnl::convert_to_c(flags)),
9154  "could not create a descriptor for a GRU backward "
9155  "propagation primitive");
9156  }
9157  };
9158 
9162  primitive_desc() = default;
9163 
9176  primitive_desc(const desc &adesc, const engine &aengine,
9177  const gru_forward::primitive_desc &hint_fwd_pd,
9178  bool allow_empty = false)
9179  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
9180  hint_fwd_pd.get(), allow_empty) {}
9181 
9195  primitive_desc(const desc &adesc, const primitive_attr &attr,
9196  const engine &aengine,
9197  const gru_forward::primitive_desc &hint_fwd_pd,
9198  bool allow_empty = false)
9199  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
9200  hint_fwd_pd.get(), allow_empty) {}
9201 
9210  dnnl::algorithm::vanilla_gru) {}
9211 
9214  return rnn_base::src_layer_desc();
9215  }
9216 
9219 
9223  }
9224 
9227  return rnn_base::weights_iter_desc();
9228  }
9229 
9232 
9235  return rnn_base::dst_layer_desc();
9236  }
9237 
9240 
9243  return rnn_base::workspace_desc();
9244  }
9245 
9249  }
9250 
9254  }
9255 
9259  }
9260 
9264  }
9265 
9268  return rnn_base::diff_bias_desc();
9269  }
9270 
9274  }
9275 
9279  }
9280  };
9281 
9283  gru_backward() = default;
9284 
9289 };
9290 
9292 struct lbr_gru_forward : public primitive {
9294  struct desc {
9295  dnnl_rnn_desc_t data;
9296 
9330  desc(prop_kind aprop_kind, rnn_direction direction,
9331  const memory::desc &src_layer_desc,
9332  const memory::desc &src_iter_desc,
9333  const memory::desc &weights_layer_desc,
9334  const memory::desc &weights_iter_desc,
9335  const memory::desc &bias_desc,
9336  const memory::desc &dst_layer_desc,
9337  const memory::desc &dst_iter_desc,
9338  rnn_flags flags = rnn_flags::undef) {
9341  dnnl::convert_to_c(aprop_kind),
9342  dnnl::convert_to_c(direction), &src_layer_desc.data,
9343  &src_iter_desc.data, &weights_layer_desc.data,
9344  &weights_iter_desc.data, &bias_desc.data,
9345  &dst_layer_desc.data, &dst_iter_desc.data,
9346  dnnl::convert_to_c(flags)),
9347  "could not create a descriptor for an LBR GRU forward "
9348  "propagation primitive");
9349  }
9350  };
9351 
9355  primitive_desc() = default;
9356 
9367  primitive_desc(const desc &adesc, const engine &aengine,
9368  bool allow_empty = false)
9370  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9371 
9383  primitive_desc(const desc &adesc, const primitive_attr &attr,
9384  const engine &aengine, bool allow_empty = false)
9386  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9387 
9397  dnnl::algorithm::lbr_gru) {}
9398 
9401  return rnn_base::src_layer_desc();
9402  }
9403 
9406 
9410  }
9411 
9414  return rnn_base::weights_iter_desc();
9415  }
9416 
9419 
9422  return rnn_base::dst_layer_desc();
9423  }
9424 
9427 
9430  return rnn_base::workspace_desc();
9431  }
9432  };
9433 
9435  lbr_gru_forward() = default;
9436 
9441 };
9442 
9444 struct lbr_gru_backward : public primitive {
9446  struct desc {
9447  dnnl_rnn_desc_t data;
9448 
9494  desc(prop_kind aprop_kind, rnn_direction direction,
9495  const memory::desc &src_layer_desc,
9496  const memory::desc &src_iter_desc,
9497  const memory::desc &weights_layer_desc,
9498  const memory::desc &weights_iter_desc,
9499  const memory::desc &bias_desc,
9500  const memory::desc &dst_layer_desc,
9501  const memory::desc &dst_iter_desc,
9502  const memory::desc &diff_src_layer_desc,
9503  const memory::desc &diff_src_iter_desc,
9504  const memory::desc &diff_weights_layer_desc,
9505  const memory::desc &diff_weights_iter_desc,
9506  const memory::desc &diff_bias_desc,
9507  const memory::desc &diff_dst_layer_desc,
9508  const memory::desc &diff_dst_iter_desc,
9509  rnn_flags flags = rnn_flags::undef) {
9512  dnnl::convert_to_c(aprop_kind),
9513  dnnl::convert_to_c(direction), &src_layer_desc.data,
9514  &src_iter_desc.data, &weights_layer_desc.data,
9515  &weights_iter_desc.data, &bias_desc.data,
9516  &dst_layer_desc.data, &dst_iter_desc.data,
9517  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9518  &diff_weights_layer_desc.data,
9519  &diff_weights_iter_desc.data, &diff_bias_desc.data,
9520  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9521  dnnl::convert_to_c(flags)),
9522  "could not create a descriptor for an LBR GRU backward "
9523  "propagation primitive");
9524  }
9525  };
9526 
9530  primitive_desc() = default;
9531 
9545  primitive_desc(const desc &adesc, const engine &aengine,
9546  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9547  bool allow_empty = false)
9548  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
9549  hint_fwd_pd.get(), allow_empty) {}
9550 
9565  primitive_desc(const desc &adesc, const primitive_attr &attr,
9566  const engine &aengine,
9567  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9568  bool allow_empty = false)
9569  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
9570  hint_fwd_pd.get(), allow_empty) {}
9571 
9581 
9584  return rnn_base::src_layer_desc();
9585  }
9586 
9589 
9593  }
9594 
9597  return rnn_base::weights_iter_desc();
9598  }
9599 
9602 
9605  return rnn_base::dst_layer_desc();
9606  }
9607 
9610 
9613  return rnn_base::workspace_desc();
9614  }
9615 
9619  }
9620 
9624  }
9625 
9629  }
9630 
9634  }
9635 
9638  return rnn_base::diff_bias_desc();
9639  }
9640 
9644  }
9645 
9649  }
9650  };
9651 
9653  lbr_gru_backward() = default;
9654 
9659 };
9660 
9662 
9670 
9672 struct shuffle_forward : public primitive {
9674  struct desc {
9675  dnnl_shuffle_desc_t data;
9676 
9686  desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis,
9687  int group_size) {
9689  dnnl::convert_to_c(aprop_kind),
9690  &data_desc.data, axis, group_size),
9691  "could not create a descriptor for a shuffle forward "
9692  "propagation primitive");
9693  }
9694  };
9695 
9699  primitive_desc() = default;
9700 
9712  primitive_desc(const desc &adesc, const engine &aengine,
9713  const primitive_attr &attr = primitive_attr(),
9714  bool allow_empty = false)
9715  : dnnl::primitive_desc(
9716  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9717 
9725  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
9728 
9730  memory::desc src_desc() const { return base::src_desc(0); }
9731 
9733  memory::desc dst_desc() const { return base::dst_desc(0); }
9734  };
9735 
9737  shuffle_forward() = default;
9738 
9743 };
9744 
9746 struct shuffle_backward : public primitive {
9749  struct desc {
9750  dnnl_shuffle_desc_t data;
9751 
9759  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
9761  &diff_data_desc.data, axis, group_size),
9762  "could not create a descriptor for a shuffle backward "
9763  "propagation primitive");
9764  }
9765  };
9766 
9770  primitive_desc() = default;
9771 
9786  primitive_desc(const desc &adesc, const engine &aengine,
9787  const shuffle_forward::primitive_desc &hint_fwd_pd,
9788  const primitive_attr &attr = primitive_attr(),
9789  bool allow_empty = false)
9790  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
9791  hint_fwd_pd.get(), allow_empty) {}
9792 
9800  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
9802 
9805 
9808  };
9809 
9811  shuffle_backward() = default;
9812 
9817 };
9818 
9820 
9828 
9830 struct binary : public primitive {
9832  struct desc {
9835 
9837  desc() = default;
9838 
9846  desc(algorithm aalgorithm, const memory::desc &src0,
9847  const memory::desc &src1, const memory::desc &dst) {
9850  &src0.data, &src1.data, &dst.data),
9851  "could not create a descriptor for a binary operation "
9852  "primitive");
9853  }
9854  };
9855 
9859  primitive_desc() = default;
9860 
9870  primitive_desc(const desc &adesc, const engine &aengine,
9871  bool allow_empty = false)
9872  : dnnl::primitive_desc(
9873  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9874 
9885  primitive_desc(const desc &adesc, const primitive_attr &attr,
9886  const engine &aengine, bool allow_empty = false)
9887  : dnnl::primitive_desc(
9888  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9889 
9896 
9898  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
9899 
9901  memory::desc src0_desc() const { return base::src_desc(0); }
9902 
9904  memory::desc src1_desc() const { return base::src_desc(1); }
9905 
9907  memory::desc dst_desc() const { return base::dst_desc(0); }
9908  };
9909 
9911  binary() = default;
9912 
9916  binary(const primitive_desc &pd) : primitive(pd) {}
9917 };
9918 
9920 
9930 
9932 struct matmul : public primitive {
9934  struct desc {
9935  dnnl_matmul_desc_t data;
9936 
9942  desc(const memory::desc &src_desc, const memory::desc &weights_desc,
9943  const memory::desc &dst_desc) {
9945  dnnl_matmul_desc_init(&data, &src_desc.data,
9946  &weights_desc.data, nullptr, &dst_desc.data),
9947  "could not create a descriptor for a matmul primitive");
9948  }
9949 
9956  desc(const memory::desc &src_desc, const memory::desc &weights_desc,
9957  const memory::desc &bias_desc, const memory::desc &dst_desc) {
9958  error::wrap_c_api(dnnl_matmul_desc_init(&data, &src_desc.data,
9959  &weights_desc.data, &bias_desc.data,
9960  &dst_desc.data),
9961  "could not create a descriptor for a matmul primitive");
9962  }
9963  };
9964 
9968  primitive_desc() = default;
9969 
9978  primitive_desc(const desc &adesc, const engine &aengine,
9979  bool allow_empty = false)
9980  : dnnl::primitive_desc(
9981  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9982 
9992  primitive_desc(const desc &adesc, const primitive_attr &attr,
9993  const engine &aengine, bool allow_empty = false)
9994  : dnnl::primitive_desc(
9995  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9996 
10003 
10006 
10009  return query_md(query::weights_md, 0);
10010  }
10011 
10014  return query_md(query::weights_md, 1);
10015  }
10016 
10019  };
10020 
10022  matmul() = default;
10023 
10026  matmul(const primitive_desc &pd) : primitive(pd) {}
10027 };
10028 
10030 
10040 
10044  struct desc {
10046 
10062  desc(prop_kind aprop_kind, algorithm aalgorithm,
10063  const memory::desc &src_desc, const memory::desc &dst_desc) {
10065  dnnl::convert_to_c(aprop_kind),
10066  convert_to_c(aalgorithm), nullptr,
10067  &src_desc.data, &dst_desc.data),
10068  "could not create a resampling forward descriptor");
10069  }
10070 
10082  desc(prop_kind aprop_kind, algorithm aalgorithm,
10083  const std::vector<float> &factors,
10084  const memory::desc &src_desc) {
10085  memory::validate_dims(factors, src_desc.data.ndims - 2);
10087  dnnl::convert_to_c(aprop_kind),
10088  convert_to_c(aalgorithm), &factors[0],
10089  &src_desc.data, nullptr),
10090  "could not create a resampling forward descriptor");
10091  }
10092 
10109  desc(prop_kind aprop_kind, algorithm aalgorithm,
10110  const std::vector<float> &factors, const memory::desc &src_desc,
10111  const memory::desc &dst_desc) {
10112  if (!factors.empty())
10113  memory::validate_dims(factors, src_desc.data.ndims - 2);
10115  dnnl::convert_to_c(aprop_kind),
10116  convert_to_c(aalgorithm), factors.data(),
10117  &src_desc.data, &dst_desc.data),
10118  "could not create a resampling forward descriptor");
10119  }
10120  };
10121 
10125  primitive_desc() = default;
10126 
10137  primitive_desc(const desc &adesc, const engine &aengine,
10138  bool allow_empty = false)
10139  : dnnl::primitive_desc(
10140  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10141 
10153  primitive_desc(const desc &adesc, const primitive_attr &attr,
10154  const engine &aengine, bool allow_empty = false)
10155  : dnnl::primitive_desc(
10156  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10157 
10165  : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
10168 
10170  memory::desc src_desc() const { return base::src_desc(0); }
10171 
10173  memory::desc dst_desc() const { return base::dst_desc(0); }
10174  };
10175 
10177  resampling_forward() = default;
10178 
10183 };
10184 
10188  struct desc {
10190 
10199  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
10200  const memory::desc &diff_dst_desc) {
10202  convert_to_c(aalgorithm), nullptr,
10203  &diff_src_desc.data, &diff_dst_desc.data),
10204  "could not create a resampling backward data descriptor");
10205  }
10206 
10216  desc(algorithm aalgorithm, const std::vector<float> &factors,
10217  const memory::desc &diff_src_desc,
10218  const memory::desc &diff_dst_desc) {
10219  if (!factors.empty())
10220  memory::validate_dims(factors, diff_src_desc.data.ndims - 2);
10222  convert_to_c(aalgorithm), factors.data(),
10223  &diff_src_desc.data, &diff_dst_desc.data),
10224  "could not create a resampling backward data descriptor");
10225  }
10226  };
10227 
10231  primitive_desc() = default;
10232 
10246  primitive_desc(const desc &adesc, const engine &aengine,
10247  const resampling_forward::primitive_desc &hint_fwd_pd,
10248  bool allow_empty = false)
10249  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10250  hint_fwd_pd.get(), allow_empty) {}
10251 
10266  primitive_desc(const desc &adesc, const primitive_attr &attr,
10267  const engine &aengine,
10268  const resampling_forward::primitive_desc &hint_fwd_pd,
10269  bool allow_empty = false)
10270  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10271  hint_fwd_pd.get(), allow_empty) {}
10272 
10280  : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
10282 
10285 
10288  };
10289 
10291  resampling_backward() = default;
10292 
10297 };
10298 
10300 
10308 
10312  struct desc {
10314 
10341  desc(prop_kind aprop_kind, algorithm aalgorithm,
10342  const memory::desc &src_desc, const memory::desc &dst_desc,
10343  const memory::dims &strides, const memory::dims &kernel,
10344  const memory::dims &dilation, const memory::dims &padding_l,
10345  const memory::dims &padding_r) {
10346  memory::validate_dims(strides, src_desc.data.ndims - 2);
10347  memory::validate_dims(kernel, src_desc.data.ndims - 2);
10348  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
10349  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
10350  memory::validate_dims(dilation, src_desc.data.ndims - 2);
10353  dnnl::convert_to_c(aprop_kind),
10354  convert_to_c(aalgorithm), &src_desc.data,
10355  &dst_desc.data, &strides[0], &kernel[0],
10356  &dilation[0], &padding_l[0], &padding_r[0]),
10357  "could not create a descriptor for a pooling forward "
10358  "propagation primitive");
10359  }
10360  };
10361 
10365  primitive_desc() = default;
10366 
10377  primitive_desc(const desc &adesc, const engine &aengine,
10378  bool allow_empty = false)
10379  : dnnl::primitive_desc(
10380  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10381 
10393  primitive_desc(const desc &adesc, const primitive_attr &attr,
10394  const engine &aengine, bool allow_empty = false)
10395  : dnnl::primitive_desc(
10396  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10397 
10406  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling_v2,
10409 
10411  memory::desc src_desc() const { return base::src_desc(0); }
10412 
10414  memory::desc dst_desc() const { return base::dst_desc(0); }
10415 
10418  };
10419 
10421  pooling_v2_forward() = default;
10422 
10428 };
10429 
10433  struct desc {
10435 
10459  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
10460  const memory::desc &diff_dst_desc, const memory::dims &strides,
10461  const memory::dims &kernel, const memory::dims &dilation,
10462  const memory::dims &padding_l, const memory::dims &padding_r) {
10463  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
10464  memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
10465  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
10466  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
10467  memory::validate_dims(dilation, diff_src_desc.data.ndims - 2);
10470  convert_to_c(aalgorithm), &diff_src_desc.data,
10471  &diff_dst_desc.data, &strides[0], &kernel[0],
10472  &dilation[0], &padding_l[0], &padding_r[0]),
10473  "could not create a descriptor for a pooling backward "
10474  "propagation primitive");
10475  }
10476  };
10477 
10482  primitive_desc() = default;
10483 
10497  primitive_desc(const desc &adesc, const engine &aengine,
10498  const pooling_v2_forward::primitive_desc &hint_fwd_pd,
10499  bool allow_empty = false)
10500  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10501  hint_fwd_pd.get(), allow_empty) {}
10502 
10517  primitive_desc(const desc &adesc, const primitive_attr &attr,
10518  const engine &aengine,
10519  const pooling_v2_forward::primitive_desc &hint_fwd_pd,
10520  bool allow_empty = false)
10521  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10522  hint_fwd_pd.get(), allow_empty) {}
10523 
10532  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling_v2,
10534 
10537 
10540 
10543  };
10544 
10546  pooling_v2_backward() = default;
10547 
10553 };
10554 
10556 
10565 
10567 struct prelu_forward : public primitive {
10569  struct desc {
10570  dnnl_prelu_desc_t data;
10571 
10580  desc(prop_kind aprop_kind, const memory::desc &data_desc,
10581  const memory::desc &weight_desc) {
10583  dnnl::convert_to_c(aprop_kind),
10584  &data_desc.data, &weight_desc.data),
10585  "could not create a descriptor for a prelu forward "
10586  "propagation primitive");
10587  }
10588  };
10589 
10593  primitive_desc() = default;
10594 
10605  primitive_desc(const desc &adesc, const engine &aengine,
10606  bool allow_empty = false)
10607  : dnnl::primitive_desc(
10608  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10609 
10621  primitive_desc(const desc &adesc, const primitive_attr &attr,
10622  const engine &aengine, bool allow_empty = false)
10623  : dnnl::primitive_desc(
10624  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10625 
10633  : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
10636 
10638  memory::desc src_desc() const { return base::src_desc(0); }
10639 
10641  memory::desc dst_desc() const { return base::dst_desc(0); }
10642  };
10643 
10645  prelu_forward() = default;
10646 
10651 };
10652 
10654 struct prelu_backward : public primitive {
10656  struct desc {
10657  dnnl_prelu_desc_t data;
10658 
10667  desc(const memory::desc &data_desc, const memory::desc &weight_desc,
10668  const memory::desc &diff_data_desc,
10669  const memory::desc &diff_weights_desc) {
10671  dnnl_prelu_backward_desc_init(&data, &data_desc.data,
10672  &weight_desc.data, &diff_data_desc.data,
10673  &diff_weights_desc.data),
10674  "could not create a descriptor for a prelu backward "
10675  "propagation primitive");
10676  }
10677  };
10678 
10682  primitive_desc() = default;
10683 
10697  primitive_desc(const desc &adesc, const engine &aengine,
10698  const prelu_forward::primitive_desc &hint_fwd_pd,
10699  bool allow_empty = false)
10700  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10701  hint_fwd_pd.get(), allow_empty) {}
10702 
10717  primitive_desc(const desc &adesc, const primitive_attr &attr,
10718  const engine &aengine,
10719  const prelu_forward::primitive_desc &hint_fwd_pd,
10720  bool allow_empty = false)
10721  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10722  hint_fwd_pd.get(), allow_empty) {}
10723 
10731  : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
10733 
10735  memory::desc src_desc() const { return base::src_desc(0); }
10736 
10739 
10742  };
10743 
10745  prelu_backward() = default;
10746 
10751 };
10752 
10754 
10763 
10765 struct reduction : public primitive {
10767  struct desc {
10768  dnnl_reduction_desc_t data;
10769 
10771  desc() = default;
10772 
10790  desc(algorithm aalgorithm, const memory::desc &src_desc,
10791  const memory::desc &dst_desc, float p, float eps) {
10793  dnnl_reduction_desc_init(&data, convert_to_c(aalgorithm),
10794  &src_desc.data, &dst_desc.data, p, eps),
10795  "could not create a reduction descriptor");
10796  }
10797  };
10798 
10802  primitive_desc() = default;
10803 
10812  primitive_desc(const desc &adesc, const engine &aengine,
10813  bool allow_empty = false)
10814  : dnnl::primitive_desc(
10815  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10816 
10826  primitive_desc(const desc &adesc, const primitive_attr &attr,
10827  const engine &aengine, bool allow_empty = false)
10828  : dnnl::primitive_desc(
10829  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10830 
10837 
10839  memory::desc src_desc() const { return base::src_desc(0); }
10840 
10842  memory::desc dst_desc() const { return base::dst_desc(0); }
10843  };
10844 
10846  reduction() = default;
10847 
10850  reduction(const primitive_desc &pd) : primitive(pd) {}
10851 };
10852 
10854 
10856 
10862 
10865 
10867 enum class status {
10882 };
10883 
10885 inline status set_verbose(int level) {
10886  return static_cast<status>(dnnl_set_verbose(level));
10887 }
10888 
10890 inline const version_t *version() {
10891  return dnnl_version();
10892 }
10893 
10895 inline status set_jit_dump(int enable) {
10896  return static_cast<status>(dnnl_set_jit_dump(enable));
10897 }
10898 
10900 inline status set_jit_profiling_flags(unsigned flags) {
10901  return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
10902 }
10903 
10905 inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
10906  return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
10907 }
10908 
10910 enum class cpu_isa {
10933 };
10934 
10937  return static_cast<status>(
10938  dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
10939 }
10940 
10943  return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
10944 }
10945 
10947 enum class cpu_isa_hints {
10952 };
10953 
10956  return static_cast<status>(dnnl_set_cpu_isa_hints(
10957  static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
10958 }
10959 
10962  return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
10963 }
10964 
10966 
10972 
10976  int result = 0;
10978  "could not get primitive cache capacity");
10979  return result;
10980 }
10981 
10983 inline void set_primitive_cache_capacity(int capacity) {
10985  "could not set primitive cache capacity");
10986 }
10987 
10989 
10996 
10998 inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
10999  dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
11000  const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
11001  return static_cast<status>(dnnl_sgemm(
11002  transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
11003 }
11004 
11006 inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
11007  dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
11008  dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
11009  float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
11010  return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
11011  K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
11012 }
11013 
11015 inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
11016  dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
11017  dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
11018  float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
11019  return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
11020  K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
11021 }
11022 
11024 
11025 // implementation section
11026 
11029  dnnl_primitive_t result;
11031  "could not create a primitive");
11032  reset(result);
11033 }
11034 
11035 inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
11036 
11037 inline void primitive::execute(const stream &astream,
11038  const std::unordered_map<int, memory> &args) const {
11039  std::vector<dnnl_exec_arg_t> c_args;
11040  c_args.reserve(args.size());
11041  for (const auto &a : args)
11042  c_args.push_back({a.first, a.second.get(true)});
11043 
11044  error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
11045  (int)c_args.size(), c_args.data()),
11046  "could not execute a primitive");
11047 }
11048 
11050 
11051 #undef DNNL_DEFINE_BITMASK_OPS
11052 
11053 } // namespace dnnl
11054 
11056 
11059 namespace oneapi {
11060 // Note: without this guard, doxygen warns of potentially recursive namespace
11061 #ifndef DOXYGEN_SHOULD_SKIP_THIS
11063 namespace dnnl = ::dnnl;
11064 #endif
11065 } // namespace oneapi
11066 
11068 
11069 #endif /* ONEAPI_DNNL_DNNL_HPP */
algorithm
Kinds of algorithms.
Definition: dnnl.hpp:470
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(dnnl_primitive_attr_t attr, const float scale, const float shift)
Set quantization scale and shift parameters for RNN data tensors.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v2(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_data_type_t *data_type)
Returns the parameters of an accumulation (sum) post-op with a data type parameter.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 2.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode)
Sets primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops)
Returns primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns the quantization scaling factors for RNN weights tensors.
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 1.
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops)
Destroys post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const int32_t *zero_points)
Sets primitive attributes zero points for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops)
Sets primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops, float scale)
Appends an accumulation (sum) to post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets quantization scaling factors for RNN weights tensors.
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr)
Destroys primitive attributes.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v2(dnnl_post_ops_t post_ops, float scale, dnnl_data_type_t data_type)
Appends an accumulation v2 (sum) to post-ops.
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops)
Returns the length of post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc)
Appends a binary post-op.
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 2.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns the quantization scaling factors for RNN projection weights tensors.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_alg_kind_t *alg_kind, float *alpha, float *beta)
Returns the parameters of an elementwise post-op.
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops)
Creates empty post-ops sequence.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const float *scales)
Sets primitive attributes scaling factors for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode)
Returns the primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr)
Clones primitive attributes.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 1.
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(const_dnnl_post_ops_t post_ops, int index)
Returns the kind of a post-op entry.
scratchpad_mode
Scratchpad mode.
Definition: dnnl.hpp:401
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets quantization scaling factors for RNN projection weights tensors.
prop_kind
Propagation kind.
Definition: dnnl.hpp:435
dnnl_scratchpad_mode_t
Scratchpad mode.
Definition: dnnl_types.h:2316
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta)
Appends an elementwise post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points(const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const int32_t **zero_points)
Returns count, correspondence zero point mask, and a pointer to a constant int32_t array of zero_poin...
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(const_dnnl_post_ops_t post_ops, int index, float *scale)
Returns the parameters of an accumulation (sum) post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind, const dnnl_memory_desc_t **src1_desc)
Returns the parameters of a binary post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(const_dnnl_primitive_attr_t attr, float *scale, float *shift)
Returns the quantization scale and shift parameters for RNN data tensors.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets output scaling factors correspondence mask and values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes scaling factors correspondence mask and values for a given memory argume...
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr)
Creates an empty (default) primitive attributes with all the parameters set to their default values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes output scaling factors correspondence mask and values.
@ eltwise_mish
Elementwise: mish.
@ resampling_linear
Linear (Bilinear, Trilinear) resampling method.
@ binary_mul
Binary mul.
@ resampling_nearest
Nearest Neighbor resampling method.
@ eltwise_elu_use_dst_for_bwd
Elementwise: exponential linear unit (ELU) (dst for backward)
@ eltwise_tanh_use_dst_for_bwd
Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
@ reduction_norm_lp_power_p_sum
Reduction using norm_lp_power_p_sum operation.
@ eltwise_linear
Elementwise: linear.
@ eltwise_clip_v2
Eltwise: clip version 2.
@ eltwise_soft_relu
Elementwise: soft_relu.
@ vanilla_gru
GRU cell.
@ eltwise_logistic
Elementwise: logistic.
@ binary_div
Binary div.
@ eltwise_clip
Elementwise: clip.
@ binary_ge
Binary greater than or equal.
@ eltwise_abs
Elementwise: abs.
@ eltwise_pow
Elementwise: pow.
@ eltwise_tanh
Elementwise: hyperbolic tangent non-linearity (tanh)
@ eltwise_logistic_use_dst_for_bwd
Elementwise: logistic (dst for backward)
@ eltwise_bounded_relu
Elementwise: bounded_relu.
@ reduction_norm_lp_power_p_max
Reduction using norm_lp_power_p_max operation.
@ reduction_max
Reduction using max operation.
@ eltwise_clip_v2_use_dst_for_bwd
Elementwise: clip version 2 (dst for backward)
@ eltwise_square
Elementwise: square.
@ binary_max
Binary max.
@ convolution_direct
Direct convolution.
@ eltwise_exp
Elementwise: exponent.
@ binary_gt
Binary greater than.
@ reduction_norm_lp_max
Reduction using norm_lp_max operation.
@ eltwise_elu
Elementwise: exponential linear unit (ELU)
@ convolution_winograd
Winograd convolution.
@ vanilla_lstm
LSTM cell.
@ deconvolution_direct
Direct deconvolution.
@ pooling_avg
Average pooling exclude padding, alias for dnnl::algorithm::pooling_avg_exclude_padding.
@ lbr_gru
GRU cell with linear before reset.
@ binary_eq
Binary equal.
@ pooling_avg_exclude_padding
Average pooling exclude padding.
@ eltwise_gelu
Elementwise: gelu alias for dnnl::algorithm::eltwise_gelu_tanh.
@ eltwise_sqrt
Elementwise: square root.
@ pooling_max
Max pooling.
@ reduction_min
Reduction using min operation.
@ eltwise_gelu_erf
Elementwise: erf-based gelu.
@ eltwise_swish
Elementwise: swish ( )
@ binary_sub
Binary sub.
@ binary_ne
Binary not equal.
@ lrn_within_channel
LRN within a single channel.
@ binary_le
Binary less than or equal.
@ eltwise_hardswish
Elementwise: hardswish.
@ reduction_mul
Reduction using mul operation.
@ vanilla_rnn
RNN cell.
@ binary_add
Binary add.
@ lrn_across_channels
Local response normalization (LRN) across multiple channels.
@ eltwise_relu
Elementwise: rectified linear unit (ReLU)
@ eltwise_gelu_tanh
Elementwise: tanh-based gelu.
@ eltwise_relu_use_dst_for_bwd
Elementwise: rectified linar unit (ReLU) (dst for backward)
@ eltwise_logsigmoid
Elementwise: logsigmoid.
@ convolution_auto
Convolution algorithm that is chosen to be either direct or Winograd automatically.
@ binary_min
Binary min.
@ eltwise_exp_use_dst_for_bwd
Elementwise: exponent (dst for backward)
@ eltwise_round
Elementwise: round.
@ eltwise_sqrt_use_dst_for_bwd
Elementwise: square root (dst for backward)
@ pooling_avg_include_padding
Average pooling include padding.
@ reduction_norm_lp_sum
Reduction using norm_lp_sum operation.
@ reduction_mean
Reduction using mean operation.
@ deconvolution_winograd
Winograd deconvolution.
@ eltwise_log
Elementwise: natural logarithm.
@ undef
Undefined algorithm.
@ binary_lt
Binary less than.
@ reduction_sum
Reduction using sum operation.
@ library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
@ user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
@ backward
Backward propagation (with respect to all parameters).
@ backward_weights
Backward weights propagation.
@ forward_training
Forward data propagation (training mode).
@ forward_inference
Forward data propagation (inference mode).
@ forward_scoring
Forward data propagation, alias for dnnl::prop_kind::forward_inference.
@ forward
Forward data propagation, alias for dnnl::prop_kind::forward_training.
@ backward_data
Backward data propagation.
@ backward_bias
Backward bias propagation.
@ undef
Undefined propagation kind.
@ dnnl_scratchpad_mode_user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
Definition: dnnl_types.h:2338
@ dnnl_scratchpad_mode_library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
Definition: dnnl_types.h:2333
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization forward propagation primitive.
dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc, const dnnl_memory_desc_t *src1_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a binary primitive.
dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
Definition: dnnl.hpp:11006
status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
Definition: dnnl.hpp:11015
dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
Definition: dnnl.hpp:10998
dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(dnnl_primitive_desc_t *concat_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an out-of-place concatenation primitive.
dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise forward propagation primitive.
dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise backward propagation primitive.
dnnl_engine_kind_t
Kinds of engines.
Definition: dnnl_types.h:2262
dnnl_status_t DNNL_API dnnl_engine_get_kind(dnnl_engine_t engine, dnnl_engine_kind_t *kind)
Returns the kind of an engine.
dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine)
Destroys an engine.
dnnl_status_t DNNL_API dnnl_engine_create(dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index)
Creates an engine.
size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind)
Returns the number of engines of a particular kind.
dnnl_engine_kind_t convert_to_c(engine::kind akind)
Converts engine kind enum value from C++ API to C API type.
Definition: dnnl.hpp:987
@ dnnl_gpu
GPU engine.
Definition: dnnl_types.h:2268
@ dnnl_cpu
CPU engine.
Definition: dnnl_types.h:2266
@ dnnl_any_engine
An unspecified engine.
Definition: dnnl_types.h:2264
dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes descriptor for inner product forward propagation.
dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product weights gradient primitive.
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product backward propagation.
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for a layer normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for layer normalization forward propagation primitive.
dnnl_status_t DNNL_API dnnl_logsoftmax_forward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax forward propagation primitive.
dnnl_status_t DNNL_API dnnl_logsoftmax_backward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_matmul_desc_init(dnnl_matmul_desc_t *matmul_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a matrix multiplication descriptor.
dnnl_data_type_t
Data type specification.
Definition: dnnl_types.h:62
size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type)
Returns the size of data type.
dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(dnnl_memory_desc_t *memory_desc, const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims, const dnnl_dims_t offsets)
Initializes a memory descriptor for a region inside an area described by an existing memory descripto...
dnnl_format_tag_t
Memory format tag specification.
Definition: dnnl_types.h:164
dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, const int *permutation)
Initializes a memory descriptor by permuting axes in an existing one.
dnnl_status_t DNNL_API dnnl_memory_unmap_data(const_dnnl_memory_t memory, void *mapped_ptr)
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine, void *handle)
Creates a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_engine(const_dnnl_memory_t memory, dnnl_engine_t *engine)
Returns the engine of a memory object.
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, int ndims, const dnnl_dims_t dims)
Initializes a memory descriptor by reshaping an existing one.
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc)
Returns the memory descriptor for a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_data_handle(const_dnnl_memory_t memory, void **handle)
Returns memory object's data handle.
dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(dnnl_memory_t memory, void *handle, dnnl_stream_t stream)
Sets the underlying memory buffer.
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, const dnnl_dims_t strides)
Initializes a memory descriptor using dimensions and strides.
int64_t dnnl_dim_t
A type to describe tensor dimension.
Definition: dnnl_types.h:1444
dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory)
Destroys a memory object.
int DNNL_API dnnl_memory_desc_equal(const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs)
Compares two memory descriptors.
#define DNNL_MAX_NDIMS
Maximum number of dimensions a tensor can have.
Definition: dnnl_types.h:1412
dnnl_status_t DNNL_API dnnl_memory_map_data(const_dnnl_memory_t memory, void **mapped_ptr)
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
size_t DNNL_API dnnl_memory_desc_get_size(const dnnl_memory_desc_t *memory_desc)
Returns the size of a memory descriptor.
#define DNNL_MEMORY_ALLOCATE
Special pointer value that indicates that the library needs to allocate an underlying buffer for a me...
Definition: dnnl_types.h:1622
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, dnnl_format_tag_t tag)
Initializes a memory descriptor using dimensions and memory format tag.
@ dnnl_f16
16-bit/half-precision floating point.
Definition: dnnl_types.h:66
@ dnnl_bf16
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
Definition: dnnl_types.h:68
@ dnnl_f32
32-bit/single-precision floating point.
Definition: dnnl_types.h:70
@ dnnl_data_type_undef
Undefined data type, used for empty memory descriptors.
Definition: dnnl_types.h:64
@ dnnl_s8
8-bit signed integer.
Definition: dnnl_types.h:74
@ dnnl_s32
32-bit signed integer.
Definition: dnnl_types.h:72
@ dnnl_u8
8-bit unsigned integer.
Definition: dnnl_types.h:76
@ dnnl_abcdefhg
permuted 8D tensor
Definition: dnnl_types.h:216
@ dnnl_aBCdef2b4c2b
6D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:362
@ dnnl_abcdefghi
plain 9D tensor
Definition: dnnl_types.h:186
@ dnnl_acdeb
permuted 5D tensor
Definition: dnnl_types.h:199
@ dnnl_abcdefgh
plain 8D tensor
Definition: dnnl_types.h:185
@ dnnl_abcdefghikj
permuted 11D tensor
Definition: dnnl_types.h:219
@ dnnl_ab
plain 2D tensor
Definition: dnnl_types.h:178
@ dnnl_ABcd8b8a
4D tensor blocked by 1st and 2nd dimension with block size 8
Definition: dnnl_types.h:288
@ dnnl_cdba
permuted 4D tensor
Definition: dnnl_types.h:208
@ dnnl_abcdefghijkl
plain 12D tensor
Definition: dnnl_types.h:189
@ dnnl_aBcdef4b
6D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:364
@ dnnl_abcdegf
permuted 7D tensor
Definition: dnnl_types.h:215
@ dnnl_abcdfe
permuted 6D tensor
Definition: dnnl_types.h:214
@ dnnl_aBcd4b
4D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:263
@ dnnl_nCdhw16c
5D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcde16b
Definition: dnnl_types.h:732
@ dnnl_abcde
plain 5D tensor
Definition: dnnl_types.h:182
@ dnnl_decab
permuted 5D tensor
Definition: dnnl_types.h:211
@ dnnl_bca
permuted 3D tensor
Definition: dnnl_types.h:204
@ dnnl_aBcde4b
5D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:315
@ dnnl_aBc16b
3D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:229
@ dnnl_aBcdef16b
6D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:354
@ dnnl_aBCde2b4c2b
5D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:352
@ dnnl_aBc4b
3D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:235
@ dnnl_abcdefghijk
plain 11D tensor
Definition: dnnl_types.h:188
@ dnnl_bacde
permuted 5D tensor
Definition: dnnl_types.h:203
@ dnnl_aBcd16b
4D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:255
@ dnnl_cba
permuted 3D tensor
Definition: dnnl_types.h:207
@ dnnl_ba
permuted 2D tensor
Definition: dnnl_types.h:200
@ dnnl_ABcde2b8a4b
5D tensor blocked by 1st dimension with block size 8
Definition: dnnl_types.h:304
@ dnnl_abcd
plain 4D tensor
Definition: dnnl_types.h:180
@ dnnl_format_tag_undef
Undefined memory format tag.
Definition: dnnl_types.h:166
@ dnnl_nCdhw4c
5D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcde4b
Definition: dnnl_types.h:735
@ dnnl_defcab
permuted 6D tensor
Definition: dnnl_types.h:212
@ dnnl_abcdef
plain 6D tensor
Definition: dnnl_types.h:183
@ dnnl_nChw8c
4D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcd8b
Definition: dnnl_types.h:750
@ dnnl_a
plain 1D tensor
Definition: dnnl_types.h:177
@ dnnl_nChw4c
4D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcd4b
Definition: dnnl_types.h:747
@ dnnl_acbdef
permuted 6D tensor
Definition: dnnl_types.h:197
@ dnnl_acdb
permuted 4D tensor
Definition: dnnl_types.h:198
@ dnnl_aBcd8b
4D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:282
@ dnnl_aBc8b
3D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:245
@ dnnl_nCw4c
3D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBc4b
Definition: dnnl_types.h:759
@ dnnl_abcdefg
plain 7D tensor
Definition: dnnl_types.h:184
@ dnnl_aBcde8b
5D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:330
@ dnnl_nChw16c
4D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcd16b
Definition: dnnl_types.h:744
@ dnnl_abdfce
permuted 6D tensor
Definition: dnnl_types.h:424
@ dnnl_abdec
permuted 5D tensor
Definition: dnnl_types.h:194
@ dnnl_bacd
permuted 4D tensor
Definition: dnnl_types.h:202
@ dnnl_nCdhw8c
5D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcde8b
Definition: dnnl_types.h:738
@ dnnl_aBcde32b
5D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:313
@ dnnl_abced
permuted 5D tensor
Definition: dnnl_types.h:213
@ dnnl_bcda
permuted 4D tensor
Definition: dnnl_types.h:205
@ dnnl_acbde
permuted 5D tensor
Definition: dnnl_types.h:196
@ dnnl_aBCd2b4c2b
4D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:300
@ dnnl_abcdefgih
permuted 9D tensor
Definition: dnnl_types.h:217
@ dnnl_bcdea
permuted 5D tensor
Definition: dnnl_types.h:206
@ dnnl_abdefc
permuted 6D tensor
Definition: dnnl_types.h:425
@ dnnl_aBcde16b
5D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:306
@ dnnl_nCw8c
3D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBc8b
Definition: dnnl_types.h:762
@ dnnl_abdc
permuted 4D tensor
Definition: dnnl_types.h:193
@ dnnl_ABcde4b16a4b
5D tensor blocked by 1st dimension with block size 16
Definition: dnnl_types.h:302
@ dnnl_aBcd32b
4D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:261
@ dnnl_abcdefghijlk
permuted 12D tensor
Definition: dnnl_types.h:220
@ dnnl_format_tag_last
Just a sentinel, not real memory format tag.
Definition: dnnl_types.h:590
@ dnnl_abc
plain 3D tensor
Definition: dnnl_types.h:179
@ dnnl_bac
permuted 3D tensor
Definition: dnnl_types.h:201
@ dnnl_dcab
permuted 4D tensor
Definition: dnnl_types.h:209
@ dnnl_cdeba
permuted 5D tensor
Definition: dnnl_types.h:210
@ dnnl_acb
permuted 3D tensor
Definition: dnnl_types.h:195
@ dnnl_aBc32b
3D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:233
@ dnnl_abcdefghji
permuted 10D tensor
Definition: dnnl_types.h:218
@ dnnl_nCw16c
3D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBc16b
Definition: dnnl_types.h:756
@ dnnl_aBCdef2c8b4c
6D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:359
@ dnnl_abcdefghij
plain 10D tensor
Definition: dnnl_types.h:187
@ dnnl_format_tag_any
Undefined memory format tag.
Definition: dnnl_types.h:169
@ dnnl_blocked
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: dnnl_types.h:89
@ dnnl_format_kind_wino
Weights format used in 8bit Winograd convolution.
Definition: dnnl_types.h:91
@ dnnl_format_kind_any
Unspecified format kind.
Definition: dnnl_types.h:85
@ dnnl_format_kind_undef
Undefined memory format kind, used for empty memory descriptors.
Definition: dnnl_types.h:82
@ dnnl_format_kind_rnn_packed
Packed weights format used in RNN.
Definition: dnnl_types.h:93
dnnl_status_t DNNL_API dnnl_pooling_v2_backward_desc_init(dnnl_pooling_v2_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t dilation, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling v2 (pooling with dilation support) backward propagation primitiv...
dnnl_status_t DNNL_API dnnl_pooling_v2_forward_desc_init(dnnl_pooling_v2_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t dilation, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling v2 (pooling with dilation support) forward propagation primitive...
dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_prelu_forward_desc_init(dnnl_prelu_desc_t *prelu_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *weights_desc)
Initializes a descriptor for PReLU (leaky ReLU with trainable alpha parameter) forward propagation pr...
dnnl_status_t DNNL_API dnnl_prelu_backward_desc_init(dnnl_prelu_desc_t *prelu_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *diff_weights_desc)
Initializes a descriptor for PReLU (leaky ReLU with trainable alpha parameter) backward propagation p...
void set_primitive_cache_capacity(int capacity)
Sets a number of primitives that can be held in the primitive cache at a time.
Definition: dnnl.hpp:10983
dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity)
Sets a number of primitives that can be held in the primitive cache at a time.
dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity)
Returns the number of primitives that can be held in the primitive cache at the same time.
int get_primitive_cache_capacity()
Returns the number of primitives that can be held in the primitive cache at the same time.
Definition: dnnl.hpp:10975
dnnl_status_t DNNL_API dnnl_primitive_desc_query(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index, void *result)
Queries a primitive descriptor for various pieces of information.
#define DNNL_ARG_DST_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2433
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(dnnl_primitive_desc_iterator_t iterator)
Destroys a primitive descriptor iterator.
#define DNNL_ARG_WEIGHTS_LAYER
A special mnemonic for RNN weights applied to the layer input.
Definition: dnnl_types.h:2451
#define DNNL_ARG_DIFF_BIAS
Gradient (diff) of the bias tensor argument.
Definition: dnnl_types.h:2563
#define DNNL_ARG_DIFF_SRC_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition: dnnl_types.h:2509
#define DNNL_ARG_DIFF_SRC_LAYER
A special mnemonic for gradient (diff) of RNN input vector.
Definition: dnnl_types.h:2497
#define DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE
A special mnemonic for diff of RNN weights applied to the peephole weights.
Definition: dnnl_types.h:2554
#define DNNL_ARG_WEIGHTS_PROJECTION
A special mnemonic for RNN weights applied to the projection weights.
Definition: dnnl_types.h:2469
dnnl_normalization_flags_t
Flags for normalization primitives.
Definition: dnnl_types.h:1335
#define DNNL_ARG_DIFF_WEIGHTS_PROJECTION
A special mnemonic for diff of RNN weights applied to the projection weights.
Definition: dnnl_types.h:2560
const dnnl_memory_desc_t DNNL_API * dnnl_primitive_desc_query_md(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index)
Queries primitive descriptor for a memory descriptor.
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(const_dnnl_primitive_desc_t primitive_desc, const_dnnl_primitive_attr_t *attr)
Returns a constant reference to the attributes of a primitive descriptor.
#define DNNL_ARG_DIFF_WEIGHTS_ITER
A special mnemonic for diff of RNN weights applied to the recurrent input.
Definition: dnnl_types.h:2548
#define DNNL_ARG_DIFF_SRC_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2503
#define DNNL_ARG_DIFF_DST_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition: dnnl_types.h:2530
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args)
Executes a primitive.
#define DNNL_ARG_WEIGHTS_ITER
A special mnemonic for RNN weights applied to the recurrent input.
Definition: dnnl_types.h:2457
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(dnnl_primitive_desc_iterator_t iterator)
Advances the primitive descriptor iterator to point to the next available implementation.
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(dnnl_primitive_desc_t primitive_desc)
Destroys a primitive descriptor.
const void * const_dnnl_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: dnnl_types.h:1634
const_dnnl_primitive_desc_t get_primitive_desc() const
Returns the C API primitive descriptor of the underlying C API primitive.
Definition: dnnl.hpp:368
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(const_dnnl_primitive_t primitive, const_dnnl_primitive_desc_t *primitive_desc)
Retrieves a constant reference to the primitive descriptor of a given primitive.
#define DNNL_ARG_DST_ITER_C
A special mnemonic for LSTM output recurrent cell state vector.
Definition: dnnl_types.h:2439
#define DNNL_ARG_SRC_ITER_C
A special mnemonic for RNN input recurrent cell state vector.
Definition: dnnl_types.h:2416
query
Primitive descriptor query specification.
Definition: dnnl.hpp:771
#define DNNL_ARG_FROM
A special mnemonic for reorder source argument.
Definition: dnnl_types.h:2404
dnnl_alg_kind_t
Kinds of algorithms.
Definition: dnnl_types.h:1185
dnnl_primitive_kind_t
Kinds of primitives.
Definition: dnnl_types.h:1131
dnnl_query_t
Primitive descriptor query specification.
Definition: dnnl_types.h:2639
dnnl_primitive_kind_t convert_to_c(primitive::kind akind)
Converts primitive kind enum value from C++ API to C API type.
Definition: dnnl.hpp:364
struct dnnl_primitive_desc * dnnl_primitive_desc_t
A primitive descriptor handle.
Definition: dnnl_types.h:2305
#define DNNL_ARG_WEIGHTS_PEEPHOLE
A special mnemonic for RNN weights applied to the peephole weights.
Definition: dnnl_types.h:2463
kind get_kind() const
Returns the kind of the primitive.
Definition: dnnl.hpp:375
#define DNNL_ARG_SRC_LAYER
A special mnemonic for RNN input vector.
Definition: dnnl_types.h:2401
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive)
Destroys a primitive.
#define DNNL_ARG_DIFF_WEIGHTS_LAYER
A special mnemonic for diff of RNN weights applied to the layer input.
Definition: dnnl_types.h:2542
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, const_dnnl_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator.
#define DNNL_ARG_DST_LAYER
A special mnemonic for RNN output vector. An alias for DNNL_ARG_DST_0.
Definition: dnnl_types.h:2427
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc)
Creates a primitive.
#define DNNL_ARG_BIAS
Bias tensor argument.
Definition: dnnl_types.h:2472
normalization_flags
Flags for normalization primitives.
Definition: dnnl.hpp:631
#define DNNL_ARG_DIFF_DST_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2524
dnnl_prop_kind_t
Kinds of propagation.
Definition: dnnl_types.h:1104
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(dnnl_primitive_desc_t *primitive_desc, const_dnnl_primitive_desc_t existing_primitive_desc)
Clones a primitive descriptor.
#define DNNL_ARG_SRC_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2410
dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(const_dnnl_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor from a primitive descriptor iterator.
#define DNNL_ARG_TO
A special mnemonic for reorder destination argument.
Definition: dnnl_types.h:2425
#define DNNL_ARG_DIFF_DST_LAYER
A special mnemonic for gradient (diff) of RNN output vector.
Definition: dnnl_types.h:2518
@ dnnl_use_scale
Use scale parameter.
Definition: dnnl_types.h:1391
@ dnnl_fuse_norm_relu
Fuse with ReLU.
Definition: dnnl_types.h:1383
@ dnnl_normalization_flags_none
Use no normalization flags.
Definition: dnnl_types.h:1344
@ dnnl_use_scaleshift
Use scale and shift parameters.
Definition: dnnl_types.h:1370
@ dnnl_use_global_stats
Use global statistics.
Definition: dnnl_types.h:1357
@ dnnl_use_shift
Use shift parameter.
Definition: dnnl_types.h:1400
@ batch_normalization_d
batch normalization descriptor
@ weights_md
weights memory descriptor desc
@ memory_consumption_s64
memory required for scratchpad (bytes)
@ shuffle_d
shuffle descriptor
@ deconvolution_d
deconvolution descriptor
@ impl_info_str
implementation name
@ diff_weights_md
weights gradient (diff) memory desc
@ workspace_md
workspace memory desc
@ reduction_d
reduction descriptor
@ eltwise_d
eltwise descriptor
@ matmul_d
matmul descriptor
@ rnn_d
rnn descriptor
@ softmax_d
softmax descriptor
@ num_of_outputs_s32
number of outputs expected
@ primitive_kind
primitive kind
@ dst_md
destination memory desc
@ scratchpad_engine
scratchpad engine
@ reorder_src_engine
reorder source engine
@ op_d
operation descriptor
@ layer_normalization_d
layer normalization descriptor
@ logsoftmax_d
logsoftmax descriptor
@ pooling_d
pooling descriptor
@ num_of_inputs_s32
number of inputs expected
@ diff_src_md
source gradient (diff) memory desc
@ src_md
source memory desc
@ scratchpad_md
scratchpad memory desc
@ reorder_dst_engine
reorder destination engine
@ engine
execution engine
@ convolution_d
convolution descriptor
@ time_estimate_f64
runtime estimation (seconds), unimplemented
@ binary_d
binary descriptor
@ diff_dst_md
destination gradient (diff) memory desc
@ exec_arg_md
memory desc of an execute argument
@ inner_product_d
inner product descriptor
@ lrn_d
lrn descriptor
@ undef
no query
@ resampling_d
resampling descriptor
@ dnnl_pooling_avg_exclude_padding
Average pooling exclude padding.
Definition: dnnl_types.h:1265
@ dnnl_eltwise_clip
Eltwise: clip.
Definition: dnnl_types.h:1231
@ dnnl_eltwise_tanh_use_dst_for_bwd
Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
Definition: dnnl_types.h:1249
@ dnnl_eltwise_logsigmoid
Eltwise: logsigmoid.
Definition: dnnl_types.h:1241
@ dnnl_pooling_avg
Average pooling (alias for dnnl_pooling_avg_exclude_padding)
Definition: dnnl_types.h:1267
@ dnnl_eltwise_gelu_tanh
Eltwise: gelu.
Definition: dnnl_types.h:1223
@ dnnl_resampling_linear
Linear Resampling Method.
Definition: dnnl_types.h:1313
@ dnnl_eltwise_sqrt
Eltwise: square root.
Definition: dnnl_types.h:1208
@ dnnl_binary_min
Binary min.
Definition: dnnl_types.h:1293
@ dnnl_reduction_norm_lp_sum
Reduction using lp norm.
Definition: dnnl_types.h:1327
@ dnnl_eltwise_abs
Eltwise: abs.
Definition: dnnl_types.h:1206
@ dnnl_reduction_norm_lp_power_p_max
Reduction using lp norm without final pth-root.
Definition: dnnl_types.h:1329
@ dnnl_reduction_min
Reduction using min.
Definition: dnnl_types.h:1317
@ dnnl_binary_ne
Binary not equal.
Definition: dnnl_types.h:1309
@ dnnl_eltwise_sqrt_use_dst_for_bwd
Eltwise: square root (dst for backward)
Definition: dnnl_types.h:1253
@ dnnl_eltwise_exp
Eltwise: exponent.
Definition: dnnl_types.h:1218
@ dnnl_eltwise_square
Eltwise: square.
Definition: dnnl_types.h:1204
@ dnnl_eltwise_gelu
Eltwise: tanh-based gelu (alias for dnnl_eltwise_gelu_tanh)
Definition: dnnl_types.h:1225
@ dnnl_convolution_winograd
Winograd convolution.
Definition: dnnl_types.h:1190
@ dnnl_eltwise_clip_v2_use_dst_for_bwd
Eltwise: clip version 2 (dst for backward)
Definition: dnnl_types.h:1259
@ dnnl_lrn_across_channels
Local response normalization (LRN) across multiple channels.
Definition: dnnl_types.h:1269
@ dnnl_binary_sub
Binary sub.
Definition: dnnl_types.h:1297
@ dnnl_deconvolution_direct
Direct deconvolution.
Definition: dnnl_types.h:1194
@ dnnl_binary_eq
Binary equal.
Definition: dnnl_types.h:1307
@ dnnl_eltwise_relu
Eltwise: ReLU.
Definition: dnnl_types.h:1198
@ dnnl_convolution_auto
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: dnnl_types.h:1192
@ dnnl_eltwise_swish
Eltwise: swish.
Definition: dnnl_types.h:1227
@ dnnl_vanilla_rnn
RNN cell.
Definition: dnnl_types.h:1273
@ dnnl_eltwise_gelu_erf
Eltwise: erf-based gelu.
Definition: dnnl_types.h:1237
@ dnnl_vanilla_lstm
LSTM cell.
Definition: dnnl_types.h:1275
@ dnnl_eltwise_elu
Eltwise: exponential linear unit (elu)
Definition: dnnl_types.h:1202
@ dnnl_vanilla_gru
GRU cell.
Definition: dnnl_types.h:1277
@ dnnl_lbr_gru
GRU cell with linear before reset.
Definition: dnnl_types.h:1285
@ dnnl_eltwise_tanh
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: dnnl_types.h:1200
@ dnnl_convolution_direct
Direct convolution.
Definition: dnnl_types.h:1188
@ dnnl_eltwise_soft_relu
Eltwise: soft_relu.
Definition: dnnl_types.h:1214
@ dnnl_binary_ge
Binary greater or equal.
Definition: dnnl_types.h:1299
@ dnnl_eltwise_log
Eltwise: natural logarithm.
Definition: dnnl_types.h:1229
@ dnnl_eltwise_clip_v2
Eltwise: clip version 2.
Definition: dnnl_types.h:1233
@ dnnl_lrn_within_channel
LRN within a single channel.
Definition: dnnl_types.h:1271
@ dnnl_eltwise_elu_use_dst_for_bwd
Eltwise: exponential linear unit (elu) (dst for backward)
Definition: dnnl_types.h:1251
@ dnnl_deconvolution_winograd
Winograd deconvolution.
Definition: dnnl_types.h:1196
@ dnnl_eltwise_hardswish
Eltwise: hardswish.
Definition: dnnl_types.h:1245
@ dnnl_reduction_mul
Reduction using mul.
Definition: dnnl_types.h:1321
@ dnnl_eltwise_pow
Eltwise: pow.
Definition: dnnl_types.h:1235
@ dnnl_eltwise_relu_use_dst_for_bwd
Eltwise: ReLU (dst for backward)
Definition: dnnl_types.h:1247
@ dnnl_binary_gt
Binary greater than.
Definition: dnnl_types.h:1301
@ dnnl_reduction_max
Reduction using max.
Definition: dnnl_types.h:1315
@ dnnl_eltwise_logistic
Eltwise: logistic.
Definition: dnnl_types.h:1216
@ dnnl_binary_lt
Binary less than.
Definition: dnnl_types.h:1305
@ dnnl_pooling_avg_include_padding
Average pooling include padding.
Definition: dnnl_types.h:1263
@ dnnl_reduction_mean
Reduction using mean.
Definition: dnnl_types.h:1323
@ dnnl_binary_le
Binary less or equal.
Definition: dnnl_types.h:1303
@ dnnl_pooling_max
Max pooling.
Definition: dnnl_types.h:1261
@ dnnl_eltwise_logistic_use_dst_for_bwd
Eltwise: logistic (dst for backward)
Definition: dnnl_types.h:1255
@ dnnl_binary_add
Binary add.
Definition: dnnl_types.h:1287
@ dnnl_binary_div
Binary div.
Definition: dnnl_types.h:1295
@ dnnl_reduction_norm_lp_max
Reduction using lp norm.
Definition: dnnl_types.h:1325
@ dnnl_reduction_norm_lp_power_p_sum
Reduction using lp norm without final pth-root.
Definition: dnnl_types.h:1331
@ dnnl_eltwise_round
Eltwise: round.
Definition: dnnl_types.h:1239
@ dnnl_binary_mul
Binary mul.
Definition: dnnl_types.h:1289
@ dnnl_eltwise_mish
Eltwise: mish.
Definition: dnnl_types.h:1243
@ dnnl_reduction_sum
Reduction using sum.
Definition: dnnl_types.h:1319
@ dnnl_eltwise_exp_use_dst_for_bwd
Eltwise: exp (dst for backward)
Definition: dnnl_types.h:1257
@ dnnl_eltwise_bounded_relu
Eltwise: bounded_relu.
Definition: dnnl_types.h:1212
@ dnnl_eltwise_linear
Eltwise: linear.
Definition: dnnl_types.h:1210
@ dnnl_resampling_nearest
Nearest Neighbor Resampling Method.
Definition: dnnl_types.h:1311
@ dnnl_binary_max
Binary max.
Definition: dnnl_types.h:1291
@ dnnl_binary
A binary primitive.
Definition: dnnl_types.h:1165
@ dnnl_concat
A (out-of-place) concat primitive.
Definition: dnnl_types.h:1139
@ dnnl_reorder
A reorder primitive.
Definition: dnnl_types.h:1135
@ dnnl_convolution
A convolution primitive.
Definition: dnnl_types.h:1143
@ dnnl_inner_product
An inner product primitive.
Definition: dnnl_types.h:1159
@ dnnl_resampling
A resampling primitive.
Definition: dnnl_types.h:1171
@ dnnl_batch_normalization
A batch normalization primitive.
Definition: dnnl_types.h:1155
@ dnnl_undefined_primitive
Undefined primitive.
Definition: dnnl_types.h:1133
@ dnnl_sum
A sum primitive.
Definition: dnnl_types.h:1141
@ dnnl_pooling_v2
A pooling version 2 primitive (pooling with dilation support).
Definition: dnnl_types.h:1173
@ dnnl_layer_normalization
A layer normalization primitive.
Definition: dnnl_types.h:1157
@ dnnl_prelu
A PReLU primitive.
Definition: dnnl_types.h:1177
@ dnnl_eltwise
An element-wise primitive.
Definition: dnnl_types.h:1147
@ dnnl_matmul
A matrix multiplication primitive.
Definition: dnnl_types.h:1169
@ dnnl_shuffle
A shuffle primitive.
Definition: dnnl_types.h:1137
@ dnnl_logsoftmax
A logsoftmax primitive.
Definition: dnnl_types.h:1167
@ dnnl_pooling
A pooling primitive.
Definition: dnnl_types.h:1151
@ dnnl_deconvolution
A deconvolution primitive.
Definition: dnnl_types.h:1145
@ dnnl_softmax
A softmax primitive.
Definition: dnnl_types.h:1149
@ dnnl_rnn
A rnn primitive.
Definition: dnnl_types.h:1161
@ dnnl_reduction
A reduction primitive.
Definition: dnnl_types.h:1175
@ dnnl_lrn
An LRN primitive.
Definition: dnnl_types.h:1153
@ dnnl_query_resampling_d
resampling descriptor
Definition: dnnl_types.h:2682
@ dnnl_query_num_of_outputs_s32
number of outputs expected
Definition: dnnl_types.h:2646
@ dnnl_query_convolution_d
convolution descriptor
Definition: dnnl_types.h:2667
@ dnnl_query_weights_md
weights memory descriptor desc
Definition: dnnl_types.h:2691
@ dnnl_query_src_md
source memory desc
Definition: dnnl_types.h:2689
@ dnnl_query_softmax_d
softmax descriptor
Definition: dnnl_types.h:2671
@ dnnl_query_binary_d
binary descriptor
Definition: dnnl_types.h:2679
@ dnnl_query_workspace_md
workspace memory desc
Definition: dnnl_types.h:2695
@ dnnl_query_matmul_d
matrix multiplication (matmul) descriptor
Definition: dnnl_types.h:2681
@ dnnl_query_num_of_inputs_s32
number of inputs expected
Definition: dnnl_types.h:2645
@ dnnl_query_op_d
op descriptor
Definition: dnnl_types.h:2666
@ dnnl_query_diff_src_md
source gradient memory desc
Definition: dnnl_types.h:2690
@ dnnl_query_scratchpad_md
scratchpad memory desc
Definition: dnnl_types.h:2696
@ dnnl_query_shuffle_d
shuffle descriptor
Definition: dnnl_types.h:2669
@ dnnl_query_memory_consumption_s64
memory consumption – extra
Definition: dnnl_types.h:2649
@ dnnl_query_inner_product_d
inner product descriptor
Definition: dnnl_types.h:2676
@ dnnl_query_deconvolution_d
deconvolution descriptor
Definition: dnnl_types.h:2668
@ dnnl_query_primitive_kind
primitive kind
Definition: dnnl_types.h:2643
@ dnnl_query_batch_normalization_d
batch normalization descriptor
Definition: dnnl_types.h:2674
@ dnnl_query_impl_info_str
for creating scratchpad memory
Definition: dnnl_types.h:2657
@ dnnl_query_time_estimate_f64
runtime estimation (seconds)
Definition: dnnl_types.h:2648
@ dnnl_query_eltwise_d
eltwise descriptor
Definition: dnnl_types.h:2670
@ dnnl_query_diff_weights_md
weights grad. memory desc
Definition: dnnl_types.h:2692
@ dnnl_query_reduction_d
reduction descriptor
Definition: dnnl_types.h:2684
@ dnnl_query_reorder_dst_engine
destination engine
Definition: dnnl_types.h:2660
@ dnnl_query_reorder_src_engine
source engine
Definition: dnnl_types.h:2659
@ dnnl_query_scratchpad_engine
(scratch) memory, additional to all inputs and outputs memory (bytes)
Definition: dnnl_types.h:2654
@ dnnl_query_undef
no query
Definition: dnnl_types.h:2640
@ dnnl_query_prop_kind
propagation kind
Definition: dnnl_types.h:2662
@ dnnl_query_pooling_d
pooling descriptor
Definition: dnnl_types.h:2672
@ dnnl_query_exec_arg_md
memory desc of an execute argument
Definition: dnnl_types.h:2697
@ dnnl_query_engine
execution engine
Definition: dnnl_types.h:2642
@ dnnl_query_rnn_d
rnn descriptor
Definition: dnnl_types.h:2677
@ dnnl_query_layer_normalization_d
layer normalization descriptor
Definition: dnnl_types.h:2675
@ dnnl_query_lrn_d
lrn descriptor
Definition: dnnl_types.h:2673
@ dnnl_query_dst_md
destination memory desc
Definition: dnnl_types.h:2693
@ dnnl_query_diff_dst_md
destination grad. memory desc
Definition: dnnl_types.h:2694
@ dnnl_query_logsoftmax_d
logsoftmax descriptor
Definition: dnnl_types.h:2680
@ use_scale_shift
Use scale and shift parameters.
@ none
Use no normalization flags.
@ fuse_norm_relu
Fuse normalization with ReLU.
@ use_global_stats
Use global statistics.
@ use_scale
Use scale parameter.
@ use_shift
Use shift parameter.
@ dnnl_backward_weights
Backward weights propagation.
Definition: dnnl_types.h:1124
@ dnnl_forward_inference
Forward data propagation (inference mode).
Definition: dnnl_types.h:1114
@ dnnl_backward
Backward propagation (with respect to all parameters).
Definition: dnnl_types.h:1120
@ dnnl_backward_data
Backward data propagation.
Definition: dnnl_types.h:1122
@ dnnl_prop_kind_undef
Undefined propagation type.
Definition: dnnl_types.h:1107
@ dnnl_forward
Forward data propagation (alias for dnnl_forward_training).
Definition: dnnl_types.h:1118
@ dnnl_forward_training
Forward data propagation (training mode).
Definition: dnnl_types.h:1110
@ dnnl_backward_bias
Backward bias propagation.
Definition: dnnl_types.h:1126
@ dnnl_forward_scoring
Forward data propagation (alias for dnnl_forward_inference).
Definition: dnnl_types.h:1116
dnnl_status_t DNNL_API dnnl_reduction_desc_init(dnnl_reduction_desc_t *desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, float p, float eps)
Initializes a descriptor for a reduction primitive.
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(dnnl_primitive_desc_t *reorder_primitive_desc, const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine, const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine, const_dnnl_primitive_attr_t attr)
Creates a primitive descriptor for a reorder primitive.
dnnl_status_t DNNL_API dnnl_resampling_backward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes a descriptor for resampling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_resampling_forward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a resampling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU backward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU forward propagation primitive.
rnn_direction
A direction of RNN primitive execution.
Definition: dnnl.hpp:738
dnnl_rnn_flags_t
Flags for RNN cell.
Definition: dnnl_types.h:2046
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM backward propagation primitive.
dnnl_rnn_direction_t
A direction of RNN primitive execution.
Definition: dnnl_types.h:2052
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_weights_projection_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or with out recurrent project...
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for LSTM forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or without recurrent projecti...
rnn_flags
RNN cell flags.
Definition: dnnl.hpp:684
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) forward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU backward propagation primitive.
@ unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
@ unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
@ bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
@ unidirectional
Alias for dnnl::rnn_direction::unidirectional_left2right.
@ bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
@ dnnl_rnn_flags_undef
Undefined RNN flags.
Definition: dnnl_types.h:2048
@ dnnl_unidirectional
Alias for dnnl_unidirectional_left2right.
Definition: dnnl_types.h:2064
@ dnnl_bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
Definition: dnnl_types.h:2059
@ dnnl_bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
Definition: dnnl_types.h:2062
@ dnnl_unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
Definition: dnnl_types.h:2054
@ dnnl_unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
Definition: dnnl_types.h:2056
@ undef
Undefined RNN flags.
dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable)
Configures dumping of JIT-generated code.
status set_max_cpu_isa(cpu_isa isa)
Sets the maximal ISA the library can dispatch to on the CPU.
Definition: dnnl.hpp:10936
dnnl_status_t DNNL_API dnnl_set_verbose(int level)
Configures verbose output to stdout.
status set_jit_dump(int enable)
Configures dumping of JIT-generated code.
Definition: dnnl.hpp:10895
status set_cpu_isa_hints(cpu_isa_hints isa_hints)
Sets the hints flag for the CPU ISA.
Definition: dnnl.hpp:10955
dnnl_cpu_isa_t
CPU instruction set flags.
Definition: dnnl_types.h:2789
status set_verbose(int level)
Configures verbose output to stdout.
Definition: dnnl.hpp:10885
cpu_isa get_effective_cpu_isa()
Gets the maximal ISA the library can dispatch to on the CPU.
Definition: dnnl.hpp:10942
dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa)
Sets the maximal ISA the library can dispatch to on the CPU.
dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
status set_jit_profiling_jitdumpdir(const std::string &dir)
Sets JIT dump output path.
Definition: dnnl.hpp:10905
const dnnl_version_t DNNL_API * dnnl_version(void)
Returns library version information.
status
Status values returned by the library functions.
Definition: dnnl.hpp:10867
cpu_isa_hints get_cpu_isa_hints()
Gets the ISA specific hints that library can follow.
Definition: dnnl.hpp:10961
status set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
Definition: dnnl.hpp:10900
const version_t * version()
Returns library version information.
Definition: dnnl.hpp:10890
cpu_isa
CPU instruction set flags.
Definition: dnnl.hpp:10910
dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void)
Gets the maximal ISA the library can dispatch to on the CPU.
dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints)
Sets the hints flag for the CPU ISA.
dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void)
Gets the ISA specific hints that library can follow.
dnnl_cpu_isa_hints_t
CPU ISA hints flags.
Definition: dnnl_types.h:2835
cpu_isa_hints
CPU ISA hints flags.
Definition: dnnl.hpp:10947
dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir)
Sets JIT dump output path.
@ dnnl_cpu_isa_avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
Definition: dnnl_types.h:2804
@ dnnl_cpu_isa_avx
Intel Advanced Vector Extensions (Intel AVX)
Definition: dnnl_types.h:2797
@ dnnl_cpu_isa_avx512_core_amx
Intel AVX-512, Intel DL Boost and bfloat16 support and Intel AMX with 8-bit integer and bfloat16 supp...
Definition: dnnl_types.h:2827
@ dnnl_cpu_isa_avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
Definition: dnnl_types.h:2817
@ dnnl_cpu_isa_avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
Definition: dnnl_types.h:2800
@ dnnl_cpu_isa_all
Any ISA (excepting those listed as initial support)
Definition: dnnl_types.h:2791
@ dnnl_cpu_isa_avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
Definition: dnnl_types.h:2812
@ dnnl_cpu_isa_sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
Definition: dnnl_types.h:2794
@ dnnl_cpu_isa_avx2_vnni
Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support.
Definition: dnnl_types.h:2830
@ dnnl_cpu_isa_avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
Definition: dnnl_types.h:2822
@ dnnl_cpu_isa_avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
Definition: dnnl_types.h:2808
@ not_required
Queried element is not required for given primitive.
@ invalid_arguments
The operation failed because of incorrect function arguments.
@ success
The operation was successful.
@ unimplemented
The operation failed because requested functionality is not implemented.
@ runtime_error
Primitive or engine failed on execution.
@ out_of_memory
The operation failed due to an out-of-memory condition.
@ iterator_ends
Primitive iterator passed over last primitive descriptor.
@ avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
@ avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
@ avx2_vnni
Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support.
@ avx
Intel Advanced Vector Extensions (Intel AVX)
@ all
Any ISA (excepting those listed as initial support)
@ avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
@ avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
@ sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
@ avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
@ avx512_core_amx
Intel AVX-512, Intel DL Boost and bfloat16 support and Intel AMX with 8-bit integer and bfloat16 supp...
@ avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
@ dnnl_cpu_isa_no_hints
No hints (use default features)
Definition: dnnl_types.h:2837
@ dnnl_cpu_isa_prefer_ymm
Prefer to exclusively use Ymm registers for computations.
Definition: dnnl_types.h:2840
@ no_hints
No hints (use default features)
@ prefer_ymm
Prefer to exclusively use Ymm registers for computations.
dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle forward propagation primitive.
dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, const dnnl_memory_desc_t *diff_data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(dnnl_softmax_desc_t *softmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax forward propagation primitive.
dnnl_stream_flags_t
Stream flags.
Definition: dnnl_types.h:2711
dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream)
Waits for all primitives in the execution stream to finish computations.
dnnl_status_t DNNL_API dnnl_stream_get_engine(const_dnnl_stream_t stream, dnnl_engine_t *engine)
Returns the engine of a stream object.
dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream)
Destroys an execution stream.
dnnl_status_t DNNL_API dnnl_stream_create(dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags)
Creates an execution stream.
@ dnnl_stream_out_of_order
Out-of-order execution.
Definition: dnnl_types.h:2715
@ dnnl_stream_default_flags
Default stream configuration.
Definition: dnnl_types.h:2717
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(dnnl_primitive_desc_t *sum_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, const float *scales, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an (out-of-place) sum primitive.
dnnl_status_t
Status values returned by the library functions.
Definition: dnnl_types.h:39
@ dnnl_iterator_ends
Primitive iterator passed over last primitive descriptor.
Definition: dnnl_types.h:49
@ dnnl_runtime_error
Primitive or engine failed on execution.
Definition: dnnl_types.h:51
@ dnnl_unimplemented
The operation failed because requested functionality is not implemented.
Definition: dnnl_types.h:47
@ dnnl_out_of_memory
The operation failed due to an out-of-memory condition.
Definition: dnnl_types.h:43
@ dnnl_success
The operation was successful.
Definition: dnnl_types.h:41
@ dnnl_invalid_arguments
The operation failed because of incorrect function arguments.
Definition: dnnl_types.h:45
@ dnnl_not_required
Queried element is not required for given primitive.
Definition: dnnl_types.h:53
oneDNN namespace
Definition: dnnl.hpp:74
oneAPI namespace
Definition: dnnl.hpp:11059
C API.
Descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6683
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for backward propagation.
Definition: dnnl.hpp:6698
Primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6712
primitive_desc(const desc &adesc, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6729
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6772
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization backward propagation primitive from a C A...
Definition: dnnl.hpp:6762
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6797
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6778
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6792
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6749
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6775
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6769
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:6781
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:6784
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6789
Batch normalization backward propagation primitive.
Definition: dnnl.hpp:6681
batch_normalization_backward()=default
Default constructor. Produces an empty object.
batch_normalization_backward(const primitive_desc &pd)
Constructs a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6806
Descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6554
desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for forward propagation.
Definition: dnnl.hpp:6571
Primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6584
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6632
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6638
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6645
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6598
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6614
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6641
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6649
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization forward propagation primitive from a C AP...
Definition: dnnl.hpp:6625
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6635
Batch normalization forward propagation primitive.
Definition: dnnl.hpp:6552
batch_normalization_forward()=default
Default constructor. Produces an empty object.
batch_normalization_forward(const primitive_desc &pd)
Constructs a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6677
Descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9832
desc()=default
Default constructor. Produces an empty object.
dnnl_binary_desc_t data
Underlying C operation descriptor.
Definition: dnnl.hpp:9834
desc(algorithm aalgorithm, const memory::desc &src0, const memory::desc &src1, const memory::desc &dst)
Constructs a descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9846
Primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9857
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9885
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:9898
memory::desc src0_desc() const
Returns the memory descriptor for source #0.
Definition: dnnl.hpp:9901
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a binary primitive from a C API primitive descriptor that must ...
Definition: dnnl.hpp:9894
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:9907
memory::desc src1_desc() const
Returns the memory descriptor for source #1.
Definition: dnnl.hpp:9904
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9870
Elementwise binary operator primitive.
Definition: dnnl.hpp:9830
binary()=default
Default constructor. Produces an empty object.
binary(const primitive_desc &pd)
Constructs an elementwise binary operation primitive.
Definition: dnnl.hpp:9916
Primitive descriptor for a concat primitive.
Definition: dnnl.hpp:3773
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3842
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for concat primitive from a C API primitive descriptor which must h...
Definition: dnnl.hpp:3835
primitive_desc(const memory::desc &dst, int concat_dimension, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition: dnnl.hpp:3789
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(int concat_dimension, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition: dnnl.hpp:3816
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3839
Tensor concatenation (concat) primitive.
Definition: dnnl.hpp:3771
concat()=default
Default constructor. Produces an empty object.
concat(const primitive_desc &pd)
Constructs a concatenation primitive.
Definition: dnnl.hpp:3850
Descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4314
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for dilated convolution backward propagation primitive.
Definition: dnnl.hpp:4385
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4342
Primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4406
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4464
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:4467
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution backward propagation primitive from a C API primi...
Definition: dnnl.hpp:4456
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4423
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:4461
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4443
Convolution backward propagation primitive.
Definition: dnnl.hpp:4311
convolution_backward_data()=default
Default constructor. Produces an empty object.
convolution_backward_data(const primitive_desc &pd)
Constructs a convolution backward propagation primitive.
Definition: dnnl.hpp:4476
Descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4482
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive without bias.
Definition: dnnl.hpp:4555
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive with bias.
Definition: dnnl.hpp:4600
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive without bias.
Definition: dnnl.hpp:4647
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive with bias.
Definition: dnnl.hpp:4512
Primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4668
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:4735
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:4724
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4703
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution weights gradient primitive from a C API primitive...
Definition: dnnl.hpp:4716
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4721
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4684
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:4729
Convolution weights gradient primitive.
Definition: dnnl.hpp:4480
convolution_backward_weights()=default
Default constructor. Produces an empty object.
convolution_backward_weights(const primitive_desc &pd)
Constructs a convolution weights gradient primitive.
Definition: dnnl.hpp:4746
Descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4041
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive with bias.
Definition: dnnl.hpp:4169
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive without bias.
Definition: dnnl.hpp:4120
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive without bias.
Definition: dnnl.hpp:4218
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive with bias.
Definition: dnnl.hpp:4074
Primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4239
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4253
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4269
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4286
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution forward propagation primitive from a C API primit...
Definition: dnnl.hpp:4280
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:4298
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4289
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:4292
Convolution forward propagation primitive.
Definition: dnnl.hpp:4039
convolution_forward(const primitive_desc &pd)
Constructs a convolution forward propagation primitive.
Definition: dnnl.hpp:4307
convolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5027
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution backward propagation primitive.
Definition: dnnl.hpp:5096
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5054
Primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5117
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5154
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:5175
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5178
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:5172
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5134
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution backward propagation primitive from a C API pri...
Definition: dnnl.hpp:5167
Deconvolution backward propagation primitive.
Definition: dnnl.hpp:5025
deconvolution_backward_data()=default
Default constructor. Produces an empty object.
deconvolution_backward_data(const primitive_desc &pd)
Constructs a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5187
Descriptor for a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5193
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive without bias.
Definition: dnnl.hpp:5354
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive with bias.
Definition: dnnl.hpp:5308
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive without bias.
Definition: dnnl.hpp:5264
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive with bias.
Definition: dnnl.hpp:5222
Primitive descriptor for a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5375
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5430
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5438
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution weights gradient primitive from a C API primiti...
Definition: dnnl.hpp:5425
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition: dnnl.hpp:5412
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition: dnnl.hpp:5392
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:5433
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:5441
primitive_desc()=default
Default constructor. Produces an empty object.
Deconvolution weights gradient primitive.
Definition: dnnl.hpp:5191
deconvolution_backward_weights()=default
Default constructor. Produces an empty object.
deconvolution_backward_weights(const primitive_desc &pd)
Constructs a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5452
Descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4762
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive with bias.
Definition: dnnl.hpp:4887
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive without bias.
Definition: dnnl.hpp:4839
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive with bias.
Definition: dnnl.hpp:4794
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive without bias.
Definition: dnnl.hpp:4935
Primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4956
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution forward propagation primitive from a C API prim...
Definition: dnnl.hpp:4997
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5009
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5003
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4986
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4970
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:5012
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:5006
Deconvolution forward propagation primitive.
Definition: dnnl.hpp:4760
deconvolution_forward(const primitive_desc &pd)
Constructs a deconvolution forward propagation primitive.
Definition: dnnl.hpp:5021
deconvolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6021
desc(algorithm aalgorithm, const memory::desc &diff_data_desc, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6035
Primitive descriptor for eltwise backward propagation.
Definition: dnnl.hpp:6048
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6106
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6065
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6085
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6103
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:6109
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise backward propagation primitive from a C API primitiv...
Definition: dnnl.hpp:6098
Elementwise unary operation backward propagation primitive.
Definition: dnnl.hpp:6019
eltwise_backward()=default
Default constructor. Produces an empty object.
eltwise_backward(const primitive_desc &pd)
Constructs an eltwise backward propagation primitive.
Definition: dnnl.hpp:6118
Descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5928
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5943
Primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5956
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5986
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6006
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6003
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5970
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise forward propagation primitive from a C API primitive...
Definition: dnnl.hpp:5997
Elementwise unary operation forward propagation primitive.
Definition: dnnl.hpp:5926
eltwise_forward(const primitive_desc &pd)
Constructs an eltwise forward propagation primitive.
Definition: dnnl.hpp:6015
eltwise_forward()=default
Default constructor. Produces an empty object.
An execution engine.
Definition: dnnl.hpp:895
static engine query(const primitive_desc &pd)
Returns the engine of a primitive descriptor.
Definition: dnnl.hpp:964
kind
Kinds of engines.
Definition: dnnl.hpp:900
@ gpu
GPU engine.
@ any
An unspecified engine.
@ cpu
CPU engine.
engine(kind akind, size_t index)
Constructs an engine.
Definition: dnnl.hpp:928
engine()=default
Constructs an empty engine.
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: dnnl.hpp:919
engine(const handle< dnnl_primitive_desc_t > &pd)
Constructs an engine based on a primitive from the primitive descriptor pd by querying its engine.
Definition: dnnl.hpp:940
kind get_kind() const
Returns the kind of the engine.
Definition: dnnl.hpp:951
oneDNN exception class.
Definition: dnnl.hpp:84
error(dnnl_status_t status, const char *message)
Constructs an instance of an exception class.
Definition: dnnl.hpp:92
static void wrap_c_api(dnnl_status_t status, const char *message)
A convenience function for wrapping calls to C API functions.
Definition: dnnl.hpp:103
const char * what() const noexcept override
Returns the explanatory string.
Definition: dnnl.hpp:96
Descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9079
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9126
Primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9160
primitive_desc(const desc &adesc, const engine &aengine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9176
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:9262
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9234
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9221
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9218
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:9267
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9226
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9231
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:9277
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:9272
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU backward propagation primitive from a C API primitive des...
Definition: dnnl.hpp:9208
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9213
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9242
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9195
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:9247
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:9252
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:9257
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9239
GRU backward propagation primitive.
Definition: dnnl.hpp:9077
gru_backward()=default
Default constructor. Produces an empty object.
gru_backward(const primitive_desc &pd)
Constructs a GRU backward propagation primitive.
Definition: dnnl.hpp:9288
Descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8930
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8965
Primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8988
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:9001
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9046
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9033
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9054
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9041
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9051
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9059
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9062
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9038
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU forward propagation primitive from a C API primitive desc...
Definition: dnnl.hpp:9027
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:9016
GRU forward propagation primitive.
Definition: dnnl.hpp:8928
gru_forward(const primitive_desc &pd)
Constructs a GRU forward propagation primitive.
Definition: dnnl.hpp:9073
gru_forward()=default
Default constructor. Produces an empty object.
A class that provides the destructor for a oneDNN C API handle.
Definition: dnnl.hpp:120
oneDNN C API handle wrapper class.
Definition: dnnl.hpp:136
handle(const handle< T, traits > &)=default
Copy constructor.
bool operator==(const handle< T, traits > &other) const
Equality operator.
Definition: dnnl.hpp:210
bool operator!=(const handle &other) const
Inequality operator.
Definition: dnnl.hpp:220
T get(bool allow_empty=false) const
Returns the underlying C API handle.
Definition: dnnl.hpp:185
handle< T, traits > & operator=(const handle< T, traits > &)=default
Assignment operator.
handle()=default
Constructs an empty handle object.
void reset(T t, bool weak=false)
Resets the handle wrapper objects to wrap a new C API handle.
Definition: dnnl.hpp:176
handle(T t, bool weak=false)
Constructs a handle wrapper object from a C API handle.
Definition: dnnl.hpp:169
handle(handle< T, traits > &&)=default
Move constructor.
handle< T, traits > & operator=(handle< T, traits > &&)=default
Move assignment operator.
Descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7270
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7283
Primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7296
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7357
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7313
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7354
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product backward propagation primitive from a C API pr...
Definition: dnnl.hpp:7346
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7333
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:7351
primitive_desc()=default
Default constructor. Produces an empty object.
Inner product backward propagation primitive.
Definition: dnnl.hpp:7268
inner_product_backward_data(const primitive_desc &pd)
Constructs an inner product backward propagation primitive.
Definition: dnnl.hpp:7366
inner_product_backward_data()=default
Default constructor. Produces an empty object.
Descriptor for an inner product weights gradient primitive.
Definition: dnnl.hpp:7372
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive with bias.
Definition: dnnl.hpp:7386
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive without bias.
Definition: dnnl.hpp:7408
Primitive descriptor for an inner product weights gradient primitive.
Definition: dnnl.hpp:7421
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7476
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:7479
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7484
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product weights update primitive from a C API primitiv...
Definition: dnnl.hpp:7471
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition: dnnl.hpp:7458
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:7487
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition: dnnl.hpp:7438
Inner product weights gradient primitive.
Definition: dnnl.hpp:7370
inner_product_backward_weights(const primitive_desc &pd)
Constructs an inner product weights gradient primitive.
Definition: dnnl.hpp:7498
inner_product_backward_weights()=default
Default constructor. Produces an empty object.
Descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7145
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive without bias.
Definition: dnnl.hpp:7186
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive with bias.
Definition: dnnl.hpp:7162
Primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7199
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product forward propagation primitive from a C API pri...
Definition: dnnl.hpp:7240
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:7252
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7213
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7249
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:7255
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7246
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7229
Inner product forward propagation primitive.
Definition: dnnl.hpp:7143
inner_product_forward(const primitive_desc &pd)
Constructs an inner product forward propagation primitive.
Definition: dnnl.hpp:7264
inner_product_forward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:6981
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition: dnnl.hpp:7021
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition: dnnl.hpp:6997
Primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7035
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:7101
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:7112
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:7120
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization backward propagation primitive from a C A...
Definition: dnnl.hpp:7085
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7104
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:7098
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:7115
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7095
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7072
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:7107
primitive_desc(const desc &adesc, const engine &aengine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7052
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7092
Layer normalization backward propagation primitive.
Definition: dnnl.hpp:6979
layer_normalization_backward(const primitive_desc &pd)
Constructs a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7129
layer_normalization_backward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6836
desc(prop_kind aprop_kind, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition: dnnl.hpp:6850
desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition: dnnl.hpp:6871
Primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6884
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6898
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6935
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6932
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6941
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization forward propagation primitive from a C AP...
Definition: dnnl.hpp:6925
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6947
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6914
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6938
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6944
Layer normalization forward propagation primitive.
Definition: dnnl.hpp:6834
layer_normalization_forward()=default
Default constructor. Produces an empty object.
layer_normalization_forward(const primitive_desc &pd)
Constructs a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6975
Descriptor for a LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9446
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9494
Primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9528
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9591
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:9627
primitive_desc(const desc &adesc, const engine &aengine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9545
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:9647
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:9637
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9565
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9609
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9596
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9588
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:9622
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:9578
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9612
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9601
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9604
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9583
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:9632
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:9642
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:9617
LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9444
lbr_gru_backward(const primitive_desc &pd)
Constructs an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9658
lbr_gru_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9294
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9330
Primitive descriptor for an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9353
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9426
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9405
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9383
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9421
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9429
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:9394
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9367
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9418
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9400
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9413
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9408
LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9292
lbr_gru_forward()=default
Default constructor. Produces an empty object.
lbr_gru_forward(const primitive_desc &pd)
Constructs an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9440
Descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6428
desc()=default
Default constructor. Produces an empty object.
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6441
Primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6452
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6511
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6517
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6514
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive from a C API primit...
Definition: dnnl.hpp:6502
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6489
primitive_desc(const desc &adesc, const engine &aengine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6469
Logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6426
logsoftmax_backward(const primitive_desc &pd)
Constructs a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6526
logsoftmax_backward()=default
Default constructor. Produces an empty object.
Descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6334
desc(prop_kind aprop_kind, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6348
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6359
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6413
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6410
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6389
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive from a C API primiti...
Definition: dnnl.hpp:6400
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6373
primitive_desc()=default
Default constructor. Produces an empty object.
Logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6332
logsoftmax_forward()=default
Default constructor. Produces an empty object.
logsoftmax_forward(const primitive_desc &pd)
Constructs a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6422
Descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5564
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5579
Primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5592
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5627
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN backward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:5640
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5648
primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5608
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5651
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5645
primitive_desc()=default
Default constructor. Produces an empty object.
Local response normalization (LRN) backward propagation primitive.
Definition: dnnl.hpp:5562
lrn_backward(const primitive_desc &pd)
Constructs an LRN backward propagation primitive.
Definition: dnnl.hpp:5660
lrn_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5469
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for a LRN forward propagation primitive.
Definition: dnnl.hpp:5485
Primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5498
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5543
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5546
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5526
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5549
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5511
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN forward propagation primitive from a C API primitive des...
Definition: dnnl.hpp:5537
Local response normalization (LRN) forward propagation primitive.
Definition: dnnl.hpp:5467
lrn_forward()=default
Default constructor. Produces an empty object.
lrn_forward(const primitive_desc &pd)
Constructs an LRN forward propagation primitive.
Definition: dnnl.hpp:5558
Descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8426
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_weights_projection_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole and with or without projection) descriptor for backward ...
Definition: dnnl.hpp:8504
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM descriptor for backward propagation using prop_kind, direction,...
Definition: dnnl.hpp:8715
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole) descriptor for backward propagation using prop_kind,...
Definition: dnnl.hpp:8616
Primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8756
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8827
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:8908
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition: dnnl.hpp:8893
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:8832
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM backward propagation primitive from a C API primitive d...
Definition: dnnl.hpp:8804
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition: dnnl.hpp:8888
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8853
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8809
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8850
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8772
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:8863
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8814
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:8883
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:8837
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:8898
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8791
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8842
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8817
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8845
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:8913
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:8868
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:8903
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8822
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:8878
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8858
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:8873
LSTM backward propagation primitive.
Definition: dnnl.hpp:8424
lstm_backward()=default
Default constructor. Produces an empty object.
lstm_backward(const primitive_desc &pd)
Constructs an LSTM backward propagation primitive.
Definition: dnnl.hpp:8924
Descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8109
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8289
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole and with or without projection) forward...
Definition: dnnl.hpp:8160
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole) forward propagation primitive.
Definition: dnnl.hpp:8228
Primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8315
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8401
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:8383
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8378
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8396
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8409
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8328
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM forward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:8354
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8404
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8373
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:8388
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8368
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8343
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8365
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8393
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8360
LSTM forward propagation primitive.
Definition: dnnl.hpp:8107
lstm_forward(const primitive_desc &pd)
Constructs an LSTM forward propagation primitive.
Definition: dnnl.hpp:8420
lstm_forward()=default
Default constructor. Produces an empty object.
Descriptor for a matmul primitive.
Definition: dnnl.hpp:9934
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition: dnnl.hpp:9942
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition: dnnl.hpp:9956
Primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9966
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9992
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:10008
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a matmul primitive from a C API primitive descriptor that must ...
Definition: dnnl.hpp:10001
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:10013
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10005
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10018
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9978
Matrix multiplication (matmul) primitive.
Definition: dnnl.hpp:9932
matmul(const primitive_desc &pd)
Constructs a matmul primitive.
Definition: dnnl.hpp:10026
matmul()=default
Default constructor. Produces an empty object.
A memory descriptor.
Definition: dnnl.hpp:2062
desc(const dims &adims, data_type adata_type, format_tag aformat_tag, bool allow_empty=false)
Constructs a memory descriptor.
Definition: dnnl.hpp:2086
desc()
Constructs a zero (empty) memory descriptor.
Definition: dnnl.hpp:2069
bool operator!=(const desc &other) const
An inequality operator.
Definition: dnnl.hpp:2297
desc permute_axes(const std::vector< int > &permutation, bool allow_empty=false) const
Constructs a memory descriptor by permuting axes in an existing one.
Definition: dnnl.hpp:2248
desc submemory_desc(const dims &adims, const dims &offsets, bool allow_empty=false) const
Constructs a memory descriptor for a region inside an area described by this memory descriptor.
Definition: dnnl.hpp:2144
bool operator==(const desc &other) const
An equality operator.
Definition: dnnl.hpp:2289
bool is_zero() const
Checks whether the memory descriptor is zero (empty).
Definition: dnnl.hpp:2283
memory::dims dims() const
Returns dimensions of the memory descriptor.
Definition: dnnl.hpp:2270
memory::data_type data_type() const
Returns the data type of the memory descriptor.
Definition: dnnl.hpp:2262
desc reshape(const dims &adims, bool allow_empty=false) const
Constructs a memory descriptor by reshaping an existing one.
Definition: dnnl.hpp:2200
desc(const dims &adims, data_type adata_type, const dims &strides, bool allow_empty=false)
Constructs a memory descriptor by strides.
Definition: dnnl.hpp:2114
size_t get_size() const
Returns size of the memory descriptor in bytes.
Definition: dnnl.hpp:2278
desc(const dnnl_memory_desc_t &data)
Constructs a memory descriptor from a C API data structure.
Definition: dnnl.hpp:2131
dnnl_memory_desc_t data
The underlying C API data structure.
Definition: dnnl.hpp:2065
Memory object.
Definition: dnnl.hpp:1134
void unmap_data(void *mapped_ptr) const
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
Definition: dnnl.hpp:2463
T * map_data() const
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
Definition: dnnl.hpp:2446
static void validate_dims(const std::vector< T > &v, int min_size=0)
Helper function that validates that an std::vector of dimensions can be safely converted to the C API...
Definition: dnnl.hpp:1150
memory()=default
Default constructor.
dnnl_dim_t dim
Integer type for representing dimension sizes and indices.
Definition: dnnl.hpp:1138
memory(const desc &md, const engine &aengine, void *handle)
Constructs a memory object.
Definition: dnnl.hpp:2330
void set_data_handle(void *handle, const stream &astream) const
Sets the underlying memory buffer.
Definition: dnnl.hpp:2402
void * get_data_handle() const
Returns the underlying memory buffer.
Definition: dnnl.hpp:2367
format_tag
Memory format tag specification.
Definition: dnnl.hpp:1237
data_type
Data type specification.
Definition: dnnl.hpp:1156
@ undef
Undefined data type (used for empty memory descriptors).
engine get_engine() const
Returns the associated engine.
Definition: dnnl.hpp:2356
format_kind
Memory format kind.
Definition: dnnl.hpp:1181
memory(const desc &md, const engine &aengine)
Constructs a memory object.
Definition: dnnl.hpp:2344
void set_data_handle(void *handle) const
Sets the underlying memory buffer.
Definition: dnnl.hpp:2418
static size_t data_type_size(data_type adata_type)
Returns size of data type in bytes.
Definition: dnnl.hpp:1176
desc get_desc() const
Returns the associated memory descriptor.
Definition: dnnl.hpp:2348
std::vector< dim > dims
Vector of dimensions.
Definition: dnnl.hpp:1141
Descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5788
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling backward propagation primitive.
Definition: dnnl.hpp:5812
Primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5831
primitive_desc(const desc &adesc, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5847
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5887
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5890
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5866
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5884
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:5879
Pooling backward propagation primitive.
Definition: dnnl.hpp:5786
pooling_backward()=default
Default constructor. Produces an empty object.
pooling_backward(const primitive_desc &pd)
Constructs a pooling backward propagation primitive.
Definition: dnnl.hpp:5899
Descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5676
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling forward propagation primitive.
Definition: dnnl.hpp:5703
Primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5722
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5750
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5770
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5767
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:5761
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5773
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5735
Pooling forward propagation primitive.
Definition: dnnl.hpp:5674
pooling_forward(const primitive_desc &pd)
Constructs a pooling forward propagation primitive.
Definition: dnnl.hpp:5782
pooling_forward()=default
Default constructor. Produces an empty object.
Descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:10433
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &dilation, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10459
Primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10480
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10539
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10536
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive f...
Definition: dnnl.hpp:10531
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const pooling_v2_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10497
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const pooling_v2_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10517
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:10542
Pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10431
pooling_v2_backward(const primitive_desc &pd)
Constructs a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10552
pooling_v2_backward()=default
Default constructor. Produces an empty object.
Descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:10312
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &dilation, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10341
Primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:10363
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:10417
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10414
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10411
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10393
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive fr...
Definition: dnnl.hpp:10405
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10377
Pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10310
pooling_v2_forward()=default
Default constructor. Produces an empty object.
pooling_v2_forward(const primitive_desc &pd)
Constructs a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10427
Post-ops.
Definition: dnnl.hpp:2528
void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 1.
Definition: dnnl.hpp:2704
void get_params_binary(int index, algorithm &aalgorithm, memory::desc &src1_desc) const
Returns the parameters of a binary post-op.
Definition: dnnl.hpp:2840
void get_params_sum(int index, float &scale, memory::data_type &data_type) const
Returns the parameters of an accumulation (sum) post-op.
Definition: dnnl.hpp:2605
void append_eltwise(float scale, algorithm aalgorithm, float alpha, float beta)
Appends an elementwise post-op.
Definition: dnnl.hpp:2627
void append_binary(algorithm aalgorithm, const memory::desc &src1_desc)
Appends a binary post-op.
Definition: dnnl.hpp:2829
void append_dw_k3s1p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 1.
Definition: dnnl.hpp:2678
primitive::kind kind(int index) const
Returns the primitive kind of post-op at entry with a certain index.
Definition: dnnl.hpp:2545
int len() const
Returns the number of post-ops entries.
Definition: dnnl.hpp:2540
void append_dw_k3s2p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 2.
Definition: dnnl.hpp:2763
post_ops()
Constructs an empty sequence of post-ops.
Definition: dnnl.hpp:2532
void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 2.
Definition: dnnl.hpp:2789
void get_params_eltwise(int index, float &scale, algorithm &aalgorithm, float &alpha, float &beta) const
Returns parameters of an elementwise post-op.
Definition: dnnl.hpp:2641
void get_params_sum(int index, float &scale) const
Returns the parameters of an accumulation (sum) post-op.
Definition: dnnl.hpp:2595
void append_sum(float scale=1.f, memory::data_type data_type=memory::data_type::undef)
Appends an accumulation (sum) post-op.
Definition: dnnl.hpp:2580
Descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10656
desc(const memory::desc &data_desc, const memory::desc &weight_desc, const memory::desc &diff_data_desc, const memory::desc &diff_weights_desc)
Constructs a descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10667
Primitive descriptor for prelu backward propagation.
Definition: dnnl.hpp:10680
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10735
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:10738
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const prelu_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10717
primitive_desc(const desc &adesc, const engine &aengine, const prelu_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10697
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10741
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a prelu backward propagation primitive from a C API primitive d...
Definition: dnnl.hpp:10730
primitive_desc()=default
Default constructor. Produces an empty object.
PReLU backward propagation primitive.
Definition: dnnl.hpp:10654
prelu_backward()=default
Default constructor. Produces an empty object.
prelu_backward(const primitive_desc &pd)
Constructs a prelu backward propagation primitive.
Definition: dnnl.hpp:10750
Descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10569
desc(prop_kind aprop_kind, const memory::desc &data_desc, const memory::desc &weight_desc)
Constructs a descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10580
Primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10591
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10641
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10638
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10605
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10621
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a prelu forward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:10632
PReLU forward propagation primitive.
Definition: dnnl.hpp:10567
prelu_forward(const primitive_desc &pd)
Constructs a prelu forward propagation primitive.
Definition: dnnl.hpp:10650
prelu_forward()=default
Default constructor. Produces an empty object.
Primitive attributes.
Definition: dnnl.hpp:2864
void get_zero_points(int arg, int &mask, std::vector< int32_t > &zero_points) const
Returns zero points correspondence mask and values.
Definition: dnnl.hpp:3031
const post_ops get_post_ops() const
Returns post-ops previously set via set_post_ops().
Definition: dnnl.hpp:3077
void set_rnn_data_qparams(float scale, float shift)
Sets quantization scale and shift parameters for RNN data tensors.
Definition: dnnl.hpp:3132
void get_rnn_weights_qparams(int &mask, std::vector< float > &scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3210
void get_rnn_data_qparams(float &scale, float &shift)
Returns the quantization scale and shift parameters for RNN data tensors.
Definition: dnnl.hpp:3148
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scaling factors correspondence mask and values.
Definition: dnnl.hpp:2966
void get_rnn_weights_projection_qparams(int &mask, std::vector< float > &scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3279
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Sets quantization scaling factors for RNN weights tensors.
Definition: dnnl.hpp:3184
void set_rnn_weights_projection_qparams(int mask, const std::vector< float > &scales)
Sets quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3251
void set_scratchpad_mode(scratchpad_mode mode)
Sets scratchpad mode.
Definition: dnnl.hpp:2895
void set_scales(int arg, int mask, const std::vector< float > &scales)
Sets scaling factors for primitive operations for a given memory argument.
Definition: dnnl.hpp:3014
void get_scales(int arg, int &mask, std::vector< float > &scales) const
Returns scaling factors correspondence mask and values for a given memory argument.
Definition: dnnl.hpp:2984
void get_output_scales(int &mask, std::vector< float > &scales) const
Returns output scaling factors correspondence mask and values.
Definition: dnnl.hpp:2910
primitive_attr(dnnl_primitive_attr_t attr)
Creates primitive attributes from a C API dnnl_primitive_attr_t handle.
Definition: dnnl.hpp:2880
void set_post_ops(const post_ops ops)
Sets post-ops.
Definition: dnnl.hpp:3094
primitive_attr()
Constructs default (empty) primitive attributes.
Definition: dnnl.hpp:2868
void set_zero_points(int arg, int mask, const std::vector< int32_t > &zero_points)
Sets zero points for primitive operations for a given memory argument.
Definition: dnnl.hpp:3066
scratchpad_mode get_scratchpad_mode() const
Returns the scratchpad mode.
Definition: dnnl.hpp:2884
Base class for all primitive descriptors.
Definition: dnnl.hpp:3303
primitive_attr get_primitive_attr() const
Returns the primitive attributes.
Definition: dnnl.hpp:3487
memory::desc diff_weights_desc(int idx) const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:3413
primitive_desc_base()=default
Default constructor. Produces an empty object.
engine get_engine() const
Returns the engine of the primitive descriptor.
Definition: dnnl.hpp:3311
memory::desc query_md(query what, int idx=0) const
Returns a memory descriptor.
Definition: dnnl.hpp:3348
memory::desc dst_desc(int idx) const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3377
memory::desc diff_dst_desc(int idx) const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:3404
memory::desc scratchpad_desc() const
Returns the scratchpad memory descriptor.
Definition: dnnl.hpp:3469
void reset_with_clone(const_dnnl_primitive_desc_t pd)
Resets the value of the handle to a clone of a C API primitive descriptor.
Definition: dnnl.hpp:3511
dnnl::primitive::kind get_kind() const
Returns the kind of the primitive descriptor.
Definition: dnnl.hpp:3499
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:3448
memory::desc diff_src_desc(int idx) const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:3395
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:3436
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3563
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3531
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:3442
memory::desc weights_desc(int idx) const
Returns a weights memory descriptor.
Definition: dnnl.hpp:3386
memory::dim query_s64(query what) const
Returns a memory::dim value (same as int64_t).
Definition: dnnl.hpp:3327
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:3460
engine scratchpad_engine() const
Returns the engine on which the scratchpad memory is located.
Definition: dnnl.hpp:3475
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3430
const char * impl_info_str() const
Returns implementation name.
Definition: dnnl.hpp:3315
memory::desc src_desc(int idx) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3368
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:3424
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3546
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:3454
A base class for descriptors of all primitives that have an operation descriptor and that support ite...
Definition: dnnl.hpp:3957
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr, const engine &aengine, const_dnnl_primitive_desc_t hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor.
Definition: dnnl.hpp:3984
bool next_impl()
Advances the primitive iterator to the next implementation.
Definition: dnnl.hpp:4002
Base class for all computational primitives.
Definition: dnnl.hpp:269
void execute(const stream &astream, const std::unordered_map< int, memory > &args) const
Executes computations specified by the primitive in a specified stream.
primitive()=default
Default constructor. Constructs an empty object.
primitive(const primitive_desc &pd)
Constructs a primitive from a primitive descriptor.
kind
Kinds of primitives supported by the library.
Definition: dnnl.hpp:271
@ deconvolution
A deconvolution primitive.
@ pooling_v2
A pooling version 2 primitive.
@ inner_product
An inner product primitive.
@ logsoftmax
A logsoftmax primitive.
@ layer_normalization
A layer normalization primitive.
@ pooling
A pooling primitive.
@ resampling
A resampling primitive.
@ shuffle
A shuffle primitive.
@ rnn
An RNN primitive.
@ batch_normalization
A batch normalization primitive.
@ lrn
An LRN primitive.
@ prelu
A PReLU primitive.
@ eltwise
An element-wise primitive.
@ convolution
A convolution primitive.
@ softmax
A softmax primitive.
@ undef
Undefined primitive.
primitive(const_dnnl_primitive_desc_t c_pd)
Constructs a primitive from a C API primitive descriptor.
Descriptor for reduction.
Definition: dnnl.hpp:10767
desc()=default
Default constructor. Produces an empty object.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, float p, float eps)
Constructs a descriptor for a reduction primitive using algorithm specific parameters,...
Definition: dnnl.hpp:10790
Primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10800
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10839
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10842
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a reduction primitive from a C API primitive descriptor that mu...
Definition: dnnl.hpp:10835
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10826
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10812
Reduction.
Definition: dnnl.hpp:10765
reduction(const primitive_desc &pd)
Constructs a reduction primitive.
Definition: dnnl.hpp:10850
reduction()=default
Default constructor. Produces an empty object.
Primitive descriptor for a reorder primitive.
Definition: dnnl.hpp:3627
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:3712
primitive_desc(const engine &src_engine, const memory::desc &src_md, const engine &dst_engine, const memory::desc &dst_md, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for reorder primitive.
Definition: dnnl.hpp:3650
primitive_desc(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for reorder primitive.
Definition: dnnl.hpp:3676
engine get_src_engine() const
Returns the engine on which the source memory is allocated.
Definition: dnnl.hpp:3701
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for reorder primitive from a C API primitive descriptor which must ...
Definition: dnnl.hpp:3696
engine get_dst_engine() const
Returns the engine on which the destination memory is allocated.
Definition: dnnl.hpp:3707
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3715
Reorder primitive.
Definition: dnnl.hpp:3625
reorder(const primitive_desc &pd)
Constructs a reorder primitive.
Definition: dnnl.hpp:3723
void execute(const stream &astream, memory &src, memory &dst) const
Executes the reorder primitive.
Definition: dnnl.hpp:3744
reorder()=default
Default constructor. Produces an empty object.
reorder(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr())
Constructs a reorder primitive that would reorder data between memory objects having the same memory ...
Definition: dnnl.hpp:3732
Descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10188
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for a resampling backward propagation primitive using source and destination ...
Definition: dnnl.hpp:10199
desc(algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for resampling backward propagation primitive.
Definition: dnnl.hpp:10216
Primitive descriptor for resampling backward propagation primitive.
Definition: dnnl.hpp:10229
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling backward propagation primitive from a C API primit...
Definition: dnnl.hpp:10279
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:10284
primitive_desc(const desc &adesc, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10246
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10287
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10266
Resampling backward propagation primitive.
Definition: dnnl.hpp:10186
resampling_backward(const primitive_desc &pd)
Constructs a resampling backward propagation primitive.
Definition: dnnl.hpp:10296
resampling_backward()=default
Default constructor. Produces an empty object.
Descriptor for resampling forward propagation.
Definition: dnnl.hpp:10044
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive using source and destination m...
Definition: dnnl.hpp:10062
desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &src_desc)
Constructs a descriptor for a resampling forward propagation primitive using source memory descriptor...
Definition: dnnl.hpp:10082
desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10109
Primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10123
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10137
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10173
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10170
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling forward propagation primitive from a C API primiti...
Definition: dnnl.hpp:10164
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10153
primitive_desc()=default
Default constructor. Produces an empty object.
Resampling forward propagation.
Definition: dnnl.hpp:10042
resampling_forward()=default
Default constructor. Produces an empty object.
resampling_forward(const primitive_desc &pd)
Constructs a resampling forward propagation primitive.
Definition: dnnl.hpp:10182
Base class for primitive descriptors for RNN primitives.
Definition: dnnl.hpp:7512
memory::desc dst_iter_c_desc() const
Returns destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:7597
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:7563
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:7623
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:7551
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:7557
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:7611
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:7671
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:7629
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:7665
rnn_primitive_desc_base()=default
Default constructor. Produces an empty object.
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:7617
rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
Constructs an RNN primitive descriptor base from a C API primitive descriptor while checking that it ...
Definition: dnnl.hpp:7525
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:7651
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:7583
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition: dnnl.hpp:7642
memory::desc src_iter_c_desc() const
Returns source recurrent cell state memory descriptor.
Definition: dnnl.hpp:7545
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:7539
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:7577
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:7569
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:7531
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:7657
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:7591
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition: dnnl.hpp:7635
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:7603
Descriptor for a shuffle primitive backward propagation primitive.
Definition: dnnl.hpp:9749
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9759
Primitive descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9768
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:9799
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:9804
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const shuffle_forward::primitive_desc &hint_fwd_pd, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9786
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:9807
Shuffle backward propagation primitive.
Definition: dnnl.hpp:9746
shuffle_backward()=default
Default constructor. Produces an empty object.
shuffle_backward(const primitive_desc &pd)
Constructs a shuffle backward propagation primitive.
Definition: dnnl.hpp:9816
Descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9674
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9686
Primitive descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9697
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:9733
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:9730
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:9724
primitive_desc(const desc &adesc, const engine &aengine, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9712
primitive_desc()=default
Default constructor. Produces an empty object.
Shuffle forward propagation primitive.
Definition: dnnl.hpp:9672
shuffle_forward()=default
Default constructor. Produces an empty object.
shuffle_forward(const primitive_desc &pd)
Constructs a shuffle forward propagation primitive.
Definition: dnnl.hpp:9742
Descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6224
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6237
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6248
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:6298
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6285
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6309
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6306
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6265
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6303
Softmax backward propagation primitive.
Definition: dnnl.hpp:6222
softmax_backward()=default
Default constructor. Produces an empty object.
softmax_backward(const primitive_desc &pd)
Constructs a softmax backward propagation primitive.
Definition: dnnl.hpp:6318
Descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6134
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6148
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6159
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6206
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6173
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6209
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:6200
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6189
primitive_desc()=default
Default constructor. Produces an empty object.
Softmax forward propagation primitive.
Definition: dnnl.hpp:6132
softmax_forward()=default
Default constructor. Produces an empty object.
softmax_forward(const primitive_desc &pd)
Constructs a softmax forward propagation primitive.
Definition: dnnl.hpp:6218
An execution stream.
Definition: dnnl.hpp:1011
engine get_engine() const
Returns the associated engine.
Definition: dnnl.hpp:1042
stream & wait()
Waits for all primitives executing in the stream to finish.
Definition: dnnl.hpp:1051
stream(const engine &aengine, flags aflags=flags::default_flags)
Constructs a stream for the specified engine and with behavior controlled by the specified flags.
Definition: dnnl.hpp:1033
flags
Stream flags. Can be combined using the bitwise OR operator.
Definition: dnnl.hpp:1015
@ out_of_order
Out-of-order execution.
@ default_flags
Default stream configuration.
@ in_order
In-order execution.
stream()=default
Constructs an empty stream.
Primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3866
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3939
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3936
primitive_desc(const memory::desc &dst, const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3880
primitive_desc(const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3910
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for sum primitive from a C API primitive descriptor which must have...
Definition: dnnl.hpp:3932
Out-of-place summation (sum) primitive.
Definition: dnnl.hpp:3864
sum()=default
Default constructor. Produces an empty object.
sum(const primitive_desc &pd)
Constructs a sum primitive.
Definition: dnnl.hpp:3947
Descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7882
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7937
Primitive descriptor for an RNN backward propagation primitive.
Definition: dnnl.hpp:7973
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:8010
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8033
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:8087
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8049
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:8067
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:8077
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7990
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:8082
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8041
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive from a C API primi...
Definition: dnnl.hpp:8023
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8036
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8046
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8054
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:8092
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:8062
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8028
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:8072
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8057
Vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7880
vanilla_rnn_backward(const primitive_desc &pd)
Constructs a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:8103
vanilla_rnn_backward()=default
Default constructor. Produces an empty object.
Descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7721
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7764
Primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7789
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7803
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive from a C API primit...
Definition: dnnl.hpp:7830
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:7836
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:7841
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:7849
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:7844
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:7865
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:7862
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7819
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:7857
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:7854
Vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7719
vanilla_rnn_forward()=default
Default constructor. Produces an empty object.
vanilla_rnn_forward(const primitive_desc &pd)
Constructs a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7876
A descriptor of a Batch Normalization operation.
Definition: dnnl_types.h:1942
A descriptor of a binary operation.
Definition: dnnl_types.h:2150
A descriptor of a convolution operation.
Definition: dnnl_types.h:1646
A descriptor of a element-wise operation.
Definition: dnnl_types.h:1721
An opaque structure to describe an engine.
A descriptor of an inner product operation.
Definition: dnnl_types.h:2012
A descriptor of a Layer Normalization operation.
Definition: dnnl_types.h:1975
A descriptor of a Local Response Normalization (LRN) operation.
Definition: dnnl_types.h:1911
A descriptor of a matrix multiplication operation.
Definition: dnnl_types.h:2176
Memory descriptor.
Definition: dnnl_types.h:1557
dnnl_data_type_t data_type
Data type of the tensor elements.
Definition: dnnl_types.h:1577
dnnl_dims_t dims
Dimensions in the following order:
Definition: dnnl_types.h:1574
int ndims
Number of dimensions.
Definition: dnnl_types.h:1559
An opaque structure to describe a memory.
A descriptor of a pooling operation.
Definition: dnnl_types.h:1811
A descriptor of a pooling operation.
Definition: dnnl_types.h:1849
An opaque structure for a chain of post operations.
An opaque structure for primitive descriptor attributes.
An opaque structure to describe a primitive descriptor iterator.
An opaque structure to describe a primitive descriptor.
An opaque structure to describe a primitive.
A descriptor of reduction operation.
Definition: dnnl_types.h:2226
A descriptor of resampling operation.
Definition: dnnl_types.h:2198
A descriptor for an RNN operation.
Definition: dnnl_types.h:2068
A descriptor of a shuffle operation.
Definition: dnnl_types.h:1699
A descriptor of a Softmax operation.
Definition: dnnl_types.h:1781
An opaque structure to describe an execution stream.
Structure containing version information as per Semantic Versioning
Definition: dnnl_types.h:2759