|
7 | 7 | "colab_type": "text" |
8 | 8 | }, |
9 | 9 | "source": [ |
10 | | - "<a href=\"https://colab.research.google.com/github/ssec/WAF_ML_Tutorial_Part2/blob/main/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
| 10 | + "<a href=\"https://colab.research.google.com/github/ssec/WAF_ML_Tutorial_Part2/blob/2-confusion-about-interpolating/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
11 | 11 | ] |
12 | 12 | }, |
13 | 13 | { |
|
51 | 51 | "base_uri": "https://localhost:8080/" |
52 | 52 | }, |
53 | 53 | "id": "x2m10lWrHr_R", |
54 | | - "outputId": "f74daa9b-3d41-43b1-bbd0-4aad76043487" |
| 54 | + "outputId": "a0fcf4ba-0837-49a9-b7da-4c9870b31550" |
55 | 55 | }, |
56 | | - "execution_count": 3, |
| 56 | + "execution_count": 1, |
57 | 57 | "outputs": [ |
58 | 58 | { |
59 | 59 | "output_type": "stream", |
|
65 | 65 | "remote: Counting objects: 100% (212/212), done.\u001b[K\n", |
66 | 66 | "remote: Compressing objects: 100% (119/119), done.\u001b[K\n", |
67 | 67 | "remote: Total 368 (delta 140), reused 134 (delta 93), pack-reused 156 (from 1)\u001b[K\n", |
68 | | - "Receiving objects: 100% (368/368), 114.60 MiB | 26.65 MiB/s, done.\n", |
| 68 | + "Receiving objects: 100% (368/368), 114.60 MiB | 15.55 MiB/s, done.\n", |
69 | 69 | "Resolving deltas: 100% (204/204), done.\n", |
70 | 70 | "Updating files: 100% (120/120), done.\n", |
71 | 71 | "done\n" |
|
75 | 75 | }, |
76 | 76 | { |
77 | 77 | "cell_type": "code", |
78 | | - "execution_count": 5, |
| 78 | + "execution_count": 2, |
79 | 79 | "metadata": { |
80 | 80 | "colab": { |
81 | 81 | "base_uri": "https://localhost:8080/" |
82 | 82 | }, |
83 | 83 | "id": "PTq52qdZGnZu", |
84 | | - "outputId": "815ab8ed-641a-4dec-afd9-9ca04cdaa9c9" |
| 84 | + "outputId": "b04f4060-fe1a-4143-e75a-07e1e0ceef40" |
85 | 85 | }, |
86 | 86 | "outputs": [ |
87 | 87 | { |
88 | 88 | "output_type": "execute_result", |
89 | 89 | "data": { |
90 | 90 | "text/plain": [ |
91 | | - "<torch._C.Generator at 0x7fce206d96d0>" |
| 91 | + "<torch._C.Generator at 0x7fbf50775a10>" |
92 | 92 | ] |
93 | 93 | }, |
94 | 94 | "metadata": {}, |
95 | | - "execution_count": 5 |
| 95 | + "execution_count": 2 |
96 | 96 | } |
97 | 97 | ], |
98 | 98 | "source": [ |
|
146 | 146 | }, |
147 | 147 | { |
148 | 148 | "cell_type": "code", |
149 | | - "execution_count": 7, |
| 149 | + "execution_count": 3, |
150 | 150 | "metadata": { |
151 | 151 | "colab": { |
152 | 152 | "base_uri": "https://localhost:8080/", |
153 | 153 | "height": 144 |
154 | 154 | }, |
155 | 155 | "id": "LCBzb32YGnZv", |
156 | | - "outputId": "f6b663d5-4d08-4142-cb08-454dcbfe5b67" |
| 156 | + "outputId": "e9ecf48f-5d2f-4c83-bade-53815cbd9ae8" |
157 | 157 | }, |
158 | 158 | "outputs": [ |
159 | 159 | { |
|
651 | 651 | " radar_field: 4)\n", |
652 | 652 | "Dimensions without coordinates: grid_row, grid_column, radar_height, radar_field\n", |
653 | 653 | "Data variables:\n", |
654 | | - " radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ...</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.Dataset</div></div><ul class='xr-sections'><li class='xr-section-item'><input id='section-a7755d39-2484-4e97-ae1d-454290a50143' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-a7755d39-2484-4e97-ae1d-454290a50143' class='xr-section-summary' title='Expand/collapse section'>Dimensions:</label><div class='xr-section-inline-details'><ul class='xr-dim-list'><li><span>grid_row</span>: 32</li><li><span>grid_column</span>: 32</li><li><span>radar_height</span>: 12</li><li><span>radar_field</span>: 4</li></ul></div><div class='xr-section-details'></div></li><li class='xr-section-item'><input id='section-9a5d17a4-de47-407b-a07c-85494b34bafb' class='xr-section-summary-in' type='checkbox' checked><label for='section-9a5d17a4-de47-407b-a07c-85494b34bafb' class='xr-section-summary' >Data variables: <span>(1)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'><li class='xr-var-item'><div class='xr-var-name'><span>radar_image_matrix</span></div><div class='xr-var-dims'>(grid_row, grid_column, radar_height, radar_field)</div><div class='xr-var-dtype'>float32</div><div class='xr-var-preview xr-preview'>...</div><input id='attrs-59812b5c-822f-469f-a22f-77ed4ca76955' class='xr-var-attrs-in' type='checkbox' disabled><label for='attrs-59812b5c-822f-469f-a22f-77ed4ca76955' title='Show/Hide attributes'><svg class='icon xr-icon-file-text2'><use xlink:href='#icon-file-text2'></use></svg></label><input id='data-f28124eb-5b12-4d34-ba0a-4485b53f44d0' class='xr-var-data-in' type='checkbox'><label for='data-f28124eb-5b12-4d34-ba0a-4485b53f44d0' title='Show/Hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-var-attrs'><dl class='xr-attrs'></dl></div><div class='xr-var-data'><pre>[49152 values with dtype=float32]</pre></div></li></ul></div></li></ul></div></div>" |
| 654 | + " radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ...</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.Dataset</div></div><ul class='xr-sections'><li class='xr-section-item'><input id='section-4183d7c7-e4db-4160-9040-d00243e6ea7f' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-4183d7c7-e4db-4160-9040-d00243e6ea7f' class='xr-section-summary' title='Expand/collapse section'>Dimensions:</label><div class='xr-section-inline-details'><ul class='xr-dim-list'><li><span>grid_row</span>: 32</li><li><span>grid_column</span>: 32</li><li><span>radar_height</span>: 12</li><li><span>radar_field</span>: 4</li></ul></div><div class='xr-section-details'></div></li><li class='xr-section-item'><input id='section-c8eb729d-53f9-4cb6-9fe8-715d4648f54f' class='xr-section-summary-in' type='checkbox' checked><label for='section-c8eb729d-53f9-4cb6-9fe8-715d4648f54f' class='xr-section-summary' >Data variables: <span>(1)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'><li class='xr-var-item'><div class='xr-var-name'><span>radar_image_matrix</span></div><div class='xr-var-dims'>(grid_row, grid_column, radar_height, radar_field)</div><div class='xr-var-dtype'>float32</div><div class='xr-var-preview xr-preview'>...</div><input id='attrs-fc79c9eb-fac4-45ba-8a24-e0cf9f9b1289' class='xr-var-attrs-in' type='checkbox' disabled><label for='attrs-fc79c9eb-fac4-45ba-8a24-e0cf9f9b1289' title='Show/Hide attributes'><svg class='icon xr-icon-file-text2'><use xlink:href='#icon-file-text2'></use></svg></label><input id='data-11a08a46-3ed6-4e67-a766-665b61804627' class='xr-var-data-in' type='checkbox'><label for='data-11a08a46-3ed6-4e67-a766-665b61804627' title='Show/Hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-var-attrs'><dl class='xr-attrs'></dl></div><div class='xr-var-data'><pre>[49152 values with dtype=float32]</pre></div></li></ul></div></li></ul></div></div>" |
655 | 655 | ] |
656 | 656 | }, |
657 | 657 | "metadata": {}, |
658 | | - "execution_count": 7 |
| 658 | + "execution_count": 3 |
659 | 659 | } |
660 | 660 | ], |
661 | 661 | "source": [ |
|
681 | 681 | }, |
682 | 682 | { |
683 | 683 | "cell_type": "code", |
684 | | - "execution_count": 9, |
| 684 | + "execution_count": 4, |
685 | 685 | "metadata": { |
686 | 686 | "colab": { |
687 | 687 | "base_uri": "https://localhost:8080/", |
688 | 688 | "height": 1000 |
689 | 689 | }, |
690 | 690 | "id": "GMjoJBblGnZw", |
691 | | - "outputId": "5f62052e-99b4-4cb0-b6db-f33a08a48276" |
| 691 | + "outputId": "13120921-9e2b-4874-9c01-6cbd3793baaf" |
692 | 692 | }, |
693 | 693 | "outputs": [ |
694 | 694 | { |
|
699 | 699 | ] |
700 | 700 | }, |
701 | 701 | "metadata": {}, |
702 | | - "execution_count": 9 |
| 702 | + "execution_count": 4 |
703 | 703 | }, |
704 | 704 | { |
705 | 705 | "output_type": "display_data", |
|
901 | 901 | }, |
902 | 902 | { |
903 | 903 | "cell_type": "code", |
904 | | - "execution_count": 12, |
| 904 | + "source": [ |
| 905 | + "# %%\n", |
| 906 | + "# Convert xarray data to numpy, then to a torch tensor\n", |
| 907 | + "input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n", |
| 908 | + "\n", |
| 909 | + "# Step 1: Convert to Tensor\n", |
| 910 | + "# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n", |
| 911 | + "# We ensure it starts as a float tensor.\n", |
| 912 | + "print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n", |
| 913 | + "input_tensor = torch.from_numpy(input_data).float()\n", |
| 914 | + "\n", |
| 915 | + "# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n", |
| 916 | + "# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n", |
| 917 | + "if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n", |
| 918 | + " input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n", |
| 919 | + "elif input_tensor.ndim == 2:\n", |
| 920 | + " input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n", |
| 921 | + "\n", |
| 922 | + "# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n", |
| 923 | + "input_tensor = input_tensor.unsqueeze(0)" |
| 924 | + ], |
| 925 | + "metadata": { |
| 926 | + "colab": { |
| 927 | + "base_uri": "https://localhost:8080/" |
| 928 | + }, |
| 929 | + "id": "zo_JAULelGBC", |
| 930 | + "outputId": "313fa14f-3be3-47e5-cbc5-48e964a80505" |
| 931 | + }, |
| 932 | + "execution_count": 5, |
| 933 | + "outputs": [ |
| 934 | + { |
| 935 | + "output_type": "stream", |
| 936 | + "name": "stdout", |
| 937 | + "text": [ |
| 938 | + "(32, 32, 1)=\n" |
| 939 | + ] |
| 940 | + } |
| 941 | + ] |
| 942 | + }, |
| 943 | + { |
| 944 | + "cell_type": "markdown", |
| 945 | + "source": [ |
| 946 | + "Now our data (`input_tensor`) is in the correct shape: Batch, Channel, Height, Width.\n", |
| 947 | + "\n", |
| 948 | + "One last step before applying the convolution - we will resize `input_tensor`. This is simply for the sake of demonstrating the sharpening convolution. You generally would not do this in a ML model setting. The interpolation will smooth the original image a little, and the final sharpened convolution with look a little nicer. (Feel free to re-run this interpolation later with size=(32, 32) to see the convolution result below without the interpolation.)\n", |
| 949 | + "\n", |
| 950 | + "Note that we use the pytorch functional interface to interpolate the pytorch tensor object `input_tensor`. Again, we aren't performing any machine learning here. We're simply demonstrating how convolutions work." |
| 951 | + ], |
| 952 | + "metadata": { |
| 953 | + "id": "cc5ABhj5nQ2e" |
| 954 | + } |
| 955 | + }, |
| 956 | + { |
| 957 | + "cell_type": "code", |
| 958 | + "source": [ |
| 959 | + "# interpolate from (1, 1, 32, 32), (1, 1, 36, 36)\n", |
| 960 | + "more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n" |
| 961 | + ], |
| 962 | + "metadata": { |
| 963 | + "id": "rnhhUlxnnLBT" |
| 964 | + }, |
| 965 | + "execution_count": 15, |
| 966 | + "outputs": [] |
| 967 | + }, |
| 968 | + { |
| 969 | + "cell_type": "markdown", |
| 970 | + "source": [ |
| 971 | + "Finally we can define, configure, and apply our pytorch convolution operation, then plot the results." |
| 972 | + ], |
| 973 | + "metadata": { |
| 974 | + "id": "nLzGMq4QlI4X" |
| 975 | + } |
| 976 | + }, |
| 977 | + { |
| 978 | + "cell_type": "code", |
| 979 | + "execution_count": 16, |
905 | 980 | "metadata": { |
906 | 981 | "colab": { |
907 | 982 | "base_uri": "https://localhost:8080/", |
908 | | - "height": 430 |
| 983 | + "height": 450 |
909 | 984 | }, |
910 | 985 | "id": "7eM4lMhIGnZx", |
911 | | - "outputId": "d89b5397-f4fb-4991-c473-cba1b02123c7" |
| 986 | + "outputId": "eb1fb700-7fc6-421b-a490-acb5a1355fed" |
912 | 987 | }, |
913 | 988 | "outputs": [ |
914 | 989 | { |
915 | 990 | "output_type": "stream", |
916 | 991 | "name": "stdout", |
917 | 992 | "text": [ |
918 | | - "(32, 32, 1)=\n" |
| 993 | + "more_points.shape=torch.Size([1, 1, 36, 36]), convolution_result.shape=torch.Size([1, 1, 34, 34])\n" |
919 | 994 | ] |
920 | 995 | }, |
921 | 996 | { |
|
935 | 1010 | } |
936 | 1011 | ], |
937 | 1012 | "source": [ |
938 | | - "# %%\n", |
939 | | - "# Convert xarray data to numpy, then to a torch tensor\n", |
940 | | - "input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n", |
941 | | - "\n", |
942 | | - "# Step 1: Convert to Tensor\n", |
943 | | - "# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n", |
944 | | - "# We ensure it starts as a float tensor.\n", |
945 | | - "print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n", |
946 | | - "input_tensor = torch.from_numpy(input_data).float()\n", |
947 | | - "\n", |
948 | | - "# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n", |
949 | | - "# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n", |
950 | | - "if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n", |
951 | | - " input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n", |
952 | | - "elif input_tensor.ndim == 2:\n", |
953 | | - " input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n", |
954 | | - "\n", |
955 | | - "# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n", |
956 | | - "input_tensor = input_tensor.unsqueeze(0)\n", |
957 | | - "\n", |
958 | | - "# Step 3: Resize using interpolate\n", |
959 | | - "# Now input_tensor is (1, 1, 32, 32), which matches the 2D spatial size expected by interpolate\n", |
960 | | - "more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n", |
961 | 1013 | "\n", |
962 | 1014 | "# Define the sharpen filter\n", |
963 | 1015 | "kernel_weights = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])\n", |
964 | 1016 | "\n", |
965 | | - "# Define conv with specific weights.\n", |
| 1017 | + "# Define conv\n", |
966 | 1018 | "conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)\n", |
967 | 1019 | "\n", |
968 | | - "# Set the weights manually\n", |
| 1020 | + "# Set the conv weights manually\n", |
969 | 1021 | "with torch.no_grad():\n", |
970 | 1022 | " # We reshape our 3x3 kernel to (1, 1, 3, 3)\n", |
971 | 1023 | " conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)\n", |
972 | 1024 | "\n", |
973 | 1025 | "# Run the data through\n", |
974 | 1026 | "convolution_result = conv(more_points)\n", |
975 | | - "\n", |
| 1027 | + "print(f\"{more_points.shape=}, {convolution_result.shape=}\")\n", |
976 | 1028 | "\n", |
977 | 1029 | "# Plotting\n", |
978 | 1030 | "fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n", |
|
1008 | 1060 | }, |
1009 | 1061 | { |
1010 | 1062 | "cell_type": "code", |
1011 | | - "execution_count": 15, |
| 1063 | + "execution_count": 7, |
1012 | 1064 | "metadata": { |
1013 | 1065 | "colab": { |
1014 | 1066 | "base_uri": "https://localhost:8080/", |
1015 | 1067 | "height": 420 |
1016 | 1068 | }, |
1017 | 1069 | "id": "oiUu9nDyGnZy", |
1018 | | - "outputId": "f22aa7f2-97c9-4b3a-cd16-a40aec3c5ee6" |
| 1070 | + "outputId": "8999b7db-e88a-4b93-ccd9-c1d5e1bb3fd0" |
1019 | 1071 | }, |
1020 | 1072 | "outputs": [ |
1021 | 1073 | { |
|
1026 | 1078 | ] |
1027 | 1079 | }, |
1028 | 1080 | "metadata": {}, |
1029 | | - "execution_count": 15 |
| 1081 | + "execution_count": 7 |
1030 | 1082 | }, |
1031 | 1083 | { |
1032 | 1084 | "output_type": "display_data", |
|
1123 | 1175 | }, |
1124 | 1176 | { |
1125 | 1177 | "cell_type": "code", |
1126 | | - "execution_count": 17, |
| 1178 | + "execution_count": 8, |
1127 | 1179 | "metadata": { |
1128 | 1180 | "colab": { |
1129 | 1181 | "base_uri": "https://localhost:8080/", |
1130 | 1182 | "height": 420 |
1131 | 1183 | }, |
1132 | 1184 | "id": "qnfWwa9MGnZz", |
1133 | | - "outputId": "6a9b113a-ee16-4a5d-ce93-6c65fa3b696e" |
| 1185 | + "outputId": "913a204a-b0f4-447e-adb1-b4004e133eb2" |
1134 | 1186 | }, |
1135 | 1187 | "outputs": [ |
1136 | 1188 | { |
|
1141 | 1193 | ] |
1142 | 1194 | }, |
1143 | 1195 | "metadata": {}, |
1144 | | - "execution_count": 17 |
| 1196 | + "execution_count": 8 |
1145 | 1197 | }, |
1146 | 1198 | { |
1147 | 1199 | "output_type": "display_data", |
|
1247 | 1299 | }, |
1248 | 1300 | { |
1249 | 1301 | "cell_type": "code", |
1250 | | - "execution_count": 18, |
| 1302 | + "execution_count": 9, |
1251 | 1303 | "metadata": { |
1252 | 1304 | "colab": { |
1253 | 1305 | "base_uri": "https://localhost:8080/", |
1254 | 1306 | "height": 430 |
1255 | 1307 | }, |
1256 | 1308 | "id": "qLPaA52tGnZz", |
1257 | | - "outputId": "0882cead-a86f-4fe8-b4e1-7b16a2cc7818" |
| 1309 | + "outputId": "65bfa86b-8896-4d38-e8de-76ad52b93a3c" |
1258 | 1310 | }, |
1259 | 1311 | "outputs": [ |
1260 | 1312 | { |
|
1265 | 1317 | ] |
1266 | 1318 | }, |
1267 | 1319 | "metadata": {}, |
1268 | | - "execution_count": 18 |
| 1320 | + "execution_count": 9 |
1269 | 1321 | }, |
1270 | 1322 | { |
1271 | 1323 | "output_type": "display_data", |
|
1329 | 1381 | }, |
1330 | 1382 | { |
1331 | 1383 | "cell_type": "code", |
1332 | | - "execution_count": 20, |
| 1384 | + "execution_count": 17, |
1333 | 1385 | "metadata": { |
1334 | 1386 | "colab": { |
1335 | 1387 | "base_uri": "https://localhost:8080/", |
1336 | | - "height": 430 |
| 1388 | + "height": 449 |
1337 | 1389 | }, |
1338 | 1390 | "id": "lPNUoEYiGnZ0", |
1339 | | - "outputId": "ccd5c1c8-3f2d-45c3-b6cc-c0ee0bd05dbe" |
| 1391 | + "outputId": "98067555-f1e0-493b-84c7-01494f3db8d9" |
1340 | 1392 | }, |
1341 | 1393 | "outputs": [ |
| 1394 | + { |
| 1395 | + "output_type": "stream", |
| 1396 | + "name": "stdout", |
| 1397 | + "text": [ |
| 1398 | + "more_points.shape=torch.Size([1, 1, 36, 36]), res.shape=torch.Size([1, 1, 36, 36])\n" |
| 1399 | + ] |
| 1400 | + }, |
1342 | 1401 | { |
1343 | 1402 | "output_type": "execute_result", |
1344 | 1403 | "data": { |
|
1347 | 1406 | ] |
1348 | 1407 | }, |
1349 | 1408 | "metadata": {}, |
1350 | | - "execution_count": 20 |
| 1409 | + "execution_count": 17 |
1351 | 1410 | }, |
1352 | 1411 | { |
1353 | 1412 | "output_type": "display_data", |
|
1393 | 1452 | "#run the data through\n", |
1394 | 1453 | "res = conv(more_points)\n", |
1395 | 1454 | "\n", |
| 1455 | + "print(f\"{more_points.shape=}, {res.shape=}\")\n", |
1396 | 1456 | "\n", |
1397 | 1457 | "fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n", |
1398 | 1458 | "ax1.imshow(more_points.squeeze(), vmin=0, vmax=60, cmap='Spectral_r')\n", |
|
0 commit comments