diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -815,12 +815,17 @@ // Normalize the summed value by the number of elements grouped in each // pool. - auto poolingOpTy = poolingOp.getType().cast(); - auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); + Value iH = rewriter.create(loc, poolingOp, 1); + Value iW = rewriter.create(loc, poolingOp, 2); + + auto one = rewriter.create(loc, 1); + iH = rewriter.create(loc, iH, one); + iW = rewriter.create(loc, iW, one); Value genericEmptyTensor = rewriter.create( loc, resultTy.getShape(), resultETy, dynamicDims); + auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); auto genericOp = rewriter.create( loc, ArrayRef({resultTy}), ValueRange{poolingOp}, ValueRange{genericEmptyTensor}, @@ -828,60 +833,59 @@ getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto zero = rewriter.create(loc, 0); - auto one = rewriter.create(loc, 1); - auto iH = rewriter.create( - loc, poolingOpTy.getDimSize(1) - 1); - auto iW = rewriter.create( - loc, poolingOpTy.getDimSize(2) - 1); - - // Compute the indices from either end. - auto y0 = rewriter.create(loc, 1); - auto x0 = rewriter.create(loc, 2); - auto y1 = rewriter.create(loc, iH, y0); - auto x1 = rewriter.create(loc, iW, x0); // Determines what the portion of valid input is covered by the // kernel. - auto padFn = [&](Value v, Value x, int64_t pad) -> Value { + auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value { if (pad == 0) - return v; + return valid; auto padVal = rewriter.create(loc, pad); - Value dx = rewriter.create(loc, x, padVal); + Value dpos = rewriter.create(loc, pos, padVal); + + Value cmp = rewriter.create( + loc, arith::CmpIPredicate::slt, dpos, zero); + Value offset = + rewriter.create(loc, cmp, dpos, zero); + return rewriter.create(loc, valid, offset) + ->getResult(0); + }; + auto coverageFn = [&](int64_t i, Value isize) -> Value { + Value strideVal = + rewriter.create(loc, stride[i - 1]); + Value val = + rewriter.create(loc, kernel[i - 1]); + + // Find the position relative to the input tensor's ends. + Value left = rewriter.create(loc, i); + Value right = rewriter.create(loc, isize, left); + left = rewriter.create(loc, left, strideVal); + right = rewriter.create(loc, right, strideVal); + + // Determine how much padding was included. + val = padFn(val, left, pad[i * 2]); + val = padFn(val, right, pad[i * 2 + 1]); Value cmp = rewriter.create( - loc, arith::CmpIPredicate::slt, dx, zero); - Value offset = rewriter.create(loc, cmp, dx, zero); - return rewriter.create(loc, v, offset)->getResult(0); + loc, arith::CmpIPredicate::slt, val, one); + return rewriter.create(loc, cmp, one, val); }; - // Compute the vertical component of coverage. - auto kH0 = rewriter.create(loc, kernel[0]); - auto kH1 = padFn(kH0, y0, pad[2]); - auto kH2 = padFn(kH1, y1, pad[3]); - auto kHCmp = rewriter.create( - loc, arith::CmpIPredicate::slt, kH2, one); - auto kH3 = rewriter.create(loc, kHCmp, one, kH2); - - // compute the horizontal component of coverage. - auto kW0 = rewriter.create(loc, kernel[1]); - auto kW1 = padFn(kW0, x0, pad[4]); - auto kW2 = padFn(kW1, x1, pad[5]); - auto kWCmp = rewriter.create( - loc, arith::CmpIPredicate::slt, kW2, one); - auto kW3 = rewriter.create(loc, kWCmp, one, kW2); + // Compute the indices from either end. + Value kH3 = coverageFn(1, iH); + Value kW3 = coverageFn(2, iW); // Compute the total number of elements and normalize. - Value count = rewriter.create(loc, kH3, kW3); - auto countI = rewriter.create( - loc, rewriter.getI32Type(), count); + auto count = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.create(loc, kH3, kW3)); // Divide by the number of summed values. For floats this is just // a div however for quantized values input normalization had // to be applied. Value poolVal = args[0]; if (accETy.isa()) { - auto countF = rewriter.create(loc, accETy, countI); + auto countF = rewriter.create(loc, accETy, count); poolVal = rewriter.create(loc, poolVal, countF) ->getResult(0); } else { @@ -893,33 +897,52 @@ auto inputZp = rewriter.create( loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp())); Value offset = - rewriter.create(loc, accETy, countI, inputZp); + rewriter.create(loc, accETy, count, inputZp); poolVal = rewriter.create(loc, accETy, poolVal, offset); } - // Compute the multiplier and shift values for the quantization - // normalization. Preferably we would want to compute more bits - // however 32-bits should be enough for compute. Honestly we - // should probably straight divide. - int64_t numerator = ((1 << 30) + 1); - int64_t shift = 30; - - Value numeratorVal = rewriter.create( - loc, rewriter.getI32IntegerAttr(numerator)); - Value multiplierVal = - rewriter - .create(loc, rewriter.getI32Type(), - numeratorVal, countI) - .getResult(); - Value shiftVal = rewriter.create( - loc, rewriter.getI8IntegerAttr(shift)); + // Compute: k = 32 - count_leading_zeros(value - 1) + Value one32 = rewriter.create( + loc, rewriter.getI32IntegerAttr(1)); + Value thirtyTwo32 = rewriter.create( + loc, rewriter.getI32IntegerAttr(32)); + + Value countSubOne = + rewriter.create(loc, count, one32); + Value leadingZeros = + rewriter.create(loc, countSubOne); + Value k = + rewriter.create(loc, thirtyTwo32, leadingZeros); + + // Compute: numerator = ((1 << 30) + 1) << k + Value k64 = + rewriter.create(loc, rewriter.getI64Type(), k); + Value thirtyShiftPlusOne = rewriter.create( + loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); + Value numerator = + rewriter.create(loc, thirtyShiftPlusOne, k64); + + // Compute: scale.multiplier = numerator / value; + Value count64 = rewriter.create( + loc, rewriter.getI64Type(), count); + Value multiplier = + rewriter.create(loc, numerator, count64); + multiplier = rewriter.create( + loc, rewriter.getI32Type(), multiplier); + + // Compute: scale.shift = 30 + k + Value k8 = + rewriter.create(loc, rewriter.getI8Type(), k); + Value thirty8 = rewriter.create( + loc, rewriter.getI8IntegerAttr(30)); + Value shift = rewriter.create(loc, k8, thirty8); auto scaled = rewriter - .create( - loc, rewriter.getI32Type(), poolVal, multiplierVal, - shiftVal, rewriter.getBoolAttr(false)) + .create(loc, rewriter.getI32Type(), + poolVal, multiplier, shift, + rewriter.getBoolAttr(false)) .getResult(); // If we have quantization information we need to apply output diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -200,144 +200,160 @@ %0 = "tosa.max_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>) return } -// ----- - -// CHECK-LABEL: @avg_pool -func.func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) { - // Initial piece computes the sum of the pooling region, with appropriate padding. - // CHECK: [[CONST:%.+]] = arith.constant 0 - // CHECK: [[PAD:%.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] - // CHECK: [[CONST:%.+]] = arith.constant 0 - // CHECK: [[POOLINIT:%.+]] = tensor.empty() - // CHECK: [[FILL:%.+]] = linalg.fill ins([[CONST]]{{.*}}outs([[POOLINIT]] - // CHECK: [[KERNEL:%.+]] = tensor.empty() - // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>) - // CHECK: [[INIT:%.+]] = tensor.empty() - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins([[POOL]] : tensor<1x5x33x62xf32>) outs([[INIT]] : tensor<1x5x33x62xf32>) - // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: f32, - // CHECK: [[ZERO:%.0]] = arith.constant 0 - // CHECK: [[ONE:%.+]] = arith.constant 1 - // CHECK: [[HEIGHT:%.+]] = arith.constant 4 - // CHECK: [[WIDTH:%.+]] = arith.constant 32 - // CHECK: [[IDX1:%.+]] = linalg.index 1 - // CHECK: [[IDX2:%.+]] = linalg.index 2 - - // The large block below computes what portion of the kernel is within non-padded input. - // CHECK: [[NY:%.+]] = arith.subi [[HEIGHT]], [[IDX1]] - // CHECK: [[NX:%.+]] = arith.subi [[WIDTH]], [[IDX2]] - // CHECK: [[KH:%.+]] = arith.constant 4 - // CHECK: [[PAD0:%.+]] = arith.constant 1 - // CHECK: [[SUBP0:%.+]] = arith.subi [[IDX1]], [[PAD0]] - // CHECK: [[P0CMP:%.+]] = arith.cmpi slt, [[SUBP0]], [[ZERO]] - // CHECK: [[SELP0:%.+]] = arith.select [[P0CMP]], [[SUBP0]], [[ZERO]] - // CHECK: [[ADDP0:%.+]] = arith.addi [[KH]], [[SELP0]] - // CHECK: [[PAD1:%.+]] = arith.constant 1 - // CHECK: [[SUBP1:%.+]] = arith.subi [[NY]], [[PAD1]] - // CHECK: [[P1CMP:%.+]] = arith.cmpi slt, [[SUBP1]], [[ZERO]] - // CHECK: [[SELP1:%.+]] = arith.select [[P1CMP]], [[SUBP1]], [[ZERO]] - // CHECK: [[ADDP1:%.+]] = arith.addi [[ADDP0]], [[SELP1]] - // CHECK: [[YCMP:%.+]] = arith.cmpi slt, [[ADDP1]], [[ONE]] - // CHECK: [[YSEL:%.+]] = arith.select [[YCMP]], [[ONE]], [[ADDP1]] - // CHECK: [[KW:%.+]] = arith.constant 4 : index - // CHECK: [[PAD2:%.+]] = arith.constant 1 : index - // CHECK: [[SUBP2:%.+]] = arith.subi [[IDX2]], [[PAD2]] - // CHECK: [[P2CMP:%.+]] = arith.cmpi slt, [[SUBP2]], [[ZERO]] - // CHECK: [[SELP2:%.+]] = arith.select [[P2CMP]], [[SUBP2]], [[ZERO]] - // CHECK: [[ADDP2:%.+]] = arith.addi [[KW]], [[SELP2]] - // CHECK: [[PAD3:%.+]] = arith.constant 1 : index - // CHECK: [[SUBP3:%.+]] = arith.subi [[NX]], [[PAD3]] - // CHECK: [[P3CMP:%.+]] = arith.cmpi slt, [[SUBP3]], [[ZERO]] - // CHECK: [[SELP3:%.+]] = arith.select [[P3CMP]], [[SUBP3]], [[ZERO]] - // CHECK: [[ADDP3:%.+]] = arith.addi [[ADDP2]], [[SELP3]] - // CHECK: [[XCMP:%.+]] = arith.cmpi slt, [[ADDP3]], [[ONE]] - // CHECK: [[XSEL:%.+]] = arith.select [[XCMP]], [[ONE]], [[ADDP3]] - - // Given the valid coverage of the pooling region, normalize the summation. - // CHECK: [[C:%.+]] = arith.muli [[YSEL]], [[XSEL]] - // CHECK: [[CI:%.+]] = arith.index_cast [[C]] - // CHECK: [[CF:%.+]] = arith.sitofp [[CI]] - // CHECK: [[RESULT:%.+]] = arith.divf %[[BBARG1]], [[CF]] - // CHECK: linalg.yield [[RESULT]] - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) - return %0 : tensor<1x5x33x62xf32> -} // ----- -// CHECK-LABEL: @avg_pool_dyn -func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) { - // The calculations remain the same as above, only testing for dyn behavior - // CHECK: %[[C0:.+]] = arith.constant 0 - // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] +// CHECK-LABEL: @avg_pool_f32 +func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) { + // Apply padding to the input: + // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] - // CHECK: %[[POOLINIT:.+]] = tensor.empty(%[[BATCH]]) - // CHECK: %[[FILL:.+]] = linalg.fill - // CHECK: %[[KERNEL:.+]] = tensor.empty() - // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%[[PAD]], %[[KERNEL]] : tensor, tensor<4x4xf32>) outs(%[[FILL]] : tensor) - // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) - // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[POOL]] : tensor) outs(%[[INIT]] : tensor) - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor) -> (tensor) - return %0 : tensor + // CHECK: tensor.yield %[[F0]] : f32 + + // Fill the pooling target: + // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32> + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor<1x5x33x62xf32>) + + // Compute the sum padding: + // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32> + // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum + // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} + // CHECK-SAME: ins(%[[PAD]], %[[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) + // CHECK-SAME: outs(%[[FILL]] : tensor<1x5x33x62xf32>) + + // Compute dimension based constants: + // CHECK: %[[I1:.+]] = arith.constant 1 : index + // CHECK: %[[DIM1:.+]] = tensor.dim %[[POOL]], %[[I1]] + // CHECK: %[[I2:.+]] = arith.constant 2 : index + // CHECK: %[[DIM2:.+]] = tensor.dim %[[POOL]], %[[I2]] + // CHECK: %[[ONE:.+]] = arith.constant 1 : index + // CHECK: %[[HEIGHT:.+]] = arith.subi %[[DIM1]], %[[ONE]] : index + // CHECK: %[[WIDTH:.+]] = arith.subi %[[DIM2]], %[[ONE]] : index + + // Divide the sum pooling by the number of summed values. + // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x5x33x62xf32> + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xf32>) + // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xf32>) + // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32) + // CHECK: %[[ZERO:.+]] = arith.constant 0 + + // Compute how much of the height does not include padding: + // CHECK: %[[STRIDE:.+]] = arith.constant 1 + // CHECK: %[[KSIZE:.+]] = arith.constant 4 + // CHECK: %[[START:.+]] = linalg.index 1 + // CHECK: %[[END:.+]] = arith.subi %[[HEIGHT]], %[[START]] + // CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]] + // CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]] + // CHECK: %[[PAD_START:.+]] = arith.constant 1 + // CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]] + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]] + // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]] + // CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]] + // CHECK: %[[PAD_END:.+]] = arith.constant 1 + // CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]] + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]] + // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]] + // CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]] + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]] + // CHECK: %[[KHEIGHT:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]] + + // Compute how much of the width does not include padding: + // CHECK: %[[STRIDE:.+]] = arith.constant 1 + // CHECK: %[[KSIZE:.+]] = arith.constant 4 + // CHECK: %[[START:.+]] = linalg.index 2 + // CHECK: %[[END:.+]] = arith.subi %[[WIDTH]], %[[START]] + // CHECK: %[[SRC_START:.+]] = arith.muli %[[START]], %[[STRIDE]] + // CHECK: %[[SRC_END:.+]] = arith.muli %[[END]], %[[STRIDE]] + // CHECK: %[[PAD_START:.+]] = arith.constant 1 + // CHECK: %[[START_SUB:.+]] = arith.subi %[[SRC_START]], %[[PAD_START]] + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[START_SUB]], %[[ZERO]] + // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[START_SUB]], %[[ZERO]] + // CHECK: %[[START_OFFSET:.+]] = arith.addi %[[KSIZE]], %[[OFFSET]] + // CHECK: %[[PAD_END:.+]] = arith.constant 1 + // CHECK: %[[END_SUB:.+]] = arith.subi %[[SRC_END]], %[[PAD_END]] + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_SUB]], %[[ZERO]] + // CHECK: %[[OFFSET:.+]] = arith.select %[[CMP]], %[[END_SUB]], %[[ZERO]] + // CHECK: %[[END_OFFSET:.+]] = arith.addi %[[START_OFFSET]], %[[OFFSET]] + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[END_OFFSET]], %[[ONE]] + // CHECK: %[[KWIDTH:.+]] = arith.select %[[CMP]], %[[ONE]], %[[END_OFFSET]] + + // Divide the summed value by the number of values summed. + // CHECK: %[[COUNT:.+]] = arith.muli %[[KHEIGHT]], %[[KWIDTH]] + // CHECK: %[[CAST:.+]] = arith.index_cast %[[COUNT]] + // CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]] + // CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]] + // CHECK: linalg.yield %[[DIV]] + %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) + return %0 : tensor<1x5x33x62xf32> } // ----- -// CHECK-LABEL: @avg_pool_i8 -func.func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () { - - // CHECK: linalg.pooling_nhwc_sum - // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32, - - // CHECK: %[[INZP:.+]] = arith.constant -128 - // CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]] - // CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]] - // CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825 - // CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}} - // CHECK: %[[SHIFT:.+]] = arith.constant 30 - // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} - // CHECK: %[[OUTZP:.+]] = arith.constant -128 - // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]] - // CHECK: %[[MIN:.+]] = arith.constant -128 - // CHECK: %[[MAX:.+]] = arith.constant 127 - // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]] - // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]] - // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]] - // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] - // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] +// CHECK-LABLE: @avg_pool_i8 +func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) { + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK-SAME: indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins(%[[POOL]] : tensor<1x5x33x62xi32>) + // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x5x33x62xi8>) + // CHECK: ^bb0(%[[IN:.+]]: i32, %{{.+}}: i8) + + // Only different behavior is how the division is performed. + // First we compute the mul and shift values for average pool: + // CHECK: %[[COUNT:.+]] = arith.muli %21, %35 + // CHECK: %[[ICAST:.+]] = arith.index_cast %[[COUNT]] + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[C32:.+]] = arith.constant 32 + // CHECK: %[[ISUB:.+]] = arith.subi %[[ICAST]], %[[C1]] + // CHECK: %[[CTLZ:.+]] = math.ctlz %[[ISUB]] + // CHECK: %[[SUB:.+]] = arith.subi %[[C32]], %[[CTLZ]] + // CHECK: %[[EXT:.+]] = arith.extui %[[SUB]] + // CHECK: %[[CBIG:.+]] = arith.constant 1073741825 + // CHECK: %[[SHL:.+]] = arith.shli %[[CBIG]], %[[EXT]] + // CHECK: %[[IEXT:.+]] = arith.extui %[[ICAST]] + // CHECK: %[[DIV:.+]] = arith.divui %[[SHL]], %[[IEXT]] + // CHECK: %[[TRUNC_MUL:.+]] = arith.trunci %[[DIV]] + // CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]] + // CHECK: %[[C30:.+]] = arith.constant 30 + // CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8 + // CHECK: %[[SCALED:.+]] = "tosa.apply_scale"(%[[IN]], %[[TRUNC_MUL]], %[[SHIFT]]) {double_round = false} + + // Perform the normalization. + // CHECK: %[[CMIN:.+]] = arith.constant -128 + // CHECK: %[[CMAX:.+]] = arith.constant 127 + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[SCALED]], %[[CMIN]] + // CHECK: %[[SEL:.+]] = arith.select %[[CMP]], %[[CMIN]], %[[SCALED]] + // CHECK: %[[CMP:.+]] = arith.cmpi slt, %[[CMAX]], %[[SCALED]] + // CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]] + // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]] // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, quantization_info = #tosa.unary_quant, stride = array} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> - return + %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) + return %0 : tensor<1x5x33x62xi8> } // ----- -// CHECK-LABEL: @avg_pool_i16 -func.func @avg_pool_i16(%arg0 : tensor<1x128x128x2xi16>) -> () { - - // CHECK: linalg.pooling_nhwc_sum - // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG1:[a-zA-Z0-9_]+]]: i32, - - // CHECK: %[[INZP:.+]] = arith.constant -128 - // CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]] - // CHECK: %[[OFFSETED:.+]] = arith.subi %[[BBARG1]], %[[INZP_OFF]] - // CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825 - // CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}} - // CHECK: %[[SHIFT:.+]] = arith.constant 30 - // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} - // CHECK: %[[OUTZP:.+]] = arith.constant -128 - // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]] - // CHECK: %[[MIN:.+]] = arith.constant -32768 - // CHECK: %[[MAX:.+]] = arith.constant 32767 - // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]] - // CHECK: %[[CLMP_MIN:.+]] = arith.select %[[CMP_MIN]], %[[MIN]], %[[OUT]] - // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]] - // CHECK: %[[CLMP_MAX:.+]] = arith.select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] - // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] - // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, quantization_info = #tosa.unary_quant, stride = array} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16> - return +// CHECK-LABEL: @avg_pool_dyn +func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] + // CHECK: tensor.yield %[[F0]] + // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor + // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EMPTY]] : tensor) + // CHECK: %[[KERNEL:.+]] = tensor.empty() : tensor<4x4xf32> + // CHECK: %[[POOL:.+]] = linalg.pooling_nhwc_sum + // CHECK-SAME: dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> + // CHECK-SAME: ins(%[[PADDED]], %[[KERNEL]] : tensor, tensor<4x4xf32>) + // CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor + // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor + // CHECK: %[[GENERIC:.+]] = linalg.generic + %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor) -> (tensor) + return %0 : tensor } // -----