diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -200,7 +200,10 @@ "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes)>, OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc, "ArrayRef>":$indexingExprs, - "ArrayRef":$iteratorTypes)> + "ArrayRef":$iteratorTypes)>, + OpBuilder<(ins "Value":$lhs, "Value":$rhs, "Value":$acc, + "ArrayAttr":$indexingMaps, "ArrayAttr":$iteratorTypes, + "CombiningKind":$kind)> ]; let extraClassDeclaration = [{ VectorType getLhsType() { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -502,13 +502,20 @@ Value lhs, Value rhs, Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes) { + build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes, + ContractionOp::getDefaultKind()); +} + +void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, + Value lhs, Value rhs, Value acc, + ArrayAttr indexingMaps, + ArrayAttr iteratorTypes, CombiningKind kind) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); result.addAttribute(getIndexingMapsAttrName(), indexingMaps); result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); result.addAttribute(ContractionOp::getKindAttrName(), - CombiningKindAttr::get(ContractionOp::getDefaultKind(), - builder.getContext())); + CombiningKindAttr::get(kind, builder.getContext())); } ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" @@ -220,6 +221,128 @@ } }; +/// Turns vector.contract on vector with leading 1 dimensions into +/// vector.extract followed by vector.contract on vector without leading +/// 1 dimensions. Also performs tranpose of lhs and rhs operands if required +/// prior to extract. +struct CastAwayContractionLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + VectorType oldAccType = contractOp.getAccType().dyn_cast(); + if (oldAccType == nullptr) + return failure(); + if (oldAccType.getRank() < 2) + return failure(); + // TODO: implement masks. + if (llvm::size(contractOp.masks()) != 0) + return failure(); + if (oldAccType.getShape()[0] != 1) + return failure(); + // currently we support only dropping one dim but the pattern can be applied + // greedily to drop more. + int64_t dropDim = 1; + + auto oldIndexingMaps = contractOp.getIndexingMaps(); + SmallVector newIndexingMaps; + + auto oldIteratorTypes = contractOp.iterator_types(); + SmallVector newIteratorTypes; + + int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); + + if (!isParallelIterator(oldIteratorTypes[dimToDrop])) + // only parallel type iterators can be dropped. + return failure(); + + for (const auto &it : llvm::enumerate(oldIteratorTypes)) { + int64_t currDim = it.index(); + if (currDim == dimToDrop) + continue; + newIteratorTypes.push_back(it.value()); + } + + SmallVector operands = {contractOp.lhs(), contractOp.rhs(), + contractOp.acc()}; + SmallVector newOperands; + + for (const auto &it : llvm::enumerate(oldIndexingMaps)) { + // Check if the dim to be dropped exists as a leading dim in the operand + // if it does then we use vector.extract to drop it. + bool validExtract = false; + SmallVector results; + auto map = it.value(); + int64_t orginalZeroDim = it.value().getDimPosition(0); + if (orginalZeroDim != dimToDrop) { + // There are two reasons to be in this path, 1. We need to + // tranpose the operand to make the dim to be dropped + // leading. 2. The dim to be dropped does not exist and in + // that case we dont want to add a unit tranpose but we must + // check all the indices to make sure this is the case. + bool tranposeNeeded = false; + SmallVector perm; + SmallVector transposeResults; + + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t currDim = map.getDimPosition(i); + if (currDim == dimToDrop) { + tranposeNeeded = true; + perm.insert(perm.begin(), i); + auto targetExpr = rewriter.getAffineDimExpr(currDim); + transposeResults.insert(transposeResults.begin(), targetExpr); + } else { + perm.push_back(i); + auto targetExpr = rewriter.getAffineDimExpr(currDim); + transposeResults.push_back(targetExpr); + } + } + // Do the tranpose now if needed so that we can drop the + // correct dim using extract later. + if (tranposeNeeded) { + map = AffineMap::get(map.getNumDims(), 0, transposeResults, + contractOp.getContext()); + operands[it.index()] = rewriter.create( + contractOp.getLoc(), operands[it.index()], perm); + } + } + // We have taken care to have the dim to be dropped be + // the leading dim. If its still not leading that means it + // does not exist in this operand and hence we do not need + // an extract. + if (map.getDimPosition(0) == dimToDrop) + validExtract = true; + + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t currDim = map.getDimPosition(i); + if (currDim == dimToDrop) + // This is the dim we are dropping. + continue; + auto targetExpr = rewriter.getAffineDimExpr( + currDim < dimToDrop ? currDim : currDim - 1); + results.push_back(targetExpr); + } + newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, + contractOp.getContext())); + // Extract if its a valid extraction, otherwise use the operand + // without extraction. + newOperands.push_back(validExtract + ? rewriter.create( + contractOp.getLoc(), operands[it.index()], + splatZero(dropDim)) + : operands[it.index()]); + } + auto newContractOp = rewriter.create( + contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], + rewriter.getAffineMapArrayAttr(newIndexingMaps), + rewriter.getArrayAttr(newIteratorTypes), contractOp.kind()); + rewriter.replaceOpWithNewOp( + contractOp, contractOp->getResultTypes()[0], newContractOp); + return success(); + } +}; + class CastAwayElementwiseLeadingOneDim : public RewritePattern { public: CastAwayElementwiseLeadingOneDim(MLIRContext *context) @@ -260,10 +383,11 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns + .add(patterns.getContext()); populateShapeCastFoldingPatterns(patterns); } diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -0,0 +1,267 @@ +// RUN: mlir-opt %s -test-vector-to-vector-lowering -split-input-file| FileCheck %s + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: cast_away_contraction_leading_one_dims +// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<1x16x8xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<1x8x16xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0] : vector<1x16x16xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32> +// CHECK-NEXT: return %[[R4]] : vector<1x16x16xf32> + +#contraction_accesses0 = [ + affine_map<(l, i, j, k) -> (l, i, k)>, + affine_map<(l, i, j, k) -> (l, k, j)>, + affine_map<(l, i, j, k) -> (l, i, j)> +] +#contraction_trait0 = { + indexing_maps = #contraction_accesses0, + iterator_types = ["parallel", "parallel", "parallel", "reduction"] +} + +func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> { + %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> + return %0: vector<1x16x16xf32> +} + +// ----- +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)> + +// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded +// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<1x8x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x16xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32> +// CHECK-NEXT: %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32> +// CHECK-NEXT: return %[[R5]] : vector<1x1x16xf32> + +#contraction_accesses1 = [ + affine_map<(l, i, j, k) -> (i, l, k)>, + affine_map<(l, i, j, k) -> (l, k, j)>, + affine_map<(l, i, j, k) -> (l, i, j)> +] +#contraction_trait1 = { + indexing_maps = #contraction_accesses1, + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind +} + +func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector<1x1x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x1x16xf32>) -> vector<1x1x16xf32> { + %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2 : vector<1x1x8xf32>, vector<1x8x16xf32> into vector<1x1x16xf32> + return %0: vector<1x1x16xf32> +} + +// ----- +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2 +// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %[[R0]][0] : vector<1x8x16xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<1x2x8xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0] : vector<1x2x16xf32> +// CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> +// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> +// CHECK-NEXT: return %[[R6]] : vector<1x2x16xf32> + +#contraction_accesses2 = [ + affine_map<(l, i, j, k) -> (k, l, j)>, + affine_map<(l, i, j, k) -> (i, k, l)>, + affine_map<(l, i, j, k) -> (l, i, j)> +] +#contraction_trait2 = { + indexing_maps = #contraction_accesses2, + iterator_types = ["parallel", "parallel", "parallel", "reduction"] +} + + +func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector<8x1x16xf32>, %arg1: vector<2x8x1xf32>, %arg2: vector<1x2x16xf32>) -> vector<1x2x16xf32> { + %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2 : vector<8x1x16xf32>, vector<2x8x1xf32> into vector<1x2x16xf32> + return %0: vector<1x2x16xf32> +} + +// ----- +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + + +// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4 +// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<1x8x1x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<1x2x8x1xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<1x8x16xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32> +// CHECK-NEXT: %[[R5:.+]] = vector.extract %[[R4]][0] : vector<1x2x8xf32> +// CHECK-NEXT: %[[R6:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x2x16xf32> +// CHECK-NEXT: %[[R7:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> +// CHECK-NEXT: %[[R8:.+]] = vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32> +// CHECK-NEXT: %[[R9:.+]] = vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> +// CHECK-NEXT: return %[[R9]] : vector<1x1x2x16xf32> + +#contraction_accesses2 = [ + affine_map<(m, l, i, j, k) -> (m, k, l, j)>, + affine_map<(m, l, i, j, k) -> (m, i, k, l)>, + affine_map<(m, l, i, j, k) -> (m, l, i, j)> +] +#contraction_trait2 = { + indexing_maps = #contraction_accesses2, + iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"] +} + + +func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> { + %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2 : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32> + return %0: vector<1x1x2x16xf32> +} + +// ----- +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose +// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.extract %[[R0]][0, 0] : vector<1x1x8x16xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R1]][0, 0] : vector<1x1x2x8xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0, 0] : vector<1x1x2x16xf32> +// CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> +// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> +// CHECK-NEXT: %[[R7:.+]] = vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> +// CHECK-NEXT: return %[[R7]] : vector<1x1x2x16xf32> + +#contraction_accesses3 = [ + affine_map<(m, l, i, j, k) -> (m, k, l, j)>, + affine_map<(m, l, i, j, k) -> (m, i, k, l)>, + affine_map<(m, l, i, j, k) -> (l, m, i, j)> +] +#contraction_trait3 = { + indexing_maps = #contraction_accesses3, + iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"] +} + +func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> { + %0 = vector.contract #contraction_trait3 %arg0, %arg1, %arg2 : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32> + return %0: vector<1x1x2x16xf32> +} + +// ----- +// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims +func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> { + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16> + // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16> + %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> + // CHECK: return %[[RET]] + return %0: vector<1x1x8xf16> +} + +// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims +func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> { + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16> + // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16> + // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16> + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> + // CHECK: return %[[RET]] + return %0: vector<1x8x8xf16> +} + +// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element +// CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16> +func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> { + // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16> + // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16> + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16> + // CHECK: return %[[B]] + return %0: vector<1x1x1xf16> +} + +// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims +func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 + %f0 = arith.constant 0. : f16 + // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> + // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> + %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> + // CHECK: return %[[CAST]] + return %0: vector<1x4xf16> +} + +// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element +func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0. : f16 + // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16> + %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16> + return %0: vector<1x1xf16> +} + +// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims +func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) { + // CHECK: %[[C0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16> + // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> + + vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> + return +} + +// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element +func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) { + %c0 = arith.constant 0 : index + // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16> + vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16> + return +} + +// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims +func @cast_away_elementwise_leading_one_dims( + %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>, + %arg3: vector<1x4xf32>, %arg4: i1) -> + (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) { + // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> + // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> + // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32> + // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> + %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1> + %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> + %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: select %arg4, %12, %{{.*}} : vector<4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> + %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32> + return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> +} + diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -419,106 +419,6 @@ return %r : tensor<4x4xf32> } -// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims -func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> { - // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16> - // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16> - %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16> - // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> - // CHECK: return %[[RET]] - return %0: vector<1x1x8xf16> -} - -// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims -func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> { - // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16> - // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16> - // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16> - %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16> - // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> - // CHECK: return %[[RET]] - return %0: vector<1x8x8xf16> -} - -// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element -// CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16> -func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> { - // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16> - // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16> - %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16> - // CHECK: return %[[B]] - return %0: vector<1x1x1xf16> -} - -// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims -func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> { - // CHECK: %[[C0:.+]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 - %f0 = arith.constant 0. : f16 - // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> - // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> - %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> - // CHECK: return %[[CAST]] - return %0: vector<1x4xf16> -} - -// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element -func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> { - %c0 = arith.constant 0 : index - %f0 = arith.constant 0. : f16 - // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16> - %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16> - return %0: vector<1x1xf16> -} - -// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims -func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) { - // CHECK: %[[C0:.+]] = arith.constant 0 : index - %c0 = arith.constant 0 : index - // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16> - // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> - - vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> - return -} - -// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element -func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) { - %c0 = arith.constant 0 : index - // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16> - vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16> - return -} - -// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims -func @cast_away_elementwise_leading_one_dims( - %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>, - %arg3: vector<1x4xf32>, %arg4: i1) -> - (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) { - // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> - // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> - // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32> - // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> - %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32> - // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> - // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> - // CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32> - // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1> - %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32> - // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> - // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> - // CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32> - // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> - %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> - // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> - // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> - // CHECK: select %arg4, %12, %{{.*}} : vector<4xf32> - // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> - %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32> - return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> -} - // CHECK-LABEL: func @bubble_down_bitcast_in_extract // CHECK-SAME: %[[SRC:.+]]: vector<4xf32> func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {