diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1862,6 +1862,13 @@ op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) return failure(); + // TODO: the code below assumes the default contraction, make sure it supports + // other kinds before enabling this lowering. + if (op.getKind() != vector::CombiningKind::ADD) { + return rewriter.notifyMatchFailure( + op, "contractions other than 'add' not supported"); + } + // TODO: implement benefits, cost models. MLIRContext *ctx = op.getContext(); ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);