diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -612,6 +612,84 @@ } }; +class MatMulConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tosa::MatMulOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + tosa::MatMulOp::Adaptor adaptor(args); + + Location loc = op.getLoc(); + + auto outputTy = op.getType().cast(); + auto outputElementTy = outputTy.getElementType(); + auto zero_attr = rewriter.getZeroAttr(outputElementTy); + Value zero = rewriter.create(loc, zero_attr); + auto initTensor = rewriter.create( + loc, outputTy.getShape(), outputTy.getElementType()); + Value zeroTensor = + rewriter.create(loc, initTensor, zero).getResult(0); + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, + ValueRange{zeroTensor}); + return success(); + } +}; + +class FullyConnectedConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef args, + ConversionPatternRewriter &rewriter) const final { + tosa::FullyConnectedOp::Adaptor adaptor(args); + + Location loc = op.getLoc(); + auto outputTy = op.getType().cast(); + auto biasTy = op->getOperand(2).getType().cast(); + + // Reshaping the bias from n to [1, n] for broadcasting + SmallVector biasShapeReshaped; + biasShapeReshaped.push_back(1); + biasShapeReshaped.push_back(biasTy.getShape()[0]); + + RankedTensorType reshapedBias = + RankedTensorType::get(biasShapeReshaped, outputTy.getElementType()); + auto reshapeResult = + rewriter.create(loc, reshapedBias, args[2]) + ->getResult(0); + + // Creating maps for the output of MatMul and the bias + SmallVector indexingMaps; + indexingMaps.push_back(createAffineMapForType(reshapedBias, rewriter)); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); + + auto initTensor = + rewriter + .create(loc, outputTy.getShape(), + outputTy.getElementType()) + ->getResults(); + + auto linalgOp = + rewriter + .create( + loc, outputTy, reshapeResult, initTensor, indexingMaps, + getNParallelLoopsAttrs(outputTy.getRank()), + [&](OpBuilder &nested_builder, Location nested_loc, + ValueRange args) { + nested_builder.create(loc, *args.begin()); + }) + ->getResults(); + + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, + ValueRange{adaptor.input(), adaptor.weight()}, linalgOp); + return success(); + } +}; + class ReshapeConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -1041,6 +1119,6 @@ IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ConcatConverter, ReshapeConverter, - RescaleConverter, ReverseConverter, TransposeConverter>( - patterns->getContext()); + RescaleConverter, ReverseConverter, TransposeConverter, MatMulConverter, + FullyConnectedConverter>(patterns->getContext()); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -639,3 +639,32 @@ return } + +// ----- + + +// CHECK-LABEL: @matmul +func @matmul(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { + // CHECK: [[C0:%.+]] = constant 0 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[FILLED:%.+]] = linalg.fill([[INIT]], [[C0]]) : tensor<5x6xf32>, f32 -> tensor<5x6xf32> + // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILLED]] : tensor<5x6xf32>) -> tensor<5x6xf32> + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<5x3xf32>, tensor<3x6xf32>) -> (tensor<5x6xf32>) + return %0 : tensor<5x6xf32> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)> + +// CHECK-LABEL: @fully_connected +func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { + // CHECK: [[RS:%.+]] = linalg.tensor_reshape %arg2 [#[[$MAP0]]] + // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RS]] : tensor<1x6xf32>) outs([[INIT]] : tensor<5x6xf32>) { + // CHECK: ^bb0([[IN:%.+]]: f32, [[MULTIPLIER:%.+]]: f32): + // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32> + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>) + return %0 : tensor<5x6xf32> +}