_tensor_str.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. import math
  2. import textwrap
  3. from typing import Optional
  4. import torch
  5. from torch import inf
  6. class __PrinterOptions:
  7. precision: int = 4
  8. threshold: float = 1000
  9. edgeitems: int = 3
  10. linewidth: int = 80
  11. sci_mode: Optional[bool] = None
  12. PRINT_OPTS = __PrinterOptions()
  13. # We could use **kwargs, but this will give better docs
  14. def set_printoptions(
  15. precision=None,
  16. threshold=None,
  17. edgeitems=None,
  18. linewidth=None,
  19. profile=None,
  20. sci_mode=None,
  21. ):
  22. r"""Set options for printing. Items shamelessly taken from NumPy
  23. Args:
  24. precision: Number of digits of precision for floating point output
  25. (default = 4).
  26. threshold: Total number of array elements which trigger summarization
  27. rather than full `repr` (default = 1000).
  28. edgeitems: Number of array items in summary at beginning and end of
  29. each dimension (default = 3).
  30. linewidth: The number of characters per line for the purpose of
  31. inserting line breaks (default = 80). Thresholded matrices will
  32. ignore this parameter.
  33. profile: Sane defaults for pretty printing. Can override with any of
  34. the above options. (any one of `default`, `short`, `full`)
  35. sci_mode: Enable (True) or disable (False) scientific notation. If
  36. None (default) is specified, the value is defined by
  37. `torch._tensor_str._Formatter`. This value is automatically chosen
  38. by the framework.
  39. Example::
  40. >>> # Limit the precision of elements
  41. >>> torch.set_printoptions(precision=2)
  42. >>> torch.tensor([1.12345])
  43. tensor([1.12])
  44. >>> # Limit the number of elements shown
  45. >>> torch.set_printoptions(threshold=5)
  46. >>> torch.arange(10)
  47. tensor([0, 1, 2, ..., 7, 8, 9])
  48. >>> # Restore defaults
  49. >>> torch.set_printoptions(profile='default')
  50. >>> torch.tensor([1.12345])
  51. tensor([1.1235])
  52. >>> torch.arange(10)
  53. tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
  54. """
  55. if profile is not None:
  56. if profile == "default":
  57. PRINT_OPTS.precision = 4
  58. PRINT_OPTS.threshold = 1000
  59. PRINT_OPTS.edgeitems = 3
  60. PRINT_OPTS.linewidth = 80
  61. elif profile == "short":
  62. PRINT_OPTS.precision = 2
  63. PRINT_OPTS.threshold = 1000
  64. PRINT_OPTS.edgeitems = 2
  65. PRINT_OPTS.linewidth = 80
  66. elif profile == "full":
  67. PRINT_OPTS.precision = 4
  68. PRINT_OPTS.threshold = inf
  69. PRINT_OPTS.edgeitems = 3
  70. PRINT_OPTS.linewidth = 80
  71. if precision is not None:
  72. PRINT_OPTS.precision = precision
  73. if threshold is not None:
  74. PRINT_OPTS.threshold = threshold
  75. if edgeitems is not None:
  76. PRINT_OPTS.edgeitems = edgeitems
  77. if linewidth is not None:
  78. PRINT_OPTS.linewidth = linewidth
  79. PRINT_OPTS.sci_mode = sci_mode
  80. def tensor_totype(t):
  81. dtype = torch.float if t.is_mps else torch.double
  82. return t.to(dtype=dtype)
  83. class _Formatter:
  84. def __init__(self, tensor):
  85. self.floating_dtype = tensor.dtype.is_floating_point
  86. self.int_mode = True
  87. self.sci_mode = False
  88. self.max_width = 1
  89. with torch.no_grad():
  90. tensor_view = tensor.reshape(-1)
  91. if not self.floating_dtype:
  92. for value in tensor_view:
  93. value_str = "{}".format(value)
  94. self.max_width = max(self.max_width, len(value_str))
  95. else:
  96. nonzero_finite_vals = torch.masked_select(
  97. tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
  98. )
  99. if nonzero_finite_vals.numel() == 0:
  100. # no valid number, do nothing
  101. return
  102. # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
  103. nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
  104. nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
  105. nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
  106. for value in nonzero_finite_vals:
  107. if value != torch.ceil(value):
  108. self.int_mode = False
  109. break
  110. if self.int_mode:
  111. # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
  112. # to indicate that the tensor is of floating type. add 1 to the len to account for this.
  113. if (
  114. nonzero_finite_max / nonzero_finite_min > 1000.0
  115. or nonzero_finite_max > 1.0e8
  116. ):
  117. self.sci_mode = True
  118. for value in nonzero_finite_vals:
  119. value_str = (
  120. ("{{:.{}e}}").format(PRINT_OPTS.precision).format(value)
  121. )
  122. self.max_width = max(self.max_width, len(value_str))
  123. else:
  124. for value in nonzero_finite_vals:
  125. value_str = ("{:.0f}").format(value)
  126. self.max_width = max(self.max_width, len(value_str) + 1)
  127. else:
  128. # Check if scientific representation should be used.
  129. if (
  130. nonzero_finite_max / nonzero_finite_min > 1000.0
  131. or nonzero_finite_max > 1.0e8
  132. or nonzero_finite_min < 1.0e-4
  133. ):
  134. self.sci_mode = True
  135. for value in nonzero_finite_vals:
  136. value_str = (
  137. ("{{:.{}e}}").format(PRINT_OPTS.precision).format(value)
  138. )
  139. self.max_width = max(self.max_width, len(value_str))
  140. else:
  141. for value in nonzero_finite_vals:
  142. value_str = (
  143. ("{{:.{}f}}").format(PRINT_OPTS.precision).format(value)
  144. )
  145. self.max_width = max(self.max_width, len(value_str))
  146. if PRINT_OPTS.sci_mode is not None:
  147. self.sci_mode = PRINT_OPTS.sci_mode
  148. def width(self):
  149. return self.max_width
  150. def format(self, value):
  151. if self.floating_dtype:
  152. if self.sci_mode:
  153. ret = (
  154. ("{{:{}.{}e}}")
  155. .format(self.max_width, PRINT_OPTS.precision)
  156. .format(value)
  157. )
  158. elif self.int_mode:
  159. ret = "{:.0f}".format(value)
  160. if not (math.isinf(value) or math.isnan(value)):
  161. ret += "."
  162. else:
  163. ret = ("{{:.{}f}}").format(PRINT_OPTS.precision).format(value)
  164. else:
  165. ret = "{}".format(value)
  166. return (self.max_width - len(ret)) * " " + ret
  167. def _scalar_str(self, formatter1, formatter2=None):
  168. if formatter2 is not None:
  169. real_str = _scalar_str(self.real, formatter1)
  170. imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
  171. # handles negative numbers, +0.0, -0.0
  172. if imag_str[0] == "+" or imag_str[0] == "-":
  173. return real_str + imag_str
  174. else:
  175. return real_str + "+" + imag_str
  176. else:
  177. return formatter1.format(self.item())
  178. def _vector_str(self, indent, summarize, formatter1, formatter2=None):
  179. # length includes spaces and comma between elements
  180. element_length = formatter1.width() + 2
  181. if formatter2 is not None:
  182. # width for imag_formatter + an extra j for complex
  183. element_length += formatter2.width() + 1
  184. elements_per_line = max(
  185. 1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
  186. )
  187. def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
  188. if formatter2 is not None:
  189. real_str = formatter1.format(val.real)
  190. imag_str = (formatter2.format(val.imag) + "j").lstrip()
  191. # handles negative numbers, +0.0, -0.0
  192. if imag_str[0] == "+" or imag_str[0] == "-":
  193. return real_str + imag_str
  194. else:
  195. return real_str + "+" + imag_str
  196. else:
  197. return formatter1.format(val)
  198. if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
  199. data = (
  200. [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
  201. + [" ..."]
  202. + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
  203. )
  204. else:
  205. data = [_val_formatter(val) for val in self.tolist()]
  206. data_lines = [
  207. data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
  208. ]
  209. lines = [", ".join(line) for line in data_lines]
  210. return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
  211. # formatter2 is only used for printing complex tensors.
  212. # For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
  213. # and tensor.imag respesectively
  214. def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
  215. dim = self.dim()
  216. if dim == 0:
  217. return _scalar_str(self, formatter1, formatter2)
  218. if dim == 1:
  219. return _vector_str(self, indent, summarize, formatter1, formatter2)
  220. if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
  221. slices = (
  222. [
  223. _tensor_str_with_formatter(
  224. self[i], indent + 1, summarize, formatter1, formatter2
  225. )
  226. for i in range(0, PRINT_OPTS.edgeitems)
  227. ]
  228. + ["..."]
  229. + [
  230. _tensor_str_with_formatter(
  231. self[i], indent + 1, summarize, formatter1, formatter2
  232. )
  233. for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
  234. ]
  235. )
  236. else:
  237. slices = [
  238. _tensor_str_with_formatter(
  239. self[i], indent + 1, summarize, formatter1, formatter2
  240. )
  241. for i in range(0, self.size(0))
  242. ]
  243. tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
  244. return "[" + tensor_str + "]"
  245. def _tensor_str(self, indent):
  246. if self.numel() == 0:
  247. return "[]"
  248. if self.has_names():
  249. # There are two main codepaths (possibly more) that tensor printing goes through:
  250. # - tensor data can fit comfortably on screen
  251. # - tensor data needs to be summarized
  252. # Some of the codepaths don't fully support named tensors, so we send in
  253. # an unnamed tensor to the formatting code as a workaround.
  254. self = self.rename(None)
  255. summarize = self.numel() > PRINT_OPTS.threshold
  256. if self._is_zerotensor():
  257. self = self.clone()
  258. # handle the negative bit
  259. if self.is_neg():
  260. self = self.resolve_neg()
  261. if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
  262. self = self.float()
  263. if self.dtype is torch.complex32:
  264. self = self.cfloat()
  265. if self.dtype.is_complex:
  266. # handle the conjugate bit
  267. self = self.resolve_conj()
  268. real_formatter = _Formatter(
  269. get_summarized_data(self.real) if summarize else self.real
  270. )
  271. imag_formatter = _Formatter(
  272. get_summarized_data(self.imag) if summarize else self.imag
  273. )
  274. return _tensor_str_with_formatter(
  275. self, indent, summarize, real_formatter, imag_formatter
  276. )
  277. else:
  278. formatter = _Formatter(get_summarized_data(self) if summarize else self)
  279. return _tensor_str_with_formatter(self, indent, summarize, formatter)
  280. def _add_suffixes(tensor_str, suffixes, indent, force_newline):
  281. tensor_strs = [tensor_str]
  282. last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
  283. for suffix in suffixes:
  284. suffix_len = len(suffix)
  285. if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
  286. tensor_strs.append(",\n" + " " * indent + suffix)
  287. last_line_len = indent + suffix_len
  288. force_newline = False
  289. else:
  290. tensor_strs.append(", " + suffix)
  291. last_line_len += suffix_len + 2
  292. tensor_strs.append(")")
  293. return "".join(tensor_strs)
  294. def get_summarized_data(self):
  295. dim = self.dim()
  296. if dim == 0:
  297. return self
  298. if dim == 1:
  299. if self.size(0) > 2 * PRINT_OPTS.edgeitems:
  300. return torch.cat(
  301. (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
  302. )
  303. else:
  304. return self
  305. if self.size(0) > 2 * PRINT_OPTS.edgeitems:
  306. start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
  307. end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
  308. return torch.stack([get_summarized_data(x) for x in (start + end)])
  309. else:
  310. return torch.stack([get_summarized_data(x) for x in self])
  311. def _str_intern(inp, *, tensor_contents=None):
  312. if torch._C._functorch.is_functorch_wrapped_tensor(inp):
  313. return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
  314. is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
  315. if inp.is_nested:
  316. prefix = "nested_tensor("
  317. elif is_plain_tensor:
  318. prefix = "tensor("
  319. else:
  320. prefix = f"{type(inp).__name__}("
  321. indent = len(prefix)
  322. suffixes = []
  323. custom_contents_provided = tensor_contents is not None
  324. if custom_contents_provided:
  325. tensor_str = tensor_contents
  326. # This is used to extract the primal value and thus disable the forward AD
  327. # within this function.
  328. # TODO(albanD) This needs to be updated when more than one level is supported
  329. self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
  330. # Note [Print tensor device]:
  331. # A general logic here is we only print device when it doesn't match
  332. # the device specified in default tensor type.
  333. # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
  334. # torch._C._get_default_device() only returns either cpu or cuda.
  335. # In other cases, we don't have a way to set them as default yet,
  336. # and we should always print out device for them.
  337. if (
  338. self.device.type != torch._C._get_default_device()
  339. or (
  340. self.device.type == "cuda"
  341. and torch.cuda.current_device() != self.device.index
  342. )
  343. or (self.device.type == "mps")
  344. ):
  345. suffixes.append("device='" + str(self.device) + "'")
  346. # Tensor printing performs tensor operations like slice, indexing, etc to make it in a
  347. # representable format. These operations on ipu/xla/lazy tensor results in compilations. Hence,
  348. # to avoid compilations, copying the tensor to cpu before printing.
  349. if self.device.type in ["xla", "lazy", "ipu"]:
  350. self = self.to("cpu")
  351. # TODO: add an API to map real -> complex dtypes
  352. _default_complex_dtype = (
  353. torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
  354. )
  355. has_default_dtype = self.dtype in (
  356. torch.get_default_dtype(),
  357. _default_complex_dtype,
  358. torch.int64,
  359. torch.bool,
  360. )
  361. if self.is_sparse:
  362. suffixes.append("size=" + str(tuple(self.shape)))
  363. from torch._subclasses.fake_tensor import FakeTensor
  364. if not self.is_meta and not isinstance(self, FakeTensor):
  365. suffixes.append("nnz=" + str(self._nnz()))
  366. if not has_default_dtype:
  367. suffixes.append("dtype=" + str(self.dtype))
  368. if not custom_contents_provided:
  369. indices_prefix = "indices=tensor("
  370. indices = self._indices().detach()
  371. indices_str = _tensor_str(indices, indent + len(indices_prefix))
  372. if indices.numel() == 0:
  373. indices_str += ", size=" + str(tuple(indices.shape))
  374. values_prefix = "values=tensor("
  375. values = self._values().detach()
  376. values_str = _tensor_str(values, indent + len(values_prefix))
  377. if values.numel() == 0:
  378. values_str += ", size=" + str(tuple(values.shape))
  379. tensor_str = (
  380. indices_prefix
  381. + indices_str
  382. + "),\n"
  383. + " " * indent
  384. + values_prefix
  385. + values_str
  386. + ")"
  387. )
  388. elif self.layout in {
  389. torch.sparse_csr,
  390. torch.sparse_csc,
  391. torch.sparse_bsr,
  392. torch.sparse_bsc,
  393. }:
  394. suffixes.append("size=" + str(tuple(self.shape)))
  395. suffixes.append("nnz=" + str(self._nnz()))
  396. if not has_default_dtype:
  397. suffixes.append("dtype=" + str(self.dtype))
  398. if not custom_contents_provided:
  399. compressed_indices_method, plain_indices_method = {
  400. torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
  401. torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
  402. torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
  403. torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
  404. }[self.layout]
  405. if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
  406. cdimname, pdimname = "row", "column"
  407. else:
  408. cdimname, pdimname = "column", "row"
  409. compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
  410. compressed_indices = compressed_indices_method(self).detach()
  411. compressed_indices_str = _tensor_str(
  412. compressed_indices, indent + len(compressed_indices_prefix)
  413. )
  414. if compressed_indices.numel() == 0:
  415. compressed_indices_str += ", size=" + str(
  416. tuple(compressed_indices.shape)
  417. )
  418. plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
  419. plain_indices = plain_indices_method(self).detach()
  420. plain_indices_str = _tensor_str(
  421. plain_indices, indent + len(plain_indices_prefix)
  422. )
  423. if plain_indices.numel() == 0:
  424. plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
  425. values_prefix = "values=tensor("
  426. values = self.values().detach()
  427. values_str = _tensor_str(values, indent + len(values_prefix))
  428. if values.numel() == 0:
  429. values_str += ", size=" + str(tuple(values.shape))
  430. tensor_str = (
  431. compressed_indices_prefix
  432. + compressed_indices_str
  433. + "),\n"
  434. + " " * indent
  435. + plain_indices_prefix
  436. + plain_indices_str
  437. + "),\n"
  438. + " " * indent
  439. + values_prefix
  440. + values_str
  441. + ")"
  442. )
  443. elif self.is_quantized:
  444. suffixes.append("size=" + str(tuple(self.shape)))
  445. if not has_default_dtype:
  446. suffixes.append("dtype=" + str(self.dtype))
  447. suffixes.append("quantization_scheme=" + str(self.qscheme()))
  448. if (
  449. self.qscheme() == torch.per_tensor_affine
  450. or self.qscheme() == torch.per_tensor_symmetric
  451. ):
  452. suffixes.append("scale=" + str(self.q_scale()))
  453. suffixes.append("zero_point=" + str(self.q_zero_point()))
  454. elif (
  455. self.qscheme() == torch.per_channel_affine
  456. or self.qscheme() == torch.per_channel_symmetric
  457. or self.qscheme() == torch.per_channel_affine_float_qparams
  458. ):
  459. suffixes.append("scale=" + str(self.q_per_channel_scales()))
  460. suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
  461. suffixes.append("axis=" + str(self.q_per_channel_axis()))
  462. if not custom_contents_provided:
  463. tensor_str = _tensor_str(self.dequantize(), indent)
  464. elif self.is_nested:
  465. if not custom_contents_provided:
  466. def indented_str(s, indent):
  467. return "\n".join(f" {line}" for line in s.split("\n"))
  468. strs = ",\n".join(
  469. indented_str(str(t), indent + 1)
  470. for t in torch.ops.aten.unbind.int(self, 0)
  471. )
  472. tensor_str = f"[\n{strs}\n]"
  473. elif torch._is_functional_tensor(self):
  474. prefix = "_to_functional_tensor("
  475. tensor_str = repr(torch._from_functional_tensor(self))
  476. else:
  477. if self.is_meta:
  478. suffixes.append("size=" + str(tuple(self.shape)))
  479. if self.dtype != torch.get_default_dtype():
  480. suffixes.append("dtype=" + str(self.dtype))
  481. # TODO: This implies that ellipses is valid syntax for allocating
  482. # a meta tensor, which it could be, but it isn't right now
  483. if not custom_contents_provided:
  484. tensor_str = "..."
  485. else:
  486. if self.numel() == 0 and not self.is_sparse:
  487. # Explicitly print the shape if it is not (0,), to match NumPy behavior
  488. if self.dim() != 1:
  489. suffixes.append("size=" + str(tuple(self.shape)))
  490. # In an empty tensor, there are no elements to infer if the dtype
  491. # should be int64, so it must be shown explicitly.
  492. if self.dtype != torch.get_default_dtype():
  493. suffixes.append("dtype=" + str(self.dtype))
  494. if not custom_contents_provided:
  495. tensor_str = "[]"
  496. else:
  497. if not has_default_dtype:
  498. suffixes.append("dtype=" + str(self.dtype))
  499. if not custom_contents_provided:
  500. if self.layout != torch.strided:
  501. tensor_str = _tensor_str(self.to_dense(), indent)
  502. else:
  503. tensor_str = _tensor_str(self, indent)
  504. if self.layout != torch.strided:
  505. suffixes.append("layout=" + str(self.layout))
  506. # Use inp here to get the original grad_fn and not the one generated by the forward grad
  507. # unpacking.
  508. if inp.grad_fn is not None:
  509. name = type(inp.grad_fn).__name__
  510. if name == "CppFunction":
  511. name = inp.grad_fn.name().rsplit("::", 1)[-1]
  512. suffixes.append("grad_fn=<{}>".format(name))
  513. elif inp.requires_grad:
  514. suffixes.append("requires_grad=True")
  515. if self.has_names():
  516. suffixes.append("names={}".format(self.names))
  517. if tangent is not None:
  518. suffixes.append("tangent={}".format(tangent))
  519. string_repr = _add_suffixes(
  520. prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse
  521. )
  522. # Check if this instance is flagged as a parameter and change the repr accordingly.
  523. # Unfortunately, this function has to be aware of this detail.
  524. # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
  525. # this should be done for those as well to produce a valid repr.
  526. if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
  527. string_repr = f"Parameter({string_repr})"
  528. return string_repr
  529. def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
  530. level = torch._C._functorch.maybe_get_level(tensor)
  531. assert level != -1
  532. if torch._C._functorch.is_functionaltensor(tensor):
  533. # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
  534. # that it's up to date first
  535. torch._sync(tensor)
  536. value = torch._C._functorch.get_unwrapped(tensor)
  537. value_repr = repr(value)
  538. indented_value_repr = textwrap.indent(value_repr, " " * 4)
  539. if torch._C._functorch.is_batchedtensor(tensor):
  540. bdim = torch._C._functorch.maybe_get_bdim(tensor)
  541. assert bdim != -1
  542. return (
  543. f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
  544. f"{indented_value_repr}\n"
  545. f")"
  546. )
  547. if torch._C._functorch.is_gradtrackingtensor(tensor):
  548. return (
  549. f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
  550. )
  551. if torch._C._functorch.is_functionaltensor(tensor):
  552. return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
  553. raise ValueError("We don't know how to print this, please file us an issue")
  554. def _str(self, *, tensor_contents=None):
  555. with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
  556. guard = torch._C._DisableFuncTorch()
  557. return _str_intern(self, tensor_contents=tensor_contents)