In [4]:

import numpy as np
import dgl
from scipy import sparse
from torch_geometric import utils
import json
from collections import defaultdict

ticker_symbols = ['600834.SH', '002577.SZ', '601360', "112", '000089', "000983","000538","600780","002352","300133", "300295", ]

### Wiki

In [1]:
import os
import typing as tp
from typing import Dict

import dgl
import jsonlines
import pandas as pd
import torch
from dgl import DGLGraph

def read_ticker_list(fpath: str) -> tp.List[str]:
    record_df = pd.read_csv(
        fpath,
        header=None,
        delimiter="\t",
        index_col=0,
        keep_default_na=False,
        na_values=["_"],
    )
    tickers = record_df.index.to_list()
    return tickers
data_dir = data_dir="/student/wangsaizhuo/Codes/q4l/examples/benchmark/data/wikidata/stock_graph/cn"
ticker_index_map = {ticker: idx for idx, ticker in enumerate(read_ticker_list("/student/wangsaizhuo/Codes/q4l/examples/benchmark/data/market_data/cn/instruments/all.txt"))}

In [3]:
ticker_index_map['000088_XSHE']

57

In [22]:
     
stock_qid_map = {}
with jsonlines.open(
	os.path.join(data_dir, "stock_records.jsonl"), "r"
) as reader:
	for stock_record in reader:
		stock_qid_map[stock_record["qid"]] = stock_record["symbol"]
  

ret_data_dict = {}
with jsonlines.open(
	os.path.join(data_dir, "intra_stock_relations.jsonl"), "r"
) as reader:
	for rel in reader:
		src_symbol = stock_qid_map[rel["qid"]]
		dst_symbol = stock_qid_map[rel["value"]]
		if (
			src_symbol not in ticker_index_map
			or dst_symbol not in ticker_index_map
		):
			continue
		xx = rel["qid"]
		xx2 = rel["value"]
		dst_node_idx = ticker_index_map[stock_qid_map[rel["value"]]]
		src_node_idx = ticker_index_map[stock_qid_map[rel["qid"]]]
		rel_type = rel["property_id"]
		eid = f"wiki_{rel_type}"
		if eid not in ret_data_dict:
			ret_data_dict[eid] = [[], []]
		# Undirected, add twice
		ret_data_dict[eid][0].append(src_node_idx)
		ret_data_dict[eid][1].append(dst_node_idx)
		ret_data_dict[eid][0].append(dst_node_idx)
		ret_data_dict[eid][1].append(src_node_idx)
  
maxn = 0
with jsonlines.open(
	os.path.join(data_dir, "intermediate_nodes.jsonl"), "r"
) as reader:
	for entry in reader:
		entry["qid"]
		relations = entry["neighbors"]
		if len(relations) > maxn:
			maxn = len(relations)
		for i in range(len(relations)):
			for j in range(i + 1, len(relations)):
				qid_i, qid_j = relations[i][0], relations[j][0]
				p1, p2 = relations[i][1], relations[j][1]

				symbol_i = stock_qid_map[qid_i]
				symbol_j = stock_qid_map[qid_j]
				if (
					symbol_i not in ticker_index_map
					or symbol_j not in ticker_index_map
				):
					continue

				nid_i, nid_j = (
					ticker_index_map[symbol_i],
					ticker_index_map[symbol_j],
				)
				# No self-loop
				if qid_i == qid_j:
					continue

				# Forward
				eid = f"wiki_{p1}_{p2}"
				if eid not in ret_data_dict:
					ret_data_dict[eid] = [[], []]
				ret_data_dict[eid][0].append(nid_i)
				ret_data_dict[eid][1].append(nid_j)

				# Backward
				eid = f"wiki_{p2}_{p1}"
				if eid not in ret_data_dict:
					ret_data_dict[eid] = [[], []]
				ret_data_dict[eid][0].append(nid_j)
				ret_data_dict[eid][1].append(nid_i)

In [None]:
import pandas as pd

df = pd.read_csv("/student/wangsaizhuo/q4l_fengrui/wszlib/examples/benchmark/data/industry/cn.csv")
edge_dict = {}

for industry, group in df.groupby("INDUSTRY_GICS"):
    stocks = group["STOCK"].tolist()
    key = f"industry_{industry}"
    edge_dict[key] = [[], []]
    for i in range(len(stocks)):
        for j in range(len(stocks)):
            if i != j and stocks[i] in ticker_index_map and stocks[j] in ticker_index_map:  # 避免自连接，确保存在于ticker_index_map中
                edge_dict[key][0].append(ticker_index_map[stocks[i]])
                edge_dict[key][1].append(ticker_index_map[stocks[j]])

print(edge_dict)

In [6]:
edge_dict.keys

<function dict.keys>

In [5]:
data_path = "/student/wangsaizhuo/Codes/q4l/examples/benchmark/data/wikidata/stock_graph/cn"
file_path_map = f'{data_path}/stock_records.jsonl'
file_path_1hop = f'{data_path}/intermediate_nodes.jsonl'
file_path_2hop = f'{data_path}/two_hop_relations.jsonl'


results = []    #match qid, stock_code
qid_list = []

with open(file_path_map, 'r') as file:
    for line in file:
        record = json.loads(line)
        symbol = record.get('symbol', 'unknown') 
        qid = record.get('qid', 'unknown')  
        results.append({'symbol': symbol, 'qid': qid})

## Transform stock code to qid
for symbol_with_extension in ticker_symbols:
    symbol = symbol_with_extension.split('.')[0] 
    matched_qid = np.nan  
    for result in results:
        if result['symbol'] == symbol:
            matched_qid = result['qid']
            break
    qid_list.append(matched_qid)  

In [6]:
sel_paths_1hop = set()
sel_paths_2hop = set()
connections = {}

def transform_1hop(data):
    transformed_data = defaultdict(lambda: defaultdict(list))
    main_qid = data["qid"]
    for neighbor_qid, relation, _ in data["stock_neighbors"]:
        transformed_data[main_qid][neighbor_qid].append([relation, relation])
    result = {}
    for key, value in transformed_data.items():
        result[key] = {k: v for k, v in value.items()}
    return result

with open(file_path_1hop, 'r') as file:
    for line in file:
        data = json.loads(line)
        if "stock_neighbors" in data:
            for _, relation, _ in data["stock_neighbors"]:
                relation_type = f'{relation}'
                sel_paths_1hop.add(relation_type)
        transformed = transform_1hop(data)
        connections.update(transformed)


relations_by_qid = {}
qid_dict = defaultdict(list)
result_dict = defaultdict(set)

with open(file_path_2hop, 'r') as file:
    for line in file:
        record = json.loads(line.strip())
        qid = record['qid']
        if qid not in relations_by_qid:
            relations_by_qid[qid] = []
        relations_by_qid[qid].append(record['property_id'])
        qid_dict[qid].append((record['property_id'], record['value']))

for qid, properties in relations_by_qid.items():
    if len(properties) > 1:
        for i, prop1 in enumerate(properties):
            for j, prop2 in enumerate(properties):
                if i != j:
                    two_hop_relation = f'{prop1}_{prop2}'
                    sel_paths_2hop.add(two_hop_relation)

for qid, values in qid_dict.items():
    for i in range(len(values) - 1):
        key = values[i][0] + "_" + values[i + 1][0]
        value_pair = [values[i][1], values[i + 1][1]]
        result_dict[key].update(value_pair)

sel_paths =  sel_paths_2hop.union(sel_paths_1hop)    ## All of the paths(1-hop&2-hop)
dict_2hop = {key: list(value) for key, value in result_dict.items()} 

KeyError: 'stock_neighbors'

In [None]:
dict_1hop = {}

for key1, value1 in connections.items():
    for key2, value2 in value1.items():
        for item in value2:
            new_key = item[0] if len(item) == 2 else item[0] 
            if new_key not in dict_1hop:
                dict_1hop[new_key] = set()
            dict_1hop[new_key].add(key1)
            dict_1hop[new_key].add(key2)
for key in dict_1hop:
    dict_1hop[key] = list(dict_1hop[key])

In [None]:
inci_matrix = np.zeros([len(qid_list), len(sel_paths)], dtype=int)
full_dict = {**dict_1hop, **dict_2hop}
code_to_row_index = {code: index for index, code in enumerate(qid_list)}
path_to_col_index = {path: index for index, path in enumerate(sel_paths)}

for path, codes in full_dict.items():
    col_index = path_to_col_index[path]
    for code in codes:
        if code in code_to_row_index: 
            row_index = code_to_row_index[code]
            inci_matrix[row_index, col_index] = 1

cols_to_delete = np.all(inci_matrix == 0, axis=0)
inci_matrix = inci_matrix[:, ~cols_to_delete]

### Industry

In [17]:
import pandas as pd

df = pd.read_csv("/student/wangsaizhuo/q4l_fengrui/wszlib/examples/benchmark/data/industry/ind_cn.csv")
df = df[df.iloc[:, 0].isin(ticker_index_map)]
industry_gics = df.iloc[:, 1].fillna(100).values
stock_codes = df.iloc[:, 0].values

ticker_index = {}
for index, ticker in enumerate(ticker_index_map):
	ticker_index[ticker] = index

industry_tickers = {}
for industry, stock_code in zip(industry_gics, stock_codes):
	if industry not in industry_tickers:
		industry_tickers[industry] = []
	industry_tickers[industry].append(stock_code)

industry_index = {}
for index, industry in enumerate([key for key in industry_tickers.keys()]):
	industry_index[industry] = index

inci_matrix = np.zeros([len(ticker_index_map), len(industry_tickers)], dtype=int)
for industry in industry_tickers.keys():
	if len(industry_tickers[industry]) > 0:
		cur_ind_tickers = industry_tickers[industry]
		ind_ind = industry_index[industry]

		for i in range(len(cur_ind_tickers)):
			tic_ind = ticker_index[cur_ind_tickers[i]]
			inci_matrix[tic_ind][ind_ind] = 1
			
# inci_matrix = np.delete(inci_matrix, industry_index["NA"], axis=1)  
inci_matrix

NameError: name 'ticker_index_map' is not defined

In [16]:
stock_codes

array([], dtype=object)

In [None]:
import dgl
import numpy as np
import torch

# 示例异构图
g = dgl.heterograph({
    ('user', 'plays', 'game'): (torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 3])),
    ('user', 'follows', 'user'): (torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])),
})

# 计算总节点数
total_nodes = sum([g.num_nodes(ntype) for ntype in g.ntypes])

# 获取边类型数量
num_edge_types = len(g.canonical_etypes)

# 创建零矩阵
matrix = np.zeros((total_nodes, num_edge_types))

# 用于跟踪每种节点类型在全局矩阵中的偏移量
offset = 0
node_type_to_offset = {}

for ntype in g.ntypes:
    node_type_to_offset[ntype] = offset
    offset += g.num_nodes(ntype)

# 遍历每种边类型，为相关节点设置值
for idx, etype in enumerate(g.canonical_etypes):
    src, dst = g.edges(etype=etype)
    src_offset = node_type_to_offset[etype[0]] # 源节点类型的偏移量
    dst_offset = node_type_to_offset[etype[2]] # 目标节点类型的偏移量

    # 为源和目标节点设置值
    matrix[src + src_offset, idx] = 1
    matrix[dst + dst_offset, idx] = 1

print(matrix)