diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -726,6 +726,104 @@ } }; +class Conv2DConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tosa::Conv2DOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value input = op.input(); + Value weight = op.weight(); + Value bias = op.bias(); + + ShapedType inputTy = input.getType().cast(); + ShapedType weightTy = weight.getType().cast(); + ShapedType biasTy = bias.getType().cast(); + ShapedType resultTy = op.getType().cast(); + + Type inputETy = inputTy.getElementType(); + Type weightETy = weightTy.getElementType(); + Type biasETy = biasTy.getElementType(); + Type resultETy = resultTy.getElementType(); + + auto inputShape = inputTy.getShape(); + auto weightShape = weightTy.getShape(); + + // TODO(suderman): Support other types. + if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() || + !resultETy.isF32()) + return failure(); + + // Broadcast the initial value to the output tensor before convolving. + SmallVector indexingMaps; + indexingMaps.push_back(AffineMap::get(/*dimCount=*/4, /*symbolCount=*/0, + {rewriter.getAffineDimExpr(3)}, + rewriter.getContext())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); + + Value initTensor = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType()); + Value biasBroadcast = rewriter + .create( + loc, resultTy, bias, initTensor, indexingMaps, + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, + Location nestedLoc, ValueRange args) { + nestedBuilder.create( + nestedLoc, *args.begin()); + }) + .getResult(0); + + // Transpose weights tensor to be in dim order: spatial dims, + // input channels, and output channels. + SmallVector permutation{1, 2, 3, 0}; + auto permutationAttr = DenseIntElementsAttr::get( + RankedTensorType::get({4}, rewriter.getI64Type()), permutation); + Value permutationValue = rewriter.create(loc, permutationAttr); + + SmallVector newKernelShape{weightShape[1], weightShape[2], + weightShape[3], weightShape[0]}; + Type newKernelTy = RankedTensorType::get(newKernelShape, biasETy); + + Value transposedKernel = rewriter.create( + loc, newKernelTy, weight, permutationValue); + + // Extract the attributes for convolution. + llvm::SmallVector stride, dilation, pad; + getValuesFromIntArrayAttribute(op.stride(), stride); + getValuesFromIntArrayAttribute(op.dilation(), dilation); + getValuesFromIntArrayAttribute(op.pad(), pad); + + // Input should be padded if necessary. + if (llvm::any_of(pad, [](int64_t p) { return p; })) { + llvm::SmallVector newPad{0, 0, pad[0], pad[1], + pad[2], pad[3], 0, 0}; + auto padAttr = DenseIntElementsAttr::get( + RankedTensorType::get({4, 2}, rewriter.getI64Type()), newPad); + Value padValue = rewriter.create(loc, padAttr); + + SmallVector paddedShape{ + inputShape[0], inputShape[1] + pad[0] + pad[1], + inputShape[2] + pad[2] + pad[3], inputShape[3]}; + Type paddedTy = RankedTensorType::get(paddedShape, inputETy); + input = rewriter.create(loc, paddedTy, input, padValue); + } + + auto strideAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), stride); + auto dilationAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), dilation); + + auto convOp = rewriter.create( + loc, resultTy, ValueRange{input, transposedKernel}, + ValueRange{biasBroadcast}, dilationAttr, strideAttr); + + rewriter.replaceOp(op, convOp.getResult(0)); + return success(); + } +}; + class ReshapeConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1424,5 +1522,5 @@ ReduceConverter, ArgMaxConverter, ConcatConverter, PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter, TileConverter, TransposeConverter, MatMulConverter, - FullyConnectedConverter>(patterns->getContext()); + FullyConnectedConverter, Conv2DConverter>(patterns->getContext()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -804,3 +804,31 @@ return } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> + +func @conv2d_f32(%input: tensor<1x49x42x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28] : tensor<1x45x40x28xf32> + // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) + // CHECK: ^bb0(%arg3: f32, %arg4: f32): + // CHECK: linalg.yield %arg3 : f32 + // CHECK: %[[INITKERNEL:.+]] = linalg.init_tensor [3, 3, 28, 28] + // CHECK: %[[TRANSPOSEKERNEL:.+]] = linalg.generic {indexing_maps = [#map2, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x28xf32>) outs(%[[INITKERNEL]] : tensor<3x3x28x28xf32>) + // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSEKERNEL]] : tensor<1x49x42x28xf32>, tensor<3x3x28x28xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>) + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>) + return +} + +// ----- + +func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () { + // CHECK: linalg.pad_tensor %arg0 + // CHECK: linalg.conv_2d_input_nhwc_filter_hwcf + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 1, 1, 1], stride = [1, 1], dilation = [2, 1]} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>) + return +} +