Changeset View
Changeset View
Standalone View
Standalone View
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Show First 20 Lines • Show All 488 Lines • ▼ Show 20 Lines | cyclicNprocsEqNiters.distributionMethod.resize(2, | ||||
DistributionMethod::Cyclic); | DistributionMethod::Cyclic); | ||||
cyclicNprocsEqNiters.procInfo = | cyclicNprocsEqNiters.procInfo = | ||||
getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; | getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; | ||||
patterns.add<LinalgTilingPattern<MatmulOp>>( | patterns.add<LinalgTilingPattern<MatmulOp>>( | ||||
context, | context, | ||||
LinalgTilingOptions() | LinalgTilingOptions() | ||||
.setTileSizes({8, 8, 4}) | .setTileSizes({8, 8, 4}) | ||||
.setLoopType(LinalgTilingLoopType::Loops) | .setLoopType(LinalgTilingLoopType::Loops) | ||||
.setDistributionOptions(cyclicNprocsEqNiters), | .setDistributionOptions(cyclicNprocsEqNiters), | ||||
LinalgTransformationFilter( | LinalgTransformationFilter( | ||||
Identifier::get("tensors_distribute1", context), | Identifier::get("tensors_distribute1", context), | ||||
Identifier::get("tensors_after_distribute1", context))); | Identifier::get("tensors_after_distribute1", context))); | ||||
} | } | ||||
} | } | ||||
static void | static void | ||||
applyMatmulToVectorPatterns(FuncOp funcOp, | applyMatmulToVectorPatterns(FuncOp funcOp, | ||||
bool testMatmulToVectorPatterns1dTiling, | bool testMatmulToVectorPatterns1dTiling, | ||||
bool testMatmulToVectorPatterns2dTiling) { | bool testMatmulToVectorPatterns2dTiling) { | ||||
MLIRContext *ctx = funcOp.getContext(); | MLIRContext *ctx = funcOp.getContext(); | ||||
SmallVector<RewritePatternSet, 4> stage1Patterns; | SmallVector<RewritePatternSet, 4> stage1Patterns; | ||||
if (testMatmulToVectorPatterns1dTiling) { | if (testMatmulToVectorPatterns1dTiling) { | ||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), | fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); | ||||
stage1Patterns); | |||||
} else if (testMatmulToVectorPatterns2dTiling) { | } else if (testMatmulToVectorPatterns2dTiling) { | ||||
stage1Patterns.emplace_back( | stage1Patterns.emplace_back( | ||||
ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( | ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( | ||||
ctx, | ctx, | ||||
LinalgTilingOptions() | LinalgTilingOptions() | ||||
.setTileSizes({768, 264, 768}) | .setTileSizes({768, 264, 768}) | ||||
.setInterchange({1, 2, 0}), | .setInterchange({1, 2, 0}), | ||||
LinalgTransformationFilter(Identifier::get("START", ctx), | LinalgTransformationFilter(Identifier::get("START", ctx), | ||||
Identifier::get("L2", ctx)))); | Identifier::get("L2", ctx)))); | ||||
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), | fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); | ||||
stage1Patterns); | |||||
} | } | ||||
{ | { | ||||
// Canonicalization patterns | // Canonicalization patterns | ||||
RewritePatternSet canonicalizationPatterns(funcOp.getContext()); | RewritePatternSet canonicalizationPatterns(funcOp.getContext()); | ||||
vector::populateVectorTransferPermutationMapLoweringPatterns( | vector::populateVectorTransferPermutationMapLoweringPatterns( | ||||
canonicalizationPatterns); | canonicalizationPatterns); | ||||
vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); | vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); | ||||
stage1Patterns.push_back(std::move(canonicalizationPatterns)); | stage1Patterns.push_back(std::move(canonicalizationPatterns)); | ||||
▲ Show 20 Lines • Show All 210 Lines • Show Last 20 Lines |