diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -42,7 +42,8 @@ PatternRewriter &rewrite) const override { Location location = op->getLoc(); - if (op->hasAttr(op.getTf32EnabledAttrName())) + if (op->hasAttr(op.getTf32EnabledAttrName()) || + !op.getMatrixA().getType().cast().getElementType().isF32()) return failure(); if (precision == MmaSyncF32Lowering::Unkown) diff --git a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir --- a/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir +++ b/mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir @@ -18,3 +18,12 @@ return %d : vector<2x2xf32> } // ----- + +// Negative test for non f32 case. +// CHECK-LABEL: mma_sync_f16 +// CHECK-NOT: tf32Enabled +// CHECK: return +func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +}