diff --git a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h --- a/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h +++ b/mlir/include/mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h @@ -21,6 +21,9 @@ #define GEN_PASS_DECL_CONVERTAMDGPUTOROCDL #include "mlir/Conversion/Passes.h.inc" +/// Note: The ROCDL target does not support the LLVM bfloat type at this time +/// and so this function will add conversions to change all `bfloat` uses +/// to `i16`. void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset); diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -13,6 +13,8 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" @@ -88,8 +90,15 @@ // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32 // and the total load size is >= 32, use a vector load of N / (bitsize(T) / // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands, - // so bitcast any floats to integers. + // so bitcast any floats to integers. On top of all this, cast bfloat + // (vectors) to i16 since the backend doesn't currently support bfloat on + // these operations. Type llvmBufferValType = llvmWantedDataType; + if (wantedDataType.isBF16()) + llvmBufferValType = rewriter.getI16Type(); + if (auto wantedVecType = dyn_cast(wantedDataType)) + if (wantedVecType.getElementType().isBF16()) + llvmBufferValType = wantedVecType.clone(rewriter.getI16Type()); if (atomicCmpData) { if (isa(wantedDataType)) return gpuOp.emitOpError("vector compare-and-swap does not exist"); @@ -315,10 +324,17 @@ /// around a wart in the AMDGPU intrinsics where operations that logically take /// vectors of bytes instead integers. Since we do not want to expose this /// implementation detail to MLIR, we correct for it here. +/// +/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU +/// MFMA intrinsics pre-date the bfloat type. static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input) { Type inputType = input.getType(); if (auto vectorType = dyn_cast(inputType)) { + if (vectorType.getElementType().isBF16()) + return rewriter.create( + loc, vectorType.clone(rewriter.getI16Type()), input); + if (!vectorType.getElementType().isInteger(8)) return input; int64_t numBytes = vectorType.getNumElements(); @@ -343,7 +359,8 @@ /// Push an input operand. If it is a float type, nothing to do. If it is /// an integer type, then we need to also push its signdness (1 for signed, 0 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 -/// vector. +/// vector. We also need to convert bfloat inputs to i16 to account for the lack +/// of bfloat support in the WMMA intrinsics themselves. static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, @@ -352,6 +369,9 @@ auto vectorType = inputType.dyn_cast(); Type elemType = vectorType.getElementType(); + if (elemType.isBF16()) + llvmInput = rewriter.create( + loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (!elemType.isInteger(8)) { operands.push_back(llvmInput); return; @@ -390,8 +410,11 @@ Type inputType = output.getType(); auto vectorType = inputType.dyn_cast(); Type elemType = vectorType.getElementType(); + if (elemType.isBF16()) + output = rewriter.create( + loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); - if (elemType.isF16() || elemType.isBF16()) { + if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); } else if (elemType.isInteger(32)) { operands.push_back(createI1Constant(rewriter, loc, clamp)); @@ -572,6 +595,10 @@ ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type outType = typeConverter->convertType(op.getDestD().getType()); + Type intrinsicOutType = outType; + if (auto outVecType = dyn_cast(outType)) + if (outVecType.getElementType().isBF16()) + intrinsicOutType = outVecType.clone(rewriter.getI16Type()); if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08) return op->emitOpError("MFMA only supported on gfx908+"); @@ -586,15 +613,17 @@ if (!maybeIntrinsic.has_value()) return op.emitOpError("no intrinsic matching MFMA size on given chipset"); OperationState loweredOp(loc, *maybeIntrinsic); - loweredOp.addTypes(outType); + loweredOp.addTypes(intrinsicOutType); loweredOp.addOperands( {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()), mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()), adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()), createI32Constant(rewriter, loc, op.getAbid()), createI32Constant(rewriter, loc, getBlgpField)}); - Operation *lowered = rewriter.create(loweredOp); - rewriter.replaceOp(op, lowered->getResults()); + Value lowered = rewriter.create(loweredOp)->getResult(0); + if (outType != intrinsicOutType) + lowered = rewriter.create(loc, outType, lowered); + rewriter.replaceOp(op, lowered); return success(); } }; @@ -667,6 +696,15 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { + converter.addConversion([](BFloat16Type t) -> Type { + return IntegerType::get(t.getContext(), 16); + }); + converter.addConversion([&converter](VectorType t) -> std::optional { + if (!t.getElementType().isBF16()) + return std::nullopt; + return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16))); + }); + patterns.add(converter); patterns.add< RawBufferOpLowering, diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -11,6 +11,8 @@ LINK_LIBS PUBLIC MLIRArithToLLVM + MLIRArithTransforms + MLIRMathToLLVM MLIRAMDGPUToROCDL MLIRFuncToLLVM MLIRGPUDialect diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -13,6 +13,10 @@ #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -162,6 +166,7 @@ { RewritePatternSet patterns(ctx); populateGpuRewritePatterns(patterns); + arith::populateExpandBFloat16Patterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); } diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -136,8 +136,8 @@ vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp, vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); - target.addLegalOp(); + target.addLegalOp(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -363,8 +363,11 @@ if (auto floatAttr = dyn_cast(attr)) { const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); // Special case for 8-bit floats, which are represented by integers due to - // the lack of native fp8 types in LLVM at the moment. - if (APFloat::getSizeInBits(sem) == 8 && llvmType->isIntegerTy(8)) + // the lack of native fp8 types in LLVM at the moment. Additionally, handle + // targets (like AMDGPU) that don't implement bfloat and convert all bfloats + // to i16. + unsigned floatWidth = APFloat::getSizeInBits(sem); + if (llvmType->isIntegerTy(floatWidth)) return llvm::ConstantInt::get(llvmType, floatAttr.getValue().bitcastToAPInt()); if (llvmType != diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir --- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir @@ -38,25 +38,25 @@ amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32> // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32> amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32> - // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32> - // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32> - // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32> - // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32> - // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32> - // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32> - // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32> - // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32> - // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32> - // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32> // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64> amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, f64, vector<4xf64> diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir @@ -2,22 +2,22 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>, %arg6 : vector<16xi8>, %arg7 : vector<4xi32>, %arg8 : vector<8xi32>, - %arg9 : vector<16xui8>){ + %arg9 : vector<16xui8>) { // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> - // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> - // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<4xf32>) -> vector<4xf32> + // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> - // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<16xbf16>, i1) -> vector<16xbf16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> - // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>, i1) -> vector<8xbf16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32> diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -480,6 +480,29 @@ // ----- +// Test that the bf16 type is lowered away on this target. + +gpu.module @test_module { + // CHECK-LABEL: func @bf16_id + func.func @bf16_id(%arg0 : bf16) -> bf16 { + // CHECK-SAME: (%[[ARG0:.+]]: i16) + // CHECK-SAME: -> i16 + // CHECK: return %[[ARG0]] : i16 + func.return %arg0 : bf16 + } + + // CHECK-LABEL: func @bf16x4_id + func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> { + // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>) + // CHECK-SAME: -> vector<4xi16> + // CHECK: return %[[ARG0]] : vector<4xi16> + func.return %arg0 : vector<4xbf16> + } + +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: @kernel_func // CHECK: attributes diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -54,6 +54,9 @@ // CHECK: @f8E4M3B11FNUZ_global_as_i8 = internal global i8 92 llvm.mlir.global internal @f8E4M3B11FNUZ_global_as_i8(1.5 : f8E4M3B11FNUZ) : i8 +// CHECK: @bf16_global_as_i16 = internal global i16 16320 +llvm.mlir.global internal @bf16_global_as_i16(1.5 : bf16) : i16 + // CHECK: @explicit_undef = global i32 undef llvm.mlir.global external @explicit_undef() : i32 { %0 = llvm.mlir.undef : i32