diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2458,39 +2458,34 @@ } }; -template -class Pool2dConverter : public OpRewritePattern { +class MaxPool2dConverter : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SrcOp op, + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.input(); ShapedType inputTy = input.getType().cast(); - Type inElementTy = inputTy.getElementType(); ShapedType resultTy = op.getType().template cast(); - Type outElementTy = inputTy.getElementType(); + Type resultETy = inputTy.getElementType(); if (!inputTy.hasStaticShape()) return failure(); // Determine what the initial value needs to be for the max pool op. Attribute initialAttr; - if (isa(op) && outElementTy.isF32()) + if (resultETy.isF32()) initialAttr = rewriter.getFloatAttr( - outElementTy, - APFloat::getLargest( - outElementTy.cast().getFloatSemantics(), true)); + resultETy, + APFloat::getLargest(resultETy.cast().getFloatSemantics(), + true)); - if (isa(op) && outElementTy.isa()) + if (resultETy.isa()) initialAttr = rewriter.getIntegerAttr( - outElementTy, - APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth())); - - if (isa(op) && outElementTy.isa()) - initialAttr = rewriter.getZeroAttr(outElementTy); + resultETy, + APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth())); if (!initialAttr) return rewriter.notifyMatchFailure( @@ -2520,93 +2515,216 @@ rewriter.create(loc, initialValue, initTensor).result(); Value fakeWindowDims = - rewriter.create(loc, kernel, outElementTy); + rewriter.create(loc, kernel, resultETy); - if (isa(op)) { - rewriter.replaceOpWithNewOp( - op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, - filledInitTensor, strideAttr, dilationAttr); - return success(); - } + rewriter.replaceOpWithNewOp( + op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, + filledInitTensor, strideAttr, dilationAttr); + return success(); + } +}; - if (isa(op) && inElementTy.isF32()) { - Value poolingOp = rewriter - .create( - loc, ArrayRef{resultTy}, - ValueRange{paddedInput, fakeWindowDims}, - filledInitTensor, strideAttr, dilationAttr) - .getResult(0); - auto poolingOpTy = poolingOp.getType().cast(); - auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); - auto genericOp = rewriter.create( - loc, ArrayRef({resultTy}), ValueRange{}, ValueRange{poolingOp}, - ArrayRef({affineMap}), - getNParallelLoopsAttrs(resultTy.getRank()), - [&](OpBuilder &b, Location loc, ValueRange args) { - auto zero = rewriter.create(loc, 0); - auto one = rewriter.create(loc, 1); - auto iH = rewriter.create( - loc, poolingOpTy.getDimSize(1) - 1); - auto iW = rewriter.create( - loc, poolingOpTy.getDimSize(2) - 1); - - // Compute the indices from either end. - auto y0 = rewriter.create(loc, 1); - auto x0 = rewriter.create(loc, 2); - auto y1 = rewriter.create(loc, iH, y0); - auto x1 = rewriter.create(loc, iW, x0); - - // Determines what the portion of valid input is covered by the - // kernel. - auto padFn = [&](Value v, Value x, int64_t pad) -> Value { - if (pad == 0) - return v; - - auto padVal = rewriter.create(loc, pad); - Value dx = rewriter.create(loc, x, padVal); - - Value cmp = rewriter.create(loc, CmpIPredicate::slt, - dx, zero); - Value offset = - rewriter.create(loc, cmp, dx, zero); - return rewriter.create(loc, v, offset) - ->getResult(0); - }; - - // Compute the vertical component of coverage. - auto kH0 = rewriter.create(loc, kernel[0]); - auto kH1 = padFn(kH0, y0, pad[2]); - auto kH2 = padFn(kH1, y1, pad[3]); - auto kHCmp = - rewriter.create(loc, CmpIPredicate::slt, kH2, one); - auto kH3 = rewriter.create(loc, kHCmp, one, kH2); - - // compute teh horizontal component of coverage. - auto kW0 = rewriter.create(loc, kernel[1]); - auto kW1 = padFn(kW0, x0, pad[4]); - auto kW2 = padFn(kW1, x1, pad[5]); - auto kWCmp = - rewriter.create(loc, CmpIPredicate::slt, kW2, one); - auto kW3 = rewriter.create(loc, kWCmp, one, kW2); - - // Compute the total number of elements and normalize. - Value count = rewriter.create(loc, kH3, kW3); - auto countI = rewriter.create( - loc, rewriter.getI32Type(), count); +class AvgPool2dConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::AvgPool2dOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value input = op.input(); + ShapedType inputTy = input.getType().cast(); + Type inElementTy = inputTy.getElementType(); + + ShapedType resultTy = op.getType().template cast(); + Type resultETy = inputTy.getElementType(); + + Type accETy = + inElementTy.isa() ? rewriter.getI32Type() : inElementTy; + ShapedType accTy = resultTy.clone(accETy); + + if (!inputTy.hasStaticShape()) + return failure(); + + // Apply padding as necessary. + llvm::SmallVector pad; + pad.resize(2, 0); + getValuesFromIntArrayAttribute(op.pad(), pad); + pad.resize(pad.size() + 2, 0); + Attribute initialAttr = rewriter.getZeroAttr(accETy); + Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); + + Value initialValue = rewriter.create(loc, initialAttr); + + SmallVector kernel, stride; + getValuesFromIntArrayAttribute(op.kernel(), kernel); + getValuesFromIntArrayAttribute(op.stride(), stride); + + Attribute strideAttr = rewriter.getI64VectorAttr(stride); + Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); + + // Create the linalg op that performs pooling. + Value poolInitTensor = + rewriter.create(loc, accTy.getShape(), accETy); + + Value filledInitTensor = + rewriter.create(loc, initialValue, poolInitTensor) + .result(); + + Value fakeWindowDims = + rewriter.create(loc, kernel, accETy); + + // Sum across the pooled region. + Value poolingOp = rewriter + .create( + loc, ArrayRef{accTy}, + ValueRange{paddedInput, fakeWindowDims}, + filledInitTensor, strideAttr, dilationAttr) + .getResult(0); + + // Normalize the summed value by the number of elements grouped in each + // pool. + auto poolingOpTy = poolingOp.getType().cast(); + auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); + + Value genericInitTensor = rewriter.create( + loc, resultTy.getShape(), resultETy); + + auto genericOp = rewriter.create( + loc, ArrayRef({resultTy}), ValueRange{poolingOp}, + ValueRange{genericInitTensor}, + ArrayRef({affineMap, affineMap}), + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &b, Location loc, ValueRange args) { + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto iH = rewriter.create( + loc, poolingOpTy.getDimSize(1) - 1); + auto iW = rewriter.create( + loc, poolingOpTy.getDimSize(2) - 1); + + // Compute the indices from either end. + auto y0 = rewriter.create(loc, 1); + auto x0 = rewriter.create(loc, 2); + auto y1 = rewriter.create(loc, iH, y0); + auto x1 = rewriter.create(loc, iW, x0); + + // Determines what the portion of valid input is covered by the + // kernel. + auto padFn = [&](Value v, Value x, int64_t pad) -> Value { + if (pad == 0) + return v; + + auto padVal = rewriter.create(loc, pad); + Value dx = rewriter.create(loc, x, padVal); + + Value cmp = rewriter.create(loc, CmpIPredicate::slt, + dx, zero); + Value offset = rewriter.create(loc, cmp, dx, zero); + return rewriter.create(loc, v, offset)->getResult(0); + }; + + // Compute the vertical component of coverage. + auto kH0 = rewriter.create(loc, kernel[0]); + auto kH1 = padFn(kH0, y0, pad[2]); + auto kH2 = padFn(kH1, y1, pad[3]); + auto kHCmp = + rewriter.create(loc, CmpIPredicate::slt, kH2, one); + auto kH3 = rewriter.create(loc, kHCmp, one, kH2); + + // compute the horizontal component of coverage. + auto kW0 = rewriter.create(loc, kernel[1]); + auto kW1 = padFn(kW0, x0, pad[4]); + auto kW2 = padFn(kW1, x1, pad[5]); + auto kWCmp = + rewriter.create(loc, CmpIPredicate::slt, kW2, one); + auto kW3 = rewriter.create(loc, kWCmp, one, kW2); + + // Compute the total number of elements and normalize. + Value count = rewriter.create(loc, kH3, kW3); + auto countI = rewriter.create( + loc, rewriter.getI32Type(), count); + + // Divide by the number of summed values. For floats this is just + // a div however for quantized values input normalization had + // to be applied. + Value poolVal = args[0]; + if (accETy.isa()) { auto countF = rewriter.create(loc, inElementTy, countI); + poolVal = + rewriter.create(loc, poolVal, countF)->getResult(0); + } else { - auto div = - rewriter.create(loc, args[0], countF)->getResult(0); + // If we have quantization information we need to apply an offset + // for the input zp value. + if (op.quantization_info()) { + auto quantizationInfo = op.quantization_info().getValue(); + auto inputZp = rewriter.create( + loc, quantizationInfo.input_zp()); + Value offset = + rewriter.create(loc, accETy, countI, inputZp); + poolVal = rewriter.create(loc, accETy, poolVal, offset); + } - rewriter.create(loc, div); - }); + // Compute the multiplier and shift values for the quantization + // normalization. Preferably we would want to compute more bits + // however 32-bits should be enough for compute. Honestly we + // should probably straight divide. + int64_t numerator = ((1 << 30) + 1); + int64_t shift = 30; + + Value numeratorVal = rewriter.create( + loc, rewriter.getI32IntegerAttr(numerator)); + Value multiplierVal = + rewriter + .create(loc, rewriter.getI32Type(), + numeratorVal, countI) + .getResult(); + Value shiftVal = rewriter.create( + loc, rewriter.getI8IntegerAttr(shift)); + + auto scaled = + rewriter + .create( + loc, rewriter.getI32Type(), poolVal, multiplierVal, + shiftVal, rewriter.getBoolAttr(false)) + .getResult(); + + // If we have quantization information we need to apply output + // zeropoint. + if (op.quantization_info()) { + auto quantizationInfo = op.quantization_info().getValue(); + auto outputZp = rewriter.create( + loc, quantizationInfo.output_zp()); + scaled = + rewriter.create(loc, scaled, outputZp).getResult(); + } - rewriter.replaceOp(op, genericOp.getResult(0)); - return success(); - } + // Apply Clip. + int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); + + auto min = rewriter.create( + loc, rewriter.getIntegerAttr( + accETy, + APInt::getSignedMinValue(outBitwidth).getSExtValue())); + auto max = rewriter.create( + loc, rewriter.getIntegerAttr( + accETy, + APInt::getSignedMaxValue(outBitwidth).getSExtValue())); + auto clamp = clampHelper( + loc, scaled, min, max, CmpIPredicate::slt, rewriter); + + // Convert type. + poolVal = rewriter.create(loc, resultETy, clamp); + } - return failure(); + // Cast to output type. + + rewriter.create(loc, poolVal); + }); + + rewriter.replaceOp(op, genericOp.getResult(0)); + return success(); } }; @@ -2673,8 +2791,8 @@ TileConverter, TransposeConverter, MatMulConverter, - Pool2dConverter, - Pool2dConverter, + MaxPool2dConverter, + AvgPool2dConverter, FullyConnectedConverter>(patterns->getContext()); // clang-format on } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1199,11 +1199,12 @@ // CHECK: [[CONST:%.+]] = constant 0 // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: [[CONST:%.+]] = constant 0 - // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] - // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[INIT]]) + // CHECK: [[POOLINIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] + // CHECK: [[FILL:%.+]] = linalg.fill([[CONST]], [[POOLINIT]]) // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4] // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>) - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs([[POOL]] : tensor<1x5x33x62xf32>) + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>) // CHECK: [[ZERO:%.0]] = constant 0 // CHECK: [[ONE:%.+]] = constant 1 // CHECK: [[HEIGHT:%.+]] = constant 4 @@ -1253,17 +1254,46 @@ // ----- -// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-LABEL: @avg_pool_i8 +func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () { + + // CHECK: linalg.pooling_nhwc_sum + // CHECK: linalg.generic + + // CHECK: %[[INZP:.+]] = constant -128 + // CHECK: %[[INZP_OFF:.+]] = muli %{{.+}}, %[[INZP]] + // CHECK: %[[OFFSETED:.+]] = subi %arg1, %[[INZP_OFF]] + // CHECK: %[[NUMERATOR:.+]] = constant 1073741825 + // CHECK: %[[MULTIPLIER:.+]] = divi_unsigned %[[NUMERATOR]], %{{.+}} + // CHECK: %[[SHIFT:.+]] = constant 30 + // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} + // CHECK: %[[OUTZP:.+]] = constant -128 + // CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]] + // CHECK: %[[MIN:.+]] = constant -128 + // CHECK: %[[MAX:.+]] = constant 127 + // CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]] + // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] + // CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]] + // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] + // CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]] + // CHECK: linalg.yield %[[TRUNC]] + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> + return +} + +// ----- + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> -// CHECK-LABEL @conv2d_f32 +// CHECK-LABEL: @conv2d_f32 func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[W_IN:.+]] = linalg.init_tensor [3, 3, 27, 28] - // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>) + // CHECK: %[[W:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[W_IN]] : tensor<3x3x27x28xf32>) // CHECK: linalg.yield %arg3 : f32 // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, 45, 40, 28] - // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) + // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) // CHECK: linalg.yield %arg3 : f32 // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %1 : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[B]] : tensor<1x45x40x28xf32>) %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>)