Index: mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp =================================================================== --- mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -42,7 +42,8 @@ PatternRewriter &rewrite) const override { Location location = op->getLoc(); - if (op->hasAttr(op.getTf32EnabledAttrName())) + if (op->hasAttr(op.getTf32EnabledAttrName()) || + !op.getMatrixA().getType().cast().getElementType().isF32()) return failure(); if (precision == MmaSyncF32Lowering::Unkown) Index: mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir =================================================================== --- mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir +++ mlir/test/Dialect/NVGPU/mma-sync-f32-to-tf32.mlir @@ -18,3 +18,13 @@ return %d : vector<2x2xf32> } // ----- + +// Negative test for non f32 case. +// CHECK-LABEL: mma_sync_f16 +// CHECK-NOT: tf32Enabled +// CHECK: return +func.func @mma_sync_f16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{TF32x3 is not supported at the moment for nvgpu.mma.sync on f32 datatype}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +}