diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -44,6 +44,20 @@ static FailureOr vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp); +/// Return the given vector type if `elementType` is valid. +static FailureOr getVectorType(ArrayRef shape, + Type elementType) { + if (!VectorType::isValidElementType(elementType)) { + return failure(); + } + return VectorType::get(shape, elementType); +} + +/// Cast the given type to a vector type if its element type is valid. +static FailureOr getVectorType(ShapedType type) { + return getVectorType(type.getShape(), type.getElementType()); +} + /// Return the unique instance of OpType in `block` if it is indeed unique. /// Return null if none or more than 1 instances exist. template @@ -431,8 +445,7 @@ vector::BroadcastableToResult::Success) return value; Location loc = b.getInsertionPoint()->getLoc(); - return b.createOrFold(loc, targetVectorType, - value); + return b.createOrFold(loc, targetVectorType, value); } /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This @@ -885,18 +898,20 @@ } auto readType = - VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get())); + getVectorType(readVecShape, getElementTypeOrSelf(opOperand->get())); + if (!succeeded(readType)) + return failure(); SmallVector indices(linalgOp.getShape(opOperand).size(), zero); Operation *read = rewriter.create( - loc, readType, opOperand->get(), indices, readMap); + loc, *readType, opOperand->get(), indices, readMap); read = state.maskOperation(rewriter, read, linalgOp, maskingMap); Value readValue = read->getResult(0); // 3.b. If masked, set in-bounds to true. Masking guarantees that the access // will be in-bounds. if (auto maskOp = dyn_cast(read)) { - SmallVector inBounds(readType.getRank(), true); + SmallVector inBounds(readType->getRank(), true); cast(maskOp.getMaskableOp()) .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); } @@ -1135,21 +1150,22 @@ if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); - auto readType = - VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType)); - auto writeType = - VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType)); + auto readType = getVectorType(srcType); + auto writeType = getVectorType(dstType); + if (!(succeeded(readType) && succeeded(writeType))) + return failure(); Location loc = copyOp->getLoc(); Value zero = rewriter.create(loc, 0); SmallVector indices(srcType.getRank(), zero); Value readValue = rewriter.create( - loc, readType, copyOp.getSource(), indices, + loc, *readType, copyOp.getSource(), indices, rewriter.getMultiDimIdentityMap(srcType.getRank())); if (readValue.getType().cast().getRank() == 0) { readValue = rewriter.create(loc, readValue); - readValue = rewriter.create(loc, writeType, readValue); + readValue = + rewriter.create(loc, *writeType, readValue); } Operation *writeValue = rewriter.create( loc, readValue, copyOp.getTarget(), indices, @@ -1200,6 +1216,10 @@ auto sourceType = padOp.getSourceType(); auto resultType = padOp.getResultType(); + // Complex is not a valid vector element type. + if (!VectorType::isValidElementType(sourceType.getElementType())) + return failure(); + // Copy cannot be vectorized if pad value is non-constant and source shape // is dynamic. In case of a dynamic source shape, padding must be appended // by TransferReadOp, but TransferReadOp supports only constant padding. @@ -1548,15 +1568,17 @@ if (insertOp.getDest() == padOp.getResult()) return failure(); - auto vecType = VectorType::get(padOp.getType().getShape(), - padOp.getType().getElementType()); - unsigned vecRank = vecType.getRank(); + auto vecType = getVectorType(padOp.getType()); + if (!succeeded(vecType)) + return failure(); + unsigned vecRank = vecType->getRank(); unsigned tensorRank = insertOp.getType().getRank(); // Check if sizes match: Insert the entire tensor into most minor dims. // (No permutations allowed.) SmallVector expectedSizes(tensorRank - vecRank, 1); - expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end()); + expectedSizes.append(vecType->getShape().begin(), + vecType->getShape().end()); if (!llvm::all_of( llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) { return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); @@ -1572,7 +1594,7 @@ SmallVector readIndices( vecRank, rewriter.create(padOp.getLoc(), 0)); auto read = rewriter.create( - padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue); + padOp.getLoc(), *vecType, padOp.getSource(), readIndices, padValue); // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at // specified offsets. Write is fully in-bounds because a InsertSliceOp's @@ -1880,7 +1902,7 @@ auto rhsRank = rhsShapedType.getRank(); switch (oper) { case Conv: - if (rhsRank != 2 && rhsRank!= 3) + if (rhsRank != 2 && rhsRank != 3) return; break; case Pool: @@ -1982,22 +2004,24 @@ Type lhsEltType = lhsShapedType.getElementType(); Type rhsEltType = rhsShapedType.getElementType(); Type resEltType = resShapedType.getElementType(); - auto lhsType = VectorType::get(lhsShape, lhsEltType); - auto rhsType = VectorType::get(rhsShape, rhsEltType); - auto resType = VectorType::get(resShape, resEltType); + auto lhsType = getVectorType(lhsShape, lhsEltType); + auto rhsType = getVectorType(rhsShape, rhsEltType); + auto resType = getVectorType(resShape, resEltType); + if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType))) + return failure(); // Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0, // 0]. Value lhs = rewriter.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c, f} @ [0, 0, 0]. // This is needed only for Conv. Value rhs = nullptr; if (oper == Conv) rhs = rewriter.create( - loc, rhsType, rhsShaped, ValueRange{zero, zero, zero}); + loc, *rhsType, rhsShaped, ValueRange{zero, zero, zero}); // Read res slice of size {n, w, f} @ [0, 0, 0]. Value res = rewriter.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + loc, *resType, resShaped, ValueRange{zero, zero, zero}); // The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output: // {n,w,f}. To reuse the base pattern vectorization case, we do pre @@ -2164,26 +2188,28 @@ Type lhsEltType = lhsShapedType.getElementType(); Type rhsEltType = rhsShapedType.getElementType(); Type resEltType = resShapedType.getElementType(); - VectorType lhsType = VectorType::get( + auto lhsType = getVectorType( {nSize, // iw = ow * sw + kw * dw - 1 // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, cSize}, lhsEltType); - VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType); - VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType); + auto rhsType = getVectorType({kwSize, cSize}, rhsEltType); + auto resType = getVectorType({nSize, wSize, cSize}, resEltType); + if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType))) + return failure(); // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // 0]. Value lhs = rewriter.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero}); // Read rhs slice of size {kw, c} @ [0, 0]. - Value rhs = rewriter.create(loc, rhsType, rhsShaped, - ValueRange{zero, zero}); + Value rhs = rewriter.create( + loc, *rhsType, rhsShaped, ValueRange{zero, zero}); // Read res slice of size {n, w, c} @ [0, 0, 0]. Value res = rewriter.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + loc, *resType, resShaped, ValueRange{zero, zero, zero}); //===------------------------------------------------------------------===// // Begin vector-only rewrite part diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir --- a/mlir/test/Dialect/Linalg/vectorize-convolution.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-convolution.mlir @@ -850,3 +850,18 @@ // CHECK: %[[V7:.+]] = arith.addf %[[V5]], %[[V6]] : vector<4x3x2xf32> // CHECK: %[[V8:.+]] = vector.transpose %[[V7]], [0, 2, 1] : vector<4x3x2xf32> to vector<4x2x3xf32> // CHECK: vector.transfer_write %[[V8:.+]], %[[OUTPUT]][%[[Vc0]], %[[Vc0]], %[[Vc0]]] {in_bounds = [true, true, true]} : vector<4x2x3xf32>, memref<4x2x3xf32> + + +// ----- + +func.func @pooling_ncw_sum_memref_complex(%input: memref<4x2x5xcomplex>, + %filter: memref<2xcomplex>, %output: memref<4x2x3xcomplex>) { + linalg.pooling_ncw_sum + {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins(%input, %filter : memref<4x2x5xcomplex>, memref<2xcomplex>) + outs(%output : memref<4x2x3xcomplex>) + return +} + +// Regression test: just check that this lowers successfully +// CHECK-LABEL: @pooling_ncw_sum_memref_complex