Index: mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -141,10 +141,13 @@ //===----------------------------------------------------------------------===// OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) { - if (auto insertOp = - getComposite().getDefiningOp()) { + Value compositeOp = getComposite(); + while (auto insertOp = + compositeOp.getDefiningOp()) { if (getIndices() == insertOp.getIndices()) return insertOp.getObject(); + compositeOp = insertOp.getComposite(); + insertOp = compositeOp.getDefiningOp(); } if (auto constructOp = Index: mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir =================================================================== --- mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -185,6 +185,18 @@ return %1, %2 : vector<2xf32>, vector<2xf32> } +// ----- + + // CHECK-LABEL: composite_extract_fold + // CHECK-SAME: (%[[COMP:.+]]: !spirv.struct<(f32, f32)>, %[[VAL1:.+]]: f32, %[[VAL2:.+]]: f32) + func.func @composite_extract_fold(%composite: !spirv.struct<(f32, f32)>, %val1: f32, %val2: f32) -> f32 { + %insert = spirv.CompositeInsert %val1, %composite[0 : i32] : f32 into !spirv.struct<(f32, f32)> + %1 = spirv.CompositeInsert %val2, %insert[1 : i32] : f32 into !spirv.struct<(f32, f32)> + %2 = spirv.CompositeExtract %1[0 : i32] : !spirv.struct<(f32, f32)> + // CHECK-NEXT: return %[[VAL1]] + return %2 : f32 + } + // ----- // Not yet implemented case