[PS] 백준 11049 행렬 곱셈 순서/최대한 자세히 설명해보기 - 자바 풀이
행렬 곱셈 순서
행렬 곱셈 순서에 대한 문제이다. 행렬은 곱 연산시 순서에 상관없이 같은 결과값을 보장하나, 곱셈 순서에 따라 곱셈 연산 횟수가 달라지게 된다. 해당 연산을 최소화 했을 때 몇 번인지 구하는 문제이다.
행렬의 곱셈 순서를 정하는 것은 대표적인 DP알고리즘 사용 사례라고 한다. 이 문제를 혼자 풀기에 실패하였기에 풀이를 보고 이해한 바를 정리한다. 접근 과정을 최대한 자세히 기술하였다.
문제
크기가 N×M인 행렬 A와 M×K인 B를 곱할 때 필요한 곱셈 연산의 수는 총 N×M×K번이다. 행렬 N개를 곱하는데 필요한 곱셈 연산의 수는 행렬을 곱하는 순서에 따라 달라지게 된다.
예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.
- AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
- BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.
같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.
행렬 N개의 크기가 주어졌을 때, 모든 행렬을 곱하는데 필요한 곱셈 연산 횟수의 최솟값을 구하는 프로그램을 작성하시오. 입력으로 주어진 행렬의 순서를 바꾸면 안 된다.
입력
첫째 줄에 행렬의 개수 N(1 ≤ N ≤ 500)이 주어진다.
둘째 줄부터 N개 줄에는 행렬의 크기 r과 c가 주어진다. (1 ≤ r, c ≤ 500)
항상 순서대로 곱셈을 할 수 있는 크기만 입력으로 주어진다.
출력
첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같다.
예제 입력 1
3
5 3
3 2
2 6
예제 출력 1
90
풀이
Metric 배열
우선 기본적인 발상은 아래와 같다. 매트릭스 A B C가 주어질 때,
metric[inx][0]은 inx 매트릭스의 행 수이다.
metric[inx][1]은 inx 매트릭스의 열 수이다.
이를테면, 첫 번째인 5*3 행렬의 경우 아래와 같이 나타낼 수 있다. (인덱스는 0부터 시작하도록 한다.)
metric[0][0] = 5;
metric[0][1] = 3;
DP 배열
다음으로 dp 배열을 정의한다. dp 배열은 특정 범위의 최소값을 뜻한다. 예를 들어
dp[0][2]의 경우 0번 매트릭스부터 2번 매트릭스까지의 곱셈 연산의 최소값이다.
예제 입력을 보면 각 매트릭스는 아래와 같다.
metric = [{5, 3}, {3, 2}, {2, 6}};
각 매트릭스는 A B C라는 별칭을 부여하자.
접근
A B C의 경우(3개의 매트릭스)
접근 발상은 범위 별로 계산을 해서 적당히 끼워 맞추는 것이다. 이를테면 ABC의 경우를 보자. 이는 두 가지 케이스로 이루어진다.
Case 1: (A*B)*C //A와 B를 먼저 곱하고 C를 곱하는 것이다.
Case 2: A*(B*C) //B와 C를 먼저 곱하고 C를 곱하는 것이다.
이 결과 dp[0][2]은 Case1과 Case2중 최소값을 선택한다.
A B C D의 경우(4개의 매트릭스)
만약 매트릭스가 A B C D로 주어진다면?
Case 1 :(A * B * C) * D //A와 B와 C를 곱한 최소 값에 D를 곱하는 것이다.
Sub Case 1 : (A * B) * C // (A * B * C)의 최소를 찾기 위한 서브 케이스
Sub Case 2: A * (B * C) // (A * B * C)의 최소를 찾기 위하 서브 케이스
즉, Case 1은 min(Sub Case1, Sub Case2)에 그 결과 배열과 D를 곱한 연산의 수가 된다.
Case 2 :(A * B) * (C * D)
Sub Case 1 :A * B
Sub Case 2: C * D
즉, Case 2은 Sub Case1, Sub Case2를 곱한 연산의 수가 된다.
Case3 : A*(B * C * D)
Sub Case 1 : (B * C) * D //(B * C * D)의 최소를 찾기 위한 서브 케이스
Sub Case 1 : B * (C * D) //(B * C * D)의 최소를 찾기 위한 서브 케이스
즉, Case 3은 A매트릭스에 min(Sub Case1, Sub Case2)를 곱한 연산의 수가 된다.
이 결과 dp[0][3]은 Case1과 Case2, Case3중 최소값을 선택한다.
즉 이전에 했던 연산을 이후의 연산에서 지속적으로 사용하게 된다. 바텀 업 방식으로 구현 할 수 있겠다. 또한 큰 경우의 수는 n-1개만큼 나오는 것을 확인 할 수 있다.
그렇다면 어떻게 케이스를 나눌까? 그것은 범위를 기준으로 나누는 것이다.
매트릭스 A B C
metric = [{5, 3}, {3, 2}, {2, 6}}; 를 기준으로 보자.
범위 0인 경우
앞서 dp[i][j]는 i부터 j까지의 행렬 곱 연산의 횟수의 최소값이라고 정의하였다.
dp의 범위(ex: dp[0][1]은 범위가 1이다) 0이라면 값은 0이다. 자기 자신뿐이기 때문이다. 이를테면 dp[1][1]은 매트릭스 1에서 1번까지의 곱셈의 수를 의미하며 이는 0이다.
범위 1인 경우
그렇담 범위가 1이라면? 행렬 A B C에서는 두 가지 케이스가 존재하게 된다.
dp[0][1] // A * B
이는 매트릭스 A와 매트릭스 B의 곱의 형태이다.
dp[0][1] = {5, 3} * {3, 2} = 5 * 3 * 2 = 30
dp[1][2] // B * C
이는 매트릭스 B와 매트릭스 C의 곱의 형태이다.
dp[1][2] = {3, 2} * {2, 6} = 3 * 2 * 6 = 36
범위 2인 경우
범위가 2라면 여기서는 전범위를 나타내게 된다. 즉 정답을 나타낸다.
dp[0][2] : 0에서 2번까지 곱셈 연산의 최소값
이는 아래 두 형태 중 최소값을 나타낸다.
(A * B) * C// 케이스 1번
A * (B * C) // 케이스 2번
케이스 1번
A * B의 결과 배열은 {5, 3} * {3, 2} = {5, 2}의 형태이다.// 매트릭스 결과는 앞 매트릭스의 행과 뒷 매트릭스의 열의 형태
그렇담 곱셈 연산 횟수는
(A * B의 연산수) + ((A * B 결과 행렬) * C행렬의 연산 수)이다. 즉,
30+ (5* 2 * 6)= 90이 된다.
케이스 2번
B * C의 결과 배열은 {3, 2} * {2, 6} = {3, 6}의 형태이다.
그렇담 곱셈 연산 횟수는
(A행렬 * (B * C 결과 행렬)의 연산 수) * (B * C의 연산 수) 이다. 즉,
(5* 3 * 6) + 36 = 126이 된다.
dp[0][2]의 값은 위 케이스 1과 2 중 최소값이 90이 되며, 이는 테스트 케이스의 정답이 된다.
표로 나타내기
이를 표로 나타내면 아래와 같다.
아래 표는 dp배열을 나타낸다. 예를 들어 0행 1열은 dp[0][1]을 나타낸다. 즉 0번부터 1번까지의 최소 곱연산 횟수이다.
DP표 | |||
index | 0 | 1 | 2 |
0 | 0 | 30 | 90 |
1 | 0 | 36 | |
2 | 0 |
위 표를 보면 알다시피 계산은 왼쪽 위부터 오른쪽 아래로 대각선 형태로 진행된다. 다음 범위로 나아가기 위해서는 이전에 범위의 모든 연산(대각 범위)가 모두 필요한 것이다.
즉, 위 표에서 대각선은 같은 범위를 나타내게 된다. 즉, 각 범위의 최소값을 채워나가며 범위를 점점 넓혀나가는 식으로 문제를 풀게 된다.
이를 일반화 하면 아래 식을 얻을 수 있다.(이 일반화해서 식을 찾는 것을 내가 해결하지 못했었다.)
for 1 ≤ i ≤ j ≤ n,
if(i = j): dp[i][j] = 0
if(i < j): dp[i][j] = Minumum(dp[i][k] + dp[k+1][j] + (metric[i][0] * metric[k][1] * metric[j][1])(단, i ≤ k ≤ j-1)
*dp[i][k] + dp[k+1][j]는 피벗 k를 기준으로 앞 행렬과 뒤 행렬의 경우의 수를 더한 것이다.
*metric[i][0] * metric[k][1] * metric[j][1]는 피벗 k를 기준으로 앞 행렬과 뒷 행렬를 곱하는 경우의 수이다. 즉 두 행렬을 이을 때 나오는 곱의 횟수이다.
이 식에서 i는 시작, j는 종료 위치, k는 나누는 위치를 나타낸다.
(A * B) * C에서 i는 0, j는 2, k는 1이다.
식을 사용한 접근
이 식을 이용해 범위 1부터 범위 2까지 접근해보자.
metric = [{5, 3}, {3, 2}, {2, 6}};
범위 1 :
i = 0, j = 1, k = 0
dp[0][1] = dp[0][0] + dp[1][1] + (metric[0][0] * metric[0][1] * metric[1][1])
= 0 + 0 + (5 * 3 * 2) = 30
i = 1, j = 2, k = 1
dp[1][2] = dp[1][1] + dp[2][2] + (metric[1][0] * metric[1][1] * metric[2][1])
= 0 + 0 + (3 * 2 * 6) = 36
범위 2:
i = 0, j = 2, k = 0
dp[0][2] = dp[0][0] + dp[1][2] + (metric[0][0] * metric[0][1] * metric[2][1])
= 0 + 36 + (5 * 3 * 6) = 126
i = 0 , j =2, k = 1
dp[0][2] = dp[0][1] + dp[2][2] + (metric[0][0] * metric[1][1] * metric[2][1])
= 30 + 0 + (5 * 2 * 6) = 90
dp[0][2]는 이 중 최소값인 90이 된다.
정답 코드
import java.util.*;
import java.math.*;
public class Main {
static int INF = Integer.MAX_VALUE;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int metric[][] = new int[501][2];
int n = sc.nextInt();
int dp[][] = new int[501][501];
for(int i = 0; i < n; i++) {
int r = sc.nextInt();
int c = sc.nextInt();
metric[i][0] = r;
metric[i][1] = c;
}
for(int gap = 1; gap < n; gap++) { //범위(0은 모두 0으로 초기화하기에 1부터 시작)
for(int i = 0; i < n-gap; i++) {//metric number
dp[i][i+gap] = INF;
for(int pivot = i; pivot < i+gap; pivot++) {//(시작범위 ~ 끝범위까지)
int val = dp[i][pivot] + dp[pivot+1][i+gap] + (metric[i][0] * metric[pivot][1] * metric[i+gap][1]);
dp[i][i+gap] = Math.min(dp[i][i+gap], val);
}
}
}
System.out.println(dp[0][n-1]);
}
}
내가 이해하는 데 약간 시간이 걸린 문제기에 풀이를 최대한 자세히 작성해 보았다. 틀린 설명을 발견할 시 댓글로 정정해 주기를 부탁드립니다.
* 아래 강의가 아주 큰 도움이 되었습니다. 이해가 안 된다면 꼭 한 번 보시길 추천합니다.
ref: