회원가입
로그인
Toggle navigation
문제
문제
전체 문제
문제 출처
단계별로 풀어보기
알고리즘 분류
추가된 문제
문제 순위
문제
푼 사람이 한 명인 문제
아무도 못 푼 문제
최근 제출된 문제
최근 풀린 문제
랜덤
출처
ICPC
Olympiad
한국정보올림피아드
한국정보올림피아드시․도지역본선
전국 대학생 프로그래밍 대회 동아리 연합
대학교 대회
카카오 코드 페스티벌
Coder's High
ICPC
Regionals
World Finals
Korea Regional
Africa and the Middle East Regionals
Europe Regionals
Latin America Regionals
North America Regionals
South Pacific Regionals
문제집
대회
2
채점 현황
랭킹
게시판
그룹
더 보기
재채점 기록
블로그
강의
실험실
도움말
BOJ Stack
BOJ Book
1067번 - 이동
sait2000
C++17
#include <cassert> #include <cstdio> #include <utility> #include <vector> #include <algorithm> using std::scanf; using std::printf; using std::swap; using std::vector; using std::fill; using std::copy; constexpr int powmod(long long a, long long b, int m) { a %= m; if (a < 0) a += m; long long r = 1; while (b > 0) { if (b % 2) { r = r * a % m; } a = a * a % m; b /= 2; } return r; } constexpr int invmod(long long a, int m) { return powmod(a, m - 2, m); } template<int prime, int prim> struct FFT_template { static constexpr int P = prime, PR = prim, PRI = invmod(PR, P); static void fft(int *a, int sz, bool inv) { assert(sz > 0 && (sz & (sz - 1)) == 0 && (P - 1) % sz == 0); int n = sz; int w = powmod(inv ? PRI : PR, (P - 1) / sz, P); for (int i = sz >> 1; i > 0; i >>= 1) { for (int j = 0; j < n; j += i * 2) { int x = 1; for (int k = 0; k < i; ++k) { int &lft = a[j | k], &rgt = a[i | j | k]; int nlft = lft + rgt; if (nlft >= P) nlft -= P; int nrgt = 1LL * (lft + (P - rgt)) * x % P; lft = nlft, rgt = nrgt; x = 1LL * x * w % P; } } w = 1LL * w * w % P; } // bit reverse for (int i = 1, j = 0; i < n; ++i) { int b = sz >> 1; while (b & j) { j ^= b; b >>= 1; } j ^= b; if (i < j) swap(a[i], a[j]); } if (inv) { int invn = invmod(sz, P); for (int i = 0; i < sz; ++i) { a[i] = 1LL * a[i] * invn % P; } } } }; template <typename FFT> struct OnlineFFT { static constexpr int MOD = FFT::P; static constexpr long long MODSQ = 1LL * MOD * MOD; static constexpr int BLOCK = 1<<10, DBLOCK = BLOCK * 2; vector<int> a, b, c, fa, fb, cmid, fftbuf; vector<long long> fftbufll; OnlineFFT(): a(), b(), c(), fa(), fb(), cmid(DBLOCK), fftbuf(DBLOCK), fftbufll(DBLOCK) {} void push(long long va, long long vb) { int n = a.size(); int t = n / BLOCK; a.push_back((va % MOD + MOD) % MOD); b.push_back((vb % MOD + MOD) % MOD); if (t == 0) { long long r = 0; for (int i = 0; i <= n; ++i) { r += 1LL * a[i] * b[n - i]; if (r >= MODSQ) { r -= MODSQ; } } r %= MOD; c.push_back(r); } else { int nm = n % BLOCK; long long r = cmid[nm]; for (int i = 0; i <= nm; ++i) { r += 1LL * a[i] * b[nm - i + BLOCK * t]; if (r >= MODSQ) { r -= MODSQ; } r += 1LL * a[i + BLOCK * t] * b[nm - i]; if (r >= MODSQ) { r -= MODSQ; } } r %= MOD; c.push_back(r); } n = a.size(); t = n / BLOCK; if (n % BLOCK != 0) { return; } fa.resize(t * DBLOCK); fb.resize(t * DBLOCK); copy(a.begin() + (t - 1) * BLOCK, a.end(), fa.begin() + (t - 1) * DBLOCK); copy(b.begin() + (t - 1) * BLOCK, b.end(), fb.begin() + (t - 1) * DBLOCK); fill(fa.begin() + (t - 1) * DBLOCK + BLOCK, fa.end(), 0); fill(fb.begin() + (t - 1) * DBLOCK + BLOCK, fb.end(), 0); FFT::fft(fa.data() + (t - 1) * DBLOCK, DBLOCK, false); FFT::fft(fb.data() + (t - 1) * DBLOCK, DBLOCK, false); auto add = [&](int ia, int ib) { for (int i = 0; i < DBLOCK; ++i) { fftbufll[i] = (fftbufll[i] + 1LL * fa[i + ia * DBLOCK] * fb[i + ib * DBLOCK]); if (fftbufll[i] >= MODSQ) { fftbufll[i] -= MODSQ; } } }; auto slide = [&]() { copy(cmid.begin() + BLOCK, cmid.end(), cmid.begin()); fill(cmid.begin() + BLOCK, cmid.end(), 0); }; auto calc_conv = [&]() { for (int i = 0; i < DBLOCK; ++i) { fftbuf[i] = fftbufll[i] % MOD; } FFT::fft(fftbuf.data(), DBLOCK, true); for (int i = 0; i < DBLOCK; ++i) { cmid[i] = (cmid[i] + fftbuf[i]) % MOD; } fill(fftbufll.begin(), fftbufll.end(), 0); }; if (t == 1) { add(0, 0); calc_conv(); slide(); } else { add(0, t - 1); add(t - 1, 0); calc_conv(); slide(); for (int i = 1; i < t; ++i) { add(i, t - i); } calc_conv(); } } int get(int i) { return c[i]; } }; int readint() { signed n; scanf("%d", &n); return n; } signed main() { using FFT = FFT_template<998244353, 3>; int n = readint(); vector<int> a(n), b(n); for (int i = 0; i < n; ++i) { a[i] = readint(); } for (int i = 0; i < n; ++i) { b[n - 1 - i] = readint(); } OnlineFFT<FFT> onf; for (int i = 0; i < n; ++i) { onf.push(a[i], b[i]); } for (int i = 0; i < n; ++i) { onf.push(a[i], 0); } int res = 0; for (int i = 0; i < n; ++i) { int cand = onf.get(i + n - 1); if (cand > res) { res = cand; } } printf("%d\n", res); return 0; }
결과
메모리
시간
코드 길이
맞았습니다!!
5860
256
5654