nokoのブログ

こちらは暫定のメモ置き場ですので悪しからず

AtcorderABC-Dより先へ行くために代表的なアルゴリズムを勉強した-その1-(DP、累積和、UnionFind木、bit関係、二分探索)

はじめに

1. DP
2. 累積和
3. UnionFind木
4. bit関係
5. 二分探索

アルゴリズム

1. DP

参考

いつ使うか

  • 同じ計算を繰り返していることに気づいたとき

概要

  • 分割統治法とメモ化
    • 問題を細分化して、細かい部分を順に解いていくことで全体を解明する
    • 計算した結果を記録しておいて、同じ数に関しては1度しか計算しないようにする

例題(簡単)

  • 問題
フィボナッチ数列
1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233,,,
  • 解法
n = int(input())
dp = [0] * n
dp[0:2] = [0, 1]
for i in range(2, n):
    dp[i] = dp[i - 1] + dp[i - 2]
print(dp[n-1])

例題(ちょっと難しい)

f:id:noko_htn:20190920213212p:plain
Frog2

  • 解法
n,k = map(int,input().split())
h_list = list(map(int,input().split()))
cost_min_list = [10001] * n
cost_min_list[0] = 0
for i in range(1, n):
    for j in range(max(0,i-k),i):
        cost_min_list[i] = min(cost_min_list[i], cost_min_list[j]+abs(h_list[i]-h_list[j]))
print(cost_min_list[n-1])

2. 累積和

参考

いつ使うか

  • 特定の区間の和を求めるとき

概要

  • 適切な前処理をしておくことで、配列上の区間の総和を求めるクエリを爆速で処理できるようになる手法
配列aに対して累積和sを求めると、配列aの区間[left, right]の総和が
s[right] - s[left]
でO(1)で求められる
  • (参考)実装例(コンテストではライブラリを使う)
n, k = list(map(int, input().split()))
a_list = list(map(int, input().split()))
sum_list = [0]
for i in range(n):
    sum_list.append(sum_list[i]+a_list[i])

例題(簡単)

N個の整数 a0,a1,…,aN−1が与えられる。K個の連続する整数の和の最大値を求めよ。
  • 解法
from itertools import accumulate

n, k = list(map(int, input().split()))
a_list = list(map(int, input().split()))

# 累積和
sum_list = list(accumulate([0] + a_list))

ans_list = []
for i in range(n-3):
    ans_list.append(sum_list[i+3]-sum_list[i])

print(max(ans_list))

例題(ちょっと難しい)

長さNの整数列a0,a1,…,aN−1が与えられる。この整数列の連続する区間であって、その区間内の値の総和が0になるものが何個あるか答えよ。
  • 解法
from itertools import accumulate
import collections

n = int(input())
a_list = list(map(int, input().split()))

sum_list = list(accumulate([0] + a_list))

# sum_list[i]=sum_list[j]となるような (i,j)(i<j) の組の個数を求めればよい
#   各要素の個数(要素ごとの出現回数)をカウントする
#     ex. Counter({0: 2, 1: 1, 2: 2, 4: 2})
counter = collections.Counter(sum_list)
#   出現回数順に要素を取得
#     ex. [(0, 2), (4, 2), (2, 2), (1, 1)]
counter_most = counter.most_common()

ans = 0
for v in counter_most:
    if v[1] > 1:
        ans += (v[1]*(v[1]-1)//2)

print(ans)

3. UnionFind木

参考

いつ使うか

  • ある二つの要素が同じグループに属するかを判定したいとき
  • ある二つの要素が同じグループに属するときに、その二つのグループを併合したいとき

概要

  • グループ分けを木構造で管理するデータ構造のこと。同じグループに属する=同じ木に属するという木構造でグループ分けをし、上記2点を高速に実行することができる。(UnionとFind)

例題

f:id:noko_htn:20190921142501p:plain
Union Find

  • 解法
n,q = map(int,input().split())
q_list=[list(map(int,input().split())) for i in range(q)]

class UnionFind:
    def __init__(self, n):
        # 親要素のノード番号を格納。par[x] == xの時そのノードは根
        self.par = [i for i in range(n+1)]
        # 木の高さを格納する(初期状態では0)
        self.rank = [0] * (n+1)

    # 検索
    # 根ならその番号を返す
    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            # 走査していく過程で親を書き換える
            self.par[x] = self.find(self.par[x])
            return self.par[x]

    # 併合
    def union(self, x, y):
        # 根を探す
        x = self.find(x)
        y = self.find(y)
        # 木の高さを比較し、低いほうから高いほうに辺を張る
        if self.rank[x] < self.rank[y]:
            self.par[x] = y
        else:
            self.par[y] = x
            # 木の高さが同じなら片方を1増やす
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1

    # 同じ集合に属するか判定
    def same_check(self, x, y):
        return self.find(x) == self.find(y)

uf = UnionFind(n)

for i in range(q):
    if q_list[i][0] == 0:
        uf.union(q_list[i][1], q_list[i][2])
    elif q_list[i][0] == 1:
        if uf.same_check(q_list[i][1], q_list[i][2]):
            print('Yes')
        else:
            print('No')

4. bit関係

参考

いつ使うか

  • n 個の選択肢それぞれに Yes or No の二択があるが、その部分集合(選択できるパターン)の全てを網羅的にチェックしたいとき

概要

bit全探索は名前の通り,bit演算を使って全探索をする方法のことで,n要素の集合(a1,a2,…,an)の部分集合を全て列挙することができる. 要するに,
(a1)
(a1,a2)
(a3,a5,a6,a8)
(a2,a4,a7)
...
みたいなパターンを全て数え上げることができる.

例題(簡単)

  • 問題
A = [0, 1, 2, 3]の部分集合を全て出力する
  • 解法
a_list = [0, 1, 2, 3]

for i in range(2 ** len(a_list)):     # 1を配列のサイズNで左シフトした値と同じ
    output = []
    for j in range(len(a_list)):
        if ((i >> j) & 1) == 1:    # iの右からj+1番目のbitが1かどうかを確認
           output.append(a_list[j])
    print(output)

例題(ちょっと難しい)

f:id:noko_htn:20190921222851p:plain
Train Ticket

  • 解法
n_str = input()

op_count = len(n_str) - 1  # すき間の個数
for i in range(2 ** op_count):
    op_list = ["-"] * op_count    # あらかじめ ["-", "-", "-"] というリストを作っておく
    for j in range(op_count):    # iの右からj+1番目のbitが1かどうかを確認
        if ((i >> j) & 1):
            op_list[len(op_list) - 1 -j] = "+"

    op_list.append("")
    formula = ""
    for i in range(len(n_str)):
        formula += (n_str[i] + op_list[i])
    if eval(formula) == 7:
        print(formula + "=7")
        break

5. 二分探索

参考

いつ使うか

  • ソート可能なリスト・配列の要素を探索したいとき

概要

  • 中央の値を見て、検索したい値との大小関係を用いて、検索したい値が中央の値の右にあるか、左にあるかを判断して、片側には存在しないことを確かめながら検索していく
  • (参考)実装例(コンテストではライブラリを使う)
def binary_search(list,taget): 
    result = -1
    left = 0
    right = len(list) - 1

    while left <= right:
        center = (left + right)/2
        if list[center] == target:
            result = center
            break
        elif list[center] < target:
            left = center + 1
        elif list[center] > target:
            right = center - 1

    if result == -1:
        return str(target) + "は見つかりませんでした"
    else:
        return "要素の値が" + str(target) + "のインデックス=>" + str(result)
  • ポイント
    • left,rightというリストの両端のインデックスを用意する
    • targetの判定によって、centerを用いてどちらかを狭めていく
    • targetを見つけたら処理をwhile文を抜ける
    • 最大まで狭めると処理を終了する
    • 結果を判定するための変数(result)は-1にしておく。result = 0だとtargetが1(indexが0)の場合に失敗になってしまう。

例題(簡単)

  • 問題
リスト[1,3,4,5,6,7,9,10,13,16,18,22,24,25,26] があるとき、
要素の値が22のインデックスは?
  • 解法
import bisect

example_list = [1,3,4,5,6,7,9,10,13,16,18,22,24,25,26] 
target = 22

example_list = sorted(example_list)

target_index = bisect.bisect_left(example_list,target)

print(target_index)

例題(ちょっと難しい)

f:id:noko_htn:20190922015235p:plain
Snuke Festival

  • 解法
import bisect

n = int(input())
a_list = [int(x) for x in input().split()]
b_list = [int(x) for x in input().split()]
c_list = [int(x) for x in input().split()]

a_list = sorted(a_list)
b_list = sorted(b_list)
c_list = sorted(c_list)

ans = 0

for b_i in b_list:
    a_count_i = bisect.bisect_left(a_list,b_i)
    c_count_i = len(c_list) - bisect.bisect_right(c_list,b_i)
    ans += a_count_i * c_count_i

print(ans)

おわりに