Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Show First 20 Lines • Show All 121 Lines • ▼ Show 20 Lines | void runOnOperation() override { | ||||
LLVMTypeConverter converter(m.getContext(), options); | LLVMTypeConverter converter(m.getContext(), options); | ||||
converter.addConversion([&](MemRefType type) -> Optional<Type> { | converter.addConversion([&](MemRefType type) -> Optional<Type> { | ||||
if (type.getMemorySpaceAsInt() != | if (type.getMemorySpaceAsInt() != | ||||
gpu::GPUDialect::getPrivateAddressSpace()) | gpu::GPUDialect::getPrivateAddressSpace()) | ||||
return llvm::None; | return llvm::None; | ||||
return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); | return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); | ||||
}); | }); | ||||
// Lowering for MMAMatrixType. | |||||
converter.addConversion([&](gpu::MMAMatrixType type) -> Type { | |||||
// The number of items in structToReturn are dependent on the the dataType | |||||
// and the MMA operand that this operation is associated with. | |||||
llvm::DenseMap<StringRef, int64_t> numElemsPerThreadF16, | |||||
numElemsPerThreadF32; | |||||
numElemsPerThreadF16["AOp"] = 8; | |||||
numElemsPerThreadF16["BOp"] = 8; | |||||
numElemsPerThreadF16["COp"] = 4; | |||||
numElemsPerThreadF16["DOp"] = 4; | |||||
numElemsPerThreadF32["AOp"] = 8; | |||||
numElemsPerThreadF32["BOp"] = 8; | |||||
numElemsPerThreadF32["COp"] = 8; | |||||
numElemsPerThreadF32["DOp"] = 8; | |||||
Type structToReturn; | |||||
if (type.getElementType().isF16()) { | |||||
// Number of f16's in 32-bit. | |||||
bondhugula: C++ style comment here. Terminate with full stop. | |||||
Doesn't look addressed. C++-style comments are line comments starting with //. ftynse: Doesn't look addressed. C++-style comments are line comments starting with `//`. | |||||
unsigned vecSize = 2; | |||||
Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext())); | |||||
unsigned size = numElemsPerThreadF16[type.getOperand()]; | |||||
SmallVector<Type> elements(size, vec); | |||||
structToReturn = | |||||
LLVM::LLVMStructType::getLiteral(&getContext(), elements); | |||||
} else if (type.getElementType().isF32()) { | |||||
unsigned size = numElemsPerThreadF32[type.getOperand()]; | |||||
SmallVector<Type> elements(size, FloatType::getF32(&getContext())); | |||||
structToReturn = | |||||
LLVM::LLVMStructType::getLiteral(&getContext(), elements); | |||||
} | |||||
return structToReturn; | |||||
}); | |||||
RewritePatternSet patterns(m.getContext()); | RewritePatternSet patterns(m.getContext()); | ||||
RewritePatternSet llvmPatterns(m.getContext()); | RewritePatternSet llvmPatterns(m.getContext()); | ||||
// Apply in-dialect lowering first. In-dialect lowering will replace ops | // Apply in-dialect lowering first. In-dialect lowering will replace ops | ||||
// which need to be lowered further, which is not supported by a single | // which need to be lowered further, which is not supported by a single | ||||
// conversion pass. | // conversion pass. | ||||
populateGpuRewritePatterns(patterns); | populateGpuRewritePatterns(patterns); | ||||
(void)applyPatternsAndFoldGreedily(m, std::move(patterns)); | (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); | ||||
populateStdToLLVMConversionPatterns(converter, llvmPatterns); | populateStdToLLVMConversionPatterns(converter, llvmPatterns); | ||||
populateGpuToNVVMConversionPatterns(converter, llvmPatterns); | populateGpuToNVVMConversionPatterns(converter, llvmPatterns); | ||||
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); | |||||
LLVMConversionTarget target(getContext()); | LLVMConversionTarget target(getContext()); | ||||
configureGpuToNVVMConversionLegality(target); | configureGpuToNVVMConversionLegality(target); | ||||
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) | if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) | ||||
signalPassFailure(); | signalPassFailure(); | ||||
} | } | ||||
}; | }; | ||||
} // anonymous namespace | } // anonymous namespace | ||||
▲ Show 20 Lines • Show All 76 Lines • Show Last 20 Lines |
C++ style comment here. Terminate with full stop.