gexf.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065
  1. """Read and write graphs in GEXF format.
  2. .. warning::
  3. This parser uses the standard xml library present in Python, which is
  4. insecure - see :external+python:mod:`xml` for additional information.
  5. Only parse GEFX files you trust.
  6. GEXF (Graph Exchange XML Format) is a language for describing complex
  7. network structures, their associated data and dynamics.
  8. This implementation does not support mixed graphs (directed and
  9. undirected edges together).
  10. Format
  11. ------
  12. GEXF is an XML format. See http://gexf.net/schema.html for the
  13. specification and http://gexf.net/basic.html for examples.
  14. """
  15. import itertools
  16. import time
  17. from xml.etree.ElementTree import (
  18. Element,
  19. ElementTree,
  20. SubElement,
  21. register_namespace,
  22. tostring,
  23. )
  24. import networkx as nx
  25. from networkx.utils import open_file
  26. __all__ = ["write_gexf", "read_gexf", "relabel_gexf_graph", "generate_gexf"]
  27. @open_file(1, mode="wb")
  28. def write_gexf(G, path, encoding="utf-8", prettyprint=True, version="1.2draft"):
  29. """Write G in GEXF format to path.
  30. "GEXF (Graph Exchange XML Format) is a language for describing
  31. complex networks structures, their associated data and dynamics" [1]_.
  32. Node attributes are checked according to the version of the GEXF
  33. schemas used for parameters which are not user defined,
  34. e.g. visualization 'viz' [2]_. See example for usage.
  35. Parameters
  36. ----------
  37. G : graph
  38. A NetworkX graph
  39. path : file or string
  40. File or file name to write.
  41. File names ending in .gz or .bz2 will be compressed.
  42. encoding : string (optional, default: 'utf-8')
  43. Encoding for text data.
  44. prettyprint : bool (optional, default: True)
  45. If True use line breaks and indenting in output XML.
  46. version: string (optional, default: '1.2draft')
  47. The version of GEXF to be used for nodes attributes checking
  48. Examples
  49. --------
  50. >>> G = nx.path_graph(4)
  51. >>> nx.write_gexf(G, "test.gexf")
  52. # visualization data
  53. >>> G.nodes[0]["viz"] = {"size": 54}
  54. >>> G.nodes[0]["viz"]["position"] = {"x": 0, "y": 1}
  55. >>> G.nodes[0]["viz"]["color"] = {"r": 0, "g": 0, "b": 256}
  56. Notes
  57. -----
  58. This implementation does not support mixed graphs (directed and undirected
  59. edges together).
  60. The node id attribute is set to be the string of the node label.
  61. If you want to specify an id use set it as node data, e.g.
  62. node['a']['id']=1 to set the id of node 'a' to 1.
  63. References
  64. ----------
  65. .. [1] GEXF File Format, http://gexf.net/
  66. .. [2] GEXF schema, http://gexf.net/schema.html
  67. """
  68. writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
  69. writer.add_graph(G)
  70. writer.write(path)
  71. def generate_gexf(G, encoding="utf-8", prettyprint=True, version="1.2draft"):
  72. """Generate lines of GEXF format representation of G.
  73. "GEXF (Graph Exchange XML Format) is a language for describing
  74. complex networks structures, their associated data and dynamics" [1]_.
  75. Parameters
  76. ----------
  77. G : graph
  78. A NetworkX graph
  79. encoding : string (optional, default: 'utf-8')
  80. Encoding for text data.
  81. prettyprint : bool (optional, default: True)
  82. If True use line breaks and indenting in output XML.
  83. version : string (default: 1.2draft)
  84. Version of GEFX File Format (see http://gexf.net/schema.html)
  85. Supported values: "1.1draft", "1.2draft"
  86. Examples
  87. --------
  88. >>> G = nx.path_graph(4)
  89. >>> linefeed = chr(10) # linefeed=\n
  90. >>> s = linefeed.join(nx.generate_gexf(G))
  91. >>> for line in nx.generate_gexf(G): # doctest: +SKIP
  92. ... print(line)
  93. Notes
  94. -----
  95. This implementation does not support mixed graphs (directed and undirected
  96. edges together).
  97. The node id attribute is set to be the string of the node label.
  98. If you want to specify an id use set it as node data, e.g.
  99. node['a']['id']=1 to set the id of node 'a' to 1.
  100. References
  101. ----------
  102. .. [1] GEXF File Format, https://gephi.org/gexf/format/
  103. """
  104. writer = GEXFWriter(encoding=encoding, prettyprint=prettyprint, version=version)
  105. writer.add_graph(G)
  106. yield from str(writer).splitlines()
  107. @open_file(0, mode="rb")
  108. def read_gexf(path, node_type=None, relabel=False, version="1.2draft"):
  109. """Read graph in GEXF format from path.
  110. "GEXF (Graph Exchange XML Format) is a language for describing
  111. complex networks structures, their associated data and dynamics" [1]_.
  112. Parameters
  113. ----------
  114. path : file or string
  115. File or file name to read.
  116. File names ending in .gz or .bz2 will be decompressed.
  117. node_type: Python type (default: None)
  118. Convert node ids to this type if not None.
  119. relabel : bool (default: False)
  120. If True relabel the nodes to use the GEXF node "label" attribute
  121. instead of the node "id" attribute as the NetworkX node label.
  122. version : string (default: 1.2draft)
  123. Version of GEFX File Format (see http://gexf.net/schema.html)
  124. Supported values: "1.1draft", "1.2draft"
  125. Returns
  126. -------
  127. graph: NetworkX graph
  128. If no parallel edges are found a Graph or DiGraph is returned.
  129. Otherwise a MultiGraph or MultiDiGraph is returned.
  130. Notes
  131. -----
  132. This implementation does not support mixed graphs (directed and undirected
  133. edges together).
  134. References
  135. ----------
  136. .. [1] GEXF File Format, http://gexf.net/
  137. """
  138. reader = GEXFReader(node_type=node_type, version=version)
  139. if relabel:
  140. G = relabel_gexf_graph(reader(path))
  141. else:
  142. G = reader(path)
  143. return G
  144. class GEXF:
  145. versions = {
  146. "1.1draft": {
  147. "NS_GEXF": "http://www.gexf.net/1.1draft",
  148. "NS_VIZ": "http://www.gexf.net/1.1draft/viz",
  149. "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
  150. "SCHEMALOCATION": " ".join(
  151. [
  152. "http://www.gexf.net/1.1draft",
  153. "http://www.gexf.net/1.1draft/gexf.xsd",
  154. ]
  155. ),
  156. "VERSION": "1.1",
  157. },
  158. "1.2draft": {
  159. "NS_GEXF": "http://www.gexf.net/1.2draft",
  160. "NS_VIZ": "http://www.gexf.net/1.2draft/viz",
  161. "NS_XSI": "http://www.w3.org/2001/XMLSchema-instance",
  162. "SCHEMALOCATION": " ".join(
  163. [
  164. "http://www.gexf.net/1.2draft",
  165. "http://www.gexf.net/1.2draft/gexf.xsd",
  166. ]
  167. ),
  168. "VERSION": "1.2",
  169. },
  170. }
  171. def construct_types(self):
  172. types = [
  173. (int, "integer"),
  174. (float, "float"),
  175. (float, "double"),
  176. (bool, "boolean"),
  177. (list, "string"),
  178. (dict, "string"),
  179. (int, "long"),
  180. (str, "liststring"),
  181. (str, "anyURI"),
  182. (str, "string"),
  183. ]
  184. # These additions to types allow writing numpy types
  185. try:
  186. import numpy as np
  187. except ImportError:
  188. pass
  189. else:
  190. # prepend so that python types are created upon read (last entry wins)
  191. types = [
  192. (np.float64, "float"),
  193. (np.float32, "float"),
  194. (np.float16, "float"),
  195. (np.float_, "float"),
  196. (np.int_, "int"),
  197. (np.int8, "int"),
  198. (np.int16, "int"),
  199. (np.int32, "int"),
  200. (np.int64, "int"),
  201. (np.uint8, "int"),
  202. (np.uint16, "int"),
  203. (np.uint32, "int"),
  204. (np.uint64, "int"),
  205. (np.int_, "int"),
  206. (np.intc, "int"),
  207. (np.intp, "int"),
  208. ] + types
  209. self.xml_type = dict(types)
  210. self.python_type = dict(reversed(a) for a in types)
  211. # http://www.w3.org/TR/xmlschema-2/#boolean
  212. convert_bool = {
  213. "true": True,
  214. "false": False,
  215. "True": True,
  216. "False": False,
  217. "0": False,
  218. 0: False,
  219. "1": True,
  220. 1: True,
  221. }
  222. def set_version(self, version):
  223. d = self.versions.get(version)
  224. if d is None:
  225. raise nx.NetworkXError(f"Unknown GEXF version {version}.")
  226. self.NS_GEXF = d["NS_GEXF"]
  227. self.NS_VIZ = d["NS_VIZ"]
  228. self.NS_XSI = d["NS_XSI"]
  229. self.SCHEMALOCATION = d["SCHEMALOCATION"]
  230. self.VERSION = d["VERSION"]
  231. self.version = version
  232. class GEXFWriter(GEXF):
  233. # class for writing GEXF format files
  234. # use write_gexf() function
  235. def __init__(
  236. self, graph=None, encoding="utf-8", prettyprint=True, version="1.2draft"
  237. ):
  238. self.construct_types()
  239. self.prettyprint = prettyprint
  240. self.encoding = encoding
  241. self.set_version(version)
  242. self.xml = Element(
  243. "gexf",
  244. {
  245. "xmlns": self.NS_GEXF,
  246. "xmlns:xsi": self.NS_XSI,
  247. "xsi:schemaLocation": self.SCHEMALOCATION,
  248. "version": self.VERSION,
  249. },
  250. )
  251. # Make meta element a non-graph element
  252. # Also add lastmodifieddate as attribute, not tag
  253. meta_element = Element("meta")
  254. subelement_text = f"NetworkX {nx.__version__}"
  255. SubElement(meta_element, "creator").text = subelement_text
  256. meta_element.set("lastmodifieddate", time.strftime("%Y-%m-%d"))
  257. self.xml.append(meta_element)
  258. register_namespace("viz", self.NS_VIZ)
  259. # counters for edge and attribute identifiers
  260. self.edge_id = itertools.count()
  261. self.attr_id = itertools.count()
  262. self.all_edge_ids = set()
  263. # default attributes are stored in dictionaries
  264. self.attr = {}
  265. self.attr["node"] = {}
  266. self.attr["edge"] = {}
  267. self.attr["node"]["dynamic"] = {}
  268. self.attr["node"]["static"] = {}
  269. self.attr["edge"]["dynamic"] = {}
  270. self.attr["edge"]["static"] = {}
  271. if graph is not None:
  272. self.add_graph(graph)
  273. def __str__(self):
  274. if self.prettyprint:
  275. self.indent(self.xml)
  276. s = tostring(self.xml).decode(self.encoding)
  277. return s
  278. def add_graph(self, G):
  279. # first pass through G collecting edge ids
  280. for u, v, dd in G.edges(data=True):
  281. eid = dd.get("id")
  282. if eid is not None:
  283. self.all_edge_ids.add(str(eid))
  284. # set graph attributes
  285. if G.graph.get("mode") == "dynamic":
  286. mode = "dynamic"
  287. else:
  288. mode = "static"
  289. # Add a graph element to the XML
  290. if G.is_directed():
  291. default = "directed"
  292. else:
  293. default = "undirected"
  294. name = G.graph.get("name", "")
  295. graph_element = Element("graph", defaultedgetype=default, mode=mode, name=name)
  296. self.graph_element = graph_element
  297. self.add_nodes(G, graph_element)
  298. self.add_edges(G, graph_element)
  299. self.xml.append(graph_element)
  300. def add_nodes(self, G, graph_element):
  301. nodes_element = Element("nodes")
  302. for node, data in G.nodes(data=True):
  303. node_data = data.copy()
  304. node_id = str(node_data.pop("id", node))
  305. kw = {"id": node_id}
  306. label = str(node_data.pop("label", node))
  307. kw["label"] = label
  308. try:
  309. pid = node_data.pop("pid")
  310. kw["pid"] = str(pid)
  311. except KeyError:
  312. pass
  313. try:
  314. start = node_data.pop("start")
  315. kw["start"] = str(start)
  316. self.alter_graph_mode_timeformat(start)
  317. except KeyError:
  318. pass
  319. try:
  320. end = node_data.pop("end")
  321. kw["end"] = str(end)
  322. self.alter_graph_mode_timeformat(end)
  323. except KeyError:
  324. pass
  325. # add node element with attributes
  326. node_element = Element("node", **kw)
  327. # add node element and attr subelements
  328. default = G.graph.get("node_default", {})
  329. node_data = self.add_parents(node_element, node_data)
  330. if self.VERSION == "1.1":
  331. node_data = self.add_slices(node_element, node_data)
  332. else:
  333. node_data = self.add_spells(node_element, node_data)
  334. node_data = self.add_viz(node_element, node_data)
  335. node_data = self.add_attributes("node", node_element, node_data, default)
  336. nodes_element.append(node_element)
  337. graph_element.append(nodes_element)
  338. def add_edges(self, G, graph_element):
  339. def edge_key_data(G):
  340. # helper function to unify multigraph and graph edge iterator
  341. if G.is_multigraph():
  342. for u, v, key, data in G.edges(data=True, keys=True):
  343. edge_data = data.copy()
  344. edge_data.update(key=key)
  345. edge_id = edge_data.pop("id", None)
  346. if edge_id is None:
  347. edge_id = next(self.edge_id)
  348. while str(edge_id) in self.all_edge_ids:
  349. edge_id = next(self.edge_id)
  350. self.all_edge_ids.add(str(edge_id))
  351. yield u, v, edge_id, edge_data
  352. else:
  353. for u, v, data in G.edges(data=True):
  354. edge_data = data.copy()
  355. edge_id = edge_data.pop("id", None)
  356. if edge_id is None:
  357. edge_id = next(self.edge_id)
  358. while str(edge_id) in self.all_edge_ids:
  359. edge_id = next(self.edge_id)
  360. self.all_edge_ids.add(str(edge_id))
  361. yield u, v, edge_id, edge_data
  362. edges_element = Element("edges")
  363. for u, v, key, edge_data in edge_key_data(G):
  364. kw = {"id": str(key)}
  365. try:
  366. edge_label = edge_data.pop("label")
  367. kw["label"] = str(edge_label)
  368. except KeyError:
  369. pass
  370. try:
  371. edge_weight = edge_data.pop("weight")
  372. kw["weight"] = str(edge_weight)
  373. except KeyError:
  374. pass
  375. try:
  376. edge_type = edge_data.pop("type")
  377. kw["type"] = str(edge_type)
  378. except KeyError:
  379. pass
  380. try:
  381. start = edge_data.pop("start")
  382. kw["start"] = str(start)
  383. self.alter_graph_mode_timeformat(start)
  384. except KeyError:
  385. pass
  386. try:
  387. end = edge_data.pop("end")
  388. kw["end"] = str(end)
  389. self.alter_graph_mode_timeformat(end)
  390. except KeyError:
  391. pass
  392. source_id = str(G.nodes[u].get("id", u))
  393. target_id = str(G.nodes[v].get("id", v))
  394. edge_element = Element("edge", source=source_id, target=target_id, **kw)
  395. default = G.graph.get("edge_default", {})
  396. if self.VERSION == "1.1":
  397. edge_data = self.add_slices(edge_element, edge_data)
  398. else:
  399. edge_data = self.add_spells(edge_element, edge_data)
  400. edge_data = self.add_viz(edge_element, edge_data)
  401. edge_data = self.add_attributes("edge", edge_element, edge_data, default)
  402. edges_element.append(edge_element)
  403. graph_element.append(edges_element)
  404. def add_attributes(self, node_or_edge, xml_obj, data, default):
  405. # Add attrvalues to node or edge
  406. attvalues = Element("attvalues")
  407. if len(data) == 0:
  408. return data
  409. mode = "static"
  410. for k, v in data.items():
  411. # rename generic multigraph key to avoid any name conflict
  412. if k == "key":
  413. k = "networkx_key"
  414. val_type = type(v)
  415. if val_type not in self.xml_type:
  416. raise TypeError(f"attribute value type is not allowed: {val_type}")
  417. if isinstance(v, list):
  418. # dynamic data
  419. for val, start, end in v:
  420. val_type = type(val)
  421. if start is not None or end is not None:
  422. mode = "dynamic"
  423. self.alter_graph_mode_timeformat(start)
  424. self.alter_graph_mode_timeformat(end)
  425. break
  426. attr_id = self.get_attr_id(
  427. str(k), self.xml_type[val_type], node_or_edge, default, mode
  428. )
  429. for val, start, end in v:
  430. e = Element("attvalue")
  431. e.attrib["for"] = attr_id
  432. e.attrib["value"] = str(val)
  433. # Handle nan, inf, -inf differently
  434. if val_type == float:
  435. if e.attrib["value"] == "inf":
  436. e.attrib["value"] = "INF"
  437. elif e.attrib["value"] == "nan":
  438. e.attrib["value"] = "NaN"
  439. elif e.attrib["value"] == "-inf":
  440. e.attrib["value"] = "-INF"
  441. if start is not None:
  442. e.attrib["start"] = str(start)
  443. if end is not None:
  444. e.attrib["end"] = str(end)
  445. attvalues.append(e)
  446. else:
  447. # static data
  448. mode = "static"
  449. attr_id = self.get_attr_id(
  450. str(k), self.xml_type[val_type], node_or_edge, default, mode
  451. )
  452. e = Element("attvalue")
  453. e.attrib["for"] = attr_id
  454. if isinstance(v, bool):
  455. e.attrib["value"] = str(v).lower()
  456. else:
  457. e.attrib["value"] = str(v)
  458. # Handle float nan, inf, -inf differently
  459. if val_type == float:
  460. if e.attrib["value"] == "inf":
  461. e.attrib["value"] = "INF"
  462. elif e.attrib["value"] == "nan":
  463. e.attrib["value"] = "NaN"
  464. elif e.attrib["value"] == "-inf":
  465. e.attrib["value"] = "-INF"
  466. attvalues.append(e)
  467. xml_obj.append(attvalues)
  468. return data
  469. def get_attr_id(self, title, attr_type, edge_or_node, default, mode):
  470. # find the id of the attribute or generate a new id
  471. try:
  472. return self.attr[edge_or_node][mode][title]
  473. except KeyError:
  474. # generate new id
  475. new_id = str(next(self.attr_id))
  476. self.attr[edge_or_node][mode][title] = new_id
  477. attr_kwargs = {"id": new_id, "title": title, "type": attr_type}
  478. attribute = Element("attribute", **attr_kwargs)
  479. # add subelement for data default value if present
  480. default_title = default.get(title)
  481. if default_title is not None:
  482. default_element = Element("default")
  483. default_element.text = str(default_title)
  484. attribute.append(default_element)
  485. # new insert it into the XML
  486. attributes_element = None
  487. for a in self.graph_element.findall("attributes"):
  488. # find existing attributes element by class and mode
  489. a_class = a.get("class")
  490. a_mode = a.get("mode", "static")
  491. if a_class == edge_or_node and a_mode == mode:
  492. attributes_element = a
  493. if attributes_element is None:
  494. # create new attributes element
  495. attr_kwargs = {"mode": mode, "class": edge_or_node}
  496. attributes_element = Element("attributes", **attr_kwargs)
  497. self.graph_element.insert(0, attributes_element)
  498. attributes_element.append(attribute)
  499. return new_id
  500. def add_viz(self, element, node_data):
  501. viz = node_data.pop("viz", False)
  502. if viz:
  503. color = viz.get("color")
  504. if color is not None:
  505. if self.VERSION == "1.1":
  506. e = Element(
  507. f"{{{self.NS_VIZ}}}color",
  508. r=str(color.get("r")),
  509. g=str(color.get("g")),
  510. b=str(color.get("b")),
  511. )
  512. else:
  513. e = Element(
  514. f"{{{self.NS_VIZ}}}color",
  515. r=str(color.get("r")),
  516. g=str(color.get("g")),
  517. b=str(color.get("b")),
  518. a=str(color.get("a", 1.0)),
  519. )
  520. element.append(e)
  521. size = viz.get("size")
  522. if size is not None:
  523. e = Element(f"{{{self.NS_VIZ}}}size", value=str(size))
  524. element.append(e)
  525. thickness = viz.get("thickness")
  526. if thickness is not None:
  527. e = Element(f"{{{self.NS_VIZ}}}thickness", value=str(thickness))
  528. element.append(e)
  529. shape = viz.get("shape")
  530. if shape is not None:
  531. if shape.startswith("http"):
  532. e = Element(
  533. f"{{{self.NS_VIZ}}}shape", value="image", uri=str(shape)
  534. )
  535. else:
  536. e = Element(f"{{{self.NS_VIZ}}}shape", value=str(shape))
  537. element.append(e)
  538. position = viz.get("position")
  539. if position is not None:
  540. e = Element(
  541. f"{{{self.NS_VIZ}}}position",
  542. x=str(position.get("x")),
  543. y=str(position.get("y")),
  544. z=str(position.get("z")),
  545. )
  546. element.append(e)
  547. return node_data
  548. def add_parents(self, node_element, node_data):
  549. parents = node_data.pop("parents", False)
  550. if parents:
  551. parents_element = Element("parents")
  552. for p in parents:
  553. e = Element("parent")
  554. e.attrib["for"] = str(p)
  555. parents_element.append(e)
  556. node_element.append(parents_element)
  557. return node_data
  558. def add_slices(self, node_or_edge_element, node_or_edge_data):
  559. slices = node_or_edge_data.pop("slices", False)
  560. if slices:
  561. slices_element = Element("slices")
  562. for start, end in slices:
  563. e = Element("slice", start=str(start), end=str(end))
  564. slices_element.append(e)
  565. node_or_edge_element.append(slices_element)
  566. return node_or_edge_data
  567. def add_spells(self, node_or_edge_element, node_or_edge_data):
  568. spells = node_or_edge_data.pop("spells", False)
  569. if spells:
  570. spells_element = Element("spells")
  571. for start, end in spells:
  572. e = Element("spell")
  573. if start is not None:
  574. e.attrib["start"] = str(start)
  575. self.alter_graph_mode_timeformat(start)
  576. if end is not None:
  577. e.attrib["end"] = str(end)
  578. self.alter_graph_mode_timeformat(end)
  579. spells_element.append(e)
  580. node_or_edge_element.append(spells_element)
  581. return node_or_edge_data
  582. def alter_graph_mode_timeformat(self, start_or_end):
  583. # If 'start' or 'end' appears, alter Graph mode to dynamic and
  584. # set timeformat
  585. if self.graph_element.get("mode") == "static":
  586. if start_or_end is not None:
  587. if isinstance(start_or_end, str):
  588. timeformat = "date"
  589. elif isinstance(start_or_end, float):
  590. timeformat = "double"
  591. elif isinstance(start_or_end, int):
  592. timeformat = "long"
  593. else:
  594. raise nx.NetworkXError(
  595. "timeformat should be of the type int, float or str"
  596. )
  597. self.graph_element.set("timeformat", timeformat)
  598. self.graph_element.set("mode", "dynamic")
  599. def write(self, fh):
  600. # Serialize graph G in GEXF to the open fh
  601. if self.prettyprint:
  602. self.indent(self.xml)
  603. document = ElementTree(self.xml)
  604. document.write(fh, encoding=self.encoding, xml_declaration=True)
  605. def indent(self, elem, level=0):
  606. # in-place prettyprint formatter
  607. i = "\n" + " " * level
  608. if len(elem):
  609. if not elem.text or not elem.text.strip():
  610. elem.text = i + " "
  611. if not elem.tail or not elem.tail.strip():
  612. elem.tail = i
  613. for elem in elem:
  614. self.indent(elem, level + 1)
  615. if not elem.tail or not elem.tail.strip():
  616. elem.tail = i
  617. else:
  618. if level and (not elem.tail or not elem.tail.strip()):
  619. elem.tail = i
  620. class GEXFReader(GEXF):
  621. # Class to read GEXF format files
  622. # use read_gexf() function
  623. def __init__(self, node_type=None, version="1.2draft"):
  624. self.construct_types()
  625. self.node_type = node_type
  626. # assume simple graph and test for multigraph on read
  627. self.simple_graph = True
  628. self.set_version(version)
  629. def __call__(self, stream):
  630. self.xml = ElementTree(file=stream)
  631. g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
  632. if g is not None:
  633. return self.make_graph(g)
  634. # try all the versions
  635. for version in self.versions:
  636. self.set_version(version)
  637. g = self.xml.find(f"{{{self.NS_GEXF}}}graph")
  638. if g is not None:
  639. return self.make_graph(g)
  640. raise nx.NetworkXError("No <graph> element in GEXF file.")
  641. def make_graph(self, graph_xml):
  642. # start with empty DiGraph or MultiDiGraph
  643. edgedefault = graph_xml.get("defaultedgetype", None)
  644. if edgedefault == "directed":
  645. G = nx.MultiDiGraph()
  646. else:
  647. G = nx.MultiGraph()
  648. # graph attributes
  649. graph_name = graph_xml.get("name", "")
  650. if graph_name != "":
  651. G.graph["name"] = graph_name
  652. graph_start = graph_xml.get("start")
  653. if graph_start is not None:
  654. G.graph["start"] = graph_start
  655. graph_end = graph_xml.get("end")
  656. if graph_end is not None:
  657. G.graph["end"] = graph_end
  658. graph_mode = graph_xml.get("mode", "")
  659. if graph_mode == "dynamic":
  660. G.graph["mode"] = "dynamic"
  661. else:
  662. G.graph["mode"] = "static"
  663. # timeformat
  664. self.timeformat = graph_xml.get("timeformat")
  665. if self.timeformat == "date":
  666. self.timeformat = "string"
  667. # node and edge attributes
  668. attributes_elements = graph_xml.findall(f"{{{self.NS_GEXF}}}attributes")
  669. # dictionaries to hold attributes and attribute defaults
  670. node_attr = {}
  671. node_default = {}
  672. edge_attr = {}
  673. edge_default = {}
  674. for a in attributes_elements:
  675. attr_class = a.get("class")
  676. if attr_class == "node":
  677. na, nd = self.find_gexf_attributes(a)
  678. node_attr.update(na)
  679. node_default.update(nd)
  680. G.graph["node_default"] = node_default
  681. elif attr_class == "edge":
  682. ea, ed = self.find_gexf_attributes(a)
  683. edge_attr.update(ea)
  684. edge_default.update(ed)
  685. G.graph["edge_default"] = edge_default
  686. else:
  687. raise # unknown attribute class
  688. # Hack to handle Gephi0.7beta bug
  689. # add weight attribute
  690. ea = {"weight": {"type": "double", "mode": "static", "title": "weight"}}
  691. ed = {}
  692. edge_attr.update(ea)
  693. edge_default.update(ed)
  694. G.graph["edge_default"] = edge_default
  695. # add nodes
  696. nodes_element = graph_xml.find(f"{{{self.NS_GEXF}}}nodes")
  697. if nodes_element is not None:
  698. for node_xml in nodes_element.findall(f"{{{self.NS_GEXF}}}node"):
  699. self.add_node(G, node_xml, node_attr)
  700. # add edges
  701. edges_element = graph_xml.find(f"{{{self.NS_GEXF}}}edges")
  702. if edges_element is not None:
  703. for edge_xml in edges_element.findall(f"{{{self.NS_GEXF}}}edge"):
  704. self.add_edge(G, edge_xml, edge_attr)
  705. # switch to Graph or DiGraph if no parallel edges were found.
  706. if self.simple_graph:
  707. if G.is_directed():
  708. G = nx.DiGraph(G)
  709. else:
  710. G = nx.Graph(G)
  711. return G
  712. def add_node(self, G, node_xml, node_attr, node_pid=None):
  713. # add a single node with attributes to the graph
  714. # get attributes and subattributues for node
  715. data = self.decode_attr_elements(node_attr, node_xml)
  716. data = self.add_parents(data, node_xml) # add any parents
  717. if self.VERSION == "1.1":
  718. data = self.add_slices(data, node_xml) # add slices
  719. else:
  720. data = self.add_spells(data, node_xml) # add spells
  721. data = self.add_viz(data, node_xml) # add viz
  722. data = self.add_start_end(data, node_xml) # add start/end
  723. # find the node id and cast it to the appropriate type
  724. node_id = node_xml.get("id")
  725. if self.node_type is not None:
  726. node_id = self.node_type(node_id)
  727. # every node should have a label
  728. node_label = node_xml.get("label")
  729. data["label"] = node_label
  730. # parent node id
  731. node_pid = node_xml.get("pid", node_pid)
  732. if node_pid is not None:
  733. data["pid"] = node_pid
  734. # check for subnodes, recursive
  735. subnodes = node_xml.find(f"{{{self.NS_GEXF}}}nodes")
  736. if subnodes is not None:
  737. for node_xml in subnodes.findall(f"{{{self.NS_GEXF}}}node"):
  738. self.add_node(G, node_xml, node_attr, node_pid=node_id)
  739. G.add_node(node_id, **data)
  740. def add_start_end(self, data, xml):
  741. # start and end times
  742. ttype = self.timeformat
  743. node_start = xml.get("start")
  744. if node_start is not None:
  745. data["start"] = self.python_type[ttype](node_start)
  746. node_end = xml.get("end")
  747. if node_end is not None:
  748. data["end"] = self.python_type[ttype](node_end)
  749. return data
  750. def add_viz(self, data, node_xml):
  751. # add viz element for node
  752. viz = {}
  753. color = node_xml.find(f"{{{self.NS_VIZ}}}color")
  754. if color is not None:
  755. if self.VERSION == "1.1":
  756. viz["color"] = {
  757. "r": int(color.get("r")),
  758. "g": int(color.get("g")),
  759. "b": int(color.get("b")),
  760. }
  761. else:
  762. viz["color"] = {
  763. "r": int(color.get("r")),
  764. "g": int(color.get("g")),
  765. "b": int(color.get("b")),
  766. "a": float(color.get("a", 1)),
  767. }
  768. size = node_xml.find(f"{{{self.NS_VIZ}}}size")
  769. if size is not None:
  770. viz["size"] = float(size.get("value"))
  771. thickness = node_xml.find(f"{{{self.NS_VIZ}}}thickness")
  772. if thickness is not None:
  773. viz["thickness"] = float(thickness.get("value"))
  774. shape = node_xml.find(f"{{{self.NS_VIZ}}}shape")
  775. if shape is not None:
  776. viz["shape"] = shape.get("shape")
  777. if viz["shape"] == "image":
  778. viz["shape"] = shape.get("uri")
  779. position = node_xml.find(f"{{{self.NS_VIZ}}}position")
  780. if position is not None:
  781. viz["position"] = {
  782. "x": float(position.get("x", 0)),
  783. "y": float(position.get("y", 0)),
  784. "z": float(position.get("z", 0)),
  785. }
  786. if len(viz) > 0:
  787. data["viz"] = viz
  788. return data
  789. def add_parents(self, data, node_xml):
  790. parents_element = node_xml.find(f"{{{self.NS_GEXF}}}parents")
  791. if parents_element is not None:
  792. data["parents"] = []
  793. for p in parents_element.findall(f"{{{self.NS_GEXF}}}parent"):
  794. parent = p.get("for")
  795. data["parents"].append(parent)
  796. return data
  797. def add_slices(self, data, node_or_edge_xml):
  798. slices_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}slices")
  799. if slices_element is not None:
  800. data["slices"] = []
  801. for s in slices_element.findall(f"{{{self.NS_GEXF}}}slice"):
  802. start = s.get("start")
  803. end = s.get("end")
  804. data["slices"].append((start, end))
  805. return data
  806. def add_spells(self, data, node_or_edge_xml):
  807. spells_element = node_or_edge_xml.find(f"{{{self.NS_GEXF}}}spells")
  808. if spells_element is not None:
  809. data["spells"] = []
  810. ttype = self.timeformat
  811. for s in spells_element.findall(f"{{{self.NS_GEXF}}}spell"):
  812. start = self.python_type[ttype](s.get("start"))
  813. end = self.python_type[ttype](s.get("end"))
  814. data["spells"].append((start, end))
  815. return data
  816. def add_edge(self, G, edge_element, edge_attr):
  817. # add an edge to the graph
  818. # raise error if we find mixed directed and undirected edges
  819. edge_direction = edge_element.get("type")
  820. if G.is_directed() and edge_direction == "undirected":
  821. raise nx.NetworkXError("Undirected edge found in directed graph.")
  822. if (not G.is_directed()) and edge_direction == "directed":
  823. raise nx.NetworkXError("Directed edge found in undirected graph.")
  824. # Get source and target and recast type if required
  825. source = edge_element.get("source")
  826. target = edge_element.get("target")
  827. if self.node_type is not None:
  828. source = self.node_type(source)
  829. target = self.node_type(target)
  830. data = self.decode_attr_elements(edge_attr, edge_element)
  831. data = self.add_start_end(data, edge_element)
  832. if self.VERSION == "1.1":
  833. data = self.add_slices(data, edge_element) # add slices
  834. else:
  835. data = self.add_spells(data, edge_element) # add spells
  836. # GEXF stores edge ids as an attribute
  837. # NetworkX uses them as keys in multigraphs
  838. # if networkx_key is not specified as an attribute
  839. edge_id = edge_element.get("id")
  840. if edge_id is not None:
  841. data["id"] = edge_id
  842. # check if there is a 'multigraph_key' and use that as edge_id
  843. multigraph_key = data.pop("networkx_key", None)
  844. if multigraph_key is not None:
  845. edge_id = multigraph_key
  846. weight = edge_element.get("weight")
  847. if weight is not None:
  848. data["weight"] = float(weight)
  849. edge_label = edge_element.get("label")
  850. if edge_label is not None:
  851. data["label"] = edge_label
  852. if G.has_edge(source, target):
  853. # seen this edge before - this is a multigraph
  854. self.simple_graph = False
  855. G.add_edge(source, target, key=edge_id, **data)
  856. if edge_direction == "mutual":
  857. G.add_edge(target, source, key=edge_id, **data)
  858. def decode_attr_elements(self, gexf_keys, obj_xml):
  859. # Use the key information to decode the attr XML
  860. attr = {}
  861. # look for outer '<attvalues>' element
  862. attr_element = obj_xml.find(f"{{{self.NS_GEXF}}}attvalues")
  863. if attr_element is not None:
  864. # loop over <attvalue> elements
  865. for a in attr_element.findall(f"{{{self.NS_GEXF}}}attvalue"):
  866. key = a.get("for") # for is required
  867. try: # should be in our gexf_keys dictionary
  868. title = gexf_keys[key]["title"]
  869. except KeyError as err:
  870. raise nx.NetworkXError(f"No attribute defined for={key}.") from err
  871. atype = gexf_keys[key]["type"]
  872. value = a.get("value")
  873. if atype == "boolean":
  874. value = self.convert_bool[value]
  875. else:
  876. value = self.python_type[atype](value)
  877. if gexf_keys[key]["mode"] == "dynamic":
  878. # for dynamic graphs use list of three-tuples
  879. # [(value1,start1,end1), (value2,start2,end2), etc]
  880. ttype = self.timeformat
  881. start = self.python_type[ttype](a.get("start"))
  882. end = self.python_type[ttype](a.get("end"))
  883. if title in attr:
  884. attr[title].append((value, start, end))
  885. else:
  886. attr[title] = [(value, start, end)]
  887. else:
  888. # for static graphs just assign the value
  889. attr[title] = value
  890. return attr
  891. def find_gexf_attributes(self, attributes_element):
  892. # Extract all the attributes and defaults
  893. attrs = {}
  894. defaults = {}
  895. mode = attributes_element.get("mode")
  896. for k in attributes_element.findall(f"{{{self.NS_GEXF}}}attribute"):
  897. attr_id = k.get("id")
  898. title = k.get("title")
  899. atype = k.get("type")
  900. attrs[attr_id] = {"title": title, "type": atype, "mode": mode}
  901. # check for the 'default' subelement of key element and add
  902. default = k.find(f"{{{self.NS_GEXF}}}default")
  903. if default is not None:
  904. if atype == "boolean":
  905. value = self.convert_bool[default.text]
  906. else:
  907. value = self.python_type[atype](default.text)
  908. defaults[title] = value
  909. return attrs, defaults
  910. def relabel_gexf_graph(G):
  911. """Relabel graph using "label" node keyword for node label.
  912. Parameters
  913. ----------
  914. G : graph
  915. A NetworkX graph read from GEXF data
  916. Returns
  917. -------
  918. H : graph
  919. A NetworkX graph with relabeled nodes
  920. Raises
  921. ------
  922. NetworkXError
  923. If node labels are missing or not unique while relabel=True.
  924. Notes
  925. -----
  926. This function relabels the nodes in a NetworkX graph with the
  927. "label" attribute. It also handles relabeling the specific GEXF
  928. node attributes "parents", and "pid".
  929. """
  930. # build mapping of node labels, do some error checking
  931. try:
  932. mapping = [(u, G.nodes[u]["label"]) for u in G]
  933. except KeyError as err:
  934. raise nx.NetworkXError(
  935. "Failed to relabel nodes: missing node labels found. Use relabel=False."
  936. ) from err
  937. x, y = zip(*mapping)
  938. if len(set(y)) != len(G):
  939. raise nx.NetworkXError(
  940. "Failed to relabel nodes: "
  941. "duplicate node labels found. "
  942. "Use relabel=False."
  943. )
  944. mapping = dict(mapping)
  945. H = nx.relabel_nodes(G, mapping)
  946. # relabel attributes
  947. for n in G:
  948. m = mapping[n]
  949. H.nodes[m]["id"] = n
  950. H.nodes[m].pop("label")
  951. if "pid" in H.nodes[m]:
  952. H.nodes[m]["pid"] = mapping[G.nodes[n]["pid"]]
  953. if "parents" in H.nodes[m]:
  954. H.nodes[m]["parents"] = [mapping[p] for p in G.nodes[n]["parents"]]
  955. return H