diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -35,11 +35,16 @@ return builder.create(loc, x, y); } -static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { - bool isInt = x.getType().isa(); - if (isInt) - return builder.create(loc, x, y); - return builder.create(loc, x, y); +static Value createMul(Location loc, Value x, Value y, Type accType, + OpBuilder &builder) { + // Linalg named ops specify signed extend for named ops. + Value xConvert = + convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false); + Value yConvert = + convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false); + if (accType.isa()) + return builder.create(loc, xConvert, yConvert); + return builder.create(loc, xConvert, yConvert); } // Delinearizes the given composite `index` by the basis specified in `factors`. @@ -185,7 +190,8 @@ /*outputs=*/ValueRange{reshapedOutput}, ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value mul = createMul(loc, args[0], args[1], nestedBuilder); + Value mul = + createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); nestedBuilder.create(nestedLoc, add); }); @@ -468,7 +474,8 @@ /*outputs=*/ValueRange{reshapedOutput}, ArrayRef{lhsMap, rhsMap, resultMap}, genericIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - Value mul = createMul(loc, args[0], args[1], nestedBuilder); + Value mul = + createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); nestedBuilder.create(nestedLoc, add); }); diff --git a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir --- a/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir +++ b/mlir/test/Dialect/Linalg/convert-conv2d-to-img2col.mlir @@ -276,3 +276,41 @@ %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!pdl.operation) -> !pdl.operation %1:2 = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation) } + +// ----- + +// Check for signed extend when the input type is smaller than the accumulator type. + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK: @conv_integer_extend +// CHECK: %[[MATMUL_RESULT:.+]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<1x196x36xi8>, tensor<36x16xi8>) +// CHECK-SAME: outs(%[[COLLAPSED_OUT]] : tensor<1x196x16xi32>) +// CHECK: ^bb0(%[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8, %[[ARG2:.+]]: i32) +// CHECK: %[[EXT0:.+]] = arith.extsi %[[ARG0]] : i8 to i32 +// CHECK: %[[EXT1:.+]] = arith.extsi %[[ARG1]] : i8 to i32 +// CHECK: %[[MUL:.+]] = arith.muli %[[EXT0]], %[[EXT1]] : i32 +// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[ARG2]] : i32 +// CHECK: linalg.yield %[[ADD]] : i32 +// CHECK: } -> tensor<1x196x16xi32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[MATMUL_RESULT]] {{\[}}[0], [1, 2], [3]] : tensor<1x196x16xi32> into tensor<1x14x14x16xi32> +// CHECK: return %[[RESULT]] + +func.func @conv_integer_extend(%arg0: tensor<1x16x16x4xi8>, %arg1: tensor<3x3x4x16xi8>, %arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> { + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xi8>, tensor<3x3x4x16xi8>) + outs(%arg2: tensor<1x14x14x16xi32>) -> tensor<1x14x14x16xi32> + return %0 : tensor<1x14x14x16xi32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %img2col_tensor_producer, %transformed = transform.structured.convert_conv2d_to_img2col %0 : (!pdl.operation) -> (!pdl.operation, !pdl.operation) + transform.print %img2col_tensor_producer {name = "tensor_producer"}: !pdl.operation + transform.print %transformed {name = "transformed"}: !pdl.operation +}