diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -40,12 +40,8 @@ // with operators other than the current set: {*, +}. def Vector_ContractionOp : Vector_Op<"contract", [NoSideEffect, - PredOpTrait<"first operand lhs and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"second operand rhs and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>, PredOpTrait<"third operand acc and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>]>, + TCresVTEtIsSameAsOpBase<0, 2>>]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "llvm/Support/CommandLine.h" @@ -1728,6 +1729,11 @@ // TODO(ajcbik): implement masks. if (llvm::size(op.masks()) != 0) return failure(); + // TODO(thomasraoux): support mixed mode contract lowering. + if (op.getLhsType().getElementType() != + getElementTypeOrSelf(op.getAccType()) || + op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) + return failure(); // TODO(ntv, ajcbik): implement benefits, cost models. MLIRContext *ctx = op.getContext(); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -749,25 +749,6 @@ // ----- #contraction_accesses = [ - affine_map<(i, j, k) -> (i, k)>, - affine_map<(i, j, k) -> (k, j)>, - affine_map<(i, j, k) -> (i, j)> - ] -#contraction_trait = { - indexing_maps = #contraction_accesses, - iterator_types = ["parallel", "parallel", "reduction"] - } -func @contraction(%arg0: vector<4x3xi32>, - %arg1: vector<3x7xf32>, - %arg2: vector<4x7xf32>) -> vector<4x7xf32> { - // expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}} - %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2 - : vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32> -} - -// ----- - -#contraction_accesses = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, affine_map<(m, n, k) -> (n, m)> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -175,7 +175,8 @@ // CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, - %arg4 : index) { + %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>, + %arg6 : index) { // Test contraction with batch and contracting dims. // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 @@ -193,6 +194,10 @@ %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // Test contraction with mixed type. + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 + : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> return }