diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -40,12 +40,9 @@ // with operators other than the current set: {*, +}. def Vector_ContractionOp : Vector_Op<"contract", [NoSideEffect, - PredOpTrait<"first operand lhs and result have same element type", - TCresVTEtIsSameAsOpBase<0, 0>>, - PredOpTrait<"second operand rhs and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>, + PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, PredOpTrait<"third operand acc and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>]>, + TCresVTEtIsSameAsOpBase<0, 2>>]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, @@ -140,6 +137,11 @@ %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> + + // Vector contraction with mixed typed. lhs/rhs have different element + // types than accumulator/result. + %6 = vector.contract #contraction_trait %0, %1, %2 + : vector<10xf16>, vector<10xf16> into f32 ``` }]; let builders = [OpBuilder< diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Module.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "llvm/Support/CommandLine.h" @@ -1731,6 +1732,11 @@ // TODO(ajcbik): implement masks. if (llvm::size(op.masks()) != 0) return failure(); + // TODO(thomasraoux): support mixed mode contract lowering. + if (op.getLhsType().getElementType() != + getElementTypeOrSelf(op.getAccType()) || + op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) + return failure(); // TODO(ntv, ajcbik): implement benefits, cost models. MLIRContext *ctx = op.getContext(); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -760,7 +760,7 @@ func @contraction(%arg0: vector<4x3xi32>, %arg1: vector<3x7xf32>, %arg2: vector<4x7xf32>) -> vector<4x7xf32> { - // expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}} + // expected-error@+1 {{'vector.contract' op failed to verify that lhs and rhs have same element type}} %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2 : vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -175,7 +175,7 @@ // CHECK-LABEL: @contraction func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>, - %arg4 : index) { + %arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) { // Test contraction with batch and contracting dims. // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2 @@ -193,6 +193,10 @@ %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32> + // Test contraction with mixed type. + // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> + %3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3 + : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32> return }