Skip to content

Commit 2fac8a4

Browse files
authored
Merge pull request #5 from ssec/2-confusion-about-interpolating
Convolution interpolation explanation addition (fix #2)
2 parents 9dc55e2 + f50cefa commit 2fac8a4

File tree

1 file changed

+118
-58
lines changed

1 file changed

+118
-58
lines changed

colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb

Lines changed: 118 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"colab_type": "text"
88
},
99
"source": [
10-
"<a href=\"https://colab.research.google.com/github/ssec/WAF_ML_Tutorial_Part2/blob/main/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
10+
"<a href=\"https://colab.research.google.com/github/ssec/WAF_ML_Tutorial_Part2/blob/2-confusion-about-interpolating/colab_notebooks/Notebook06_Convolutions_PyTorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
1111
]
1212
},
1313
{
@@ -51,9 +51,9 @@
5151
"base_uri": "https://localhost:8080/"
5252
},
5353
"id": "x2m10lWrHr_R",
54-
"outputId": "f74daa9b-3d41-43b1-bbd0-4aad76043487"
54+
"outputId": "a0fcf4ba-0837-49a9-b7da-4c9870b31550"
5555
},
56-
"execution_count": 3,
56+
"execution_count": 1,
5757
"outputs": [
5858
{
5959
"output_type": "stream",
@@ -65,7 +65,7 @@
6565
"remote: Counting objects: 100% (212/212), done.\u001b[K\n",
6666
"remote: Compressing objects: 100% (119/119), done.\u001b[K\n",
6767
"remote: Total 368 (delta 140), reused 134 (delta 93), pack-reused 156 (from 1)\u001b[K\n",
68-
"Receiving objects: 100% (368/368), 114.60 MiB | 26.65 MiB/s, done.\n",
68+
"Receiving objects: 100% (368/368), 114.60 MiB | 15.55 MiB/s, done.\n",
6969
"Resolving deltas: 100% (204/204), done.\n",
7070
"Updating files: 100% (120/120), done.\n",
7171
"done\n"
@@ -75,24 +75,24 @@
7575
},
7676
{
7777
"cell_type": "code",
78-
"execution_count": 5,
78+
"execution_count": 2,
7979
"metadata": {
8080
"colab": {
8181
"base_uri": "https://localhost:8080/"
8282
},
8383
"id": "PTq52qdZGnZu",
84-
"outputId": "815ab8ed-641a-4dec-afd9-9ca04cdaa9c9"
84+
"outputId": "b04f4060-fe1a-4143-e75a-07e1e0ceef40"
8585
},
8686
"outputs": [
8787
{
8888
"output_type": "execute_result",
8989
"data": {
9090
"text/plain": [
91-
"<torch._C.Generator at 0x7fce206d96d0>"
91+
"<torch._C.Generator at 0x7fbf50775a10>"
9292
]
9393
},
9494
"metadata": {},
95-
"execution_count": 5
95+
"execution_count": 2
9696
}
9797
],
9898
"source": [
@@ -146,14 +146,14 @@
146146
},
147147
{
148148
"cell_type": "code",
149-
"execution_count": 7,
149+
"execution_count": 3,
150150
"metadata": {
151151
"colab": {
152152
"base_uri": "https://localhost:8080/",
153153
"height": 144
154154
},
155155
"id": "LCBzb32YGnZv",
156-
"outputId": "f6b663d5-4d08-4142-cb08-454dcbfe5b67"
156+
"outputId": "e9ecf48f-5d2f-4c83-bade-53815cbd9ae8"
157157
},
158158
"outputs": [
159159
{
@@ -651,11 +651,11 @@
651651
" radar_field: 4)\n",
652652
"Dimensions without coordinates: grid_row, grid_column, radar_height, radar_field\n",
653653
"Data variables:\n",
654-
" radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ...</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.Dataset</div></div><ul class='xr-sections'><li class='xr-section-item'><input id='section-a7755d39-2484-4e97-ae1d-454290a50143' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-a7755d39-2484-4e97-ae1d-454290a50143' class='xr-section-summary' title='Expand/collapse section'>Dimensions:</label><div class='xr-section-inline-details'><ul class='xr-dim-list'><li><span>grid_row</span>: 32</li><li><span>grid_column</span>: 32</li><li><span>radar_height</span>: 12</li><li><span>radar_field</span>: 4</li></ul></div><div class='xr-section-details'></div></li><li class='xr-section-item'><input id='section-9a5d17a4-de47-407b-a07c-85494b34bafb' class='xr-section-summary-in' type='checkbox' checked><label for='section-9a5d17a4-de47-407b-a07c-85494b34bafb' class='xr-section-summary' >Data variables: <span>(1)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'><li class='xr-var-item'><div class='xr-var-name'><span>radar_image_matrix</span></div><div class='xr-var-dims'>(grid_row, grid_column, radar_height, radar_field)</div><div class='xr-var-dtype'>float32</div><div class='xr-var-preview xr-preview'>...</div><input id='attrs-59812b5c-822f-469f-a22f-77ed4ca76955' class='xr-var-attrs-in' type='checkbox' disabled><label for='attrs-59812b5c-822f-469f-a22f-77ed4ca76955' title='Show/Hide attributes'><svg class='icon xr-icon-file-text2'><use xlink:href='#icon-file-text2'></use></svg></label><input id='data-f28124eb-5b12-4d34-ba0a-4485b53f44d0' class='xr-var-data-in' type='checkbox'><label for='data-f28124eb-5b12-4d34-ba0a-4485b53f44d0' title='Show/Hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-var-attrs'><dl class='xr-attrs'></dl></div><div class='xr-var-data'><pre>[49152 values with dtype=float32]</pre></div></li></ul></div></li></ul></div></div>"
654+
" radar_image_matrix (grid_row, grid_column, radar_height, radar_field) float32 197kB ...</pre><div class='xr-wrap' style='display:none'><div class='xr-header'><div class='xr-obj-type'>xarray.Dataset</div></div><ul class='xr-sections'><li class='xr-section-item'><input id='section-4183d7c7-e4db-4160-9040-d00243e6ea7f' class='xr-section-summary-in' type='checkbox' disabled ><label for='section-4183d7c7-e4db-4160-9040-d00243e6ea7f' class='xr-section-summary' title='Expand/collapse section'>Dimensions:</label><div class='xr-section-inline-details'><ul class='xr-dim-list'><li><span>grid_row</span>: 32</li><li><span>grid_column</span>: 32</li><li><span>radar_height</span>: 12</li><li><span>radar_field</span>: 4</li></ul></div><div class='xr-section-details'></div></li><li class='xr-section-item'><input id='section-c8eb729d-53f9-4cb6-9fe8-715d4648f54f' class='xr-section-summary-in' type='checkbox' checked><label for='section-c8eb729d-53f9-4cb6-9fe8-715d4648f54f' class='xr-section-summary' >Data variables: <span>(1)</span></label><div class='xr-section-inline-details'></div><div class='xr-section-details'><ul class='xr-var-list'><li class='xr-var-item'><div class='xr-var-name'><span>radar_image_matrix</span></div><div class='xr-var-dims'>(grid_row, grid_column, radar_height, radar_field)</div><div class='xr-var-dtype'>float32</div><div class='xr-var-preview xr-preview'>...</div><input id='attrs-fc79c9eb-fac4-45ba-8a24-e0cf9f9b1289' class='xr-var-attrs-in' type='checkbox' disabled><label for='attrs-fc79c9eb-fac4-45ba-8a24-e0cf9f9b1289' title='Show/Hide attributes'><svg class='icon xr-icon-file-text2'><use xlink:href='#icon-file-text2'></use></svg></label><input id='data-11a08a46-3ed6-4e67-a766-665b61804627' class='xr-var-data-in' type='checkbox'><label for='data-11a08a46-3ed6-4e67-a766-665b61804627' title='Show/Hide data repr'><svg class='icon xr-icon-database'><use xlink:href='#icon-database'></use></svg></label><div class='xr-var-attrs'><dl class='xr-attrs'></dl></div><div class='xr-var-data'><pre>[49152 values with dtype=float32]</pre></div></li></ul></div></li></ul></div></div>"
655655
]
656656
},
657657
"metadata": {},
658-
"execution_count": 7
658+
"execution_count": 3
659659
}
660660
],
661661
"source": [
@@ -681,14 +681,14 @@
681681
},
682682
{
683683
"cell_type": "code",
684-
"execution_count": 9,
684+
"execution_count": 4,
685685
"metadata": {
686686
"colab": {
687687
"base_uri": "https://localhost:8080/",
688688
"height": 1000
689689
},
690690
"id": "GMjoJBblGnZw",
691-
"outputId": "5f62052e-99b4-4cb0-b6db-f33a08a48276"
691+
"outputId": "13120921-9e2b-4874-9c01-6cbd3793baaf"
692692
},
693693
"outputs": [
694694
{
@@ -699,7 +699,7 @@
699699
]
700700
},
701701
"metadata": {},
702-
"execution_count": 9
702+
"execution_count": 4
703703
},
704704
{
705705
"output_type": "display_data",
@@ -901,21 +901,96 @@
901901
},
902902
{
903903
"cell_type": "code",
904-
"execution_count": 12,
904+
"source": [
905+
"# %%\n",
906+
"# Convert xarray data to numpy, then to a torch tensor\n",
907+
"input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n",
908+
"\n",
909+
"# Step 1: Convert to Tensor\n",
910+
"# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n",
911+
"# We ensure it starts as a float tensor.\n",
912+
"print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n",
913+
"input_tensor = torch.from_numpy(input_data).float()\n",
914+
"\n",
915+
"# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n",
916+
"# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n",
917+
"if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n",
918+
" input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n",
919+
"elif input_tensor.ndim == 2:\n",
920+
" input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n",
921+
"\n",
922+
"# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n",
923+
"input_tensor = input_tensor.unsqueeze(0)"
924+
],
925+
"metadata": {
926+
"colab": {
927+
"base_uri": "https://localhost:8080/"
928+
},
929+
"id": "zo_JAULelGBC",
930+
"outputId": "313fa14f-3be3-47e5-cbc5-48e964a80505"
931+
},
932+
"execution_count": 5,
933+
"outputs": [
934+
{
935+
"output_type": "stream",
936+
"name": "stdout",
937+
"text": [
938+
"(32, 32, 1)=\n"
939+
]
940+
}
941+
]
942+
},
943+
{
944+
"cell_type": "markdown",
945+
"source": [
946+
"Now our data (`input_tensor`) is in the correct shape: Batch, Channel, Height, Width.\n",
947+
"\n",
948+
"One last step before applying the convolution - we will resize `input_tensor`. This is simply for the sake of demonstrating the sharpening convolution. You generally would not do this in a ML model setting. The interpolation will smooth the original image a little, and the final sharpened convolution with look a little nicer. (Feel free to re-run this interpolation later with size=(32, 32) to see the convolution result below without the interpolation.)\n",
949+
"\n",
950+
"Note that we use the pytorch functional interface to interpolate the pytorch tensor object `input_tensor`. Again, we aren't performing any machine learning here. We're simply demonstrating how convolutions work."
951+
],
952+
"metadata": {
953+
"id": "cc5ABhj5nQ2e"
954+
}
955+
},
956+
{
957+
"cell_type": "code",
958+
"source": [
959+
"# interpolate from (1, 1, 32, 32), (1, 1, 36, 36)\n",
960+
"more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n"
961+
],
962+
"metadata": {
963+
"id": "rnhhUlxnnLBT"
964+
},
965+
"execution_count": 15,
966+
"outputs": []
967+
},
968+
{
969+
"cell_type": "markdown",
970+
"source": [
971+
"Finally we can define, configure, and apply our pytorch convolution operation, then plot the results."
972+
],
973+
"metadata": {
974+
"id": "nLzGMq4QlI4X"
975+
}
976+
},
977+
{
978+
"cell_type": "code",
979+
"execution_count": 16,
905980
"metadata": {
906981
"colab": {
907982
"base_uri": "https://localhost:8080/",
908-
"height": 430
983+
"height": 450
909984
},
910985
"id": "7eM4lMhIGnZx",
911-
"outputId": "d89b5397-f4fb-4991-c473-cba1b02123c7"
986+
"outputId": "eb1fb700-7fc6-421b-a490-acb5a1355fed"
912987
},
913988
"outputs": [
914989
{
915990
"output_type": "stream",
916991
"name": "stdout",
917992
"text": [
918-
"(32, 32, 1)=\n"
993+
"more_points.shape=torch.Size([1, 1, 36, 36]), convolution_result.shape=torch.Size([1, 1, 34, 34])\n"
919994
]
920995
},
921996
{
@@ -935,44 +1010,21 @@
9351010
}
9361011
],
9371012
"source": [
938-
"# %%\n",
939-
"# Convert xarray data to numpy, then to a torch tensor\n",
940-
"input_data = ds_sample.radar_image_matrix.isel(radar_height=slice(0,1),radar_field=0).values\n",
941-
"\n",
942-
"# Step 1: Convert to Tensor\n",
943-
"# The shape of input_data is likely (32, 32, 1) or (32, 32) depending on how xarray slices it.\n",
944-
"# We ensure it starts as a float tensor.\n",
945-
"print(f\"{input_data.shape}=\") # 32, 32, 1; (H, W, C)\n",
946-
"input_tensor = torch.from_numpy(input_data).float()\n",
947-
"\n",
948-
"# Step 2: Ensure Correct Shape (Batch, Channel, Height, Width)\n",
949-
"# If the input is (Height, Width, Channel), we need to permute it to (Channel, Height, Width)\n",
950-
"if input_tensor.ndim == 3 and input_tensor.shape[-1] == 1:\n",
951-
" input_tensor = input_tensor.permute(2, 0, 1) # (H, W, C) -> (C, H, W)\n",
952-
"elif input_tensor.ndim == 2:\n",
953-
" input_tensor = input_tensor.unsqueeze(0) # Add Channel dim if missing: (H, W) -> (C, H, W)\n",
954-
"\n",
955-
"# Add Batch Dimension: (C, H, W) -> (Batch, C, H, W)\n",
956-
"input_tensor = input_tensor.unsqueeze(0)\n",
957-
"\n",
958-
"# Step 3: Resize using interpolate\n",
959-
"# Now input_tensor is (1, 1, 32, 32), which matches the 2D spatial size expected by interpolate\n",
960-
"more_points = F.interpolate(input_tensor, size=(36, 36), mode='bilinear', align_corners=False)\n",
9611013
"\n",
9621014
"# Define the sharpen filter\n",
9631015
"kernel_weights = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])\n",
9641016
"\n",
965-
"# Define conv with specific weights.\n",
1017+
"# Define conv\n",
9661018
"conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, bias=False)\n",
9671019
"\n",
968-
"# Set the weights manually\n",
1020+
"# Set the conv weights manually\n",
9691021
"with torch.no_grad():\n",
9701022
" # We reshape our 3x3 kernel to (1, 1, 3, 3)\n",
9711023
" conv.weight.data = torch.from_numpy(kernel_weights).float().view(1, 1, 3, 3)\n",
9721024
"\n",
9731025
"# Run the data through\n",
9741026
"convolution_result = conv(more_points)\n",
975-
"\n",
1027+
"print(f\"{more_points.shape=}, {convolution_result.shape=}\")\n",
9761028
"\n",
9771029
"# Plotting\n",
9781030
"fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n",
@@ -1008,14 +1060,14 @@
10081060
},
10091061
{
10101062
"cell_type": "code",
1011-
"execution_count": 15,
1063+
"execution_count": 7,
10121064
"metadata": {
10131065
"colab": {
10141066
"base_uri": "https://localhost:8080/",
10151067
"height": 420
10161068
},
10171069
"id": "oiUu9nDyGnZy",
1018-
"outputId": "f22aa7f2-97c9-4b3a-cd16-a40aec3c5ee6"
1070+
"outputId": "8999b7db-e88a-4b93-ccd9-c1d5e1bb3fd0"
10191071
},
10201072
"outputs": [
10211073
{
@@ -1026,7 +1078,7 @@
10261078
]
10271079
},
10281080
"metadata": {},
1029-
"execution_count": 15
1081+
"execution_count": 7
10301082
},
10311083
{
10321084
"output_type": "display_data",
@@ -1123,14 +1175,14 @@
11231175
},
11241176
{
11251177
"cell_type": "code",
1126-
"execution_count": 17,
1178+
"execution_count": 8,
11271179
"metadata": {
11281180
"colab": {
11291181
"base_uri": "https://localhost:8080/",
11301182
"height": 420
11311183
},
11321184
"id": "qnfWwa9MGnZz",
1133-
"outputId": "6a9b113a-ee16-4a5d-ce93-6c65fa3b696e"
1185+
"outputId": "913a204a-b0f4-447e-adb1-b4004e133eb2"
11341186
},
11351187
"outputs": [
11361188
{
@@ -1141,7 +1193,7 @@
11411193
]
11421194
},
11431195
"metadata": {},
1144-
"execution_count": 17
1196+
"execution_count": 8
11451197
},
11461198
{
11471199
"output_type": "display_data",
@@ -1247,14 +1299,14 @@
12471299
},
12481300
{
12491301
"cell_type": "code",
1250-
"execution_count": 18,
1302+
"execution_count": 9,
12511303
"metadata": {
12521304
"colab": {
12531305
"base_uri": "https://localhost:8080/",
12541306
"height": 430
12551307
},
12561308
"id": "qLPaA52tGnZz",
1257-
"outputId": "0882cead-a86f-4fe8-b4e1-7b16a2cc7818"
1309+
"outputId": "65bfa86b-8896-4d38-e8de-76ad52b93a3c"
12581310
},
12591311
"outputs": [
12601312
{
@@ -1265,7 +1317,7 @@
12651317
]
12661318
},
12671319
"metadata": {},
1268-
"execution_count": 18
1320+
"execution_count": 9
12691321
},
12701322
{
12711323
"output_type": "display_data",
@@ -1329,16 +1381,23 @@
13291381
},
13301382
{
13311383
"cell_type": "code",
1332-
"execution_count": 20,
1384+
"execution_count": 17,
13331385
"metadata": {
13341386
"colab": {
13351387
"base_uri": "https://localhost:8080/",
1336-
"height": 430
1388+
"height": 449
13371389
},
13381390
"id": "lPNUoEYiGnZ0",
1339-
"outputId": "ccd5c1c8-3f2d-45c3-b6cc-c0ee0bd05dbe"
1391+
"outputId": "98067555-f1e0-493b-84c7-01494f3db8d9"
13401392
},
13411393
"outputs": [
1394+
{
1395+
"output_type": "stream",
1396+
"name": "stdout",
1397+
"text": [
1398+
"more_points.shape=torch.Size([1, 1, 36, 36]), res.shape=torch.Size([1, 1, 36, 36])\n"
1399+
]
1400+
},
13421401
{
13431402
"output_type": "execute_result",
13441403
"data": {
@@ -1347,7 +1406,7 @@
13471406
]
13481407
},
13491408
"metadata": {},
1350-
"execution_count": 20
1409+
"execution_count": 17
13511410
},
13521411
{
13531412
"output_type": "display_data",
@@ -1393,6 +1452,7 @@
13931452
"#run the data through\n",
13941453
"res = conv(more_points)\n",
13951454
"\n",
1455+
"print(f\"{more_points.shape=}, {res.shape=}\")\n",
13961456
"\n",
13971457
"fig,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))\n",
13981458
"ax1.imshow(more_points.squeeze(), vmin=0, vmax=60, cmap='Spectral_r')\n",

0 commit comments

Comments
 (0)