diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -432,5 +432,55 @@ ++idx; } + // Check if given shapes match to inferred shapes. + Optional> loopRanges = linalgOp.getStaticLoopRanges(); + if (!loopRanges) + return linalgOp.emitError("unable to find loop range for operation"); + + // Verify only static cases since we can't get exact dimension sizes and loop + // ranges for dynamic cases in this stage. + if (llvm::none_of(*loopRanges, [](int64_t &range) { + return range == ShapedType::kDynamicSize; + })) { + for (int64_t &range : *loopRanges) + range -= 1; + for (const auto &en : llvm::enumerate(linalgOp.getShapedOperandTypes())) { + auto indices = indexingMaps[en.index()].compose(*loopRanges); + for (auto j : llvm::seq(0, en.value().getRank())) { + + // Ignore dynamic dimension or the case that the inferred last index is + // zero. The index is increasing or decreasing in Linalg, for example, + // the last index should be `0` or `size-1`. We only check the cases + // that are non-zero because most of cases are increasing and it is too + // expensive to find the shape of decreasing cases. + if (en.value().isDynamicDim(j) || indices[j] == 0) + continue; + + // The size of shaped operands and inferred dimension size should be + // same. But, for now we check if the inferred sizes are in boundary of + // shaped operands' size or not in case that Affine Expressions are + // complicated such as d0 * 3 + d1 since it is not easy to handle the + // issues. + auto inferredSize = indices[j] + 1; + auto shapedDimSize = en.value().getDimSize(j); + if (indexingMaps[en.index()].getResult(j).dyn_cast()) { + if (inferredSize != shapedDimSize) { + return linalgOp.emitOpError("inferred shaped operand #") + << en.index() << " has shape's dimension #" << j << " to be " + << inferredSize << ", but found " + << shapedDimSize; + } + } else { + if (inferredSize > shapedDimSize) { + return linalgOp.emitOpError("inferred shaped operand #") + << en.index() << " has shape's dimension #" << j + << " to be greater than or equal to " << inferredSize + << ", but found " << shapedDimSize; + } + } + } + } + } + return success(); } diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir --- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir +++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir @@ -28,7 +28,7 @@ scf.for %arg10 = %c0 to %10 step %c4 { %14 = memref.subview %5[%arg8, %arg10][%c2, %c4][%c1, %c1] : memref to memref %16 = memref.subview %7[%arg10, %arg9][%c4, %c3][%c1, %c1]: memref to memref - %17 = memref.subview %8[%arg8, %arg9][%c2, %c4][%c1, %c1] : memref to memref + %17 = memref.subview %8[%arg8, %arg9][%c2, %c3][%c1, %c1] : memref to memref linalg.matmul ins(%14, %16: memref, memref) outs(%17: memref) } diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s -func @generalize_conv(%input : memref<1x225x225x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) { - linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x225x225x3xf32>, memref<1x112x112x32xf32> +func @generalize_conv(%input : memref<1x449x562x3xf32>, %filter: memref<3x3x3x32xf32>, %output: memref<1x112x112x32xf32>) { + linalg.conv(%filter, %input, %output) {dilations = [2, 3], strides = [4, 5]} : memref<3x3x3x32xf32>, memref<1x449x562x3xf32>, memref<1x112x112x32xf32> return } @@ -10,7 +10,7 @@ // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> // CHECK: func @generalize_conv -// CHECK-SAME: %[[INPUT:.+]]: memref<1x225x225x3xf32> +// CHECK-SAME: %[[INPUT:.+]]: memref<1x449x562x3xf32> // CHECK-SAME: %[[FILTER:.+]]: memref<3x3x3x32xf32> // CHECK-SAME: %[[OUTPUT:.+]]: memref<1x112x112x32xf32> diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -703,3 +703,23 @@ %0 = linalg.fill(%arg0, %arg1) : tensor, f32 -> memref return %0 : memref } + +// ----- + +func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) { + // expected-error @+1 {{inferred shaped operand #1 has shape's dimension #0 to be 4, but found 3}} + linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>) + outs(%arg2 :memref<2x4xf32>) + return +} + +// ----- + +func @invalid_static_2d_conv(%input : memref<1x3x4x2xf32>, %filter: memref<3x2x2x1xf32>, %output: memref<1x2x3x1xf32>) { + // expected-error @+1 {{inferred shaped operand #0 has shape's dimension #1 to be greater than or equal to 4, but found 3}} + linalg.conv_2d_input_nhwc_filter_hwcf + { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %filter : memref<1x3x4x2xf32>, memref<3x2x2x1xf32>) + outs(%output : memref<1x2x3x1xf32>) + return +} diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -282,15 +282,15 @@ // CHECK: %{{.+}} = linalg.pooling_nhwc_sum // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>) +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> -func @pooling_nhwc_sum_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> { +func @pooling_nhwc_sum_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> { %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> %cst = constant 0.000000e+00 : f32 %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>) + ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>) outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> return %res : tensor<1x2x2x1xf32> } @@ -301,11 +301,11 @@ // CHECK: linalg.pooling_nhwc_sum // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>) +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>) -func @pooling_nhwc_sum(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { +func @pooling_nhwc_sum(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>) + ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>) outs(%output: memref<1x2x2x1xf32>) return } @@ -316,15 +316,15 @@ // CHECK: %{{.+}} = linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>) +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> -func @pooling_nhwc_max_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> { +func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> { %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> %cst = constant 0.000000e+00 : f32 %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>) + ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>) outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> return %res : tensor<1x2x2x1xf32> } @@ -335,11 +335,11 @@ // CHECK: linalg.pooling_nhwc_max // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>) +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>) -func @pooling_nhwc_max(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { +func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>) + ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>) outs(%output: memref<1x2x2x1xf32>) return } @@ -350,15 +350,15 @@ // CHECK: %{{.+}} = linalg.pooling_nhwc_min // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>) +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xf32>, tensor<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> -func @pooling_nhwc_min_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> { +func @pooling_nhwc_min_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32> { %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> %cst = constant 0.000000e+00 : f32 %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32> %res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>) + ins(%input, %fake: tensor<1x4x4x1xf32>, tensor<3x3xf32>) outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> return %res : tensor<1x2x2x1xf32> } @@ -369,11 +369,11 @@ // CHECK: linalg.pooling_nhwc_min // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>) +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xf32>, memref<3x3xf32>) // CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>) -func @pooling_nhwc_min(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { +func @pooling_nhwc_min(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} - ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>) + ins(%input, %fake: memref<1x4x4x1xf32>, memref<3x3xf32>) outs(%output: memref<1x2x2x1xf32>) return } diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -168,9 +168,9 @@ #map0 = affine_map<(d0, d1, d2) -> (d0)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32> %2 = linalg.generic @@ -183,9 +183,9 @@ return %2 : tensor<5x7x3xf32> } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @generic_op_120_permultation_reshape_producer_fusion +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK: func @generic_op_120_permutation_reshape_producer_fusion // CHECK-NOT: linalg.tensor_reshape // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] diff --git a/mlir/test/Dialect/Linalg/sparse_nd.mlir b/mlir/test/Dialect/Linalg/sparse_nd.mlir --- a/mlir/test/Dialect/Linalg/sparse_nd.mlir +++ b/mlir/test/Dialect/Linalg/sparse_nd.mlir @@ -21,7 +21,7 @@ // CHECK-LABEL: func @mul( // CHECK-SAME: %[[VAL_0:.*0]]: tensor<10x20x30x40x50x60x70x80xf32>, -// CHECK-SAME: %[[VAL_1:.*1]]: tensor<10x20x30x40x50x60x70x80xf32>, +// CHECK-SAME: %[[VAL_1:.*1]]: tensor<80x70x60x50x40x30x20x10xf32>, // CHECK-SAME: %[[VAL_2:.*2]]: tensor<10x20x30x40x50x60x70x80xf32>) -> tensor<10x20x30x40x50x60x70x80xf32> { // CHECK: %[[VAL_3:.*]] = constant 3 : index // CHECK: %[[VAL_4:.*]] = constant 4 : index @@ -34,11 +34,11 @@ // CHECK: %[[VAL_11:.*]] = constant 0 : index // CHECK: %[[VAL_12:.*]] = constant 1 : index // CHECK: %[[VAL_13:.*]] = memref.buffer_cast %[[VAL_0]] : memref<10x20x30x40x50x60x70x80xf32> -// CHECK: %[[VAL_14:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_3]] : tensor<10x20x30x40x50x60x70x80xf32> to memref -// CHECK: %[[VAL_15:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_3]] : tensor<10x20x30x40x50x60x70x80xf32> to memref -// CHECK: %[[VAL_16:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_4]] : tensor<10x20x30x40x50x60x70x80xf32> to memref -// CHECK: %[[VAL_17:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_4]] : tensor<10x20x30x40x50x60x70x80xf32> to memref -// CHECK: %[[VAL_18:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<10x20x30x40x50x60x70x80xf32> to memref +// CHECK: %[[VAL_14:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref +// CHECK: %[[VAL_15:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_3]] : tensor<80x70x60x50x40x30x20x10xf32> to memref +// CHECK: %[[VAL_16:.*]] = linalg.sparse_pointers %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref +// CHECK: %[[VAL_17:.*]] = linalg.sparse_indices %[[VAL_1]], %[[VAL_4]] : tensor<80x70x60x50x40x30x20x10xf32> to memref +// CHECK: %[[VAL_18:.*]] = linalg.sparse_values %[[VAL_1]] : tensor<80x70x60x50x40x30x20x10xf32> to memref // CHECK: %[[VAL_19:.*]] = memref.buffer_cast %[[VAL_2]] : memref<10x20x30x40x50x60x70x80xf32> // CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<10x20x30x40x50x60x70x80xf32> // CHECK: linalg.copy(%[[VAL_19]], %[[VAL_20]]) : memref<10x20x30x40x50x60x70x80xf32>, memref<10x20x30x40x50x60x70x80xf32> @@ -84,12 +84,12 @@ // CHECK: return %[[VAL_50]] : tensor<10x20x30x40x50x60x70x80xf32> // CHECK: } func @mul(%arga: tensor<10x20x30x40x50x60x70x80xf32>, - %argb: tensor<10x20x30x40x50x60x70x80xf32>, + %argb: tensor<80x70x60x50x40x30x20x10xf32>, %argx: tensor<10x20x30x40x50x60x70x80xf32>) -> tensor<10x20x30x40x50x60x70x80xf32> { %0 = linalg.generic #trait_mul ins(%arga, %argb: tensor<10x20x30x40x50x60x70x80xf32>, - tensor<10x20x30x40x50x60x70x80xf32>) + tensor<80x70x60x50x40x30x20x10xf32>) outs(%argx: tensor<10x20x30x40x50x60x70x80xf32>) { ^bb(%a: f32, %b: f32, %x: f32): %0 = mulf %a, %b : f32 diff --git a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir b/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir --- a/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir +++ b/mlir/test/Dialect/Linalg/tile-indexed-generic.mlir @@ -54,10 +54,10 @@ ], iterator_types = ["parallel", "parallel"] } -func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) { +func @indexed_generic_matrix(%operand: memref<50x99xf32>, %result: memref<50x50xf32>) { linalg.indexed_generic #combined_indices_trait - ins(%operand : memref<50x100xf32>) - outs(%result : memref<50x100xf32>) { + ins(%operand : memref<50x99xf32>) + outs(%result : memref<50x50xf32>) { ^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32): %i_int = index_cast %i: index to i32 %i_float = sitofp %i_int : i32 to f32