diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -745,26 +745,28 @@ LogicalResult matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef args, ConversionPatternRewriter &rewriter) const final { - tosa::FullyConnectedOp::Adaptor adaptor(args); - Location loc = op.getLoc(); auto outputTy = op.getType().cast(); - auto biasTy = op->getOperand(2).getType().cast(); + auto input = op.input(); + auto weight = op.weight(); + auto bias = op.bias(); - // Reshaping the bias from n to [1, n] for broadcasting - SmallVector biasShapeReshaped; - biasShapeReshaped.push_back(1); - biasShapeReshaped.push_back(biasTy.getShape()[0]); + auto weightTy = weight.getType().cast(); + auto biasTy = bias.getType().cast(); - RankedTensorType reshapedBias = - RankedTensorType::get(biasShapeReshaped, outputTy.getElementType()); - auto reshapeResult = - rewriter.create(loc, reshapedBias, args[2]) - ->getResult(0); + auto weightShape = weightTy.getShape(); + + if (op.quantization_info()) + return failure(); // Creating maps for the output of MatMul and the bias SmallVector indexingMaps; - indexingMaps.push_back(createAffineMapForType(reshapedBias, rewriter)); + + // Broadcast the bias. + indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, + {rewriter.getAffineDimExpr(1)}, + rewriter.getContext())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); auto initTensor = @@ -776,7 +778,7 @@ auto linalgOp = rewriter .create( - loc, outputTy, reshapeResult, initTensor, indexingMaps, + loc, outputTy, bias, initTensor, indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), [&](OpBuilder &nested_builder, Location nested_loc, ValueRange args) { @@ -784,9 +786,21 @@ }) ->getResults(); + SmallVector permutation{1, 0}; + auto permutationAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), permutation); + Value permutationValue = rewriter.create(loc, permutationAttr); + + SmallVector newWeightShape{weightShape[1], weightShape[0]}; + Type newWeightTy = + RankedTensorType::get(newWeightShape, biasTy.getElementType()); + + Value transposedWeight = rewriter.create( + loc, newWeightTy, weight, permutationValue); + rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, - ValueRange{adaptor.input(), adaptor.weight()}, linalgOp); + op, TypeRange{op.getType()}, ValueRange{input, transposedWeight}, + linalgOp); return success(); } }; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -756,17 +756,22 @@ // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: @fully_connected -func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<3x6xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { - // CHECK: [[RS:%.+]] = linalg.tensor_reshape %arg2 [#[[$MAP0]]] - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 6] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins([[RS]] : tensor<1x6xf32>) outs([[INIT]] : tensor<5x6xf32>) { - // CHECK: ^bb0([[IN:%.+]]: f32, [[MULTIPLIER:%.+]]: f32): - // CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32> - %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<3x6xf32>, tensor<6xf32>) -> (tensor<5x6xf32>) +func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { + // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs([[INITB]] : tensor<5x6xf32>) { + // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32): + // CHECK: linalg.yield [[IN]] : f32 + // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6] + // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs([[INITT]] + // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32): + // CHECK: linalg.yield [[IN]] : f32 + // CHECK: linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32> + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> (tensor<5x6xf32>) return %0 : tensor<5x6xf32> }