diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2263,7 +2263,7 @@ pad.resize(2, 0); getValuesFromIntArrayAttribute(op.pad(), pad); pad.resize(pad.size() + 2, 0); - input = applyPad(loc, input, pad, initialAttr, rewriter); + Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); Value initialValue = rewriter.create(loc, initialAttr); @@ -2273,7 +2273,6 @@ Attribute strideAttr = rewriter.getI64VectorAttr(stride); Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); - int64_t kernelSize = kernel[0] * kernel[1]; // Create the linalg op that performs pooling. Value initTensor = rewriter.create( @@ -2290,7 +2289,7 @@ rewriter .create>( loc, ArrayRef{resultTy}, - ValueRange{input, fakeWindowDims}, filledInitTensor, + ValueRange{paddedInput, fakeWindowDims}, filledInitTensor, dilationAttr, strideAttr) .getOperation()); }; @@ -2324,14 +2323,76 @@ } if (isa(op) && inElementTy.isF32()) { - linalg::LinalgOp poolingOp = - createOp(static_cast(nullptr)); - auto constAttr = DenseElementsAttr::get( - resultTy, static_cast(1.0 / kernelSize)); - auto constant = rewriter.create(loc, constAttr); - auto mul = rewriter.create( - loc, resultTy, poolingOp->getResult(0), constant, 0); - rewriter.replaceOp(op, mul.output()); + Value poolingOp = + createOp(static_cast(nullptr)) + ->getResult(0); + auto poolingOpTy = poolingOp.getType().cast(); + auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); + auto genericOp = rewriter.create( + loc, ArrayRef({resultTy}), ValueRange{}, ValueRange{poolingOp}, + ArrayRef({affineMap}), + getNParallelLoopsAttrs(resultTy.getRank()), + [&](OpBuilder &b, Location loc, ValueRange indices, ValueRange args) { + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + auto iH = rewriter.create( + loc, poolingOpTy.getDimSize(1) - 1); + auto iW = rewriter.create( + loc, poolingOpTy.getDimSize(2) - 1); + + // Compute the indices from either end. + auto y0 = indices[1]; + auto x0 = indices[2]; + auto y1 = rewriter.create(loc, iH, y0); + auto x1 = rewriter.create(loc, iW, x0); + + // Determines what the portion of valid input is covered by the + // kernel. + auto padFn = [&](Value v, Value x, int64_t pad) -> Value { + if (pad == 0) + return v; + + auto padVal = rewriter.create(loc, pad); + Value dx = rewriter.create(loc, x, padVal); + + Value cmp = rewriter.create(loc, CmpIPredicate::slt, + dx, zero); + Value offset = + rewriter.create(loc, cmp, dx, zero); + return rewriter.create(loc, v, offset) + ->getResult(0); + }; + + // Compute the vertical component of coverage. + auto kH0 = rewriter.create(loc, kernel[0]); + auto kH1 = padFn(kH0, y0, pad[2]); + auto kH2 = padFn(kH1, y1, pad[3]); + auto kHCmp = + rewriter.create(loc, CmpIPredicate::slt, kH2, one); + auto kH3 = rewriter.create(loc, kHCmp, one, kH2); + + // compute teh horizontal component of coverage. + auto kW0 = rewriter.create(loc, kernel[1]); + auto kW1 = padFn(kW0, x0, pad[4]); + auto kW2 = padFn(kW1, x1, pad[5]); + auto kWCmp = + rewriter.create(loc, CmpIPredicate::slt, kW2, one); + auto kW3 = rewriter.create(loc, kWCmp, one, kW2); + + // Compute the total number of elements and normalize. + Value count = rewriter.create(loc, kH3, kW3); + auto countI = rewriter.create( + loc, rewriter.getI32Type(), count); + auto countF = + rewriter.create(loc, inElementTy, countI); + + auto div = + rewriter.create(loc, args[0], countF)->getResult(0); + + rewriter.create(loc, div); + }); + + rewriter.replaceOp(op, genericOp.getResult(0)); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1087,17 +1087,59 @@ // ----- // CHECK-LABEL: @avg_pool -func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> () { - // CHECK-DAG: [[CONST:%.+]] = constant 0 - // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 3, 31, 62] - // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]]) - // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [4, 4] - // CHECK: linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x3x31x62xf32>) - // CHECK: constant dense<6.250000e-02> - // CHECK: linalg.generic - // CHECK: mulf - %0 = "tosa.avg_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x3x31x62xf32>) - return +func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) { + // Initial piece computes the sum of the pooling region, with appropriate padding. + // CHECK: [[CONST:%.+]] = constant 0 + // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] + // CHECK: [[CONST:%.+]] = constant 0 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 33, 62] + // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]]) + // CHECK: [[KERNEL:%.+]] = linalg.init_tensor [4, 4] + // CHECK: [[POOL:%.+]] = linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x8x36x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x5x33x62xf32>) + // CHECK: [[GENERIC:%.+]] = linalg.indexed_generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs([[POOL]] : tensor<1x5x33x62xf32>) + // CHECK: [[ZERO:%.0]] = constant 0 + // CHECK: [[ONE:%.+]] = constant 1 + // CHECK: [[HEIGHT:%.+]] = constant 4 + // CHECK: [[WIDTH:%.+]] = constant 32 + + // The large block below computes what portion of the kernel is within non-padded input. + // CHECK: [[NY:%.+]] = subi [[HEIGHT]], %arg2 + // CHECK: [[NX:%.+]] = subi [[WIDTH]], %arg3 + // CHECK: [[KH:%.+]] = constant 4 + // CHECK: [[PAD0:%.+]] = constant 1 + // CHECK: [[SUBP0:%.+]] = subi %arg2, [[PAD0]] + // CHECK: [[P0CMP:%.+]] = cmpi slt, [[SUBP0]], [[ZERO]] + // CHECK: [[SELP0:%.+]] = select [[P0CMP]], [[SUBP0]], [[ZERO]] + // CHECK: [[ADDP0:%.+]] = addi [[KH]], [[SELP0]] + // CHECK: [[PAD1:%.+]] = constant 1 + // CHECK: [[SUBP1:%.+]] = subi [[NY]], [[PAD1]] + // CHECK: [[P1CMP:%.+]] = cmpi slt, [[SUBP1]], [[ZERO]] + // CHECK: [[SELP1:%.+]] = select [[P1CMP]], [[SUBP1]], [[ZERO]] + // CHECK: [[ADDP1:%.+]] = addi [[ADDP0]], [[SELP1]] + // CHECK: [[YCMP:%.+]] = cmpi slt, [[ADDP1]], [[ONE]] + // CHECK: [[YSEL:%.+]] = select [[YCMP]], [[ONE]], [[ADDP1]] + // CHECK: [[KW:%.+]] = constant 4 : index + // CHECK: [[PAD2:%.+]] = constant 1 : index + // CHECK: [[SUBP2:%.+]] = subi %arg3, [[PAD2]] + // CHECK: [[P2CMP:%.+]] = cmpi slt, [[SUBP2]], [[ZERO]] + // CHECK: [[SELP2:%.+]] = select [[P2CMP]], [[SUBP2]], [[ZERO]] + // CHECK: [[ADDP2:%.+]] = addi [[KW]], [[SELP2]] + // CHECK: [[PAD3:%.+]] = constant 1 : index + // CHECK: [[SUBP3:%.+]] = subi [[NX]], [[PAD3]] + // CHECK: [[P3CMP:%.+]] = cmpi slt, [[SUBP3]], [[ZERO]] + // CHECK: [[SELP3:%.+]] = select [[P3CMP]], [[SUBP3]], [[ZERO]] + // CHECK: [[ADDP3:%.+]] = addi [[ADDP2]], [[SELP3]] + // CHECK: [[XCMP:%.+]] = cmpi slt, [[ADDP3]], [[ONE]] + // CHECK: [[XSEL:%.+]] = select [[XCMP]], [[ONE]], [[ADDP3]] + + // Given the valid coverage of the pooling region, normalize the summation. + // CHECK: [[C:%.+]] = muli [[YSEL]], [[XSEL]] + // CHECK: [[CI:%.+]] = index_cast [[C]] + // CHECK: [[CF:%.+]] = sitofp [[CI]] + // CHECK: [[RESULT:%.+]] = divf %arg5, [[CF]] + // CHECK: linalg.yield [[RESULT]] + %0 = "tosa.avg_pool2d"(%arg0) {pad = [1, 1, 1, 1], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) + return %0 : tensor<1x5x33x62xf32> } // -----