diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -95,7 +95,6 @@ DeclareOpInterfaceMethods ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, - Variadic>:$masks, ArrayAttr:$indexing_maps, Vector_IteratorTypeArrayAttr:$iterator_types, DefaultValuedAttr, vector<8x16x7x5xf32> into vector<8x15x5xf32> - // 4D vector contraction with two contracting dimensions and optional - // vector mask arguments. - %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> - %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> - - %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask - : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> - // Vector contraction with mixed typed. lhs/rhs have different element // types than accumulator/result. - %6 = vector.contract #contraction_trait %0, %1, %2 + %5 = vector.contract #contraction_trait %0, %1, %2 : vector<10xf16>, vector<10xf16> into f32 // Contract with max (K = 0). @@ -219,7 +206,7 @@ iterator_types = ["reduction"], kind = #vector.kind } - %7 = vector.contract #contraction_trait %0, %1, %2 + %6 = vector.contract #contraction_trait %0, %1, %2 : vector<10xf32>, vector<10xf32> into f32 ``` }]; @@ -241,14 +228,6 @@ return getRhs().getType().cast(); } Type getAccType() { return getAcc().getType(); } - VectorType getLHSVectorMaskType() { - if (llvm::size(getMasks()) != 2) return VectorType(); - return getOperand(3).getType().cast(); - } - VectorType getRHSVectorMaskType() { - if (llvm::size(getMasks()) != 2) return VectorType(); - return getOperand(4).getType().cast(); - } Type getResultType() { return getResult().getType(); } SmallVector getTraitAttrNames(); static unsigned getAccOperandIndex() { return 2; } @@ -1173,7 +1152,7 @@ static StringRef getSizesAttrStrName() { return "sizes"; } static StringRef getStridesAttrStrName() { return "strides"; } VectorType getSourceVectorType() { - return getVector().getType().cast(); + return getVector().getType().cast(); } void getOffsets(SmallVectorImpl &results); bool hasNonUnitStrides() { diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -75,9 +75,6 @@ // Return true if the contract op can be convert to MMA matmul. static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu) { - if (!contract.getMasks().empty()) - return false; - using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -699,8 +699,6 @@ auto dictAttr = DictionaryAttr::get(getContext(), attrs); p << " " << dictAttr << " " << getLhs() << ", "; p << getRhs() << ", " << getAcc(); - if (getMasks().size() == 2) - p << ", " << getMasks(); p.printOptionalAttrDict((*this)->getAttrs(), attrNames); p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into " @@ -868,18 +866,6 @@ contractingDimMap, batchDimMap))) return failure(); - // Verify that either two vector masks are set or none are set. - auto lhsMaskType = getLHSVectorMaskType(); - auto rhsMaskType = getRHSVectorMaskType(); - if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) - return emitOpError("invalid number of vector masks specified"); - if (lhsMaskType && rhsMaskType) { - // Verify mask rank == argument rank. - if (lhsMaskType.getShape().size() != lhsType.getShape().size() || - rhsMaskType.getShape().size() != rhsType.getShape().size()) - return emitOpError("invalid vector mask rank"); - } - // Verify supported combining kind. auto vectorType = resType.dyn_cast(); auto elementType = vectorType ? vectorType.getElementType() : resType; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -292,9 +292,6 @@ return failure(); if (oldAccType.getRank() < 2) return failure(); - // TODO: implement masks. - if (!contractOp.getMasks().empty()) - return failure(); if (oldAccType.getShape()[0] != 1) return failure(); // currently we support only dropping one dim but the pattern can be applied diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -628,10 +628,6 @@ if (maskableOp.isMasked()) return failure(); - // TODO: Remove native masks from contraction op? - if (!contractOp.getMasks().empty()) - return failure(); - if (failed(filter(contractOp))) return failure(); @@ -1462,9 +1458,6 @@ if (maskableOp.isMasked()) return failure(); - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); if (vectorTransformOptions.vectorContractLowering != vector::VectorContractLowering::Matmul) return failure(); @@ -1718,10 +1711,6 @@ /// otherwise supports any layout permutation of the matrix-multiply. LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); - if (vectorTransformOptions.vectorContractLowering != vector::VectorContractLowering::OuterProduct) return failure(); @@ -1768,10 +1757,6 @@ if (maskableOp.isMasked()) return failure(); - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); - if (failed(filter(op))) return failure(); @@ -1905,10 +1890,6 @@ if (maskableOp.isMasked()) return failure(); - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); - if (failed(filter(op))) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -351,20 +351,12 @@ SmallVector lhsOffets = applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); - // If there is a mask associated to lhs, extract it as well. - if (slicesOperands.size() > 3) - extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap, - lhsOffets); // Extract the new rhs operand. AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; SmallVector rhsOffets = applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); - // If there is a mask associated to rhs, extract it as well. - if (slicesOperands.size() > 4) - extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap, - rhsOffets); AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; SmallVector accOffets = diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -353,13 +353,6 @@ // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> - // Test contraction with optional vector mask arguments. - %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> - %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> - // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> - %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, - %rhs_mask - : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> // Test contraction with mixed type. // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -62,186 +62,6 @@ return %3: vector<4x4xf32> } -#contraction_accesses0 = [ - affine_map<(i, j, k) -> (i, k)>, - affine_map<(i, j, k) -> (k, j)>, - affine_map<(i, j, k) -> (i, j)> -] -#contraction_trait0 = { - indexing_maps = #contraction_accesses0, - iterator_types = ["parallel", "parallel", "reduction"] -} - -// CHECK-LABEL: func @contraction4x4_ijk - -// Reducing output vector [0, 0] - -// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> - -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S8:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> - -// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S6]], %[[S7]], %[[R1S00]], %[[S8]], %[[S9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S10:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S12:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S13:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S10]], %[[S11]], %[[R2S00]], %[[S12]], %[[S13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// Reducing output vector [0, 2] -// CHECK-NEXT: %[[S14:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S17:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S15:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S18:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S14]], %[[S15]], %[[S16]], %[[S17]], %[[S18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S19:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S21:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S20:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S22:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S19]], %[[S20]], %[[R1S02]], %[[S21]], %[[S22]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S23:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S25:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S24:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S26:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S23]], %[[S24]], %[[R2S02]], %[[S25]], %[[S26]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// Reducing output vector [2, 0] - -// CHECK-NEXT: %[[S27:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S30:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S28:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S31:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S29:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S27]], %[[S28]], %[[S29]], %[[S30]], %[[S31]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S32:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S34:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S33:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S35:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S32]], %[[S33]], %[[R1S20]], %[[S34]], %[[S35]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S36:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S38:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S37:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S39:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S36]], %[[S37]], %[[R2S20]], %[[S38]], %[[S39]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// Reducing output vector [2, 2] - -// CHECK-NEXT: %[[S40:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S43:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S41:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S44:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S42:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S40]], %[[S41]], %[[S42]], %[[S43]], %[[S44]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S45:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S47:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S46:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S48:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S45]], %[[S46]], %[[R1S22]], %[[S47]], %[[S48]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[S49:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S51:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S50:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S52:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[S49]], %[[S50]], %[[R2S22]], %[[S51]], %[[S52]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R3S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R3S02]], %[[VEC1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R3S20]], %[[VEC2]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VEC4:.*]] = vector.insert_strided_slice %[[R3S22]], %[[VEC3]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> - -// CHECK-NEXT: return %[[VEC4]] : vector<4x4xf32> - -func.func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, - %arg2 : vector<4x4xf32>, %arg3 : index) - -> (vector<4x4xf32>) { - %lhsm = vector.constant_mask [4, 6] : vector<4x6xi1> - %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1> - %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm - : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> - - return %0 : vector<4x4xf32> -} - -#contraction_accesses1 = [ - affine_map<(i, k, j) -> (i, k)>, - affine_map<(i, k, j) -> (k, j)>, - affine_map<(i, k, j) -> (i, j)> -] -#contraction_trait1 = { - indexing_maps = #contraction_accesses1, - iterator_types = ["parallel", "reduction", "parallel"] -} - -// CHECK-LABEL: func @contraction4x4_ikj - -// Reducing output vector [0, 0] - -// CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S1]], %[[S2]], %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// Reducing output vector [0, 2] - -// CHECK-NEXT: %[[S6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S9:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S10:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S6]], %[[S7]], %[[S8]], %[[S9]], %[[S10]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// Reducing output vector [2, 0] - -// CHECK-NEXT: %[[S11:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S14:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S12:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S15:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S13:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S11]], %[[S12]], %[[S13]], %[[S14]], %[[S15]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// Reducing output vector [2, 2] - -// CHECK-NEXT: %[[S16:.*]] = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S19:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S17:.*]] = vector.extract_strided_slice %arg1 {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[S20:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> -// CHECK-NEXT: %[[S18:.*]] = vector.extract_strided_slice %arg2 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[S16]], %[[S17]], %[[S18]], %[[S19]], %[[S20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> - -// CHECK-NEXT: %[[VEC0:.*]] = vector.insert_strided_slice %[[R1S00]], %{{.*}} {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VEC1:.*]] = vector.insert_strided_slice %[[R1S02]], %[[VEC0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VEC2:.*]] = vector.insert_strided_slice %[[R1S20]], %[[VEC1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: %[[VEC3:.*]] = vector.insert_strided_slice %[[R1S22]], %[[VEC2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> -// CHECK-NEXT: return %[[VEC3]] : vector<4x4xf32> - -func.func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, - %arg2 : vector<4x4xf32>, %arg3 : index) - -> (vector<4x4xf32>) { - %lhsm = vector.constant_mask [4, 2] : vector<4x2xi1> - %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1> - %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm - : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> - - return %0 : vector<4x4xf32> -} - // CHECK-LABEL: func @contraction4x4_ikj_xfer_read // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -260,10 +80,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> @@ -271,6 +91,16 @@ // CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: return +#contraction_accesses1 = [ + affine_map<(i, k, j) -> (i, k)>, + affine_map<(i, k, j) -> (k, j)>, + affine_map<(i, k, j) -> (i, j)> +] +#contraction_trait1 = { + indexing_maps = #contraction_accesses1, + iterator_types = ["parallel", "reduction", "parallel"] +} + func.func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, %arg1 : memref<2x4xf32>, %arg2 : memref<4x4xf32>) { @@ -389,10 +219,10 @@ // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> -// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map{{.*}}, #map{{.*}}, #map{{.*}}], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32> // CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {in_bounds = [true, true]} : vector<2x2xf32>, tensor<4x4xf32>