diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -970,26 +970,12 @@ weight = rewriter.create(loc, newWeightTy, weight, weightPermValue); - // Broadcast the initial value to the output tensor before convolving. - SmallVector indexingMaps; - indexingMaps.push_back(AffineMap::get( - /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); - indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); - + Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( loc, resultTy.getShape(), resultETy); - - Value biasBroadcast = - rewriter - .create( - loc, resultTy, bias, initTensor, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - nestedBuilder.create(nestedLoc, args[0]); - }) - .getResult(0); + Value zero = rewriter.create(loc, resultZeroAttr); + Value zeroTensor = + rewriter.create(loc, zero, initTensor).getResult(0); // Extract the attributes for convolution. llvm::SmallVector stride, dilation; @@ -1002,7 +988,17 @@ auto dilationAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), dilation); - Value conv; + // Create maps for the bias broadcasting + SmallVector indexingMaps; + indexingMaps.push_back(AffineMap::get( + /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, + {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); + + Value biasInitTensor = rewriter.create( + loc, resultTy.getShape(), resultETy); + if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); @@ -1013,15 +1009,49 @@ auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); - rewriter.replaceOpWithNewOp( - op, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{biasBroadcast}, strideAttr, dilationAttr); + Value conv = + rewriter + .create( + loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + ->getResult(0); + + Value result = + rewriter + .create( + loc, resultTy, ValueRange({bias, conv}), biasInitTensor, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added = + nestedBuilder.create(loc, args[0], args[1]); + nestedBuilder.create(nestedLoc, added); + }) + .getResult(0); + rewriter.replaceOp(op, result); return success(); } - rewriter.replaceOpWithNewOp( - op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast}, - strideAttr, dilationAttr); + Value conv = rewriter + .create( + loc, resultTy, ValueRange{input, weight}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + ->getResult(0); + + Value result = + rewriter + .create( + loc, resultTy, ValueRange({bias, conv}), biasInitTensor, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added = + nestedBuilder.create(loc, args[0], args[1]); + nestedBuilder.create(nestedLoc, added); + }) + .getResult(0); + + rewriter.replaceOp(op, result); return success(); } }; @@ -1288,6 +1318,8 @@ auto weightTy = weight.getType().cast(); auto weightShape = weightTy.getShape(); + auto outputETy = outputTy.getElementType(); + // Creating maps for the output of MatMul and the bias SmallVector indexingMaps; @@ -1297,23 +1329,16 @@ rewriter.getContext())); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank())); - auto initTensor = - rewriter - .create(loc, outputTy.getShape(), - outputTy.getElementType()) - ->getResults(); + auto initTensor = rewriter.create( + loc, outputTy.getShape(), outputTy.getElementType()); - auto linalgOp = - rewriter - .create( - loc, outputTy, bias, initTensor, indexingMaps, - getNParallelLoopsAttrs(outputTy.getRank()), - [&](OpBuilder &nested_builder, Location nested_loc, - ValueRange args) { - nested_builder.create(loc, *args.begin()); - }) - ->getResults(); + // When quantized, the input elemeny type is not the same as the output + Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy); + Value zero = rewriter.create(loc, resultZeroAttr); + Value zeroTensor = + rewriter.create(loc, zero, initTensor).getResult(0); SmallVector permutation{1, 0}; auto permutationAttr = DenseIntElementsAttr::get( @@ -1327,10 +1352,31 @@ Value transposedWeight = rewriter.create( loc, newWeightTy, weight, permutationValue); + auto biasInitTensor = + rewriter + .create(loc, outputTy.getShape(), outputETy) + ->getResults(); + if (!op.quantization_info()) { - rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, ValueRange{input, transposedWeight}, - linalgOp); + Value matmul = rewriter + .create( + loc, TypeRange{op.getType()}, + ValueRange{input, transposedWeight}, zeroTensor) + ->getResult(0); + + Value result = + rewriter + .create( + loc, outputTy, ValueRange({bias, matmul}), biasInitTensor, + indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added = + nestedBuilder.create(loc, args[0], args[1]); + nestedBuilder.create(nestedLoc, added); + }) + .getResult(0); + rewriter.replaceOp(op, result); return success(); } @@ -1341,10 +1387,26 @@ auto outputZp = rewriter.create( loc, rewriter.getI32IntegerAttr( quantizationInfo.weight_zp().getValue().getSExtValue())); - rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, - ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp); - + Value matmul = + rewriter + .create( + loc, TypeRange{op.getType()}, + ValueRange{input, transposedWeight, inputZp, outputZp}, + zeroTensor) + ->getResult(0); + Value result = + rewriter + .create( + loc, outputTy, ValueRange({bias, matmul}), biasInitTensor, + indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added = + nestedBuilder.create(loc, args[0], args[1]); + nestedBuilder.create(nestedLoc, added); + }) + .getResult(0); + rewriter.replaceOp(op, result); return success(); } }; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -993,44 +993,55 @@ // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK-LABEL: @fully_connected func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) { - // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs([[INITB]] : tensor<5x6xf32>) { - // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32): - // CHECK: linalg.yield [[IN]] : f32 + // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[ZERO:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]]) + // CHECK: [[PERM:%.+]] = constant dense<[1, 0]> // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6] - // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs([[INITT]] + // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xf32>) outs([[INITT]] : tensor<3x6xf32>) { // CHECK: ^bb0([[IN:%.+]]: f32, [[UNUSED:%.+]]: f32): // CHECK: linalg.yield [[IN]] : f32 - // CHECK: linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[GENERIC]] : tensor<5x6xf32>) -> tensor<5x6xf32> + // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32> + // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) { + // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + // CHECK: [[ADD:%.+]] = addf %arg3, %arg4 : f32 + // CHECK: linalg.yield [[ADD]] : f32 + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> (tensor<5x6xf32>) return %0 : tensor<5x6xf32> } // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK-LABEL: @quantized_fully_connected func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) { - // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs([[INITB]] : tensor<5x6xi32>) { - // CHECK: ^bb0([[IN:%.+]]: i32, [[UNUSED:%.+]]: i32): - // CHECK: linalg.yield [[IN]] : i32 + // CHECK: [[INITT:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[ZERO:%.+]] = constant 0 + // CHECK: [[FILL:%.+]] = linalg.fill([[ZERO]], [[INITT]]) + // CHECK: [[PERM:%.+]] = constant dense<[1, 0]> // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6] - // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xi8>) outs([[INITT]] + // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xi8>) outs([[INITT]] : tensor<3x6xi8>) { // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): // CHECK: linalg.yield [[IN]] : i8 + // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] // CHECK: [[ONE:%.+]] = constant 1 - // CHECK: [[TWO:%.+]] = constant 2 - // CHECK: linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[GENERIC]] : tensor<5x6xi32>) -> tensor<5x6xi32> + // CHECK: [[TWO:%.+]] = constant 2 + // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32> + // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]] + // CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32): + // CHECK: [[ADD:%.+]] = addi + // CHECK: linalg.yield [[ADD]] : i32 %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = {input_zp = 1:i32, weight_zp = 2:i32}} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> (tensor<5x6xi32>) return %0 : tensor<5x6xi32> } @@ -1350,10 +1361,14 @@ // CHECK: %[[W_IN:.+]] = linalg.init_tensor [3, 3, 27, 28] // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>) // CHECK: linalg.yield %arg3 : f32 + // CHECK: %[[M_IN:.+]] = linalg.init_tensor [1, 45, 40, 28] + // CHECK: %[[CST:.+]] = constant 0 + // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28] - // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) - // CHECK: linalg.yield %arg3 : f32 - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %1 : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[B]] : tensor<1x45x40x28xf32>) + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>) + // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) + // CHECK: addf + // CHECK: linalg.yield %7 : f32 %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>) return }