diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1671,6 +1671,7 @@ if (!valid) return failure(); + int64_t nSize, wSize, cSize, kwSize; // kernel{kw, c} bindShapeDims(rhsShapedType, kwSize, cSize); @@ -1746,7 +1747,7 @@ // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals[w] = depthwiseConv1dSliceAsFma( + resVals[w] = depthwiseConv1dSliceAsMulAcc( builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); } } @@ -1770,11 +1771,51 @@ .getOperation(); } - /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma. - Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs, - Value rhs, Value res) { - Value bcast = builder.create(loc, res.getType(), rhs); - return b.create(loc, lhs, bcast, res); + // Take a value of element type T and widen to the destination type. + Value promote(OpBuilder &b, Location loc, Value val, Type ty) { + if (val.getType() == ty) + return val; + + const int64_t srcWidth = + getElementTypeOrSelf(val.getType()).getIntOrFloatBitWidth(); + const int64_t destWidth = getElementTypeOrSelf(ty).getIntOrFloatBitWidth(); + + if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) + return builder.create(loc, ty, val); + + if (getElementTypeOrSelf(ty).isa() && srcWidth < destWidth) + return builder.create(loc, ty, val); + + if (getElementTypeOrSelf(ty).isa() && srcWidth > destWidth) + return builder.create(loc, ty, val); + + if (getElementTypeOrSelf(ty).isa() && srcWidth > destWidth) + return builder.create(loc, ty, val); + + return nullptr; + } + + /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc + Value depthwiseConv1dSliceAsMulAcc(OpBuilder &b, Location loc, Value lhs, + Value rhs, Value res) { + auto rhsTy = rhs.getType().cast(); + auto resTy = res.getType().cast(); + + // TODO(suderman): Change this to use a vector.ima intrinsic. + lhs = promote(b, loc, lhs, resTy); + + rhs = builder.create( + loc, resTy.clone(rhsTy.getElementType()), rhs); + rhs = promote(b, loc, rhs, resTy); + + if (!lhs || !rhs) + return nullptr; + + if (resTy.getElementType().isa()) + return b.create(loc, lhs, rhs, res); + + auto mul = b.create(loc, lhs, rhs); + return b.create(loc, mul, res); } /// Entry point that transposes into the common form: diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -463,7 +463,7 @@ // ----- -func.func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) { +func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) { linalg.depthwise_conv_1d_nwc_wc {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>) @@ -471,7 +471,7 @@ return } -// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref +// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref // CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index @@ -502,6 +502,51 @@ // CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] +// ----- + +func.func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref(%input: memref<3x5x4xi8>, %filter: memref<2x4xi8>, %output: memref<3x2x4xi32>) { + linalg.depthwise_conv_1d_nwc_wc + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<3x5x4xi8>, memref<2x4xi8>) + outs(%output : memref<3x2x4xi32>) + return +} + +// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4xi8_memref +// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xi8>, %[[FILTER:[0-9a-z]+]]: memref<2x4xi8>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xi32>) + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + +/// Read the whole data in one shot. +// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]] +// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + +// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8> +// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]] +// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xi8> to vector<3x2x4xi8> + +// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xi8> +// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xi8> + +/// w == 0, kw = +// CHECK: %[[EXT_INPUT_0:.*]] = arith.extsi %[[V_INPUT_0]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xi8> to vector<3x2x4xi8> +// CHECK: %[[EXT_FILTER_0:.*]] = arith.extsi %[[B_FILTER_0]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[MUL_0:.*]] = arith.muli %[[EXT_INPUT_0]], %[[EXT_FILTER_0]] : vector<3x2x4xi32> +// CHECK: %[[ADD_0:.*]] = arith.addi %[[MUL_0]], %[[V_OUTPUT_R]] : vector<3x2x4xi32> + +/// w == 0, kw = 1 +// CHECK: %[[EXT_INPUT_1:.*]] = arith.extsi %[[V_INPUT_1]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xi8> to vector<3x2x4xi8> +// CHECK: %[[EXT_FILTER_1:.*]] = arith.extsi %[[B_FILTER_1]] : vector<3x2x4xi8> to vector<3x2x4xi32> +// CHECK: %[[MUL_1:.*]] = arith.muli %[[EXT_INPUT_1]], %[[EXT_FILTER_1]] : vector<3x2x4xi32> +// CHECK: %[[ADD_1:.*]] = arith.addi %[[MUL_1]], %[[ADD_0]] : vector<3x2x4xi32> + +// Write the result back in one shot. +// CHECK: vector.transfer_write %[[ADD_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]] + // ----- func.func @conv_1d_nwc_wcf_mixed_type_memref(%input: memref<1x2x3xf16>, %filter: memref<1x3x2xf16>, %output: memref<1x2x2xf32>) {