diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2001,6 +2001,145 @@ - !ScalarExpression scalar_arg: K --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: conv_3d_ndhwc_dhwcf_q + cpp_class_name: Conv3DNdhwcDhwcfQOp + doc: |- + Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, + s13, s14] -> (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9 * s10 + s11 * s12, + s13)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, + s13, s14] -> (s3, s7, s11, s13, s14)> + - !LinalgOperandDefConfig + name: IZp + kind: scalar + type_var: I32 + - !LinalgOperandDefConfig + name: KZp + kind: scalar + type_var: I32 + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, + s13, s14] -> (s0, s1, s5, s9, s14)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12, s13, s14] -> (s2, s6, s10)> + default_indices: + - 1 + - 1 + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12, s13, s14] -> (s4, s8, s12)> + default_indices: + - 1 + - 1 + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 + * s8, d3 * s10 + d7 * s12, d8)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> (d5, d6, d7, d8, d4)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6, + s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d1, d2, d3, d4)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: sub + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: IZp + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: sub + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: K + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: KZp +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_1d_nwc_wc cpp_class_name: DepthwiseConv1DNwcWcOp @@ -4441,3 +4580,4 @@ scalar_const: '2.3283063999999999E-10 : f64' - !ScalarExpression scalar_arg: min + diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -99,47 +99,40 @@ } // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D -static SmallVector -inferDynamicDimsForConv(Location loc, Value input, Value weight, - ShapedType resultTy, DenseI64ArrayAttr padAttr, - DenseI64ArrayAttr strideAttr, - DenseI64ArrayAttr dilationAttr, int64_t weightHDim, - int64_t weightWDim, OpBuilder &rewriter) { +static SmallVector inferDynamicDimsForConv( + Location loc, Value input, Value weight, ShapedType resultTy, + ArrayRef padAttr, ArrayRef strideAttr, + ArrayRef dilationAttr, ArrayRef inputSizeDims, + ArrayRef kernelSizeDims, OpBuilder &rewriter) { ShapedType inputTy = input.getType().cast(); Type inputETy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); - int64_t heightDim = 1; - int64_t weightDim = 2; SmallVector dynDims; dynDims.resize(resultTy.getRank()); - for (int i = 0; i < inputRank; i++) { - if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim) - dynDims[i] = rewriter.create(loc, input, i); - } - // Dynamic input height - if (inputTy.isDynamicDim(heightDim)) { - Value initHDim = - rewriter.create(loc, input, heightDim).getResult(); - Value kernelHDim = - rewriter.create(loc, weight, weightHDim).getResult(); - // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) - dynDims[heightDim] = - getConvOutputDim(loc, initHDim, padAttr[0], padAttr[1], kernelHDim, - strideAttr[0], dilationAttr[0], inputETy, rewriter); + for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) { + int64_t inputDim = inputSizeDims[i]; + int64_t kernelDim = kernelSizeDims[i]; + if (inputTy.isDynamicDim(inputDim)) { + auto padTop = padAttr[i * 2]; + auto padBottom = padAttr[i * 2 + 1]; + auto stride = strideAttr[i]; + auto dilation = dilationAttr[i]; + Value initDynDim = rewriter.create(loc, input, inputDim); + Value kernelDynDim = + rewriter.create(loc, weight, kernelDim); + // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) + dynDims[inputDim] = + getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim, + stride, dilation, inputETy, rewriter); + } } - // Dynamic input weight - if (inputTy.isDynamicDim(weightDim)) { - Value initWDim = - rewriter.create(loc, input, weightDim).getResult(); - Value kernelWDim = - rewriter.create(loc, weight, weightWDim).getResult(); - // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x) - dynDims[weightDim] = - getConvOutputDim(loc, initWDim, padAttr[2], padAttr[3], kernelWDim, - strideAttr[1], dilationAttr[1], inputETy, rewriter); + // Get the batch/channels dimensions. + for (int i = 0; i < inputRank; i++) { + if (inputTy.isDynamicDim(i) && !dynDims[i]) + dynDims[i] = rewriter.create(loc, input, i); } SmallVector filteredDims = condenseValues(dynDims); @@ -161,21 +154,23 @@ namespace { -class ConvConverter : public OpConversionPattern { +template +class ConvConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor, + matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); Value weight = op->getOperand(1); Value bias = op->getOperand(2); - ShapedType inputTy = input.getType().cast(); - ShapedType weightTy = weight.getType().cast(); - ShapedType biasTy = bias.getType().cast(); - ShapedType resultTy = op->getResult(0).getType().cast(); + ShapedType inputTy = input.getType().template cast(); + ShapedType weightTy = weight.getType().template cast(); + ShapedType biasTy = bias.getType().template cast(); + ShapedType resultTy = + op->getResult(0).getType().template cast(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -183,7 +178,7 @@ DenseI64ArrayAttr padAttr = op.getPadAttr(); DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr(); DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr(); - bool isQuantized = op->hasAttr("quantization_info"); + bool isQuantized = op.getQuantizationInfo().has_value(); if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -193,17 +188,24 @@ return rewriter.notifyMatchFailure( op, "tosa.conv ops does not support unsigned integer input"); + llvm::SmallVector inputSizeDims; + llvm::SmallVector kernelSizeDims; + for (int i = 1; i < resultTy.getRank() - 1; i++) { + inputSizeDims.push_back(i); + kernelSizeDims.push_back(i); + } + SmallVector filteredDims = inferDynamicDimsForConv( - loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr, - /*weightHDim=*/1, /*weightWDim=*/2, rewriter); + loc, input, weight, resultTy, padAttr.asArrayRef(), + strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(), + inputSizeDims, kernelSizeDims, rewriter); auto weightShape = weightTy.getShape(); // Apply padding as necessary. Attribute zeroAttr = rewriter.getZeroAttr(inputETy); if (isQuantized) { - auto quantizationInfo = - op->getAttr("quantization_info").cast(); + auto quantizationInfo = *op.getQuantizationInfo(); int64_t iZp = quantizationInfo.getInputZp(); int64_t intMin = @@ -230,11 +232,15 @@ // convolution operation. // TODO(suderman): See if this can be efficiently folded - check whether // the input is used anywhere else, if not fold the constant. - SmallVector weightPerm{1, 2, 3, 0}; - SmallVector newWeightShape{weightShape[1], weightShape[2], - weightShape[3], weightShape[0]}; - auto weightPermAttr = DenseIntElementsAttr::get( - RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm); + SmallVector weightPerm; + for (int i = 1; i < resultTy.getRank(); i++) + weightPerm.push_back(i); + weightPerm.push_back(0); + + SmallVector newWeightShape; + for (auto dim : weightPerm) + newWeightShape.push_back(weightShape[dim]); + auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); Value weightPermValue = rewriter.create(loc, weightPermAttr); Type newWeightTy = @@ -256,16 +262,15 @@ ArrayRef dilation = dilationTosaAttr; // Create the convolution op. - auto strideAttr = DenseIntElementsAttr::get( - RankedTensorType::get({2}, rewriter.getI64Type()), stride); - auto dilationAttr = DenseIntElementsAttr::get( - RankedTensorType::get({2}, rewriter.getI64Type()), dilation); + auto strideAttr = rewriter.getI64TensorAttr(stride); + auto dilationAttr = rewriter.getI64TensorAttr(dilation); // Create maps for the bias broadcasting SmallVector indexingMaps; indexingMaps.push_back(AffineMap::get( /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, - {rewriter.getAffineDimExpr(3)}, rewriter.getContext())); + {rewriter.getAffineDimExpr(resultTy.getRank() - 1)}, + rewriter.getContext())); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); @@ -273,8 +278,7 @@ loc, resultTy.getShape(), resultETy, filteredDims); if (isQuantized) { - auto quantizationInfo = - op->getAttr("quantization_info").cast(); + auto quantizationInfo = *op.getQuantizationInfo(); auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); @@ -282,7 +286,7 @@ auto kZpVal = rewriter.create(loc, kZp); Value conv = rewriter - .create( + .create( loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, ValueRange{zeroTensor}, strideAttr, dilationAttr) ->getResult(0); @@ -304,7 +308,7 @@ } Value conv = rewriter - .create( + .create( loc, resultTy, ValueRange{input, weight}, ValueRange{zeroTensor}, strideAttr, dilationAttr) ->getResult(0); @@ -358,8 +362,10 @@ // Compute output dynamic dims SmallVector filteredDims = inferDynamicDimsForConv( - loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr, - 0, 1, rewriter); + loc, input, weight, resultTy, padAttr.asArrayRef(), + strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(), + /*inputSizeDims=*/{1, 2}, + /*kernelSizeDims=*/{0, 1}, rewriter); bool isQuantized = op->hasAttr("quantization_info"); IntegerAttr iZp; @@ -408,11 +414,8 @@ ArrayRef dilation = dilationTosaAttr; // Create the convolution op. - auto strideAttr = DenseIntElementsAttr::get( - RankedTensorType::get({2}, rewriter.getI64Type()), stride); - auto dilationAttr = DenseIntElementsAttr::get( - RankedTensorType::get({2}, rewriter.getI64Type()), dilation); - + auto strideAttr = rewriter.getI64TensorAttr(stride); + auto dilationAttr = rewriter.getI64TensorAttr(dilation); ShapedType linalgConvTy = RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], weightShape[2], weightShape[3]}, @@ -610,8 +613,7 @@ .result(); SmallVector permutation{1, 0}; - auto permutationAttr = DenseIntElementsAttr::get( - RankedTensorType::get({2}, rewriter.getI64Type()), permutation); + auto permutationAttr = rewriter.getI64TensorAttr(permutation); Value permutationValue = rewriter.create(loc, permutationAttr); @@ -966,7 +968,8 @@ RewritePatternSet *patterns) { patterns->add< // clang-format off - ConvConverter, + ConvConverter, + ConvConverter, DepthwiseConvConverter, MatMulConverter, MaxPool2dConverter, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp @@ -51,6 +51,7 @@ // Not every TOSA op can be legalized to linalg. target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -150,6 +150,7 @@ TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) + @linalg_structured_op def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.K, S.N), @@ -162,8 +163,9 @@ """ domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed( - U, B[D.b, D.k, D.n])) + C[D.m, D.n] += TypeFn.cast_signed( + U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])) + @linalg_structured_op def matvec(A=TensorDef(T1, S.M, S.N), @@ -283,6 +285,7 @@ U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( U, K[D.kw, D.c, D.f]) + @linalg_structured_op def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.F, S.C, S.KW), @@ -304,6 +307,7 @@ U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( U, K[D.f, D.c, D.kw]) + @linalg_structured_op def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), @@ -400,13 +404,15 @@ U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) + @linalg_structured_op -def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D grouped convolution. Layout: @@ -420,7 +426,8 @@ domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + @linalg_structured_op def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, @@ -449,6 +456,43 @@ U, K[D.kd, D.kh, D.kw, D.c, D.f]) +@linalg_structured_op +def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, + S.N, + S.OD, + S.OH, + S.OW, + S.F, + output=True), + strides=IndexAttrDef(S.SD, + S.SH, + S.SW, + default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): + """Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * ( + TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) - + TypeFn.cast_signed(U, KZp)) + + @linalg_structured_op def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), @@ -517,7 +561,8 @@ @linalg_structured_op -def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, +def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, + S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.IC, S.KH, S.KW), O=TensorDef(U, @@ -539,7 +584,8 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) + U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) @linalg_structured_op @@ -642,7 +688,11 @@ S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, + O=TensorDef(U, + S.N, + S.OD, + S.OH, + S.OW, output=True), strides=IndexAttrDef(S.SD, S.SH, @@ -667,12 +717,17 @@ @linalg_structured_op -def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, - S.N, S.OD * S.SD + S.KD * S.DD, +def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N, + S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, + O=TensorDef(U, + S.N, + S.OD, + S.OH, + S.OW, + S.CM, output=True), strides=IndexAttrDef(S.SD, S.SH, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -603,3 +603,55 @@ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = array, dilation = array, stride = array} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> return } + +// ----- + +// CHECK-LABEL: @conv3d_f32 +func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () { + // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]> + // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERMS]]) + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 + // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() + // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf + // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>) + // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32> + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xf32>, tensor<1x47x45x43x28xf32>) + // CHECK--SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) { + // CHECK: ^bb0(%[[A1:.+]]: f32, %[[A2:.+]]: f32, %{{.+}}: f32): + // CHECK: %[[ADD:.+]] = arith.addf %[[A1]], %[[A2]] : f32 + // CHECK: linalg.yield %[[ADD]] + %0 = "tosa.conv3d"(%input, %weights, %bias) {pad = array, stride = array, dilation = array} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32> + return +} + +// ----- + +// CHECK-LABEL: @conv3d_i8 +func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () { + // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]> + // CHECK-DAG: %[[TRANSPOSE:.+]] = "tosa.transpose"(%arg1, %[[PERMS]]) + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0 + // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : i32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() + // CHECK-DAG: %[[IZP:.+]] = arith.constant -128 : i32 + // CHECK-DAG: %[[FZP:.+]] = arith.constant 42 : i32 + // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf_q + // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32) + // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32> + // CHECK: %[[GENERIC:.+]] = linalg.generic + // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} + // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xi32>, tensor<1x47x45x43x28xi32>) + // CHECK--SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) { + // CHECK: ^bb0(%[[A1:.+]]: i32, %[[A2:.+]]: i32, %{{.+}}: i32): + // CHECK: %[[ADD:.+]] = arith.addi %[[A1]], %[[A2]] : i32 + // CHECK: linalg.yield %[[ADD]] + %0 = "tosa.conv3d"(%input, %weights, %bias) {pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32> + return +}